attention.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. import math
  2. # CrossAttn precision handling
  3. import os
  4. from inspect import isfunction
  5. from typing import Any, Optional
  6. import torch
  7. import torch.nn.functional as F
  8. from einops import rearrange, repeat
  9. from torch import einsum, nn
  10. from sorawm.iopaint.model.anytext.ldm.modules.diffusionmodules.util import checkpoint
  11. _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
  12. def exists(val):
  13. return val is not None
  14. def uniq(arr):
  15. return {el: True for el in arr}.keys()
  16. def default(val, d):
  17. if exists(val):
  18. return val
  19. return d() if isfunction(d) else d
  20. def max_neg_value(t):
  21. return -torch.finfo(t.dtype).max
  22. def init_(tensor):
  23. dim = tensor.shape[-1]
  24. std = 1 / math.sqrt(dim)
  25. tensor.uniform_(-std, std)
  26. return tensor
  27. # feedforward
  28. class GEGLU(nn.Module):
  29. def __init__(self, dim_in, dim_out):
  30. super().__init__()
  31. self.proj = nn.Linear(dim_in, dim_out * 2)
  32. def forward(self, x):
  33. x, gate = self.proj(x).chunk(2, dim=-1)
  34. return x * F.gelu(gate)
  35. class FeedForward(nn.Module):
  36. def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
  37. super().__init__()
  38. inner_dim = int(dim * mult)
  39. dim_out = default(dim_out, dim)
  40. project_in = (
  41. nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
  42. if not glu
  43. else GEGLU(dim, inner_dim)
  44. )
  45. self.net = nn.Sequential(
  46. project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
  47. )
  48. def forward(self, x):
  49. return self.net(x)
  50. def zero_module(module):
  51. """
  52. Zero out the parameters of a module and return it.
  53. """
  54. for p in module.parameters():
  55. p.detach().zero_()
  56. return module
  57. def Normalize(in_channels):
  58. return torch.nn.GroupNorm(
  59. num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
  60. )
  61. class SpatialSelfAttention(nn.Module):
  62. def __init__(self, in_channels):
  63. super().__init__()
  64. self.in_channels = in_channels
  65. self.norm = Normalize(in_channels)
  66. self.q = torch.nn.Conv2d(
  67. in_channels, in_channels, kernel_size=1, stride=1, padding=0
  68. )
  69. self.k = torch.nn.Conv2d(
  70. in_channels, in_channels, kernel_size=1, stride=1, padding=0
  71. )
  72. self.v = torch.nn.Conv2d(
  73. in_channels, in_channels, kernel_size=1, stride=1, padding=0
  74. )
  75. self.proj_out = torch.nn.Conv2d(
  76. in_channels, in_channels, kernel_size=1, stride=1, padding=0
  77. )
  78. def forward(self, x):
  79. h_ = x
  80. h_ = self.norm(h_)
  81. q = self.q(h_)
  82. k = self.k(h_)
  83. v = self.v(h_)
  84. # compute attention
  85. b, c, h, w = q.shape
  86. q = rearrange(q, "b c h w -> b (h w) c")
  87. k = rearrange(k, "b c h w -> b c (h w)")
  88. w_ = torch.einsum("bij,bjk->bik", q, k)
  89. w_ = w_ * (int(c) ** (-0.5))
  90. w_ = torch.nn.functional.softmax(w_, dim=2)
  91. # attend to values
  92. v = rearrange(v, "b c h w -> b c (h w)")
  93. w_ = rearrange(w_, "b i j -> b j i")
  94. h_ = torch.einsum("bij,bjk->bik", v, w_)
  95. h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
  96. h_ = self.proj_out(h_)
  97. return x + h_
  98. class CrossAttention(nn.Module):
  99. def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
  100. super().__init__()
  101. inner_dim = dim_head * heads
  102. context_dim = default(context_dim, query_dim)
  103. self.scale = dim_head**-0.5
  104. self.heads = heads
  105. self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
  106. self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
  107. self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
  108. self.to_out = nn.Sequential(
  109. nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
  110. )
  111. def forward(self, x, context=None, mask=None):
  112. h = self.heads
  113. q = self.to_q(x)
  114. context = default(context, x)
  115. k = self.to_k(context)
  116. v = self.to_v(context)
  117. q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
  118. # force cast to fp32 to avoid overflowing
  119. if _ATTN_PRECISION == "fp32":
  120. with torch.autocast(enabled=False, device_type="cuda"):
  121. q, k = q.float(), k.float()
  122. sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
  123. else:
  124. sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
  125. del q, k
  126. if exists(mask):
  127. mask = rearrange(mask, "b ... -> b (...)")
  128. max_neg_value = -torch.finfo(sim.dtype).max
  129. mask = repeat(mask, "b j -> (b h) () j", h=h)
  130. sim.masked_fill_(~mask, max_neg_value)
  131. # attention, what we cannot get enough of
  132. sim = sim.softmax(dim=-1)
  133. out = einsum("b i j, b j d -> b i d", sim, v)
  134. out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
  135. return self.to_out(out)
  136. class SDPACrossAttention(CrossAttention):
  137. def forward(self, x, context=None, mask=None):
  138. batch_size, sequence_length, inner_dim = x.shape
  139. if mask is not None:
  140. mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
  141. mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
  142. h = self.heads
  143. q_in = self.to_q(x)
  144. context = default(context, x)
  145. k_in = self.to_k(context)
  146. v_in = self.to_v(context)
  147. head_dim = inner_dim // h
  148. q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  149. k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  150. v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
  151. del q_in, k_in, v_in
  152. dtype = q.dtype
  153. if _ATTN_PRECISION == "fp32":
  154. q, k, v = q.float(), k.float(), v.float()
  155. # the output of sdp = (batch, num_heads, seq_len, head_dim)
  156. hidden_states = torch.nn.functional.scaled_dot_product_attention(
  157. q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
  158. )
  159. hidden_states = hidden_states.transpose(1, 2).reshape(
  160. batch_size, -1, h * head_dim
  161. )
  162. hidden_states = hidden_states.to(dtype)
  163. # linear proj
  164. hidden_states = self.to_out[0](hidden_states)
  165. # dropout
  166. hidden_states = self.to_out[1](hidden_states)
  167. return hidden_states
  168. class BasicTransformerBlock(nn.Module):
  169. def __init__(
  170. self,
  171. dim,
  172. n_heads,
  173. d_head,
  174. dropout=0.0,
  175. context_dim=None,
  176. gated_ff=True,
  177. checkpoint=True,
  178. disable_self_attn=False,
  179. ):
  180. super().__init__()
  181. if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
  182. attn_cls = SDPACrossAttention
  183. else:
  184. attn_cls = CrossAttention
  185. self.disable_self_attn = disable_self_attn
  186. self.attn1 = attn_cls(
  187. query_dim=dim,
  188. heads=n_heads,
  189. dim_head=d_head,
  190. dropout=dropout,
  191. context_dim=context_dim if self.disable_self_attn else None,
  192. ) # is a self-attention if not self.disable_self_attn
  193. self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
  194. self.attn2 = attn_cls(
  195. query_dim=dim,
  196. context_dim=context_dim,
  197. heads=n_heads,
  198. dim_head=d_head,
  199. dropout=dropout,
  200. ) # is self-attn if context is none
  201. self.norm1 = nn.LayerNorm(dim)
  202. self.norm2 = nn.LayerNorm(dim)
  203. self.norm3 = nn.LayerNorm(dim)
  204. self.checkpoint = checkpoint
  205. def forward(self, x, context=None):
  206. return checkpoint(
  207. self._forward, (x, context), self.parameters(), self.checkpoint
  208. )
  209. def _forward(self, x, context=None):
  210. x = (
  211. self.attn1(
  212. self.norm1(x), context=context if self.disable_self_attn else None
  213. )
  214. + x
  215. )
  216. x = self.attn2(self.norm2(x), context=context) + x
  217. x = self.ff(self.norm3(x)) + x
  218. return x
  219. class SpatialTransformer(nn.Module):
  220. """
  221. Transformer block for image-like data.
  222. First, project the input (aka embedding)
  223. and reshape to b, t, d.
  224. Then apply standard transformer action.
  225. Finally, reshape to image
  226. NEW: use_linear for more efficiency instead of the 1x1 convs
  227. """
  228. def __init__(
  229. self,
  230. in_channels,
  231. n_heads,
  232. d_head,
  233. depth=1,
  234. dropout=0.0,
  235. context_dim=None,
  236. disable_self_attn=False,
  237. use_linear=False,
  238. use_checkpoint=True,
  239. ):
  240. super().__init__()
  241. if exists(context_dim) and not isinstance(context_dim, list):
  242. context_dim = [context_dim]
  243. self.in_channels = in_channels
  244. inner_dim = n_heads * d_head
  245. self.norm = Normalize(in_channels)
  246. if not use_linear:
  247. self.proj_in = nn.Conv2d(
  248. in_channels, inner_dim, kernel_size=1, stride=1, padding=0
  249. )
  250. else:
  251. self.proj_in = nn.Linear(in_channels, inner_dim)
  252. self.transformer_blocks = nn.ModuleList(
  253. [
  254. BasicTransformerBlock(
  255. inner_dim,
  256. n_heads,
  257. d_head,
  258. dropout=dropout,
  259. context_dim=context_dim[d],
  260. disable_self_attn=disable_self_attn,
  261. checkpoint=use_checkpoint,
  262. )
  263. for d in range(depth)
  264. ]
  265. )
  266. if not use_linear:
  267. self.proj_out = zero_module(
  268. nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
  269. )
  270. else:
  271. self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
  272. self.use_linear = use_linear
  273. def forward(self, x, context=None):
  274. # note: if no context is given, cross-attention defaults to self-attention
  275. if not isinstance(context, list):
  276. context = [context]
  277. b, c, h, w = x.shape
  278. x_in = x
  279. x = self.norm(x)
  280. if not self.use_linear:
  281. x = self.proj_in(x)
  282. x = rearrange(x, "b c h w -> b (h w) c").contiguous()
  283. if self.use_linear:
  284. x = self.proj_in(x)
  285. for i, block in enumerate(self.transformer_blocks):
  286. x = block(x, context=context[i])
  287. if self.use_linear:
  288. x = self.proj_out(x)
  289. x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
  290. if not self.use_linear:
  291. x = self.proj_out(x)
  292. return x + x_in