llama.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. from dataclasses import dataclass
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. from einops import rearrange
  6. from torch import Tensor
  7. from torch.nn import functional as F
  8. from transformers.utils import is_flash_attn_2_available
  9. if is_flash_attn_2_available():
  10. from flash_attn import flash_attn_func, flash_attn_varlen_func
  11. from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
  12. def find_multiple(n: int, k: int) -> int:
  13. if n % k == 0:
  14. return n
  15. return n + k - (n % k)
  16. @dataclass
  17. class ModelArgs:
  18. vocab_size: int = 32000
  19. n_layer: int = 32
  20. n_head: int = 32
  21. dim: int = 4096
  22. intermediate_size: int = None
  23. n_local_heads: int = -1
  24. head_dim: int = 64
  25. rope_base: float = 10000
  26. norm_eps: float = 1e-5
  27. max_seq_len: int = 2048
  28. dropout: float = 0.0
  29. # Additional decoding heads
  30. codebook_size: int = 160
  31. num_codebooks: int = 4
  32. codebook_padding_idx: int = 0
  33. # Use flash attention
  34. use_flash_attention: bool = False
  35. # Gradient checkpointing
  36. use_gradient_checkpointing: bool = True
  37. def __post_init__(self):
  38. if self.n_local_heads == -1:
  39. self.n_local_heads = self.n_head
  40. if self.intermediate_size is None:
  41. hidden_dim = 4 * self.dim
  42. n_hidden = int(2 * hidden_dim / 3)
  43. self.intermediate_size = find_multiple(n_hidden, 256)
  44. self.head_dim = self.dim // self.n_head
  45. class KVCache(nn.Module):
  46. def __init__(
  47. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  48. ):
  49. super().__init__()
  50. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  51. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  52. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  53. def update(self, input_pos, k_val, v_val):
  54. # input_pos: [S], k_val: [B, H, S, D]
  55. assert input_pos.shape[0] == k_val.shape[2]
  56. k_out = self.k_cache
  57. v_out = self.v_cache
  58. k_out[:, :, input_pos] = k_val
  59. v_out[:, :, input_pos] = v_val
  60. return k_out, v_out
  61. @dataclass
  62. class TransformerForwardResult:
  63. token_logits: Tensor
  64. codebook_logits: Tensor
  65. class Transformer(nn.Module):
  66. def __init__(self, config: ModelArgs) -> None:
  67. super().__init__()
  68. self.config = config
  69. self.embeddings = nn.Embedding(
  70. config.vocab_size + config.codebook_size * config.num_codebooks, config.dim
  71. )
  72. self.layers = nn.ModuleList(
  73. TransformerBlock(config) for _ in range(config.n_layer)
  74. )
  75. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  76. self.output = nn.Linear(
  77. config.dim,
  78. config.vocab_size + config.codebook_size * config.num_codebooks,
  79. bias=False,
  80. )
  81. self.register_buffer(
  82. "freqs_cis",
  83. precompute_freqs_cis(
  84. config.max_seq_len,
  85. config.dim // config.n_head,
  86. config.rope_base,
  87. ),
  88. )
  89. self.register_buffer(
  90. "causal_mask",
  91. torch.tril(
  92. torch.ones(
  93. config.max_seq_len,
  94. config.max_seq_len,
  95. dtype=torch.bool,
  96. )
  97. ),
  98. )
  99. # For kv cache
  100. self.max_batch_size = -1
  101. self.max_seq_len = -1
  102. def setup_caches(
  103. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  104. ):
  105. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  106. return
  107. head_dim = self.config.dim // self.config.n_head
  108. max_seq_len = find_multiple(max_seq_len, 8)
  109. self.max_seq_len = max_seq_len
  110. self.max_batch_size = max_batch_size
  111. for b in self.layers:
  112. b.attention.kv_cache = KVCache(
  113. max_batch_size,
  114. max_seq_len,
  115. self.config.n_local_heads,
  116. head_dim,
  117. dtype=dtype,
  118. )
  119. def embed(self, x: Tensor) -> Tensor:
  120. # Here we want to merge the embeddings of the codebooks
  121. if self.config.num_codebooks == 0:
  122. return self.embeddings(x[:, 0])
  123. vocab_embeds = [self.embeddings(x[:, 0])]
  124. for i in range(self.config.num_codebooks):
  125. emb = self.embeddings(
  126. x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
  127. )
  128. emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
  129. vocab_embeds.append(emb)
  130. x = torch.stack(vocab_embeds, dim=3)
  131. return x.sum(dim=3)
  132. def compute(
  133. self,
  134. x: Tensor,
  135. freqs_cis: Tensor,
  136. mask: Tensor,
  137. input_pos: Optional[Tensor] = None,
  138. ) -> TransformerForwardResult:
  139. for layer in self.layers:
  140. if self.config.use_gradient_checkpointing and self.training:
  141. x = torch.utils.checkpoint.checkpoint(
  142. layer, x, freqs_cis, mask, input_pos, use_reentrant=True
  143. )
  144. else:
  145. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  146. x = self.norm(x)
  147. logits = self.output(x)
  148. token_logits = logits[:, :, : self.config.vocab_size]
  149. if self.config.num_codebooks == 0:
  150. return TransformerForwardResult(
  151. token_logits=token_logits,
  152. codebook_logits=None,
  153. )
  154. codebook_logits = logits[:, :, self.config.vocab_size :]
  155. codebook_logits = rearrange(
  156. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  157. )
  158. return TransformerForwardResult(
  159. token_logits=token_logits,
  160. codebook_logits=codebook_logits,
  161. )
  162. def forward(
  163. self, x: Tensor, key_padding_mask: Optional[Tensor] = None
  164. ) -> TransformerForwardResult:
  165. # x: (batch, num_codebooks + 1, seq_len)
  166. seq_len = x.size(2)
  167. # Here we want to merge the embeddings of the codebooks
  168. x = self.embed(x)
  169. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  170. freqs_cis = self.freqs_cis[:seq_len]
  171. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  172. # That is, FALSE means masked out
  173. # To maintain consistency, key_padding_mask use TRUE to mask out
  174. if self.config.use_flash_attention is False and key_padding_mask is not None:
  175. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  176. elif self.config.use_flash_attention is True and key_padding_mask is not None:
  177. mask = key_padding_mask.logical_not()
  178. return self.compute(x, freqs_cis, mask)
  179. def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
  180. # x: (batch, num_codebooks + 1, 1)
  181. assert (
  182. self.max_seq_len != -1 and self.max_batch_size != -1
  183. ), "Please call setup_caches before forward_generate"
  184. x = self.embed(x)
  185. mask = self.causal_mask[
  186. None, None, input_pos, : self.max_seq_len
  187. ] # (B, N, Q, K)
  188. freqs_cis = self.freqs_cis[input_pos]
  189. # TODO: support key padding mask for generation
  190. return self.compute(x, freqs_cis, mask, input_pos=input_pos)
  191. class TransformerBlock(nn.Module):
  192. def __init__(self, config: ModelArgs) -> None:
  193. super().__init__()
  194. self.attention = Attention(config)
  195. self.feed_forward = FeedForward(config)
  196. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  197. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  198. def forward(
  199. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  200. ) -> Tensor:
  201. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  202. out = h + self.feed_forward(self.ffn_norm(h))
  203. return out
  204. class Attention(nn.Module):
  205. def __init__(self, config: ModelArgs):
  206. super().__init__()
  207. assert config.dim % config.n_head == 0
  208. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  209. # key, query, value projections for all heads, but in a batch
  210. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  211. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  212. self.kv_cache = None
  213. self.dropout = config.dropout
  214. self.n_head = config.n_head
  215. self.head_dim = config.head_dim
  216. self.n_local_heads = config.n_local_heads
  217. self.dim = config.dim
  218. self.use_flash_attention = config.use_flash_attention
  219. self._register_load_state_dict_pre_hook(self.load_hook)
  220. def load_hook(self, state_dict, prefix, *args):
  221. if prefix + "wq.weight" in state_dict:
  222. wq = state_dict.pop(prefix + "wq.weight")
  223. wk = state_dict.pop(prefix + "wk.weight")
  224. wv = state_dict.pop(prefix + "wv.weight")
  225. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  226. def forward(
  227. self,
  228. x: Tensor,
  229. freqs_cis: Tensor,
  230. mask: Tensor,
  231. input_pos: Optional[Tensor] = None,
  232. ) -> Tensor:
  233. bsz, seqlen, _ = x.shape
  234. kv_size = self.n_local_heads * self.head_dim
  235. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  236. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  237. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  238. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  239. q = apply_rotary_emb(q, freqs_cis)
  240. k = apply_rotary_emb(k, freqs_cis)
  241. if self.use_flash_attention is False:
  242. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  243. if self.kv_cache is not None:
  244. k, v = self.kv_cache.update(input_pos, k, v)
  245. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  246. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  247. y = F.scaled_dot_product_attention(
  248. q,
  249. k,
  250. v,
  251. attn_mask=mask,
  252. dropout_p=self.dropout if self.training else 0.0,
  253. )
  254. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  255. else:
  256. assert (
  257. self.kv_cache is None
  258. ), "kv_cache is not supported for flash attention for now"
  259. # We don't need to transpose q, k, v here because flash_attn_varlen_func
  260. attn_output = self._flash_attention_forward(
  261. q, k, v, mask, seqlen, dropout=self.dropout if self.training else 0.0
  262. )
  263. y = attn_output.reshape(bsz, seqlen, self.dim).contiguous()
  264. return self.wo(y)
  265. def _flash_attention_forward(
  266. self,
  267. query_states,
  268. key_states,
  269. value_states,
  270. attention_mask,
  271. query_length,
  272. dropout=0.0,
  273. softmax_scale=None,
  274. ):
  275. """
  276. Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
  277. first unpad the input, then computes the attention scores and pad the final attention scores.
  278. Args:
  279. query_states (`torch.Tensor`):
  280. Input query states to be passed to Flash Attention API
  281. key_states (`torch.Tensor`):
  282. Input key states to be passed to Flash Attention API
  283. value_states (`torch.Tensor`):
  284. Input value states to be passed to Flash Attention API
  285. attention_mask (`torch.Tensor`):
  286. The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
  287. position of padding tokens and 1 for the position of non-padding tokens.
  288. dropout (`int`, *optional*):
  289. Attention dropout
  290. softmax_scale (`float`, *optional*):
  291. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
  292. """
  293. # Contains at least one padding token in the sequence
  294. if attention_mask is not None:
  295. batch_size = query_states.shape[0]
  296. (
  297. query_states,
  298. key_states,
  299. value_states,
  300. indices_q,
  301. cu_seq_lens,
  302. max_seq_lens,
  303. ) = self._upad_input(
  304. query_states, key_states, value_states, attention_mask, query_length
  305. )
  306. cu_seqlens_q, cu_seqlens_k = cu_seq_lens
  307. max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
  308. attn_output_unpad = flash_attn_varlen_func(
  309. query_states,
  310. key_states,
  311. value_states,
  312. cu_seqlens_q=cu_seqlens_q,
  313. cu_seqlens_k=cu_seqlens_k,
  314. max_seqlen_q=max_seqlen_in_batch_q,
  315. max_seqlen_k=max_seqlen_in_batch_k,
  316. dropout_p=dropout,
  317. softmax_scale=softmax_scale,
  318. causal=True,
  319. )
  320. attn_output = pad_input(
  321. attn_output_unpad, indices_q, batch_size, query_length
  322. )
  323. else:
  324. attn_output = flash_attn_func(
  325. query_states,
  326. key_states,
  327. value_states,
  328. dropout,
  329. softmax_scale=softmax_scale,
  330. causal=True,
  331. )
  332. return attn_output
  333. def _get_unpad_data(self, attention_mask):
  334. seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  335. indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  336. max_seqlen_in_batch = seqlens_in_batch.max().item()
  337. cu_seqlens = F.pad(
  338. torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
  339. )
  340. return (
  341. indices,
  342. cu_seqlens,
  343. max_seqlen_in_batch,
  344. )
  345. def _upad_input(
  346. self, query_layer, key_layer, value_layer, attention_mask, query_length
  347. ):
  348. indices_k, cu_seqlens_k, max_seqlen_in_batch_k = self._get_unpad_data(
  349. attention_mask
  350. )
  351. batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
  352. key_layer = index_first_axis(
  353. key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
  354. indices_k,
  355. )
  356. value_layer = index_first_axis(
  357. value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
  358. indices_k,
  359. )
  360. if query_length == kv_seq_len:
  361. query_layer = index_first_axis(
  362. query_layer.reshape(batch_size * kv_seq_len, self.n_head, head_dim),
  363. indices_k,
  364. )
  365. cu_seqlens_q = cu_seqlens_k
  366. max_seqlen_in_batch_q = max_seqlen_in_batch_k
  367. indices_q = indices_k
  368. elif query_length == 1:
  369. max_seqlen_in_batch_q = 1
  370. cu_seqlens_q = torch.arange(
  371. batch_size + 1, dtype=torch.int32, device=query_layer.device
  372. ) # There is a memcpy here, that is very bad.
  373. indices_q = cu_seqlens_q[:-1]
  374. query_layer = query_layer.squeeze(1)
  375. else:
  376. # The -q_len: slice assumes left padding.
  377. attention_mask = attention_mask[:, -query_length:]
  378. query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
  379. query_layer, attention_mask
  380. )
  381. return (
  382. query_layer,
  383. key_layer,
  384. value_layer,
  385. indices_q,
  386. (cu_seqlens_q, cu_seqlens_k),
  387. (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
  388. )
  389. class FeedForward(nn.Module):
  390. def __init__(self, config: ModelArgs) -> None:
  391. super().__init__()
  392. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  393. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  394. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  395. def forward(self, x: Tensor) -> Tensor:
  396. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  397. class RMSNorm(nn.Module):
  398. def __init__(self, dim: int, eps: float = 1e-5):
  399. super().__init__()
  400. self.eps = eps
  401. self.weight = nn.Parameter(torch.ones(dim))
  402. def _norm(self, x):
  403. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  404. def forward(self, x: Tensor) -> Tensor:
  405. output = self._norm(x.float()).type_as(x)
  406. return output * self.weight
  407. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  408. freqs = 1.0 / (
  409. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  410. )
  411. t = torch.arange(seq_len, device=freqs.device)
  412. freqs = torch.outer(t, freqs)
  413. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  414. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  415. return cache.to(dtype=torch.bfloat16)
  416. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  417. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  418. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  419. x_out2 = torch.stack(
  420. [
  421. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  422. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  423. ],
  424. -1,
  425. )
  426. x_out2 = x_out2.flatten(3)
  427. return x_out2.type_as(x)
  428. if __name__ == "__main__":
  429. args = ModelArgs(
  430. max_seq_len=4096,
  431. vocab_size=32312,
  432. n_layer=12,
  433. n_head=12,
  434. dim=768,
  435. rope_base=10000,
  436. norm_eps=1e-5,
  437. codebook_size=0,
  438. num_codebooks=0,
  439. )
  440. model = Transformer(args)
  441. model = model.cuda().bfloat16()
  442. print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
  443. inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
  444. key_padding_mask = torch.zeros(2, 128).bool().cuda()
  445. key_padding_mask[0, 2:] = True
  446. x1 = model(inputs, key_padding_mask=key_padding_mask)
  447. print(x1.token_logits.shape)
  448. # print(x1.codebook_logits.shape)