llama.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. import math
  2. from dataclasses import dataclass
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from einops import rearrange
  7. from torch import Tensor
  8. from torch.nn import functional as F
  9. from torch.utils.checkpoint import checkpoint
  10. def find_multiple(n: int, k: int) -> int:
  11. if n % k == 0:
  12. return n
  13. return n + k - (n % k)
  14. @dataclass
  15. class BaseModelArgs:
  16. vocab_size: int = 32000
  17. n_layer: int = 32
  18. n_head: int = 32
  19. dim: int = 4096
  20. intermediate_size: int = None
  21. n_local_heads: int = -1
  22. head_dim: int = 64
  23. rope_base: float = 10000
  24. norm_eps: float = 1e-5
  25. max_seq_len: int = 2048
  26. dropout: float = 0.0
  27. # Codebook configs
  28. codebook_size: int = 160
  29. num_codebooks: int = 4
  30. num_in_codebooks: Optional[int] = None
  31. codebook_padding_idx: int = 0
  32. # Gradient checkpointing
  33. use_gradient_checkpointing: bool = True
  34. def __post_init__(self):
  35. if self.n_local_heads == -1:
  36. self.n_local_heads = self.n_head
  37. if self.intermediate_size is None:
  38. hidden_dim = 4 * self.dim
  39. n_hidden = int(2 * hidden_dim / 3)
  40. self.intermediate_size = find_multiple(n_hidden, 256)
  41. if self.num_in_codebooks is None:
  42. self.num_in_codebooks = self.num_codebooks
  43. self.head_dim = self.dim // self.n_head
  44. @dataclass
  45. class NaiveModelArgs(BaseModelArgs):
  46. pass
  47. @dataclass
  48. class DualARModelArgs(BaseModelArgs):
  49. n_fast_layer: int = 4
  50. class KVCache(nn.Module):
  51. def __init__(
  52. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  53. ):
  54. super().__init__()
  55. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  56. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  57. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  58. def update(self, input_pos, k_val, v_val):
  59. # input_pos: [S], k_val: [B, H, S, D]
  60. assert input_pos.shape[0] == k_val.shape[2]
  61. k_out = self.k_cache
  62. v_out = self.v_cache
  63. k_out[:, :, input_pos] = k_val
  64. v_out[:, :, input_pos] = v_val
  65. return k_out, v_out
  66. @dataclass
  67. class TransformerForwardResult:
  68. token_logits: Tensor
  69. codebook_logits: Tensor
  70. @dataclass
  71. class BaseTransformerForwardResult:
  72. logits: Tensor
  73. hidden_states: Tensor
  74. class BaseTransformer(nn.Module):
  75. def __init__(self, config: BaseModelArgs) -> None:
  76. super().__init__()
  77. self.config = config
  78. # Slow transformer
  79. self.embeddings = nn.Embedding(
  80. config.vocab_size + config.codebook_size * config.num_in_codebooks,
  81. config.dim,
  82. )
  83. self.layers = nn.ModuleList(
  84. TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
  85. )
  86. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  87. self.output = nn.Linear(
  88. config.dim,
  89. config.vocab_size,
  90. bias=False,
  91. )
  92. self.register_buffer(
  93. "freqs_cis",
  94. precompute_freqs_cis(
  95. config.max_seq_len,
  96. config.dim // config.n_head,
  97. config.rope_base,
  98. ),
  99. persistent=False,
  100. )
  101. self.register_buffer(
  102. "causal_mask",
  103. torch.tril(
  104. torch.ones(
  105. config.max_seq_len,
  106. config.max_seq_len,
  107. dtype=torch.bool,
  108. )
  109. ),
  110. persistent=False,
  111. )
  112. # For kv cache
  113. self.max_batch_size = -1
  114. self.max_seq_len = -1
  115. def setup_caches(
  116. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  117. ):
  118. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  119. return
  120. head_dim = self.config.dim // self.config.n_head
  121. max_seq_len = find_multiple(max_seq_len, 8)
  122. self.max_seq_len = max_seq_len
  123. self.max_batch_size = max_batch_size
  124. for b in self.layers:
  125. b.attention.kv_cache = KVCache(
  126. max_batch_size,
  127. max_seq_len,
  128. self.config.n_local_heads,
  129. head_dim,
  130. dtype=dtype,
  131. )
  132. def embed(self, x: Tensor) -> Tensor:
  133. vocab_embeds = [self.embeddings(x[:, 0])]
  134. for i in range(self.config.num_in_codebooks):
  135. emb = self.embeddings(
  136. x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
  137. )
  138. emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
  139. vocab_embeds.append(emb)
  140. x = torch.stack(vocab_embeds, dim=3)
  141. x = x.sum(dim=3)
  142. return x
  143. def forward(
  144. self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
  145. ) -> BaseTransformerForwardResult:
  146. # x: (batch, num_codebooks + 1, seq_len)
  147. seq_len = inp.size(2)
  148. # Here we want to merge the embeddings of the codebooks
  149. x = self.embed(inp)
  150. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  151. freqs_cis = self.freqs_cis[:seq_len]
  152. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  153. # That is, FALSE means masked out
  154. # To maintain consistency, key_padding_mask use TRUE to mask out
  155. if key_padding_mask is not None:
  156. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  157. for layer in self.layers:
  158. if self.config.use_gradient_checkpointing and self.training:
  159. x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
  160. else:
  161. x = layer(x, freqs_cis, mask)
  162. # We got slow_out here
  163. slow_out = self.norm(x)
  164. token_logits = self.output(slow_out)
  165. return BaseTransformerForwardResult(
  166. logits=token_logits,
  167. hidden_states=x,
  168. )
  169. def forward_generate(
  170. self, x: Tensor, input_pos: Optional[Tensor] = None
  171. ) -> BaseTransformerForwardResult:
  172. # This is used for generation, optimized for torch compile
  173. assert (
  174. self.max_seq_len != -1 and self.max_batch_size != -1
  175. ), "Please call setup_caches before forward_generate"
  176. x = self.embed(x)
  177. mask = self.causal_mask[
  178. None, None, input_pos, : self.max_seq_len
  179. ] # (B, N, Q, K)
  180. freqs_cis = self.freqs_cis[input_pos]
  181. for layer in self.layers:
  182. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  183. # If prefill, we only calculate the logits of last token
  184. if x.size(1) > 1:
  185. x = x[:, -1:]
  186. # We got slow_out here
  187. slow_out = self.norm(x)
  188. token_logits = self.output(slow_out)
  189. return BaseTransformerForwardResult(
  190. logits=token_logits,
  191. hidden_states=x,
  192. )
  193. class NaiveTransformer(BaseTransformer):
  194. def __init__(self, config: NaiveModelArgs) -> None:
  195. super().__init__(config)
  196. self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
  197. self.codebook_output = nn.Linear(
  198. config.dim,
  199. config.codebook_size * config.num_codebooks,
  200. bias=False,
  201. )
  202. def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
  203. token_logits = result.logits
  204. x = result.hidden_states
  205. # Codebook
  206. codebook_logits = self.codebook_output(self.codebook_norm(x))
  207. codebook_logits = rearrange(
  208. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  209. )
  210. return TransformerForwardResult(
  211. token_logits=token_logits,
  212. codebook_logits=codebook_logits,
  213. )
  214. def forward(
  215. self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
  216. ) -> TransformerForwardResult:
  217. result = super().forward(inp, key_padding_mask)
  218. return self.decode(result)
  219. def forward_generate(
  220. self, x: Tensor, input_pos: Optional[Tensor] = None
  221. ) -> TransformerForwardResult:
  222. result = super().forward_generate(x, input_pos)
  223. return self.decode(result)
  224. class DualARTransformer(BaseTransformer):
  225. def __init__(self, config: DualARModelArgs) -> None:
  226. super().__init__(config)
  227. # Fast transformer
  228. self.fast_embeddings = nn.Embedding(
  229. config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
  230. )
  231. # The equivalent bs is so large that sdpa doesn't work
  232. self.fast_layers = nn.ModuleList(
  233. TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
  234. )
  235. self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
  236. self.fast_output = nn.Linear(
  237. config.dim,
  238. config.codebook_size,
  239. bias=False,
  240. )
  241. def setup_caches(
  242. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  243. ):
  244. super().setup_caches(max_batch_size, max_seq_len, dtype)
  245. head_dim = self.config.dim // self.config.n_head
  246. # Fast transformer
  247. # The max seq len here is the number of codebooks
  248. for b in self.fast_layers:
  249. b.attention.kv_cache = KVCache(
  250. max_batch_size,
  251. self.config.num_codebooks,
  252. self.config.n_local_heads,
  253. head_dim,
  254. dtype=dtype,
  255. )
  256. def forward(
  257. self, inp: Tensor, key_padding_mask: Optional[Tensor] = None
  258. ) -> TransformerForwardResult:
  259. parent_result = super().forward(inp, key_padding_mask)
  260. token_logits = parent_result.logits
  261. x = parent_result.hidden_states
  262. # Fast transformer
  263. fast_seq_len = self.config.num_codebooks
  264. fast_mask = self.causal_mask[
  265. None, None, :fast_seq_len, :fast_seq_len
  266. ] # (B, N, Q, K)
  267. fast_freqs_cis = self.freqs_cis[:fast_seq_len]
  268. # Drop the last token and rotate left
  269. codebooks = inp[:, 1:-1, 1:]
  270. codebooks = F.pad(codebooks, (0, 1), value=self.config.codebook_padding_idx)
  271. codebook_embeddings = self.fast_embeddings(codebooks)
  272. x = torch.cat([x[:, None], codebook_embeddings], dim=1)
  273. b, s = x.size(0), x.size(2)
  274. x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
  275. # Remove padded part
  276. codebooks = rearrange(codebooks, "b n s -> (b s) n")
  277. codebook_mask = (codebooks == self.config.codebook_padding_idx).all(dim=-1)
  278. x_bs, x_len = x.size(0), x.size(1)
  279. x = x[~codebook_mask]
  280. for layer in self.fast_layers:
  281. if self.config.use_gradient_checkpointing and self.training:
  282. x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
  283. else:
  284. x = layer(x, fast_freqs_cis, fast_mask)
  285. # unflatten the batch and num_codebooks
  286. fast_out = self.fast_norm(x)
  287. codebook_logits = self.fast_output(fast_out)
  288. # Re-pad the codebook_logits
  289. buffer = torch.zeros(
  290. x_bs,
  291. x_len,
  292. codebook_logits.size(-1),
  293. device=codebook_logits.device,
  294. dtype=codebook_logits.dtype,
  295. )
  296. buffer[~codebook_mask] = codebook_logits
  297. codebook_logits = buffer
  298. assert codebook_logits.shape[1] == self.config.num_codebooks
  299. codebook_logits = rearrange(
  300. codebook_logits,
  301. "(b s) n d -> b s n d",
  302. b=b,
  303. s=s,
  304. n=self.config.num_codebooks,
  305. )
  306. return TransformerForwardResult(
  307. token_logits=token_logits,
  308. codebook_logits=codebook_logits,
  309. )
  310. def forward_generate_fast(
  311. self, x: Tensor, input_pos: Optional[Tensor] = None
  312. ) -> Tensor:
  313. # Fast transformer
  314. x = x.view(1, 1, -1)
  315. fast_mask = self.causal_mask[
  316. None, None, input_pos, : self.config.num_codebooks
  317. ] # (B, N, Q, K)
  318. fast_freqs_cis = self.freqs_cis[input_pos]
  319. for layer in self.fast_layers:
  320. x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
  321. # unflatten the batch and num_codebooks
  322. fast_out = self.fast_norm(x) # only take the last token
  323. codebook_logits = self.fast_output(fast_out)
  324. return codebook_logits
  325. class TransformerBlock(nn.Module):
  326. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
  327. super().__init__()
  328. self.attention = Attention(config, use_sdpa=use_sdpa)
  329. self.feed_forward = FeedForward(config)
  330. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  331. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  332. def forward(
  333. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  334. ) -> Tensor:
  335. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  336. out = h + self.feed_forward(self.ffn_norm(h))
  337. return out
  338. class Attention(nn.Module):
  339. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
  340. super().__init__()
  341. assert config.dim % config.n_head == 0
  342. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  343. # key, query, value projections for all heads, but in a batch
  344. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  345. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  346. self.kv_cache = None
  347. self.dropout = config.dropout
  348. self.n_head = config.n_head
  349. self.head_dim = config.head_dim
  350. self.n_local_heads = config.n_local_heads
  351. self.dim = config.dim
  352. self.use_sdpa = use_sdpa
  353. self._register_load_state_dict_pre_hook(self.load_hook)
  354. def load_hook(self, state_dict, prefix, *args):
  355. if prefix + "wq.weight" in state_dict:
  356. wq = state_dict.pop(prefix + "wq.weight")
  357. wk = state_dict.pop(prefix + "wk.weight")
  358. wv = state_dict.pop(prefix + "wv.weight")
  359. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  360. def forward(
  361. self,
  362. x: Tensor,
  363. freqs_cis: Tensor,
  364. mask: Tensor,
  365. input_pos: Optional[Tensor] = None,
  366. ) -> Tensor:
  367. bsz, seqlen, _ = x.shape
  368. kv_size = self.n_local_heads * self.head_dim
  369. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  370. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  371. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  372. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  373. q = apply_rotary_emb(q, freqs_cis)
  374. k = apply_rotary_emb(k, freqs_cis)
  375. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  376. if self.kv_cache is not None:
  377. k, v = self.kv_cache.update(input_pos, k, v)
  378. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  379. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  380. if self.use_sdpa:
  381. y = F.scaled_dot_product_attention(
  382. q,
  383. k,
  384. v,
  385. attn_mask=mask,
  386. dropout_p=self.dropout if self.training else 0.0,
  387. )
  388. else:
  389. y = self.eq_scaled_dot_product_attention(
  390. q,
  391. k,
  392. v,
  393. attn_mask=mask,
  394. dropout_p=self.dropout if self.training else 0.0,
  395. )
  396. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  397. return self.wo(y)
  398. def eq_scaled_dot_product_attention(
  399. self,
  400. query,
  401. key,
  402. value,
  403. attn_mask=None,
  404. dropout_p=0.0,
  405. ) -> torch.Tensor:
  406. # This is a standard scaled dot product attention
  407. # It's low efficient, but it doesn't raise cuda error
  408. L, S = query.size(-2), key.size(-2)
  409. scale_factor = 1 / math.sqrt(query.size(-1))
  410. attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
  411. if attn_mask is not None:
  412. if attn_mask.dtype == torch.bool:
  413. attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
  414. else:
  415. attn_bias += attn_mask
  416. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  417. attn_weight += attn_bias
  418. attn_weight = torch.softmax(attn_weight, dim=-1)
  419. attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
  420. return attn_weight @ value
  421. class FeedForward(nn.Module):
  422. def __init__(self, config: BaseModelArgs) -> None:
  423. super().__init__()
  424. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  425. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  426. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  427. def forward(self, x: Tensor) -> Tensor:
  428. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  429. class RMSNorm(nn.Module):
  430. def __init__(self, dim: int, eps: float = 1e-5):
  431. super().__init__()
  432. self.eps = eps
  433. self.weight = nn.Parameter(torch.ones(dim))
  434. def _norm(self, x):
  435. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  436. def forward(self, x: Tensor) -> Tensor:
  437. output = self._norm(x.float()).type_as(x)
  438. return output * self.weight
  439. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  440. freqs = 1.0 / (
  441. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  442. )
  443. t = torch.arange(seq_len, device=freqs.device)
  444. freqs = torch.outer(t, freqs)
  445. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  446. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  447. return cache.to(dtype=torch.bfloat16)
  448. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  449. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  450. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  451. x_out2 = torch.stack(
  452. [
  453. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  454. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  455. ],
  456. -1,
  457. )
  458. x_out2 = x_out2.flatten(3)
  459. return x_out2.type_as(x)
  460. if __name__ == "__main__":
  461. args = DualARModelArgs(
  462. max_seq_len=4096,
  463. vocab_size=32312,
  464. n_layer=12,
  465. n_fast_layer=4,
  466. n_head=12,
  467. dim=768,
  468. rope_base=10000,
  469. norm_eps=1e-5,
  470. codebook_size=128,
  471. num_codebooks=4,
  472. )
  473. model = DualARTransformer(args)
  474. model = model.cuda().bfloat16()
  475. print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
  476. inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
  477. key_padding_mask = torch.zeros(2, 128).bool().cuda()
  478. key_padding_mask[0, 2:] = True
  479. x1 = model(inputs, key_padding_mask=key_padding_mask)
  480. print(x1.token_logits.shape)
  481. print(x1.codebook_logits.shape)