brushnet_unet_forward.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. from typing import Any, Dict, Optional, Tuple, Union
  2. import torch
  3. from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
  4. from diffusers.utils import (
  5. USE_PEFT_BACKEND,
  6. deprecate,
  7. scale_lora_layers,
  8. unscale_lora_layers,
  9. )
  10. def brushnet_unet_forward(
  11. self,
  12. sample: torch.FloatTensor,
  13. timestep: Union[torch.Tensor, float, int],
  14. encoder_hidden_states: torch.Tensor,
  15. class_labels: Optional[torch.Tensor] = None,
  16. timestep_cond: Optional[torch.Tensor] = None,
  17. attention_mask: Optional[torch.Tensor] = None,
  18. cross_attention_kwargs: Optional[Dict[str, Any]] = None,
  19. added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
  20. down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
  21. mid_block_additional_residual: Optional[torch.Tensor] = None,
  22. down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
  23. encoder_attention_mask: Optional[torch.Tensor] = None,
  24. return_dict: bool = True,
  25. down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
  26. mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
  27. up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
  28. ) -> Union[UNet2DConditionOutput, Tuple]:
  29. r"""
  30. The [`UNet2DConditionModel`] forward method.
  31. Args:
  32. sample (`torch.FloatTensor`):
  33. The noisy input tensor with the following shape `(batch, channel, height, width)`.
  34. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
  35. encoder_hidden_states (`torch.FloatTensor`):
  36. The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
  37. class_labels (`torch.Tensor`, *optional*, defaults to `None`):
  38. Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
  39. timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
  40. Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
  41. through the `self.time_embedding` layer to obtain the timestep embeddings.
  42. attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
  43. An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
  44. is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
  45. negative values to the attention scores corresponding to "discard" tokens.
  46. cross_attention_kwargs (`dict`, *optional*):
  47. A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
  48. `self.processor` in
  49. [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
  50. added_cond_kwargs: (`dict`, *optional*):
  51. A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
  52. are passed along to the UNet blocks.
  53. down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
  54. A tuple of tensors that if specified are added to the residuals of down unet blocks.
  55. mid_block_additional_residual: (`torch.Tensor`, *optional*):
  56. A tensor that if specified is added to the residual of the middle unet block.
  57. encoder_attention_mask (`torch.Tensor`):
  58. A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
  59. `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
  60. which adds large negative values to the attention scores corresponding to "discard" tokens.
  61. return_dict (`bool`, *optional*, defaults to `True`):
  62. Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
  63. tuple.
  64. cross_attention_kwargs (`dict`, *optional*):
  65. A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
  66. added_cond_kwargs: (`dict`, *optional*):
  67. A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
  68. are passed along to the UNet blocks.
  69. down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
  70. additional residuals to be added to UNet long skip connections from down blocks to up blocks for
  71. example from ControlNet side model(s)
  72. mid_block_additional_residual (`torch.Tensor`, *optional*):
  73. additional residual to be added to UNet mid block output, for example from ControlNet side model
  74. down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
  75. additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
  76. Returns:
  77. [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
  78. If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
  79. a `tuple` is returned where the first element is the sample tensor.
  80. """
  81. # By default samples have to be AT least a multiple of the overall upsampling factor.
  82. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
  83. # However, the upsampling interpolation output size can be forced to fit any upsampling size
  84. # on the fly if necessary.
  85. default_overall_up_factor = 2**self.num_upsamplers
  86. # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
  87. forward_upsample_size = False
  88. upsample_size = None
  89. for dim in sample.shape[-2:]:
  90. if dim % default_overall_up_factor != 0:
  91. # Forward upsample size to force interpolation output size.
  92. forward_upsample_size = True
  93. break
  94. # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
  95. # expects mask of shape:
  96. # [batch, key_tokens]
  97. # adds singleton query_tokens dimension:
  98. # [batch, 1, key_tokens]
  99. # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
  100. # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
  101. # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
  102. if attention_mask is not None:
  103. # assume that mask is expressed as:
  104. # (1 = keep, 0 = discard)
  105. # convert mask into a bias that can be added to attention scores:
  106. # (keep = +0, discard = -10000.0)
  107. attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
  108. attention_mask = attention_mask.unsqueeze(1)
  109. # convert encoder_attention_mask to a bias the same way we do for attention_mask
  110. if encoder_attention_mask is not None:
  111. encoder_attention_mask = (
  112. 1 - encoder_attention_mask.to(sample.dtype)
  113. ) * -10000.0
  114. encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
  115. # 0. center input if necessary
  116. if self.config.center_input_sample:
  117. sample = 2 * sample - 1.0
  118. # 1. time
  119. t_emb = self.get_time_embed(sample=sample, timestep=timestep)
  120. emb = self.time_embedding(t_emb, timestep_cond)
  121. aug_emb = None
  122. class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
  123. if class_emb is not None:
  124. if self.config.class_embeddings_concat:
  125. emb = torch.cat([emb, class_emb], dim=-1)
  126. else:
  127. emb = emb + class_emb
  128. aug_emb = self.get_aug_embed(
  129. emb=emb,
  130. encoder_hidden_states=encoder_hidden_states,
  131. added_cond_kwargs=added_cond_kwargs,
  132. )
  133. if self.config.addition_embed_type == "image_hint":
  134. aug_emb, hint = aug_emb
  135. sample = torch.cat([sample, hint], dim=1)
  136. emb = emb + aug_emb if aug_emb is not None else emb
  137. if self.time_embed_act is not None:
  138. emb = self.time_embed_act(emb)
  139. encoder_hidden_states = self.process_encoder_hidden_states(
  140. encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
  141. )
  142. # 2. pre-process
  143. sample = self.conv_in(sample)
  144. # 2.5 GLIGEN position net
  145. if (
  146. cross_attention_kwargs is not None
  147. and cross_attention_kwargs.get("gligen", None) is not None
  148. ):
  149. cross_attention_kwargs = cross_attention_kwargs.copy()
  150. gligen_args = cross_attention_kwargs.pop("gligen")
  151. cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
  152. # 3. down
  153. lora_scale = (
  154. cross_attention_kwargs.get("scale", 1.0)
  155. if cross_attention_kwargs is not None
  156. else 1.0
  157. )
  158. if USE_PEFT_BACKEND:
  159. # weight the lora layers by setting `lora_scale` for each PEFT layer
  160. scale_lora_layers(self, lora_scale)
  161. is_controlnet = (
  162. mid_block_additional_residual is not None
  163. and down_block_additional_residuals is not None
  164. )
  165. # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
  166. is_adapter = down_intrablock_additional_residuals is not None
  167. # maintain backward compatibility for legacy usage, where
  168. # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
  169. # but can only use one or the other
  170. is_brushnet = (
  171. down_block_add_samples is not None
  172. and mid_block_add_sample is not None
  173. and up_block_add_samples is not None
  174. )
  175. if (
  176. not is_adapter
  177. and mid_block_additional_residual is None
  178. and down_block_additional_residuals is not None
  179. ):
  180. deprecate(
  181. "T2I should not use down_block_additional_residuals",
  182. "1.3.0",
  183. "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
  184. and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
  185. for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
  186. standard_warn=False,
  187. )
  188. down_intrablock_additional_residuals = down_block_additional_residuals
  189. is_adapter = True
  190. down_block_res_samples = (sample,)
  191. if is_brushnet:
  192. sample = sample + down_block_add_samples.pop(0)
  193. for downsample_block in self.down_blocks:
  194. if (
  195. hasattr(downsample_block, "has_cross_attention")
  196. and downsample_block.has_cross_attention
  197. ):
  198. # For t2i-adapter CrossAttnDownBlock2D
  199. additional_residuals = {}
  200. if is_adapter and len(down_intrablock_additional_residuals) > 0:
  201. additional_residuals[
  202. "additional_residuals"
  203. ] = down_intrablock_additional_residuals.pop(0)
  204. if is_brushnet and len(down_block_add_samples) > 0:
  205. additional_residuals["down_block_add_samples"] = [
  206. down_block_add_samples.pop(0)
  207. for _ in range(
  208. len(downsample_block.resnets)
  209. + (downsample_block.downsamplers != None)
  210. )
  211. ]
  212. sample, res_samples = downsample_block(
  213. hidden_states=sample,
  214. temb=emb,
  215. encoder_hidden_states=encoder_hidden_states,
  216. attention_mask=attention_mask,
  217. cross_attention_kwargs=cross_attention_kwargs,
  218. encoder_attention_mask=encoder_attention_mask,
  219. **additional_residuals,
  220. )
  221. else:
  222. additional_residuals = {}
  223. if is_brushnet and len(down_block_add_samples) > 0:
  224. additional_residuals["down_block_add_samples"] = [
  225. down_block_add_samples.pop(0)
  226. for _ in range(
  227. len(downsample_block.resnets)
  228. + (downsample_block.downsamplers != None)
  229. )
  230. ]
  231. sample, res_samples = downsample_block(
  232. hidden_states=sample, temb=emb, scale=lora_scale, **additional_residuals
  233. )
  234. if is_adapter and len(down_intrablock_additional_residuals) > 0:
  235. sample += down_intrablock_additional_residuals.pop(0)
  236. down_block_res_samples += res_samples
  237. if is_controlnet:
  238. new_down_block_res_samples = ()
  239. for down_block_res_sample, down_block_additional_residual in zip(
  240. down_block_res_samples, down_block_additional_residuals
  241. ):
  242. down_block_res_sample = (
  243. down_block_res_sample + down_block_additional_residual
  244. )
  245. new_down_block_res_samples = new_down_block_res_samples + (
  246. down_block_res_sample,
  247. )
  248. down_block_res_samples = new_down_block_res_samples
  249. # 4. mid
  250. if self.mid_block is not None:
  251. if (
  252. hasattr(self.mid_block, "has_cross_attention")
  253. and self.mid_block.has_cross_attention
  254. ):
  255. sample = self.mid_block(
  256. sample,
  257. emb,
  258. encoder_hidden_states=encoder_hidden_states,
  259. attention_mask=attention_mask,
  260. cross_attention_kwargs=cross_attention_kwargs,
  261. encoder_attention_mask=encoder_attention_mask,
  262. )
  263. else:
  264. sample = self.mid_block(sample, emb)
  265. # To support T2I-Adapter-XL
  266. if (
  267. is_adapter
  268. and len(down_intrablock_additional_residuals) > 0
  269. and sample.shape == down_intrablock_additional_residuals[0].shape
  270. ):
  271. sample += down_intrablock_additional_residuals.pop(0)
  272. if is_controlnet:
  273. sample = sample + mid_block_additional_residual
  274. if is_brushnet:
  275. sample = sample + mid_block_add_sample
  276. # 5. up
  277. for i, upsample_block in enumerate(self.up_blocks):
  278. is_final_block = i == len(self.up_blocks) - 1
  279. res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
  280. down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
  281. # if we have not reached the final block and need to forward the
  282. # upsample size, we do it here
  283. if not is_final_block and forward_upsample_size:
  284. upsample_size = down_block_res_samples[-1].shape[2:]
  285. if (
  286. hasattr(upsample_block, "has_cross_attention")
  287. and upsample_block.has_cross_attention
  288. ):
  289. additional_residuals = {}
  290. if is_brushnet and len(up_block_add_samples) > 0:
  291. additional_residuals["up_block_add_samples"] = [
  292. up_block_add_samples.pop(0)
  293. for _ in range(
  294. len(upsample_block.resnets)
  295. + (upsample_block.upsamplers != None)
  296. )
  297. ]
  298. sample = upsample_block(
  299. hidden_states=sample,
  300. temb=emb,
  301. res_hidden_states_tuple=res_samples,
  302. encoder_hidden_states=encoder_hidden_states,
  303. cross_attention_kwargs=cross_attention_kwargs,
  304. upsample_size=upsample_size,
  305. attention_mask=attention_mask,
  306. encoder_attention_mask=encoder_attention_mask,
  307. **additional_residuals,
  308. )
  309. else:
  310. additional_residuals = {}
  311. if is_brushnet and len(up_block_add_samples) > 0:
  312. additional_residuals["up_block_add_samples"] = [
  313. up_block_add_samples.pop(0)
  314. for _ in range(
  315. len(upsample_block.resnets)
  316. + (upsample_block.upsamplers != None)
  317. )
  318. ]
  319. sample = upsample_block(
  320. hidden_states=sample,
  321. temb=emb,
  322. res_hidden_states_tuple=res_samples,
  323. upsample_size=upsample_size,
  324. scale=lora_scale,
  325. **additional_residuals,
  326. )
  327. # 6. post-process
  328. if self.conv_norm_out:
  329. sample = self.conv_norm_out(sample)
  330. sample = self.conv_act(sample)
  331. sample = self.conv_out(sample)
  332. if USE_PEFT_BACKEND:
  333. # remove `lora_scale` from each PEFT layer
  334. unscale_lora_layers(self, lora_scale)
  335. if not return_dict:
  336. return (sample,)
  337. return UNet2DConditionOutput(sample=sample)