modules.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. import math
  2. from typing import Optional
  3. import torch
  4. from einops import rearrange
  5. from torch import nn
  6. from torch.nn import functional as F
  7. from transformers.modeling_attn_mask_utils import AttentionMaskConverter
  8. def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
  9. """
  10. Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
  11. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
  12. and the end index 'end'. The 'theta' parameter scales the frequencies.
  13. The returned tensor contains complex values in complex64 data type.
  14. Args:
  15. dim (int): Dimension of the frequency tensor.
  16. end (int): End index for precomputing frequencies.
  17. theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
  18. Returns:
  19. torch.Tensor: Precomputed frequency tensor with complex exponentials.
  20. """
  21. freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
  22. t = torch.arange(end, device=freqs.device) # type: ignore
  23. freqs = torch.outer(t, freqs).float() # type: ignore
  24. freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
  25. return freqs_cis
  26. def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
  27. """
  28. Reshape frequency tensor for broadcasting it with another tensor.
  29. This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
  30. for the purpose of broadcasting the frequency tensor during element-wise operations.
  31. Args:
  32. freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
  33. x (torch.Tensor): Target tensor for broadcasting compatibility.
  34. Returns:
  35. torch.Tensor: Reshaped frequency tensor.
  36. Raises:
  37. AssertionError: If the frequency tensor doesn't match the expected shape.
  38. AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
  39. """
  40. ndim = x.ndim
  41. assert 0 <= 1 < ndim
  42. assert freqs_cis.shape == (x.shape[1], x.shape[-1])
  43. shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
  44. return freqs_cis.view(*shape)
  45. def apply_rotary_emb(
  46. x: torch.Tensor,
  47. freqs_cis: torch.Tensor,
  48. ) -> tuple[torch.Tensor, torch.Tensor]:
  49. x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
  50. freqs_cis = reshape_for_broadcast(freqs_cis, x_)
  51. return torch.view_as_real(x_ * freqs_cis).flatten(3).type_as(x)
  52. class MultiheadAttention(nn.Module):
  53. def __init__(self, d_model, nhead, dropout=0.1, is_cross_attention=False):
  54. super().__init__()
  55. assert d_model % nhead == 0
  56. self.nhead = nhead
  57. self.d_model = d_model
  58. self.head_dim = d_model // nhead
  59. self.is_cross_attention = is_cross_attention
  60. # Auto fuse linear projection
  61. if is_cross_attention:
  62. self.q_proj = nn.Linear(d_model, d_model)
  63. self.kv_proj = nn.Linear(d_model, d_model * 2)
  64. else:
  65. self.qkv_proj = nn.Linear(d_model, d_model * 3)
  66. self.o_proj = nn.Linear(d_model, d_model)
  67. self.dropout = nn.Dropout(dropout)
  68. def forward(
  69. self,
  70. q,
  71. freqs_cis_q,
  72. kv=None,
  73. freqs_cis_kv=None,
  74. attn_mask=None,
  75. input_pos=None,
  76. kv_cache=None,
  77. ):
  78. if self.is_cross_attention:
  79. q = self.q_proj(q)
  80. if kv is None:
  81. assert self.kv_cache is not None, "kv_cache should be initialized"
  82. k, v = None
  83. else:
  84. # Using kv cache
  85. kv = self.kv_proj(kv)
  86. k, v = torch.chunk(kv, 2, dim=-1)
  87. else:
  88. assert kv is None, f"kv should be None for self attention"
  89. assert (
  90. freqs_cis_kv is None
  91. ), f"freqs_cis_kv should be None for self attention"
  92. q, k, v = torch.chunk(self.qkv_proj(q), 3, dim=-1)
  93. # max_batch_size, max_seq_length, n_heads, head_dim
  94. q = rearrange(q, "b t (h d) -> b t h d", h=self.nhead, d=self.head_dim)
  95. q = apply_rotary_emb(q, freqs_cis_q)
  96. if freqs_cis_kv is None:
  97. freqs_cis_kv = freqs_cis_q
  98. # Only do when self attention or cross attention without kv cache
  99. if k is not None:
  100. assert v is not None, "v should not be None when k is not None"
  101. k = rearrange(k, "b t (h d) -> b t h d", h=self.nhead, d=self.head_dim)
  102. v = rearrange(v, "b t (h d) -> b t h d", h=self.nhead, d=self.head_dim)
  103. k = apply_rotary_emb(k, freqs_cis_kv)
  104. if kv_cache is not None:
  105. if k is None:
  106. assert v is None, "v should be None when k is None"
  107. k, v = kv_cache[0], kv_cache[1]
  108. else:
  109. k = torch.cat([kv_cache[0], k], dim=1)
  110. v = torch.cat([kv_cache[1], v], dim=1)
  111. kv_cache = (k, v)
  112. q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
  113. value = F.scaled_dot_product_attention(
  114. q,
  115. k,
  116. v,
  117. attn_mask=attn_mask,
  118. dropout_p=self.dropout.p if self.training else 0,
  119. )
  120. value = rearrange(value, "b h t d -> b t (h d)")
  121. return self.o_proj(value), kv_cache
  122. class GluMLP(nn.Module):
  123. def __init__(self, hidden_size=1024, intermediate_size=None, activation=nn.SiLU):
  124. super().__init__()
  125. if intermediate_size is None:
  126. intermediate_size = hidden_size * (11 / 3)
  127. intermediate_size = round(intermediate_size / 8) * 8
  128. self.hidden_size = hidden_size
  129. self.intermediate_size = intermediate_size
  130. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  131. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  132. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  133. self.act_fn = activation()
  134. def forward(self, x):
  135. return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  136. class RMSNorm(nn.Module):
  137. def __init__(self, hidden_size, eps=1e-6):
  138. """
  139. RMSNorm is equivalent to T5LayerNorm
  140. """
  141. super().__init__()
  142. self.weight = nn.Parameter(torch.ones(hidden_size))
  143. self.variance_epsilon = eps
  144. def forward(self, hidden_states):
  145. input_dtype = hidden_states.dtype
  146. hidden_states = hidden_states.to(torch.float32)
  147. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  148. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  149. return self.weight * hidden_states.to(input_dtype)
  150. class TransformerEncoderLayer(nn.Module):
  151. def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
  152. super().__init__()
  153. self.attention = MultiheadAttention(hidden_size, nhead, dropout=dropout)
  154. self.ffn = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
  155. self.attention_norm = RMSNorm(hidden_size, eps=1e-6)
  156. self.ffn_norm = RMSNorm(hidden_size, eps=1e-6)
  157. def forward(
  158. self,
  159. x,
  160. freqs_cis,
  161. attn_mask=None,
  162. input_pos=None,
  163. ):
  164. x = (
  165. x
  166. + self.attention(
  167. q=self.attention_norm(x),
  168. freqs_cis_q=freqs_cis,
  169. attn_mask=attn_mask,
  170. input_pos=input_pos,
  171. )[0]
  172. )
  173. return x + self.ffn(self.ffn_norm(x))
  174. class TransformerDecoderLayer(nn.Module):
  175. def __init__(self, hidden_size=1024, intermediate_size=None, nhead=16, dropout=0.1):
  176. super().__init__()
  177. self.self_attention = MultiheadAttention(hidden_size, nhead, dropout=dropout)
  178. self.cross_attention = MultiheadAttention(
  179. hidden_size, nhead, dropout=dropout, is_cross_attention=True
  180. )
  181. self.ffn = GluMLP(hidden_size=hidden_size, intermediate_size=intermediate_size)
  182. self.self_attention_norm = RMSNorm(hidden_size, eps=1e-6)
  183. self.cross_attention_norm = RMSNorm(hidden_size, eps=1e-6)
  184. self.ffn_norm = RMSNorm(hidden_size, eps=1e-6)
  185. def forward(
  186. self,
  187. x,
  188. context,
  189. freqs_cis_q,
  190. freqs_cis_kv,
  191. self_attn_mask=None,
  192. cross_attn_mask=None,
  193. input_pos=None,
  194. ):
  195. x = x + self.self_attention(
  196. q=self.self_attention_norm(x),
  197. freqs_cis_q=freqs_cis_q,
  198. attn_mask=self_attn_mask,
  199. input_pos=input_pos,
  200. )
  201. x = x + self.cross_attention(
  202. q=self.cross_attention_norm(x),
  203. kv=context,
  204. freqs_cis_q=freqs_cis_q,
  205. freqs_cis_kv=freqs_cis_kv,
  206. attn_mask=cross_attn_mask,
  207. input_pos=input_pos,
  208. )
  209. return x + self.ffn(self.ffn_norm(x))
  210. class Transformer(nn.Module):
  211. def __init__(
  212. self,
  213. vocab_size,
  214. codebook_size,
  215. num_codebooks,
  216. hidden_size=1024,
  217. intermediate_size=None,
  218. nhead=16,
  219. num_encoder_layers=12,
  220. num_decoder_layers=12,
  221. dropout=0.1,
  222. max_position=4096,
  223. ):
  224. super().__init__()
  225. self.encoder_embedding = nn.Embedding(vocab_size, hidden_size)
  226. self.decoder_embeddings = nn.ModuleList(
  227. [nn.Embedding(codebook_size, hidden_size) for _ in range(num_codebooks)]
  228. )
  229. self.decoder_head = nn.Linear(hidden_size, codebook_size * num_codebooks)
  230. self.codebook_size = codebook_size
  231. self.num_codebooks = num_codebooks
  232. self.nhead = nhead
  233. self.encoder = nn.ModuleList(
  234. [
  235. TransformerEncoderLayer(
  236. hidden_size=hidden_size,
  237. intermediate_size=intermediate_size,
  238. nhead=nhead,
  239. dropout=dropout,
  240. )
  241. for _ in range(num_encoder_layers)
  242. ]
  243. )
  244. self.decoder = nn.ModuleList(
  245. [
  246. TransformerDecoderLayer(
  247. hidden_size=hidden_size,
  248. intermediate_size=intermediate_size,
  249. nhead=nhead,
  250. dropout=dropout,
  251. )
  252. for _ in range(num_decoder_layers)
  253. ]
  254. )
  255. self.register_buffer(
  256. "freqs_cis",
  257. precompute_freqs_cis(hidden_size // nhead, max_position, theta=10000.0),
  258. )
  259. causual_mask = torch.triu(
  260. torch.ones(max_position, max_position), diagonal=1
  261. ).bool()
  262. causual_mask = torch.zeros(max_position, max_position).masked_fill(
  263. causual_mask, float("-inf")
  264. )
  265. self.register_buffer("causual_mask", causual_mask)
  266. # The following are reserved for kv cache
  267. self.max_batch_size = -1
  268. self.max_seq_length = -1
  269. def setup_kv_caches(self, max_batch_size, max_seq_length):
  270. if (
  271. self.max_seq_length >= max_seq_length
  272. and self.max_batch_size >= max_batch_size
  273. ):
  274. return
  275. if max_seq_length % 8 != 0:
  276. max_seq_length = max_seq_length + (8 - max_seq_length % 8)
  277. self.max_seq_length = max_seq_length
  278. self.max_batch_size = max_batch_size
  279. for b in self.decoder:
  280. b.self_attention.kv_cache = KVCache(
  281. max_batch_size,
  282. max_seq_length,
  283. b.self_attention.nhead,
  284. b.self_attention.head_dim,
  285. ).to(b.self_attention_norm.weight.device)
  286. b.cross_attention.kv_cache = KVCache(
  287. max_batch_size,
  288. max_seq_length,
  289. b.cross_attention.nhead,
  290. b.cross_attention.head_dim,
  291. ).to(b.cross_attention_norm.weight.device)
  292. def get_key_padding_mask(self, key_padding_mask, q_size=None):
  293. # inputs: (B, T) bool ->
  294. assert key_padding_mask.dtype == torch.bool and key_padding_mask.ndim == 2
  295. key_padding_mask = (
  296. key_padding_mask.unsqueeze(1).unsqueeze(1).expand(-1, self.nhead, -1, -1)
  297. )
  298. key_padding_mask = key_padding_mask.reshape(
  299. key_padding_mask.shape[0], self.nhead, 1, key_padding_mask.shape[1]
  300. )
  301. if q_size is not None:
  302. key_padding_mask = key_padding_mask.expand(-1, -1, q_size, -1)
  303. new_mask = torch.zeros(
  304. *key_padding_mask.shape, dtype=torch.float, device=key_padding_mask.device
  305. )
  306. new_mask = new_mask.masked_fill(key_padding_mask, float("-inf"))
  307. return new_mask
  308. def forward_encoder(
  309. self, inputs, input_mask=None
  310. ) -> tuple[torch.Tensor, torch.Tensor]:
  311. # inputs: (B, T)
  312. # input_mask: (B, T), bool mask
  313. inputs = self.encoder_embedding(inputs)
  314. # Calculate mask
  315. if input_mask is None:
  316. # Assume no padding
  317. input_mask = torch.zeros(
  318. inputs.shape[0], inputs.shape[1], dtype=torch.bool, device=inputs.device
  319. )
  320. input_mask = self.get_key_padding_mask(input_mask, q_size=None).to(inputs.dtype)
  321. freqs_cis = self.freqs_cis[: inputs.shape[1]]
  322. input_mask_self = input_mask.expand(-1, -1, inputs.shape[1], -1)
  323. for layer in self.encoder:
  324. inputs = layer(inputs, freqs_cis=freqs_cis, attn_mask=input_mask_self)
  325. return inputs, input_mask
  326. def forward_decoder(
  327. self, codes, inputs, input_mask, codes_mask=None, input_pos=None
  328. ):
  329. # codes: (B, C, T)
  330. # inputs: (B, T, N)
  331. print(f"Codes: {codes.shape}, Inputs: {inputs.shape}")
  332. codes = rearrange(codes, "b c t -> c b t")
  333. codes = torch.stack(
  334. [emb(code) for emb, code in zip(self.decoder_embeddings, codes)], dim=0
  335. )
  336. codes = torch.mean(codes, dim=0) # (B, T)
  337. # If kv cache is enabled
  338. input_mask = input_mask.expand(-1, -1, codes.shape[1], -1)
  339. # Calculate mask
  340. if input_pos is not None:
  341. attn_mask = self.causual_mask[: codes.shape[1], : codes.shape[1]]
  342. else:
  343. attn_mask = None
  344. # if codes_mask is not None:
  345. # codes_mask = self.get_key_padding_mask(codes_mask)
  346. # attn_mask = attn_mask + codes_mask
  347. # For kv cache
  348. if input_pos is not None:
  349. freqs_cis_q = self.freqs_cis[input_pos]
  350. else:
  351. freqs_cis_q = self.freqs_cis[: codes.shape[1]]
  352. freqs_cis_kv = self.freqs_cis[: inputs.shape[1]]
  353. for layer in self.decoder:
  354. codes = layer(
  355. codes,
  356. inputs,
  357. freqs_cis_q=freqs_cis_q,
  358. freqs_cis_kv=freqs_cis_kv,
  359. self_attn_mask=attn_mask,
  360. cross_attn_mask=input_mask,
  361. input_pos=input_pos,
  362. )
  363. codes = self.decoder_head(codes)
  364. codes = rearrange(
  365. codes, "b t (c d) -> b c t d", c=self.num_codebooks, d=self.codebook_size
  366. )
  367. return codes
  368. def forward(
  369. self,
  370. inputs,
  371. codes,
  372. input_mask=None,
  373. codes_mask=None,
  374. input_pos=None,
  375. ):
  376. # inputs: (B, T)
  377. # codes: (B, C, T)
  378. # input_mask: (B, T), bool mask
  379. # codes_mask: (B, T), bool mask
  380. # input_pos: (B, T), int mask
  381. inputs, input_mask = self.forward_encoder(inputs, input_mask)
  382. codes = self.forward_decoder(codes, inputs, input_mask, codes_mask, input_pos)
  383. return codes
  384. if __name__ == "__main__":
  385. mha = MultiheadAttention(512, 8, dropout=0, is_cross_attention=True)
  386. mha.eval()
  387. mha.cuda()
  388. q, kv = torch.randn(2, 10, 16, 512)
  389. q, kv = q.cuda(), kv.cuda()
  390. mha.bfloat16()
  391. q, kv = q.bfloat16(), kv.bfloat16()
  392. freqs_cis = precompute_freqs_cis(512 // 8, 4096 * 2).cuda()[:16]
  393. # Causual mask
  394. attn_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool().cuda()
  395. o = mha(q, freqs_cis, kv=kv, attn_mask=attn_mask)
  396. trans = (
  397. Transformer(
  398. vocab_size=30000,
  399. codebook_size=120,
  400. num_codebooks=4,
  401. hidden_size=1024,
  402. intermediate_size=None,
  403. nhead=16,
  404. num_encoder_layers=12,
  405. num_decoder_layers=12,
  406. )
  407. .bfloat16()
  408. .cuda()
  409. )
  410. trans.eval()
  411. # Print n param
  412. print("Total params:", sum(i.numel() for i in trans.parameters()) / 1024 / 1024)
  413. inputs = torch.randint(0, 1000, (2, 16)).cuda()
  414. codes = torch.randint(0, 120, (2, 4, 128)).cuda()
  415. x = trans(inputs, codes)
  416. x1 = trans(inputs, codes)
  417. assert torch.allclose(x, x1, atol=1e-4, rtol=1e-3), "Model is not deterministic"
  418. print("Model is deterministic")
  419. # Test kv cache
  420. trans.setup_kv_caches(2, 1024)
  421. inputs, inputs_mask = trans.forward_encoder(inputs)
  422. outputs = []
  423. for i in range(128):
  424. code = codes[..., i].unsqueeze(-1)
  425. code_mask = torch.tensor([[1], [1]], dtype=torch.bool, device=code.device)
  426. input_pos = torch.tensor([i], dtype=torch.long, device=code.device)
  427. outputs.append(
  428. trans.forward_decoder(
  429. code, inputs, inputs_mask, code_mask, input_pos=input_pos
  430. )
  431. )
  432. outputs = torch.cat(outputs, dim=2)
  433. print(x.shape, outputs.shape)
  434. assert torch.allclose(x, outputs, atol=1e-4, rtol=1e-3), "KV cache is not working"