| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- """
- Copyright (c) Alibaba, Inc. and its affiliates.
- """
- from functools import partial
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
- conv_nd,
- linear,
- )
- def get_clip_token_for_string(tokenizer, string):
- batch_encoding = tokenizer(
- string,
- truncation=True,
- max_length=77,
- return_length=True,
- return_overflowing_tokens=False,
- padding="max_length",
- return_tensors="pt",
- )
- tokens = batch_encoding["input_ids"]
- assert (
- torch.count_nonzero(tokens - 49407) == 2
- ), f"String '{string}' maps to more than a single token. Please use another string"
- return tokens[0, 1]
- def get_bert_token_for_string(tokenizer, string):
- token = tokenizer(string)
- assert (
- torch.count_nonzero(token) == 3
- ), f"String '{string}' maps to more than a single token. Please use another string"
- token = token[0, 1]
- return token
- def get_clip_vision_emb(encoder, processor, img):
- _img = img.repeat(1, 3, 1, 1) * 255
- inputs = processor(images=_img, return_tensors="pt")
- inputs["pixel_values"] = inputs["pixel_values"].to(img.device)
- outputs = encoder(**inputs)
- emb = outputs.image_embeds
- return emb
- def get_recog_emb(encoder, img_list):
- _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list]
- encoder.predictor.eval()
- _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
- return preds_neck
- def pad_H(x):
- _, _, H, W = x.shape
- p_top = (W - H) // 2
- p_bot = W - H - p_top
- return F.pad(x, (0, 0, p_top, p_bot))
- class EncodeNet(nn.Module):
- def __init__(self, in_channels, out_channels):
- super(EncodeNet, self).__init__()
- chan = 16
- n_layer = 4 # downsample
- self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
- self.conv_list = nn.ModuleList([])
- _c = chan
- for i in range(n_layer):
- self.conv_list.append(conv_nd(2, _c, _c * 2, 3, padding=1, stride=2))
- _c *= 2
- self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
- self.avgpool = nn.AdaptiveAvgPool2d(1)
- self.act = nn.SiLU()
- def forward(self, x):
- x = self.act(self.conv1(x))
- for layer in self.conv_list:
- x = self.act(layer(x))
- x = self.act(self.conv2(x))
- x = self.avgpool(x)
- x = x.view(x.size(0), -1)
- return x
- class EmbeddingManager(nn.Module):
- def __init__(
- self,
- embedder,
- valid=True,
- glyph_channels=20,
- position_channels=1,
- placeholder_string="*",
- add_pos=False,
- emb_type="ocr",
- **kwargs,
- ):
- super().__init__()
- if hasattr(embedder, "tokenizer"): # using Stable Diffusion's CLIP encoder
- get_token_for_string = partial(
- get_clip_token_for_string, embedder.tokenizer
- )
- token_dim = 768
- if hasattr(embedder, "vit"):
- assert emb_type == "vit"
- self.get_vision_emb = partial(
- get_clip_vision_emb, embedder.vit, embedder.processor
- )
- self.get_recog_emb = None
- else: # using LDM's BERT encoder
- get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
- token_dim = 1280
- self.token_dim = token_dim
- self.emb_type = emb_type
- self.add_pos = add_pos
- if add_pos:
- self.position_encoder = EncodeNet(position_channels, token_dim)
- if emb_type == "ocr":
- self.proj = linear(40 * 64, token_dim)
- if emb_type == "conv":
- self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
- self.placeholder_token = get_token_for_string(placeholder_string)
- def encode_text(self, text_info):
- if self.get_recog_emb is None and self.emb_type == "ocr":
- self.get_recog_emb = partial(get_recog_emb, self.recog)
- gline_list = []
- pos_list = []
- for i in range(len(text_info["n_lines"])): # sample index in a batch
- n_lines = text_info["n_lines"][i]
- for j in range(n_lines): # line
- gline_list += [text_info["gly_line"][j][i : i + 1]]
- if self.add_pos:
- pos_list += [text_info["positions"][j][i : i + 1]]
- if len(gline_list) > 0:
- if self.emb_type == "ocr":
- recog_emb = self.get_recog_emb(gline_list)
- enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
- elif self.emb_type == "vit":
- enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
- elif self.emb_type == "conv":
- enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
- if self.add_pos:
- enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
- enc_glyph = enc_glyph + enc_pos
- self.text_embs_all = []
- n_idx = 0
- for i in range(len(text_info["n_lines"])): # sample index in a batch
- n_lines = text_info["n_lines"][i]
- text_embs = []
- for j in range(n_lines): # line
- text_embs += [enc_glyph[n_idx : n_idx + 1]]
- n_idx += 1
- self.text_embs_all += [text_embs]
- def forward(
- self,
- tokenized_text,
- embedded_text,
- ):
- b, device = tokenized_text.shape[0], tokenized_text.device
- for i in range(b):
- idx = tokenized_text[i] == self.placeholder_token.to(device)
- if sum(idx) > 0:
- if i >= len(self.text_embs_all):
- print("truncation for log images...")
- break
- text_emb = torch.cat(self.text_embs_all[i], dim=0)
- if sum(idx) != len(text_emb):
- print("truncation for long caption...")
- embedded_text[i][idx] = text_emb[: sum(idx)]
- return embedded_text
- def embedding_parameters(self):
- return self.parameters()
|