embedding_manager.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """
  2. Copyright (c) Alibaba, Inc. and its affiliates.
  3. """
  4. from functools import partial
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import (
  9. conv_nd,
  10. linear,
  11. )
  12. def get_clip_token_for_string(tokenizer, string):
  13. batch_encoding = tokenizer(
  14. string,
  15. truncation=True,
  16. max_length=77,
  17. return_length=True,
  18. return_overflowing_tokens=False,
  19. padding="max_length",
  20. return_tensors="pt",
  21. )
  22. tokens = batch_encoding["input_ids"]
  23. assert (
  24. torch.count_nonzero(tokens - 49407) == 2
  25. ), f"String '{string}' maps to more than a single token. Please use another string"
  26. return tokens[0, 1]
  27. def get_bert_token_for_string(tokenizer, string):
  28. token = tokenizer(string)
  29. assert (
  30. torch.count_nonzero(token) == 3
  31. ), f"String '{string}' maps to more than a single token. Please use another string"
  32. token = token[0, 1]
  33. return token
  34. def get_clip_vision_emb(encoder, processor, img):
  35. _img = img.repeat(1, 3, 1, 1) * 255
  36. inputs = processor(images=_img, return_tensors="pt")
  37. inputs["pixel_values"] = inputs["pixel_values"].to(img.device)
  38. outputs = encoder(**inputs)
  39. emb = outputs.image_embeds
  40. return emb
  41. def get_recog_emb(encoder, img_list):
  42. _img_list = [(img.repeat(1, 3, 1, 1) * 255)[0] for img in img_list]
  43. encoder.predictor.eval()
  44. _, preds_neck = encoder.pred_imglist(_img_list, show_debug=False)
  45. return preds_neck
  46. def pad_H(x):
  47. _, _, H, W = x.shape
  48. p_top = (W - H) // 2
  49. p_bot = W - H - p_top
  50. return F.pad(x, (0, 0, p_top, p_bot))
  51. class EncodeNet(nn.Module):
  52. def __init__(self, in_channels, out_channels):
  53. super(EncodeNet, self).__init__()
  54. chan = 16
  55. n_layer = 4 # downsample
  56. self.conv1 = conv_nd(2, in_channels, chan, 3, padding=1)
  57. self.conv_list = nn.ModuleList([])
  58. _c = chan
  59. for i in range(n_layer):
  60. self.conv_list.append(conv_nd(2, _c, _c * 2, 3, padding=1, stride=2))
  61. _c *= 2
  62. self.conv2 = conv_nd(2, _c, out_channels, 3, padding=1)
  63. self.avgpool = nn.AdaptiveAvgPool2d(1)
  64. self.act = nn.SiLU()
  65. def forward(self, x):
  66. x = self.act(self.conv1(x))
  67. for layer in self.conv_list:
  68. x = self.act(layer(x))
  69. x = self.act(self.conv2(x))
  70. x = self.avgpool(x)
  71. x = x.view(x.size(0), -1)
  72. return x
  73. class EmbeddingManager(nn.Module):
  74. def __init__(
  75. self,
  76. embedder,
  77. valid=True,
  78. glyph_channels=20,
  79. position_channels=1,
  80. placeholder_string="*",
  81. add_pos=False,
  82. emb_type="ocr",
  83. **kwargs,
  84. ):
  85. super().__init__()
  86. if hasattr(embedder, "tokenizer"): # using Stable Diffusion's CLIP encoder
  87. get_token_for_string = partial(
  88. get_clip_token_for_string, embedder.tokenizer
  89. )
  90. token_dim = 768
  91. if hasattr(embedder, "vit"):
  92. assert emb_type == "vit"
  93. self.get_vision_emb = partial(
  94. get_clip_vision_emb, embedder.vit, embedder.processor
  95. )
  96. self.get_recog_emb = None
  97. else: # using LDM's BERT encoder
  98. get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
  99. token_dim = 1280
  100. self.token_dim = token_dim
  101. self.emb_type = emb_type
  102. self.add_pos = add_pos
  103. if add_pos:
  104. self.position_encoder = EncodeNet(position_channels, token_dim)
  105. if emb_type == "ocr":
  106. self.proj = linear(40 * 64, token_dim)
  107. if emb_type == "conv":
  108. self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
  109. self.placeholder_token = get_token_for_string(placeholder_string)
  110. def encode_text(self, text_info):
  111. if self.get_recog_emb is None and self.emb_type == "ocr":
  112. self.get_recog_emb = partial(get_recog_emb, self.recog)
  113. gline_list = []
  114. pos_list = []
  115. for i in range(len(text_info["n_lines"])): # sample index in a batch
  116. n_lines = text_info["n_lines"][i]
  117. for j in range(n_lines): # line
  118. gline_list += [text_info["gly_line"][j][i : i + 1]]
  119. if self.add_pos:
  120. pos_list += [text_info["positions"][j][i : i + 1]]
  121. if len(gline_list) > 0:
  122. if self.emb_type == "ocr":
  123. recog_emb = self.get_recog_emb(gline_list)
  124. enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
  125. elif self.emb_type == "vit":
  126. enc_glyph = self.get_vision_emb(pad_H(torch.cat(gline_list, dim=0)))
  127. elif self.emb_type == "conv":
  128. enc_glyph = self.glyph_encoder(pad_H(torch.cat(gline_list, dim=0)))
  129. if self.add_pos:
  130. enc_pos = self.position_encoder(torch.cat(gline_list, dim=0))
  131. enc_glyph = enc_glyph + enc_pos
  132. self.text_embs_all = []
  133. n_idx = 0
  134. for i in range(len(text_info["n_lines"])): # sample index in a batch
  135. n_lines = text_info["n_lines"][i]
  136. text_embs = []
  137. for j in range(n_lines): # line
  138. text_embs += [enc_glyph[n_idx : n_idx + 1]]
  139. n_idx += 1
  140. self.text_embs_all += [text_embs]
  141. def forward(
  142. self,
  143. tokenized_text,
  144. embedded_text,
  145. ):
  146. b, device = tokenized_text.shape[0], tokenized_text.device
  147. for i in range(b):
  148. idx = tokenized_text[i] == self.placeholder_token.to(device)
  149. if sum(idx) > 0:
  150. if i >= len(self.text_embs_all):
  151. print("truncation for log images...")
  152. break
  153. text_emb = torch.cat(self.text_embs_all[i], dim=0)
  154. if sum(idx) != len(text_emb):
  155. print("truncation for long caption...")
  156. embedded_text[i][idx] = text_emb[: sum(idx)]
  157. return embedded_text
  158. def embedding_parameters(self):
  159. return self.parameters()