llama.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. import math
  2. from dataclasses import dataclass
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from einops import rearrange
  7. from torch import Tensor
  8. from torch.nn import functional as F
  9. from torch.utils.checkpoint import checkpoint
  10. def find_multiple(n: int, k: int) -> int:
  11. if n % k == 0:
  12. return n
  13. return n + k - (n % k)
  14. @dataclass
  15. class ModelArgs:
  16. vocab_size: int = 32000
  17. n_slow_layer: int = 32
  18. n_fast_layer: int = 4
  19. n_head: int = 32
  20. dim: int = 4096
  21. intermediate_size: int = None
  22. n_local_heads: int = -1
  23. head_dim: int = 64
  24. rope_base: float = 10000
  25. norm_eps: float = 1e-5
  26. max_seq_len: int = 2048
  27. dropout: float = 0.0
  28. # Additional decoding heads
  29. codebook_size: int = 160
  30. num_codebooks: int = 4
  31. codebook_padding_idx: int = 0
  32. # Gradient checkpointing
  33. use_gradient_checkpointing: bool = True
  34. def __post_init__(self):
  35. if self.n_local_heads == -1:
  36. self.n_local_heads = self.n_head
  37. if self.intermediate_size is None:
  38. hidden_dim = 4 * self.dim
  39. n_hidden = int(2 * hidden_dim / 3)
  40. self.intermediate_size = find_multiple(n_hidden, 256)
  41. self.head_dim = self.dim // self.n_head
  42. class KVCache(nn.Module):
  43. def __init__(
  44. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  45. ):
  46. super().__init__()
  47. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  48. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  49. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  50. def update(self, input_pos, k_val, v_val):
  51. # input_pos: [S], k_val: [B, H, S, D]
  52. assert input_pos.shape[0] == k_val.shape[2]
  53. k_out = self.k_cache
  54. v_out = self.v_cache
  55. k_out[:, :, input_pos] = k_val
  56. v_out[:, :, input_pos] = v_val
  57. return k_out, v_out
  58. @dataclass
  59. class TransformerForwardResult:
  60. token_logits: Tensor
  61. codebook_logits: Tensor
  62. class Transformer(nn.Module):
  63. def __init__(self, config: ModelArgs) -> None:
  64. super().__init__()
  65. self.config = config
  66. # Slow transformer
  67. self.embeddings = nn.Embedding(
  68. config.vocab_size + config.codebook_size * config.num_codebooks,
  69. config.dim,
  70. )
  71. self.slow_layers = nn.ModuleList(
  72. TransformerBlock(config, use_sdpa=True) for _ in range(config.n_slow_layer)
  73. )
  74. self.slow_norm = RMSNorm(config.dim, eps=config.norm_eps)
  75. self.slow_output = nn.Linear(
  76. config.dim,
  77. config.vocab_size,
  78. bias=False,
  79. )
  80. # Fast transformer
  81. self.fast_embeddings = nn.Embedding(
  82. config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
  83. )
  84. # The equivalent bs is so large that sdpa doesn't work
  85. self.fast_layers = nn.ModuleList(
  86. TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
  87. )
  88. self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
  89. self.fast_output = nn.Linear(
  90. config.dim,
  91. config.codebook_size,
  92. bias=False,
  93. )
  94. self.register_buffer(
  95. "freqs_cis",
  96. precompute_freqs_cis(
  97. config.max_seq_len,
  98. config.dim // config.n_head,
  99. config.rope_base,
  100. ),
  101. persistent=False,
  102. )
  103. self.register_buffer(
  104. "causal_mask",
  105. torch.tril(
  106. torch.ones(
  107. config.max_seq_len,
  108. config.max_seq_len,
  109. dtype=torch.bool,
  110. )
  111. ),
  112. persistent=False,
  113. )
  114. # For kv cache
  115. self.max_batch_size = -1
  116. self.max_seq_len = -1
  117. def setup_caches(
  118. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  119. ):
  120. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  121. return
  122. head_dim = self.config.dim // self.config.n_head
  123. max_seq_len = find_multiple(max_seq_len, 8)
  124. self.max_seq_len = max_seq_len
  125. self.max_batch_size = max_batch_size
  126. # Slow transformer
  127. for b in self.slow_layers:
  128. b.attention.kv_cache = KVCache(
  129. max_batch_size,
  130. max_seq_len,
  131. self.config.n_local_heads,
  132. head_dim,
  133. dtype=dtype,
  134. )
  135. # Fast transformer
  136. # The max seq len here is the number of codebooks
  137. for b in self.fast_layers:
  138. b.attention.kv_cache = KVCache(
  139. max_batch_size,
  140. self.config.num_codebooks,
  141. self.config.n_local_heads,
  142. head_dim,
  143. dtype=dtype,
  144. )
  145. def embed(self, x: Tensor) -> Tensor:
  146. # Here we want to merge the embeddings of the codebooks
  147. if self.config.num_codebooks == 0:
  148. return self.embeddings(x[:, 0])
  149. vocab_embeds = [self.embeddings(x[:, 0])]
  150. for i in range(self.config.num_codebooks):
  151. emb = self.embeddings(
  152. x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
  153. )
  154. emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
  155. vocab_embeds.append(emb)
  156. x = torch.stack(vocab_embeds, dim=3)
  157. x = x.sum(dim=3)
  158. return x
  159. def forward(
  160. self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
  161. ) -> TransformerForwardResult:
  162. # x: (batch, num_codebooks + 1, seq_len)
  163. seq_len = inp.size(2)
  164. codebooks = inp[:, 1:]
  165. # Here we want to merge the embeddings of the codebooks
  166. x = self.embed(inp)
  167. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  168. freqs_cis = self.freqs_cis[:seq_len]
  169. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  170. # That is, FALSE means masked out
  171. # To maintain consistency, key_padding_mask use TRUE to mask out
  172. if key_padding_mask is not None:
  173. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  174. for layer in self.slow_layers:
  175. if self.config.use_gradient_checkpointing and self.training:
  176. x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
  177. else:
  178. x = layer(x, freqs_cis, mask)
  179. # We got slow_out here
  180. slow_out = self.slow_norm(x)
  181. token_logits = self.slow_output(slow_out)
  182. # Fast transformer
  183. fast_seq_len = self.config.num_codebooks
  184. fast_mask = self.causal_mask[
  185. None, None, :fast_seq_len, :fast_seq_len
  186. ] # (B, N, Q, K)
  187. fast_freqs_cis = self.freqs_cis[:fast_seq_len]
  188. # Drop the last token and rotate left
  189. codebooks = codebooks[:, :-1, 1:]
  190. codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
  191. codebook_embeddings = self.fast_embeddings(codebooks)
  192. x = torch.cat([x[:, None], codebook_embeddings], dim=1)
  193. b, s = x.size(0), x.size(2)
  194. x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
  195. # Remove padded part
  196. codebooks = rearrange(codebooks, "b n s -> (b s) n")
  197. codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
  198. x_bs, x_len = x.size(0), x.size(1)
  199. x = x[~codebook_mask]
  200. for layer in self.fast_layers:
  201. if self.config.use_gradient_checkpointing and self.training:
  202. x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
  203. else:
  204. x = layer(x, fast_freqs_cis, fast_mask)
  205. # unflatten the batch and num_codebooks
  206. fast_out = self.fast_norm(x)
  207. codebook_logits = self.fast_output(fast_out)
  208. # Re-pad the codebook_logits
  209. buffer = torch.zeros(x_bs, x_len, codebook_logits.size(-1), device=x.device)
  210. buffer[~codebook_mask] = codebook_logits
  211. codebook_logits = buffer
  212. assert codebook_logits.shape[1] == self.config.num_codebooks
  213. codebook_logits = rearrange(
  214. codebook_logits,
  215. "(b s) n d -> b s n d",
  216. b=b,
  217. s=s,
  218. n=self.config.num_codebooks,
  219. )
  220. return TransformerForwardResult(
  221. token_logits=token_logits,
  222. codebook_logits=codebook_logits,
  223. )
  224. def forward_generate_slow(
  225. self, x: Tensor, input_pos: Optional[Tensor] = None
  226. ) -> Tensor:
  227. ### TODO: fix this
  228. # x: (batch, num_codebooks + 1, 1)
  229. assert (
  230. self.max_seq_len != -1 and self.max_batch_size != -1
  231. ), "Please call setup_caches before forward_generate"
  232. x = self.embed(x)
  233. mask = self.causal_mask[
  234. None, None, input_pos, : self.max_seq_len
  235. ] # (B, N, Q, K)
  236. freqs_cis = self.freqs_cis[input_pos]
  237. for layer in self.slow_layers:
  238. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  239. # If prefill, we only calculate the logits of last token
  240. if x.size(1) > 1:
  241. x = x[:, -1:]
  242. # We got slow_out here
  243. slow_out = self.slow_norm(x)
  244. token_logits = self.slow_output(slow_out)
  245. return x, token_logits
  246. def forward_generate_fast(
  247. self, x: Tensor, input_pos: Optional[Tensor] = None
  248. ) -> Tensor:
  249. # Fast transformer
  250. x = x.view(1, 1, -1)
  251. fast_mask = self.causal_mask[
  252. None, None, input_pos, : self.config.num_codebooks
  253. ] # (B, N, Q, K)
  254. fast_freqs_cis = self.freqs_cis[input_pos]
  255. for layer in self.fast_layers:
  256. x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
  257. # unflatten the batch and num_codebooks
  258. fast_out = self.fast_norm(x) # only take the last token
  259. codebook_logits = self.fast_output(fast_out)
  260. return codebook_logits
  261. class TransformerBlock(nn.Module):
  262. def __init__(self, config: ModelArgs, use_sdpa: bool = True) -> None:
  263. super().__init__()
  264. self.attention = Attention(config, use_sdpa=use_sdpa)
  265. self.feed_forward = FeedForward(config)
  266. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  267. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  268. def forward(
  269. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  270. ) -> Tensor:
  271. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  272. out = h + self.feed_forward(self.ffn_norm(h))
  273. return out
  274. class Attention(nn.Module):
  275. def __init__(self, config: ModelArgs, use_sdpa: bool = True):
  276. super().__init__()
  277. assert config.dim % config.n_head == 0
  278. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  279. # key, query, value projections for all heads, but in a batch
  280. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  281. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  282. self.kv_cache = None
  283. self.dropout = config.dropout
  284. self.n_head = config.n_head
  285. self.head_dim = config.head_dim
  286. self.n_local_heads = config.n_local_heads
  287. self.dim = config.dim
  288. self.use_sdpa = use_sdpa
  289. self._register_load_state_dict_pre_hook(self.load_hook)
  290. def load_hook(self, state_dict, prefix, *args):
  291. if prefix + "wq.weight" in state_dict:
  292. wq = state_dict.pop(prefix + "wq.weight")
  293. wk = state_dict.pop(prefix + "wk.weight")
  294. wv = state_dict.pop(prefix + "wv.weight")
  295. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  296. def forward(
  297. self,
  298. x: Tensor,
  299. freqs_cis: Tensor,
  300. mask: Tensor,
  301. input_pos: Optional[Tensor] = None,
  302. ) -> Tensor:
  303. bsz, seqlen, _ = x.shape
  304. kv_size = self.n_local_heads * self.head_dim
  305. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  306. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  307. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  308. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  309. q = apply_rotary_emb(q, freqs_cis)
  310. k = apply_rotary_emb(k, freqs_cis)
  311. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  312. if self.kv_cache is not None:
  313. k, v = self.kv_cache.update(input_pos, k, v)
  314. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  315. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  316. if self.use_sdpa:
  317. y = F.scaled_dot_product_attention(
  318. q,
  319. k,
  320. v,
  321. attn_mask=mask,
  322. dropout_p=self.dropout if self.training else 0.0,
  323. )
  324. else:
  325. y = self.eq_scaled_dot_product_attention(
  326. q,
  327. k,
  328. v,
  329. attn_mask=mask,
  330. dropout_p=self.dropout if self.training else 0.0,
  331. )
  332. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  333. return self.wo(y)
  334. def eq_scaled_dot_product_attention(
  335. self,
  336. query,
  337. key,
  338. value,
  339. attn_mask=None,
  340. dropout_p=0.0,
  341. ) -> torch.Tensor:
  342. # This is a standard scaled dot product attention
  343. # It's low efficient, but it doesn't raise cuda error
  344. L, S = query.size(-2), key.size(-2)
  345. scale_factor = 1 / math.sqrt(query.size(-1))
  346. attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
  347. if attn_mask is not None:
  348. if attn_mask.dtype == torch.bool:
  349. attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
  350. else:
  351. attn_bias += attn_mask
  352. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  353. attn_weight += attn_bias
  354. attn_weight = torch.softmax(attn_weight, dim=-1)
  355. attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
  356. return attn_weight @ value
  357. class FeedForward(nn.Module):
  358. def __init__(self, config: ModelArgs) -> None:
  359. super().__init__()
  360. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  361. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  362. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  363. def forward(self, x: Tensor) -> Tensor:
  364. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  365. class RMSNorm(nn.Module):
  366. def __init__(self, dim: int, eps: float = 1e-5):
  367. super().__init__()
  368. self.eps = eps
  369. self.weight = nn.Parameter(torch.ones(dim))
  370. def _norm(self, x):
  371. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  372. def forward(self, x: Tensor) -> Tensor:
  373. output = self._norm(x.float()).type_as(x)
  374. return output * self.weight
  375. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  376. freqs = 1.0 / (
  377. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  378. )
  379. t = torch.arange(seq_len, device=freqs.device)
  380. freqs = torch.outer(t, freqs)
  381. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  382. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  383. return cache.to(dtype=torch.bfloat16)
  384. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  385. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  386. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  387. x_out2 = torch.stack(
  388. [
  389. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  390. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  391. ],
  392. -1,
  393. )
  394. x_out2 = x_out2.flatten(3)
  395. return x_out2.type_as(x)
  396. if __name__ == "__main__":
  397. args = ModelArgs(
  398. max_seq_len=4096,
  399. vocab_size=32312,
  400. n_slow_layer=12,
  401. n_fast_layer=4,
  402. n_head=12,
  403. dim=768,
  404. rope_base=10000,
  405. norm_eps=1e-5,
  406. codebook_size=128,
  407. num_codebooks=4,
  408. )
  409. model = Transformer(args)
  410. model = model.cuda().bfloat16()
  411. print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
  412. inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
  413. key_padding_mask = torch.zeros(2, 128).bool().cuda()
  414. key_padding_mask[0, 2:] = True
  415. x1 = model(inputs, key_padding_mask=key_padding_mask)
  416. print(x1.token_logits.shape)
  417. # print(x1.codebook_logits.shape)