modules.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. import torch
  2. import torch.nn as nn
  3. from torch.utils.checkpoint import checkpoint
  4. from transformers import (
  5. AutoProcessor,
  6. CLIPTextModel,
  7. CLIPTokenizer,
  8. CLIPVisionModelWithProjection,
  9. T5EncoderModel,
  10. T5Tokenizer,
  11. )
  12. from sorawm.iopaint.model.anytext.ldm.util import count_params
  13. def _expand_mask(mask, dtype, tgt_len=None):
  14. """
  15. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  16. """
  17. bsz, src_len = mask.size()
  18. tgt_len = tgt_len if tgt_len is not None else src_len
  19. expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
  20. inverted_mask = 1.0 - expanded_mask
  21. return inverted_mask.masked_fill(
  22. inverted_mask.to(torch.bool), torch.finfo(dtype).min
  23. )
  24. def _build_causal_attention_mask(bsz, seq_len, dtype):
  25. # lazily create causal attention mask, with full attention between the vision tokens
  26. # pytorch uses additive attention mask; fill with -inf
  27. mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
  28. mask.fill_(torch.tensor(torch.finfo(dtype).min))
  29. mask.triu_(1) # zero out the lower diagonal
  30. mask = mask.unsqueeze(1) # expand mask
  31. return mask
  32. class AbstractEncoder(nn.Module):
  33. def __init__(self):
  34. super().__init__()
  35. def encode(self, *args, **kwargs):
  36. raise NotImplementedError
  37. class IdentityEncoder(AbstractEncoder):
  38. def encode(self, x):
  39. return x
  40. class ClassEmbedder(nn.Module):
  41. def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
  42. super().__init__()
  43. self.key = key
  44. self.embedding = nn.Embedding(n_classes, embed_dim)
  45. self.n_classes = n_classes
  46. self.ucg_rate = ucg_rate
  47. def forward(self, batch, key=None, disable_dropout=False):
  48. if key is None:
  49. key = self.key
  50. # this is for use in crossattn
  51. c = batch[key][:, None]
  52. if self.ucg_rate > 0.0 and not disable_dropout:
  53. mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
  54. c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
  55. c = c.long()
  56. c = self.embedding(c)
  57. return c
  58. def get_unconditional_conditioning(self, bs, device="cuda"):
  59. uc_class = (
  60. self.n_classes - 1
  61. ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
  62. uc = torch.ones((bs,), device=device) * uc_class
  63. uc = {self.key: uc}
  64. return uc
  65. def disabled_train(self, mode=True):
  66. """Overwrite model.train with this function to make sure train/eval mode
  67. does not change anymore."""
  68. return self
  69. class FrozenT5Embedder(AbstractEncoder):
  70. """Uses the T5 transformer encoder for text"""
  71. def __init__(
  72. self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
  73. ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
  74. super().__init__()
  75. self.tokenizer = T5Tokenizer.from_pretrained(version)
  76. self.transformer = T5EncoderModel.from_pretrained(version)
  77. self.device = device
  78. self.max_length = max_length # TODO: typical value?
  79. if freeze:
  80. self.freeze()
  81. def freeze(self):
  82. self.transformer = self.transformer.eval()
  83. # self.train = disabled_train
  84. for param in self.parameters():
  85. param.requires_grad = False
  86. def forward(self, text):
  87. batch_encoding = self.tokenizer(
  88. text,
  89. truncation=True,
  90. max_length=self.max_length,
  91. return_length=True,
  92. return_overflowing_tokens=False,
  93. padding="max_length",
  94. return_tensors="pt",
  95. )
  96. tokens = batch_encoding["input_ids"].to(self.device)
  97. outputs = self.transformer(input_ids=tokens)
  98. z = outputs.last_hidden_state
  99. return z
  100. def encode(self, text):
  101. return self(text)
  102. class FrozenCLIPEmbedder(AbstractEncoder):
  103. """Uses the CLIP transformer encoder for text (from huggingface)"""
  104. LAYERS = ["last", "pooled", "hidden"]
  105. def __init__(
  106. self,
  107. version="openai/clip-vit-large-patch14",
  108. device="cuda",
  109. max_length=77,
  110. freeze=True,
  111. layer="last",
  112. layer_idx=None,
  113. ): # clip-vit-base-patch32
  114. super().__init__()
  115. assert layer in self.LAYERS
  116. self.tokenizer = CLIPTokenizer.from_pretrained(version)
  117. self.transformer = CLIPTextModel.from_pretrained(version)
  118. self.device = device
  119. self.max_length = max_length
  120. if freeze:
  121. self.freeze()
  122. self.layer = layer
  123. self.layer_idx = layer_idx
  124. if layer == "hidden":
  125. assert layer_idx is not None
  126. assert 0 <= abs(layer_idx) <= 12
  127. def freeze(self):
  128. self.transformer = self.transformer.eval()
  129. # self.train = disabled_train
  130. for param in self.parameters():
  131. param.requires_grad = False
  132. def forward(self, text):
  133. batch_encoding = self.tokenizer(
  134. text,
  135. truncation=True,
  136. max_length=self.max_length,
  137. return_length=True,
  138. return_overflowing_tokens=False,
  139. padding="max_length",
  140. return_tensors="pt",
  141. )
  142. tokens = batch_encoding["input_ids"].to(self.device)
  143. outputs = self.transformer(
  144. input_ids=tokens, output_hidden_states=self.layer == "hidden"
  145. )
  146. if self.layer == "last":
  147. z = outputs.last_hidden_state
  148. elif self.layer == "pooled":
  149. z = outputs.pooler_output[:, None, :]
  150. else:
  151. z = outputs.hidden_states[self.layer_idx]
  152. return z
  153. def encode(self, text):
  154. return self(text)
  155. class FrozenCLIPT5Encoder(AbstractEncoder):
  156. def __init__(
  157. self,
  158. clip_version="openai/clip-vit-large-patch14",
  159. t5_version="google/t5-v1_1-xl",
  160. device="cuda",
  161. clip_max_length=77,
  162. t5_max_length=77,
  163. ):
  164. super().__init__()
  165. self.clip_encoder = FrozenCLIPEmbedder(
  166. clip_version, device, max_length=clip_max_length
  167. )
  168. self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
  169. print(
  170. f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
  171. f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
  172. )
  173. def encode(self, text):
  174. return self(text)
  175. def forward(self, text):
  176. clip_z = self.clip_encoder.encode(text)
  177. t5_z = self.t5_encoder.encode(text)
  178. return [clip_z, t5_z]
  179. class FrozenCLIPEmbedderT3(AbstractEncoder):
  180. """Uses the CLIP transformer encoder for text (from Hugging Face)"""
  181. def __init__(
  182. self,
  183. version="openai/clip-vit-large-patch14",
  184. device="cuda",
  185. max_length=77,
  186. freeze=True,
  187. use_vision=False,
  188. ):
  189. super().__init__()
  190. self.tokenizer = CLIPTokenizer.from_pretrained(version)
  191. self.transformer = CLIPTextModel.from_pretrained(version)
  192. if use_vision:
  193. self.vit = CLIPVisionModelWithProjection.from_pretrained(version)
  194. self.processor = AutoProcessor.from_pretrained(version)
  195. self.device = device
  196. self.max_length = max_length
  197. if freeze:
  198. self.freeze()
  199. def embedding_forward(
  200. self,
  201. input_ids=None,
  202. position_ids=None,
  203. inputs_embeds=None,
  204. embedding_manager=None,
  205. ):
  206. seq_length = (
  207. input_ids.shape[-1]
  208. if input_ids is not None
  209. else inputs_embeds.shape[-2]
  210. )
  211. if position_ids is None:
  212. position_ids = self.position_ids[:, :seq_length]
  213. if inputs_embeds is None:
  214. inputs_embeds = self.token_embedding(input_ids)
  215. if embedding_manager is not None:
  216. inputs_embeds = embedding_manager(input_ids, inputs_embeds)
  217. position_embeddings = self.position_embedding(position_ids)
  218. embeddings = inputs_embeds + position_embeddings
  219. return embeddings
  220. self.transformer.text_model.embeddings.forward = embedding_forward.__get__(
  221. self.transformer.text_model.embeddings
  222. )
  223. def encoder_forward(
  224. self,
  225. inputs_embeds,
  226. attention_mask=None,
  227. causal_attention_mask=None,
  228. output_attentions=None,
  229. output_hidden_states=None,
  230. return_dict=None,
  231. ):
  232. output_attentions = (
  233. output_attentions
  234. if output_attentions is not None
  235. else self.config.output_attentions
  236. )
  237. output_hidden_states = (
  238. output_hidden_states
  239. if output_hidden_states is not None
  240. else self.config.output_hidden_states
  241. )
  242. return_dict = (
  243. return_dict if return_dict is not None else self.config.use_return_dict
  244. )
  245. encoder_states = () if output_hidden_states else None
  246. all_attentions = () if output_attentions else None
  247. hidden_states = inputs_embeds
  248. for idx, encoder_layer in enumerate(self.layers):
  249. if output_hidden_states:
  250. encoder_states = encoder_states + (hidden_states,)
  251. layer_outputs = encoder_layer(
  252. hidden_states,
  253. attention_mask,
  254. causal_attention_mask,
  255. output_attentions=output_attentions,
  256. )
  257. hidden_states = layer_outputs[0]
  258. if output_attentions:
  259. all_attentions = all_attentions + (layer_outputs[1],)
  260. if output_hidden_states:
  261. encoder_states = encoder_states + (hidden_states,)
  262. return hidden_states
  263. self.transformer.text_model.encoder.forward = encoder_forward.__get__(
  264. self.transformer.text_model.encoder
  265. )
  266. def text_encoder_forward(
  267. self,
  268. input_ids=None,
  269. attention_mask=None,
  270. position_ids=None,
  271. output_attentions=None,
  272. output_hidden_states=None,
  273. return_dict=None,
  274. embedding_manager=None,
  275. ):
  276. output_attentions = (
  277. output_attentions
  278. if output_attentions is not None
  279. else self.config.output_attentions
  280. )
  281. output_hidden_states = (
  282. output_hidden_states
  283. if output_hidden_states is not None
  284. else self.config.output_hidden_states
  285. )
  286. return_dict = (
  287. return_dict if return_dict is not None else self.config.use_return_dict
  288. )
  289. if input_ids is None:
  290. raise ValueError("You have to specify either input_ids")
  291. input_shape = input_ids.size()
  292. input_ids = input_ids.view(-1, input_shape[-1])
  293. hidden_states = self.embeddings(
  294. input_ids=input_ids,
  295. position_ids=position_ids,
  296. embedding_manager=embedding_manager,
  297. )
  298. bsz, seq_len = input_shape
  299. # CLIP's text model uses causal mask, prepare it here.
  300. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
  301. causal_attention_mask = _build_causal_attention_mask(
  302. bsz, seq_len, hidden_states.dtype
  303. ).to(hidden_states.device)
  304. # expand attention_mask
  305. if attention_mask is not None:
  306. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  307. attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
  308. last_hidden_state = self.encoder(
  309. inputs_embeds=hidden_states,
  310. attention_mask=attention_mask,
  311. causal_attention_mask=causal_attention_mask,
  312. output_attentions=output_attentions,
  313. output_hidden_states=output_hidden_states,
  314. return_dict=return_dict,
  315. )
  316. last_hidden_state = self.final_layer_norm(last_hidden_state)
  317. return last_hidden_state
  318. self.transformer.text_model.forward = text_encoder_forward.__get__(
  319. self.transformer.text_model
  320. )
  321. def transformer_forward(
  322. self,
  323. input_ids=None,
  324. attention_mask=None,
  325. position_ids=None,
  326. output_attentions=None,
  327. output_hidden_states=None,
  328. return_dict=None,
  329. embedding_manager=None,
  330. ):
  331. return self.text_model(
  332. input_ids=input_ids,
  333. attention_mask=attention_mask,
  334. position_ids=position_ids,
  335. output_attentions=output_attentions,
  336. output_hidden_states=output_hidden_states,
  337. return_dict=return_dict,
  338. embedding_manager=embedding_manager,
  339. )
  340. self.transformer.forward = transformer_forward.__get__(self.transformer)
  341. def freeze(self):
  342. self.transformer = self.transformer.eval()
  343. for param in self.parameters():
  344. param.requires_grad = False
  345. def forward(self, text, **kwargs):
  346. batch_encoding = self.tokenizer(
  347. text,
  348. truncation=True,
  349. max_length=self.max_length,
  350. return_length=True,
  351. return_overflowing_tokens=False,
  352. padding="max_length",
  353. return_tensors="pt",
  354. )
  355. tokens = batch_encoding["input_ids"].to(self.device)
  356. z = self.transformer(input_ids=tokens, **kwargs)
  357. return z
  358. def encode(self, text, **kwargs):
  359. return self(text, **kwargs)