dit.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. import math
  2. from typing import Callable, Optional, Union
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. def modulate(x, shift, scale):
  8. return x * (1 + scale) + shift
  9. def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
  10. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  11. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  12. x_out2 = torch.stack(
  13. [
  14. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  15. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  16. ],
  17. -1,
  18. )
  19. x_out2 = x_out2.flatten(3)
  20. return x_out2.type_as(x)
  21. class TimestepEmbedder(nn.Module):
  22. """
  23. Embeds scalar timesteps into vector representations.
  24. """
  25. def __init__(self, hidden_size, frequency_embedding_size=256):
  26. super().__init__()
  27. self.mlp = FeedForward(
  28. frequency_embedding_size, hidden_size, out_dim=hidden_size
  29. )
  30. self.frequency_embedding_size = frequency_embedding_size
  31. @staticmethod
  32. def timestep_embedding(t, dim, max_period=10000):
  33. """
  34. Create sinusoidal timestep embeddings.
  35. :param t: a 1-D Tensor of N indices, one per batch element.
  36. These may be fractional.
  37. :param dim: the dimension of the output.
  38. :param max_period: controls the minimum frequency of the embeddings.
  39. :return: an (N, D) Tensor of positional embeddings.
  40. """
  41. # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
  42. half = dim // 2
  43. freqs = torch.exp(
  44. -math.log(max_period)
  45. * torch.arange(start=0, end=half, dtype=torch.float32)
  46. / half
  47. ).to(device=t.device)
  48. args = t[:, None].float() * freqs[None]
  49. embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
  50. if dim % 2:
  51. embedding = torch.cat(
  52. [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
  53. )
  54. return embedding
  55. def forward(self, t):
  56. t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
  57. t_emb = self.mlp(t_freq)
  58. return t_emb
  59. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> torch.Tensor:
  60. freqs = 1.0 / (
  61. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  62. )
  63. t = torch.arange(seq_len, device=freqs.device)
  64. freqs = torch.outer(t, freqs)
  65. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  66. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  67. return cache.to(dtype=torch.bfloat16)
  68. class Attention(nn.Module):
  69. def __init__(
  70. self,
  71. dim,
  72. n_head,
  73. ):
  74. super().__init__()
  75. assert dim % n_head == 0
  76. self.dim = dim
  77. self.n_head = n_head
  78. self.head_dim = dim // n_head
  79. self.wq = nn.Linear(dim, dim)
  80. self.wk = nn.Linear(dim, dim)
  81. self.wv = nn.Linear(dim, dim)
  82. self.wo = nn.Linear(dim, dim)
  83. def forward(self, q, freqs_cis, kv=None, mask=None):
  84. bsz, seqlen, _ = q.shape
  85. if kv is None:
  86. kv = q
  87. kv_seqlen = kv.shape[1]
  88. q = self.wq(q).view(bsz, seqlen, self.n_head, self.head_dim)
  89. k = self.wk(kv).view(bsz, kv_seqlen, self.n_head, self.head_dim)
  90. v = self.wv(kv).view(bsz, kv_seqlen, self.n_head, self.head_dim)
  91. q = apply_rotary_emb(q, freqs_cis[:seqlen])
  92. k = apply_rotary_emb(k, freqs_cis[:kv_seqlen])
  93. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  94. y = F.scaled_dot_product_attention(
  95. q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
  96. )
  97. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  98. y = self.wo(y)
  99. return y
  100. class FeedForward(nn.Module):
  101. def __init__(self, in_dim, intermediate_size, out_dim=None):
  102. super().__init__()
  103. self.w1 = nn.Linear(in_dim, intermediate_size)
  104. self.w3 = nn.Linear(in_dim, intermediate_size)
  105. self.w2 = nn.Linear(intermediate_size, out_dim or in_dim)
  106. def forward(self, x: torch.Tensor) -> torch.Tensor:
  107. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  108. class DiTBlock(nn.Module):
  109. def __init__(
  110. self,
  111. hidden_size,
  112. num_heads,
  113. mlp_ratio=4.0,
  114. use_self_attention=True,
  115. use_cross_attention=False,
  116. ):
  117. super().__init__()
  118. self.use_self_attention = use_self_attention
  119. self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
  120. if use_self_attention:
  121. self.mix = Attention(hidden_size, num_heads)
  122. else:
  123. self.mix = nn.Conv1d(
  124. hidden_size,
  125. hidden_size,
  126. kernel_size=7,
  127. padding=3,
  128. bias=True,
  129. groups=hidden_size,
  130. )
  131. self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
  132. self.mlp = FeedForward(hidden_size, int(hidden_size * mlp_ratio))
  133. self.adaLN_modulation = nn.Sequential(
  134. nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
  135. )
  136. self.use_cross_attention = use_cross_attention
  137. if self.use_cross_attention:
  138. self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
  139. self.norm4 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
  140. self.cross_attn = Attention(hidden_size, num_heads)
  141. self.adaLN_modulation_cross = nn.Sequential(
  142. nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size, bias=True)
  143. )
  144. self.adaLN_modulation_cross_condition = nn.Sequential(
  145. nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
  146. )
  147. def forward(
  148. self,
  149. x,
  150. condition,
  151. freqs_cis,
  152. self_mask=None,
  153. cross_condition=None,
  154. cross_mask=None,
  155. ):
  156. (
  157. shift_msa,
  158. scale_msa,
  159. gate_msa,
  160. shift_mlp,
  161. scale_mlp,
  162. gate_mlp,
  163. ) = self.adaLN_modulation(condition).chunk(6, dim=-1)
  164. # Self-attention
  165. inp = modulate(self.norm1(x), shift_msa, scale_msa)
  166. if self.use_self_attention:
  167. inp = self.mix(inp, freqs_cis=freqs_cis, mask=self_mask)
  168. else:
  169. inp = self.mix(inp.mT).mT
  170. x = x + gate_msa * inp
  171. # Cross-attention
  172. if self.use_cross_attention:
  173. (
  174. shift_cross,
  175. scale_cross,
  176. gate_cross,
  177. ) = self.adaLN_modulation_cross(
  178. condition
  179. ).chunk(3, dim=-1)
  180. (
  181. shift_cross_condition,
  182. scale_cross_condition,
  183. ) = self.adaLN_modulation_cross_condition(cross_condition).chunk(2, dim=-1)
  184. inp = modulate(self.norm3(x), shift_cross, scale_cross)
  185. inp = self.cross_attn(
  186. inp,
  187. freqs_cis=freqs_cis,
  188. kv=modulate(
  189. self.norm4(cross_condition),
  190. shift_cross_condition,
  191. scale_cross_condition,
  192. ),
  193. mask=cross_mask,
  194. )
  195. x = x + gate_cross * inp
  196. # MLP
  197. x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
  198. return x
  199. class FinalLayer(nn.Module):
  200. """
  201. The final layer of DiT.
  202. """
  203. def __init__(self, hidden_size, out_channels):
  204. super().__init__()
  205. self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
  206. self.linear = nn.Linear(hidden_size, out_channels, bias=True)
  207. self.adaLN_modulation = nn.Sequential(
  208. nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
  209. )
  210. def forward(self, x, c):
  211. shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
  212. x = modulate(self.norm_final(x), shift, scale)
  213. return self.linear(x)
  214. class DiT(nn.Module):
  215. def __init__(
  216. self,
  217. hidden_size,
  218. num_heads,
  219. diffusion_num_layers,
  220. channels=160,
  221. mlp_ratio=4.0,
  222. max_seq_len=16384,
  223. condition_dim=512,
  224. style_dim=None,
  225. cross_condition_dim=None,
  226. ):
  227. super().__init__()
  228. self.max_seq_len = max_seq_len
  229. self.time_embedder = TimestepEmbedder(hidden_size)
  230. self.condition_embedder = FeedForward(
  231. condition_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
  232. )
  233. if cross_condition_dim is not None:
  234. self.cross_condition_embedder = FeedForward(
  235. cross_condition_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
  236. )
  237. self.use_style = style_dim is not None
  238. if self.use_style:
  239. self.style_embedder = FeedForward(
  240. style_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
  241. )
  242. self.diffusion_blocks = nn.ModuleList(
  243. [
  244. DiTBlock(
  245. hidden_size,
  246. num_heads,
  247. mlp_ratio,
  248. use_self_attention=i % 4 == 0,
  249. use_cross_attention=cross_condition_dim is not None,
  250. )
  251. for i in range(diffusion_num_layers)
  252. ]
  253. )
  254. # Downsample & upsample blocks
  255. self.input_embedder = FeedForward(
  256. channels, int(hidden_size * mlp_ratio), out_dim=hidden_size
  257. )
  258. self.final_layer = FinalLayer(hidden_size, channels)
  259. self.register_buffer(
  260. "freqs_cis", precompute_freqs_cis(max_seq_len, hidden_size // num_heads)
  261. )
  262. self.initialize_weights()
  263. def initialize_weights(self):
  264. # Initialize input embedding:
  265. self.input_embedder.apply(self.init_weight)
  266. self.time_embedder.mlp.apply(self.init_weight)
  267. self.condition_embedder.apply(self.init_weight)
  268. if self.use_style:
  269. self.style_embedder.apply(self.init_weight)
  270. if hasattr(self, "cross_condition_embedder"):
  271. self.cross_condition_embedder.apply(self.init_weight)
  272. for block in self.diffusion_blocks:
  273. nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
  274. nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
  275. block.mix.apply(self.init_weight)
  276. # Zero-out output layers:
  277. nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
  278. nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
  279. self.final_layer.linear.apply(self.init_weight)
  280. def init_weight(self, m):
  281. if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, nn.Linear)):
  282. nn.init.normal_(m.weight, 0, 0.02)
  283. if m.bias is not None:
  284. nn.init.constant_(m.bias, 0)
  285. def forward(
  286. self,
  287. x,
  288. time,
  289. condition,
  290. style=None,
  291. self_mask=None,
  292. cross_condition=None,
  293. cross_mask=None,
  294. ):
  295. # Embed inputs
  296. x = self.input_embedder(x)
  297. t = self.time_embedder(time)
  298. condition = self.condition_embedder(condition)
  299. if self.use_style:
  300. style = self.style_embedder(style)
  301. if cross_condition is not None:
  302. cross_condition = self.cross_condition_embedder(cross_condition)
  303. cross_condition = t[:, None, :] + cross_condition
  304. # Merge t, condition, and style
  305. condition = t[:, None, :] + condition
  306. if self.use_style:
  307. condition = condition + style[:, None, :]
  308. if self_mask is not None:
  309. self_mask = self_mask[:, None, None, :]
  310. if cross_mask is not None:
  311. cross_mask = cross_mask[:, None, None, :]
  312. # DiT
  313. for block in self.diffusion_blocks:
  314. x = block(
  315. x,
  316. condition,
  317. self.freqs_cis,
  318. self_mask=self_mask,
  319. cross_condition=cross_condition,
  320. cross_mask=cross_mask,
  321. )
  322. x = self.final_layer(x, condition)
  323. return x
  324. if __name__ == "__main__":
  325. model = DiT(
  326. hidden_size=384,
  327. num_heads=6,
  328. diffusion_num_layers=12,
  329. channels=160,
  330. condition_dim=512,
  331. style_dim=256,
  332. )
  333. bs, seq_len = 8, 1024
  334. x = torch.randn(bs, seq_len, 160)
  335. condition = torch.randn(bs, seq_len, 512)
  336. style = torch.randn(bs, 256)
  337. mask = torch.ones(bs, seq_len, dtype=torch.bool)
  338. mask[0, 5:] = False
  339. time = torch.arange(bs)
  340. print(time)
  341. out = model(x, time, condition, style, self_mask=mask)
  342. print(out.shape) # torch.Size([2, 100, 160])
  343. # Print model size
  344. num_params = sum(p.numel() for p in model.parameters())
  345. print(f"Number of parameters: {num_params / 1e6:.1f}M")