llama.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. import math
  2. from dataclasses import dataclass
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from einops import rearrange
  7. from torch import Tensor
  8. from torch.nn import functional as F
  9. from torch.utils.checkpoint import checkpoint
  10. def find_multiple(n: int, k: int) -> int:
  11. if n % k == 0:
  12. return n
  13. return n + k - (n % k)
  14. @dataclass
  15. class ModelArgs:
  16. vocab_size: int = 32000
  17. n_slow_layer: int = 32
  18. n_fast_layer: int = 4
  19. n_head: int = 32
  20. dim: int = 4096
  21. intermediate_size: int = None
  22. n_local_heads: int = -1
  23. head_dim: int = 64
  24. rope_base: float = 10000
  25. norm_eps: float = 1e-5
  26. max_seq_len: int = 2048
  27. dropout: float = 0.0
  28. # Additional decoding heads
  29. codebook_size: int = 160
  30. num_codebooks: int = 4
  31. codebook_padding_idx: int = 0
  32. # Gradient checkpointing
  33. use_gradient_checkpointing: bool = True
  34. def __post_init__(self):
  35. if self.n_local_heads == -1:
  36. self.n_local_heads = self.n_head
  37. if self.intermediate_size is None:
  38. hidden_dim = 4 * self.dim
  39. n_hidden = int(2 * hidden_dim / 3)
  40. self.intermediate_size = find_multiple(n_hidden, 256)
  41. self.head_dim = self.dim // self.n_head
  42. class KVCache(nn.Module):
  43. def __init__(
  44. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  45. ):
  46. super().__init__()
  47. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  48. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  49. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  50. def update(self, input_pos, k_val, v_val):
  51. # input_pos: [S], k_val: [B, H, S, D]
  52. assert input_pos.shape[0] == k_val.shape[2]
  53. k_out = self.k_cache
  54. v_out = self.v_cache
  55. k_out[:, :, input_pos] = k_val
  56. v_out[:, :, input_pos] = v_val
  57. return k_out, v_out
  58. @dataclass
  59. class TransformerForwardResult:
  60. token_logits: Tensor
  61. codebook_logits: Tensor
  62. class Transformer(nn.Module):
  63. def __init__(self, config: ModelArgs) -> None:
  64. super().__init__()
  65. self.config = config
  66. # Slow transformer
  67. self.embeddings = nn.Embedding(
  68. config.vocab_size + config.codebook_size * config.num_codebooks,
  69. config.dim,
  70. )
  71. self.slow_layers = nn.ModuleList(
  72. TransformerBlock(config, use_sdpa=True) for _ in range(config.n_slow_layer)
  73. )
  74. self.slow_norm = RMSNorm(config.dim, eps=config.norm_eps)
  75. self.slow_output = nn.Linear(
  76. config.dim,
  77. config.vocab_size,
  78. bias=False,
  79. )
  80. # Fast transformer
  81. self.fast_embeddings = nn.Embedding(
  82. config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
  83. )
  84. # The equivalent bs is so large that sdpa doesn't work
  85. self.fast_layers = nn.ModuleList(
  86. TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
  87. )
  88. self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
  89. self.fast_output = nn.Linear(
  90. config.dim,
  91. config.codebook_size,
  92. bias=False,
  93. )
  94. self.register_buffer(
  95. "freqs_cis",
  96. precompute_freqs_cis(
  97. config.max_seq_len,
  98. config.dim // config.n_head,
  99. config.rope_base,
  100. ),
  101. persistent=False,
  102. )
  103. self.register_buffer(
  104. "causal_mask",
  105. torch.tril(
  106. torch.ones(
  107. config.max_seq_len,
  108. config.max_seq_len,
  109. dtype=torch.bool,
  110. )
  111. ),
  112. persistent=False,
  113. )
  114. # For kv cache
  115. self.max_batch_size = -1
  116. self.max_seq_len = -1
  117. def setup_caches(
  118. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  119. ):
  120. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  121. return
  122. head_dim = self.config.dim // self.config.n_head
  123. max_seq_len = find_multiple(max_seq_len, 8)
  124. self.max_seq_len = max_seq_len
  125. self.max_batch_size = max_batch_size
  126. # Slow transformer
  127. for b in self.slow_layers:
  128. b.attention.kv_cache = KVCache(
  129. max_batch_size,
  130. max_seq_len,
  131. self.config.n_local_heads,
  132. head_dim,
  133. dtype=dtype,
  134. )
  135. # Fast transformer
  136. # The max seq len here is the number of codebooks
  137. for b in self.fast_layers:
  138. b.attention.kv_cache = KVCache(
  139. max_batch_size,
  140. self.config.num_codebooks,
  141. self.config.n_local_heads,
  142. head_dim,
  143. dtype=dtype,
  144. )
  145. def embed(self, x: Tensor) -> Tensor:
  146. # Here we want to merge the embeddings of the codebooks
  147. if self.config.num_codebooks == 0:
  148. return self.embeddings(x[:, 0])
  149. vocab_embeds = [self.embeddings(x[:, 0])]
  150. for i in range(self.config.num_codebooks):
  151. emb = self.embeddings(
  152. x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
  153. )
  154. emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
  155. vocab_embeds.append(emb)
  156. x = torch.stack(vocab_embeds, dim=3)
  157. x = x.sum(dim=3)
  158. return x
  159. def forward(
  160. self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
  161. ) -> TransformerForwardResult:
  162. # x: (batch, num_codebooks + 1, seq_len)
  163. seq_len = inp.size(2)
  164. codebooks = inp[:, 1:]
  165. # Here we want to merge the embeddings of the codebooks
  166. x = self.embed(inp)
  167. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  168. freqs_cis = self.freqs_cis[:seq_len]
  169. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  170. # That is, FALSE means masked out
  171. # To maintain consistency, key_padding_mask use TRUE to mask out
  172. if key_padding_mask is not None:
  173. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  174. for layer in self.slow_layers:
  175. if self.config.use_gradient_checkpointing and self.training:
  176. x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
  177. else:
  178. x = layer(x, freqs_cis, mask)
  179. # We got slow_out here
  180. slow_out = self.slow_norm(x)
  181. token_logits = self.slow_output(slow_out)
  182. # Fast transformer
  183. fast_seq_len = self.config.num_codebooks
  184. fast_mask = self.causal_mask[
  185. None, None, :fast_seq_len, :fast_seq_len
  186. ] # (B, N, Q, K)
  187. fast_freqs_cis = self.freqs_cis[:fast_seq_len]
  188. # Drop the last token and rotate left
  189. codebooks = codebooks[:, :-1, 1:]
  190. codebooks = F.pad(
  191. codebooks, (0, 1, 1, 0), value=self.config.codebook_padding_idx
  192. )
  193. codebook_embeddings = self.fast_embeddings(codebooks)
  194. x = codebook_embeddings + x[:, None] # (B, N + 1, S, D)
  195. b, s = x.size(0), x.size(2)
  196. x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
  197. # Remove padded part
  198. codebooks = rearrange(codebooks, "b n s -> (b s) n")
  199. codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
  200. x_bs, x_len = x.size(0), x.size(1)
  201. x = x[~codebook_mask]
  202. for layer in self.fast_layers:
  203. if self.config.use_gradient_checkpointing and self.training:
  204. x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
  205. else:
  206. x = layer(x, fast_freqs_cis, fast_mask)
  207. # unflatten the batch and num_codebooks
  208. fast_out = self.fast_norm(x)
  209. codebook_logits = self.fast_output(fast_out)
  210. # Re-pad the codebook_logits
  211. buffer = torch.zeros(x_bs, x_len, codebook_logits.size(-1), device=x.device)
  212. buffer[~codebook_mask] = codebook_logits
  213. codebook_logits = buffer
  214. assert codebook_logits.shape[1] == self.config.num_codebooks
  215. codebook_logits = rearrange(
  216. codebook_logits,
  217. "(b s) n d -> b s n d",
  218. b=b,
  219. s=s,
  220. n=self.config.num_codebooks,
  221. )
  222. return TransformerForwardResult(
  223. token_logits=token_logits,
  224. codebook_logits=codebook_logits,
  225. )
  226. def forward_fast(self, x: Tensor) -> Tensor:
  227. # Fast transformer
  228. fast_seq_len = x.shape[1]
  229. fast_mask = self.causal_mask[
  230. None, None, :fast_seq_len, :fast_seq_len
  231. ] # (B, N, Q, K)
  232. fast_freqs_cis = self.freqs_cis[:fast_seq_len]
  233. for layer in self.fast_layers:
  234. x = layer(x, fast_freqs_cis, fast_mask)
  235. fast_out = self.fast_norm(x)
  236. codebook_logits = self.fast_output(fast_out)
  237. return codebook_logits
  238. def forward_generate_slow(
  239. self, x: Tensor, input_pos: Optional[Tensor] = None
  240. ) -> Tensor:
  241. ### TODO: fix this
  242. # x: (batch, num_codebooks + 1, 1)
  243. assert (
  244. self.max_seq_len != -1 and self.max_batch_size != -1
  245. ), "Please call setup_caches before forward_generate"
  246. x = self.embed(x)
  247. mask = self.causal_mask[
  248. None, None, input_pos, : self.max_seq_len
  249. ] # (B, N, Q, K)
  250. freqs_cis = self.freqs_cis[input_pos]
  251. for layer in self.slow_layers:
  252. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  253. # If prefill, we only calculate the logits of last token
  254. if x.size(1) > 1:
  255. x = x[:, -1:]
  256. # We got slow_out here
  257. slow_out = self.slow_norm(x)
  258. token_logits = self.slow_output(slow_out)
  259. return x, token_logits
  260. def forward_generate_fast(
  261. self, x: Tensor, input_pos: Optional[Tensor] = None
  262. ) -> Tensor:
  263. # Fast transformer
  264. x = x.view(1, 1, -1)
  265. fast_mask = self.causal_mask[
  266. None, None, input_pos, : self.config.num_codebooks
  267. ] # (B, N, Q, K)
  268. fast_freqs_cis = self.freqs_cis[input_pos]
  269. for layer in self.fast_layers:
  270. x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
  271. # unflatten the batch and num_codebooks
  272. fast_out = self.fast_norm(x) # only take the last token
  273. codebook_logits = self.fast_output(fast_out)
  274. return codebook_logits
  275. class TransformerBlock(nn.Module):
  276. def __init__(self, config: ModelArgs, use_sdpa: bool = True) -> None:
  277. super().__init__()
  278. self.attention = Attention(config, use_sdpa=use_sdpa)
  279. self.feed_forward = FeedForward(config)
  280. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  281. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  282. def forward(
  283. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  284. ) -> Tensor:
  285. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  286. out = h + self.feed_forward(self.ffn_norm(h))
  287. return out
  288. class Attention(nn.Module):
  289. def __init__(self, config: ModelArgs, use_sdpa: bool = True):
  290. super().__init__()
  291. assert config.dim % config.n_head == 0
  292. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  293. # key, query, value projections for all heads, but in a batch
  294. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  295. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  296. self.kv_cache = None
  297. self.dropout = config.dropout
  298. self.n_head = config.n_head
  299. self.head_dim = config.head_dim
  300. self.n_local_heads = config.n_local_heads
  301. self.dim = config.dim
  302. self.use_sdpa = use_sdpa
  303. self._register_load_state_dict_pre_hook(self.load_hook)
  304. def load_hook(self, state_dict, prefix, *args):
  305. if prefix + "wq.weight" in state_dict:
  306. wq = state_dict.pop(prefix + "wq.weight")
  307. wk = state_dict.pop(prefix + "wk.weight")
  308. wv = state_dict.pop(prefix + "wv.weight")
  309. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  310. def forward(
  311. self,
  312. x: Tensor,
  313. freqs_cis: Tensor,
  314. mask: Tensor,
  315. input_pos: Optional[Tensor] = None,
  316. ) -> Tensor:
  317. bsz, seqlen, _ = x.shape
  318. kv_size = self.n_local_heads * self.head_dim
  319. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  320. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  321. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  322. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  323. q = apply_rotary_emb(q, freqs_cis)
  324. k = apply_rotary_emb(k, freqs_cis)
  325. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  326. if self.kv_cache is not None:
  327. k, v = self.kv_cache.update(input_pos, k, v)
  328. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  329. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  330. if self.use_sdpa:
  331. y = F.scaled_dot_product_attention(
  332. q,
  333. k,
  334. v,
  335. attn_mask=mask,
  336. dropout_p=self.dropout if self.training else 0.0,
  337. )
  338. else:
  339. y = self.eq_scaled_dot_product_attention(
  340. q,
  341. k,
  342. v,
  343. attn_mask=mask,
  344. dropout_p=self.dropout if self.training else 0.0,
  345. )
  346. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  347. return self.wo(y)
  348. def eq_scaled_dot_product_attention(
  349. self,
  350. query,
  351. key,
  352. value,
  353. attn_mask=None,
  354. dropout_p=0.0,
  355. ) -> torch.Tensor:
  356. # This is a standard scaled dot product attention
  357. # It's low efficient, but it doesn't raise cuda error
  358. L, S = query.size(-2), key.size(-2)
  359. scale_factor = 1 / math.sqrt(query.size(-1))
  360. attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
  361. if attn_mask is not None:
  362. if attn_mask.dtype == torch.bool:
  363. attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
  364. else:
  365. attn_bias += attn_mask
  366. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  367. attn_weight += attn_bias
  368. attn_weight = torch.softmax(attn_weight, dim=-1)
  369. attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
  370. return attn_weight @ value
  371. class FeedForward(nn.Module):
  372. def __init__(self, config: ModelArgs) -> None:
  373. super().__init__()
  374. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  375. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  376. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  377. def forward(self, x: Tensor) -> Tensor:
  378. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  379. class RMSNorm(nn.Module):
  380. def __init__(self, dim: int, eps: float = 1e-5):
  381. super().__init__()
  382. self.eps = eps
  383. self.weight = nn.Parameter(torch.ones(dim))
  384. def _norm(self, x):
  385. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  386. def forward(self, x: Tensor) -> Tensor:
  387. output = self._norm(x.float()).type_as(x)
  388. return output * self.weight
  389. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  390. freqs = 1.0 / (
  391. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  392. )
  393. t = torch.arange(seq_len, device=freqs.device)
  394. freqs = torch.outer(t, freqs)
  395. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  396. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  397. return cache.to(dtype=torch.bfloat16)
  398. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  399. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  400. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  401. x_out2 = torch.stack(
  402. [
  403. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  404. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  405. ],
  406. -1,
  407. )
  408. x_out2 = x_out2.flatten(3)
  409. return x_out2.type_as(x)
  410. if __name__ == "__main__":
  411. args = ModelArgs(
  412. max_seq_len=4096,
  413. vocab_size=32312,
  414. n_slow_layer=12,
  415. n_fast_layer=4,
  416. n_head=12,
  417. dim=768,
  418. rope_base=10000,
  419. norm_eps=1e-5,
  420. codebook_size=128,
  421. num_codebooks=4,
  422. )
  423. model = Transformer(args)
  424. model = model.cuda().bfloat16()
  425. print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
  426. inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
  427. key_padding_mask = torch.zeros(2, 128).bool().cuda()
  428. key_padding_mask[0, 2:] = True
  429. x1 = model(inputs, key_padding_mask=key_padding_mask)
  430. print(x1.token_logits.shape)
  431. # print(x1.codebook_logits.shape)