unet_2d_blocks.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. from typing import Any, Dict, Optional, Tuple
  2. import torch
  3. from diffusers.models.resnet import ResnetBlock2D
  4. from diffusers.utils import is_torch_version
  5. from diffusers.utils.torch_utils import apply_freeu
  6. from torch import nn
  7. class MidBlock2D(nn.Module):
  8. def __init__(
  9. self,
  10. in_channels: int,
  11. temb_channels: int,
  12. dropout: float = 0.0,
  13. num_layers: int = 1,
  14. resnet_eps: float = 1e-6,
  15. resnet_time_scale_shift: str = "default",
  16. resnet_act_fn: str = "swish",
  17. resnet_groups: int = 32,
  18. resnet_pre_norm: bool = True,
  19. output_scale_factor: float = 1.0,
  20. use_linear_projection: bool = False,
  21. ):
  22. super().__init__()
  23. self.has_cross_attention = False
  24. resnet_groups = (
  25. resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
  26. )
  27. # there is always at least one resnet
  28. resnets = [
  29. ResnetBlock2D(
  30. in_channels=in_channels,
  31. out_channels=in_channels,
  32. temb_channels=temb_channels,
  33. eps=resnet_eps,
  34. groups=resnet_groups,
  35. dropout=dropout,
  36. time_embedding_norm=resnet_time_scale_shift,
  37. non_linearity=resnet_act_fn,
  38. output_scale_factor=output_scale_factor,
  39. pre_norm=resnet_pre_norm,
  40. )
  41. ]
  42. for i in range(num_layers):
  43. resnets.append(
  44. ResnetBlock2D(
  45. in_channels=in_channels,
  46. out_channels=in_channels,
  47. temb_channels=temb_channels,
  48. eps=resnet_eps,
  49. groups=resnet_groups,
  50. dropout=dropout,
  51. time_embedding_norm=resnet_time_scale_shift,
  52. non_linearity=resnet_act_fn,
  53. output_scale_factor=output_scale_factor,
  54. pre_norm=resnet_pre_norm,
  55. )
  56. )
  57. self.resnets = nn.ModuleList(resnets)
  58. self.gradient_checkpointing = False
  59. def forward(
  60. self,
  61. hidden_states: torch.FloatTensor,
  62. temb: Optional[torch.FloatTensor] = None,
  63. ) -> torch.FloatTensor:
  64. lora_scale = 1.0
  65. hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
  66. for resnet in self.resnets[1:]:
  67. if self.training and self.gradient_checkpointing:
  68. def create_custom_forward(module, return_dict=None):
  69. def custom_forward(*inputs):
  70. if return_dict is not None:
  71. return module(*inputs, return_dict=return_dict)
  72. else:
  73. return module(*inputs)
  74. return custom_forward
  75. ckpt_kwargs: Dict[str, Any] = (
  76. {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
  77. )
  78. hidden_states = torch.utils.checkpoint.checkpoint(
  79. create_custom_forward(resnet),
  80. hidden_states,
  81. temb,
  82. **ckpt_kwargs,
  83. )
  84. else:
  85. hidden_states = resnet(hidden_states, temb, scale=lora_scale)
  86. return hidden_states
  87. def DownBlock2D_forward(
  88. self,
  89. hidden_states: torch.FloatTensor,
  90. temb: Optional[torch.FloatTensor] = None,
  91. scale: float = 1.0,
  92. down_block_add_samples: Optional[torch.FloatTensor] = None,
  93. ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
  94. output_states = ()
  95. for resnet in self.resnets:
  96. if self.training and self.gradient_checkpointing:
  97. def create_custom_forward(module):
  98. def custom_forward(*inputs):
  99. return module(*inputs)
  100. return custom_forward
  101. if is_torch_version(">=", "1.11.0"):
  102. hidden_states = torch.utils.checkpoint.checkpoint(
  103. create_custom_forward(resnet),
  104. hidden_states,
  105. temb,
  106. use_reentrant=False,
  107. )
  108. else:
  109. hidden_states = torch.utils.checkpoint.checkpoint(
  110. create_custom_forward(resnet), hidden_states, temb
  111. )
  112. else:
  113. hidden_states = resnet(hidden_states, temb, scale=scale)
  114. if down_block_add_samples is not None:
  115. hidden_states = hidden_states + down_block_add_samples.pop(0)
  116. output_states = output_states + (hidden_states,)
  117. if self.downsamplers is not None:
  118. for downsampler in self.downsamplers:
  119. hidden_states = downsampler(hidden_states, scale=scale)
  120. if down_block_add_samples is not None:
  121. hidden_states = hidden_states + down_block_add_samples.pop(
  122. 0
  123. ) # todo: add before or after
  124. output_states = output_states + (hidden_states,)
  125. return hidden_states, output_states
  126. def CrossAttnDownBlock2D_forward(
  127. self,
  128. hidden_states: torch.FloatTensor,
  129. temb: Optional[torch.FloatTensor] = None,
  130. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  131. attention_mask: Optional[torch.FloatTensor] = None,
  132. cross_attention_kwargs: Optional[Dict[str, Any]] = None,
  133. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  134. additional_residuals: Optional[torch.FloatTensor] = None,
  135. down_block_add_samples: Optional[torch.FloatTensor] = None,
  136. ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
  137. output_states = ()
  138. lora_scale = (
  139. cross_attention_kwargs.get("scale", 1.0)
  140. if cross_attention_kwargs is not None
  141. else 1.0
  142. )
  143. blocks = list(zip(self.resnets, self.attentions))
  144. for i, (resnet, attn) in enumerate(blocks):
  145. if self.training and self.gradient_checkpointing:
  146. def create_custom_forward(module, return_dict=None):
  147. def custom_forward(*inputs):
  148. if return_dict is not None:
  149. return module(*inputs, return_dict=return_dict)
  150. else:
  151. return module(*inputs)
  152. return custom_forward
  153. ckpt_kwargs: Dict[str, Any] = (
  154. {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
  155. )
  156. hidden_states = torch.utils.checkpoint.checkpoint(
  157. create_custom_forward(resnet),
  158. hidden_states,
  159. temb,
  160. **ckpt_kwargs,
  161. )
  162. hidden_states = attn(
  163. hidden_states,
  164. encoder_hidden_states=encoder_hidden_states,
  165. cross_attention_kwargs=cross_attention_kwargs,
  166. attention_mask=attention_mask,
  167. encoder_attention_mask=encoder_attention_mask,
  168. return_dict=False,
  169. )[0]
  170. else:
  171. hidden_states = resnet(hidden_states, temb, scale=lora_scale)
  172. hidden_states = attn(
  173. hidden_states,
  174. encoder_hidden_states=encoder_hidden_states,
  175. cross_attention_kwargs=cross_attention_kwargs,
  176. attention_mask=attention_mask,
  177. encoder_attention_mask=encoder_attention_mask,
  178. return_dict=False,
  179. )[0]
  180. # apply additional residuals to the output of the last pair of resnet and attention blocks
  181. if i == len(blocks) - 1 and additional_residuals is not None:
  182. hidden_states = hidden_states + additional_residuals
  183. if down_block_add_samples is not None:
  184. hidden_states = hidden_states + down_block_add_samples.pop(0)
  185. output_states = output_states + (hidden_states,)
  186. if self.downsamplers is not None:
  187. for downsampler in self.downsamplers:
  188. hidden_states = downsampler(hidden_states, scale=lora_scale)
  189. if down_block_add_samples is not None:
  190. hidden_states = hidden_states + down_block_add_samples.pop(
  191. 0
  192. ) # todo: add before or after
  193. output_states = output_states + (hidden_states,)
  194. return hidden_states, output_states
  195. def CrossAttnUpBlock2D_forward(
  196. self,
  197. hidden_states: torch.FloatTensor,
  198. res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
  199. temb: Optional[torch.FloatTensor] = None,
  200. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  201. cross_attention_kwargs: Optional[Dict[str, Any]] = None,
  202. upsample_size: Optional[int] = None,
  203. attention_mask: Optional[torch.FloatTensor] = None,
  204. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  205. return_res_samples: Optional[bool] = False,
  206. up_block_add_samples: Optional[torch.FloatTensor] = None,
  207. ) -> torch.FloatTensor:
  208. lora_scale = (
  209. cross_attention_kwargs.get("scale", 1.0)
  210. if cross_attention_kwargs is not None
  211. else 1.0
  212. )
  213. is_freeu_enabled = (
  214. getattr(self, "s1", None)
  215. and getattr(self, "s2", None)
  216. and getattr(self, "b1", None)
  217. and getattr(self, "b2", None)
  218. )
  219. if return_res_samples:
  220. output_states = ()
  221. for resnet, attn in zip(self.resnets, self.attentions):
  222. # pop res hidden states
  223. res_hidden_states = res_hidden_states_tuple[-1]
  224. res_hidden_states_tuple = res_hidden_states_tuple[:-1]
  225. # FreeU: Only operate on the first two stages
  226. if is_freeu_enabled:
  227. hidden_states, res_hidden_states = apply_freeu(
  228. self.resolution_idx,
  229. hidden_states,
  230. res_hidden_states,
  231. s1=self.s1,
  232. s2=self.s2,
  233. b1=self.b1,
  234. b2=self.b2,
  235. )
  236. hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
  237. if self.training and self.gradient_checkpointing:
  238. def create_custom_forward(module, return_dict=None):
  239. def custom_forward(*inputs):
  240. if return_dict is not None:
  241. return module(*inputs, return_dict=return_dict)
  242. else:
  243. return module(*inputs)
  244. return custom_forward
  245. ckpt_kwargs: Dict[str, Any] = (
  246. {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
  247. )
  248. hidden_states = torch.utils.checkpoint.checkpoint(
  249. create_custom_forward(resnet),
  250. hidden_states,
  251. temb,
  252. **ckpt_kwargs,
  253. )
  254. hidden_states = attn(
  255. hidden_states,
  256. encoder_hidden_states=encoder_hidden_states,
  257. cross_attention_kwargs=cross_attention_kwargs,
  258. attention_mask=attention_mask,
  259. encoder_attention_mask=encoder_attention_mask,
  260. return_dict=False,
  261. )[0]
  262. else:
  263. hidden_states = resnet(hidden_states, temb, scale=lora_scale)
  264. hidden_states = attn(
  265. hidden_states,
  266. encoder_hidden_states=encoder_hidden_states,
  267. cross_attention_kwargs=cross_attention_kwargs,
  268. attention_mask=attention_mask,
  269. encoder_attention_mask=encoder_attention_mask,
  270. return_dict=False,
  271. )[0]
  272. if return_res_samples:
  273. output_states = output_states + (hidden_states,)
  274. if up_block_add_samples is not None:
  275. hidden_states = hidden_states + up_block_add_samples.pop(0)
  276. if self.upsamplers is not None:
  277. for upsampler in self.upsamplers:
  278. hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
  279. if return_res_samples:
  280. output_states = output_states + (hidden_states,)
  281. if up_block_add_samples is not None:
  282. hidden_states = hidden_states + up_block_add_samples.pop(0)
  283. if return_res_samples:
  284. return hidden_states, output_states
  285. else:
  286. return hidden_states
  287. def UpBlock2D_forward(
  288. self,
  289. hidden_states: torch.FloatTensor,
  290. res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
  291. temb: Optional[torch.FloatTensor] = None,
  292. upsample_size: Optional[int] = None,
  293. scale: float = 1.0,
  294. return_res_samples: Optional[bool] = False,
  295. up_block_add_samples: Optional[torch.FloatTensor] = None,
  296. ) -> torch.FloatTensor:
  297. is_freeu_enabled = (
  298. getattr(self, "s1", None)
  299. and getattr(self, "s2", None)
  300. and getattr(self, "b1", None)
  301. and getattr(self, "b2", None)
  302. )
  303. if return_res_samples:
  304. output_states = ()
  305. for resnet in self.resnets:
  306. # pop res hidden states
  307. res_hidden_states = res_hidden_states_tuple[-1]
  308. res_hidden_states_tuple = res_hidden_states_tuple[:-1]
  309. # FreeU: Only operate on the first two stages
  310. if is_freeu_enabled:
  311. hidden_states, res_hidden_states = apply_freeu(
  312. self.resolution_idx,
  313. hidden_states,
  314. res_hidden_states,
  315. s1=self.s1,
  316. s2=self.s2,
  317. b1=self.b1,
  318. b2=self.b2,
  319. )
  320. hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
  321. if self.training and self.gradient_checkpointing:
  322. def create_custom_forward(module):
  323. def custom_forward(*inputs):
  324. return module(*inputs)
  325. return custom_forward
  326. if is_torch_version(">=", "1.11.0"):
  327. hidden_states = torch.utils.checkpoint.checkpoint(
  328. create_custom_forward(resnet),
  329. hidden_states,
  330. temb,
  331. use_reentrant=False,
  332. )
  333. else:
  334. hidden_states = torch.utils.checkpoint.checkpoint(
  335. create_custom_forward(resnet), hidden_states, temb
  336. )
  337. else:
  338. hidden_states = resnet(hidden_states, temb, scale=scale)
  339. if return_res_samples:
  340. output_states = output_states + (hidden_states,)
  341. if up_block_add_samples is not None:
  342. hidden_states = hidden_states + up_block_add_samples.pop(
  343. 0
  344. ) # todo: add before or after
  345. if self.upsamplers is not None:
  346. for upsampler in self.upsamplers:
  347. hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
  348. if return_res_samples:
  349. output_states = output_states + (hidden_states,)
  350. if up_block_add_samples is not None:
  351. hidden_states = hidden_states + up_block_add_samples.pop(
  352. 0
  353. ) # todo: add before or after
  354. if return_res_samples:
  355. return hidden_states, output_states
  356. else:
  357. return hidden_states