llama.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. from dataclasses import dataclass
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. from einops import rearrange
  6. from torch import Tensor
  7. from torch.nn import functional as F
  8. from transformers.utils import is_flash_attn_2_available
  9. if is_flash_attn_2_available():
  10. from flash_attn import flash_attn_func, flash_attn_varlen_func
  11. from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
  12. def find_multiple(n: int, k: int) -> int:
  13. if n % k == 0:
  14. return n
  15. return n + k - (n % k)
  16. @dataclass
  17. class ModelArgs:
  18. vocab_size: int = 32000
  19. n_layer: int = 32
  20. n_head: int = 32
  21. dim: int = 4096
  22. intermediate_size: int = None
  23. n_local_heads: int = -1
  24. head_dim: int = 64
  25. rope_base: float = 10000
  26. norm_eps: float = 1e-5
  27. max_seq_len: int = 2048
  28. # Additional decoding heads
  29. codebook_size: int = 160
  30. num_codebooks: int = 4
  31. codebook_padding_idx: int = 0
  32. # Use flash attention
  33. use_flash_attention: bool = False
  34. # Gradient checkpointing
  35. use_gradient_checkpointing: bool = True
  36. def __post_init__(self):
  37. if self.n_local_heads == -1:
  38. self.n_local_heads = self.n_head
  39. if self.intermediate_size is None:
  40. hidden_dim = 4 * self.dim
  41. n_hidden = int(2 * hidden_dim / 3)
  42. self.intermediate_size = find_multiple(n_hidden, 256)
  43. self.head_dim = self.dim // self.n_head
  44. class KVCache(nn.Module):
  45. def __init__(
  46. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  47. ):
  48. super().__init__()
  49. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  50. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  51. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  52. def update(self, input_pos, k_val, v_val):
  53. # input_pos: [S], k_val: [B, H, S, D]
  54. assert input_pos.shape[0] == k_val.shape[2]
  55. k_out = self.k_cache
  56. v_out = self.v_cache
  57. k_out[:, :, input_pos] = k_val
  58. v_out[:, :, input_pos] = v_val
  59. return k_out, v_out
  60. @dataclass
  61. class TransformerForwardResult:
  62. token_logits: Tensor
  63. codebook_logits: Tensor
  64. class Transformer(nn.Module):
  65. def __init__(self, config: ModelArgs) -> None:
  66. super().__init__()
  67. self.config = config
  68. self.embeddings = nn.Embedding(
  69. config.vocab_size + config.codebook_size * config.num_codebooks, config.dim
  70. )
  71. self.layers = nn.ModuleList(
  72. TransformerBlock(config) for _ in range(config.n_layer)
  73. )
  74. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  75. self.output = nn.Linear(
  76. config.dim,
  77. config.vocab_size + config.codebook_size * config.num_codebooks,
  78. bias=False,
  79. )
  80. self.register_buffer(
  81. "freqs_cis",
  82. precompute_freqs_cis(
  83. config.max_seq_len,
  84. config.dim // config.n_head,
  85. config.rope_base,
  86. ),
  87. )
  88. self.register_buffer(
  89. "causal_mask",
  90. torch.tril(
  91. torch.ones(
  92. config.max_seq_len,
  93. config.max_seq_len,
  94. dtype=torch.bool,
  95. )
  96. ),
  97. )
  98. # For kv cache
  99. self.max_batch_size = -1
  100. self.max_seq_len = -1
  101. def setup_caches(
  102. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  103. ):
  104. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  105. return
  106. head_dim = self.config.dim // self.config.n_head
  107. max_seq_len = find_multiple(max_seq_len, 8)
  108. self.max_seq_len = max_seq_len
  109. self.max_batch_size = max_batch_size
  110. for b in self.layers:
  111. b.attention.kv_cache = KVCache(
  112. max_batch_size,
  113. max_seq_len,
  114. self.config.n_local_heads,
  115. head_dim,
  116. dtype=dtype,
  117. )
  118. def embed(self, x: Tensor) -> Tensor:
  119. # Here we want to merge the embeddings of the codebooks
  120. if self.config.num_codebooks == 0:
  121. return self.embeddings(x[:, 0])
  122. vocab_embeds = [self.embeddings(x[:, 0])]
  123. for i in range(self.config.num_codebooks):
  124. emb = self.embeddings(
  125. x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
  126. )
  127. emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
  128. vocab_embeds.append(emb)
  129. x = torch.stack(vocab_embeds, dim=3)
  130. return x.sum(dim=3)
  131. def compute(
  132. self,
  133. x: Tensor,
  134. freqs_cis: Tensor,
  135. mask: Tensor,
  136. input_pos: Optional[Tensor] = None,
  137. ) -> TransformerForwardResult:
  138. for layer in self.layers:
  139. if self.config.use_gradient_checkpointing and self.training:
  140. x = torch.utils.checkpoint.checkpoint(
  141. layer, x, freqs_cis, mask, input_pos, use_reentrant=True
  142. )
  143. else:
  144. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  145. x = self.norm(x)
  146. logits = self.output(x)
  147. token_logits = logits[:, :, : self.config.vocab_size]
  148. if self.config.num_codebooks == 0:
  149. return TransformerForwardResult(
  150. token_logits=token_logits,
  151. codebook_logits=None,
  152. )
  153. codebook_logits = logits[:, :, self.config.vocab_size :]
  154. codebook_logits = rearrange(
  155. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  156. )
  157. return TransformerForwardResult(
  158. token_logits=token_logits,
  159. codebook_logits=codebook_logits,
  160. )
  161. def forward(
  162. self, x: Tensor, key_padding_mask: Optional[Tensor] = None
  163. ) -> TransformerForwardResult:
  164. # x: (batch, num_codebooks + 1, seq_len)
  165. seq_len = x.size(2)
  166. # Here we want to merge the embeddings of the codebooks
  167. x = self.embed(x)
  168. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  169. freqs_cis = self.freqs_cis[:seq_len]
  170. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  171. # That is, FALSE means masked out
  172. # To maintain consistency, key_padding_mask use TRUE to mask out
  173. if self.config.use_flash_attention is False and key_padding_mask is not None:
  174. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  175. elif self.config.use_flash_attention is True and key_padding_mask is not None:
  176. mask = key_padding_mask.logical_not()
  177. return self.compute(x, freqs_cis, mask)
  178. def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
  179. # x: (batch, num_codebooks + 1, 1)
  180. assert (
  181. self.max_seq_len != -1 and self.max_batch_size != -1
  182. ), "Please call setup_caches before forward_generate"
  183. x = self.embed(x)
  184. mask = self.causal_mask[
  185. None, None, input_pos, : self.max_seq_len
  186. ] # (B, N, Q, K)
  187. freqs_cis = self.freqs_cis[input_pos]
  188. # TODO: support key padding mask for generation
  189. return self.compute(x, freqs_cis, mask, input_pos=input_pos)
  190. class TransformerBlock(nn.Module):
  191. def __init__(self, config: ModelArgs) -> None:
  192. super().__init__()
  193. self.attention = Attention(config)
  194. self.feed_forward = FeedForward(config)
  195. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  196. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  197. def forward(
  198. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  199. ) -> Tensor:
  200. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  201. out = h + self.feed_forward(self.ffn_norm(h))
  202. return out
  203. class Attention(nn.Module):
  204. def __init__(self, config: ModelArgs):
  205. super().__init__()
  206. assert config.dim % config.n_head == 0
  207. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  208. # key, query, value projections for all heads, but in a batch
  209. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  210. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  211. self.kv_cache = None
  212. self.n_head = config.n_head
  213. self.head_dim = config.head_dim
  214. self.n_local_heads = config.n_local_heads
  215. self.dim = config.dim
  216. self.use_flash_attention = config.use_flash_attention
  217. self._register_load_state_dict_pre_hook(self.load_hook)
  218. def load_hook(self, state_dict, prefix, *args):
  219. if prefix + "wq.weight" in state_dict:
  220. wq = state_dict.pop(prefix + "wq.weight")
  221. wk = state_dict.pop(prefix + "wk.weight")
  222. wv = state_dict.pop(prefix + "wv.weight")
  223. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  224. def forward(
  225. self,
  226. x: Tensor,
  227. freqs_cis: Tensor,
  228. mask: Tensor,
  229. input_pos: Optional[Tensor] = None,
  230. ) -> Tensor:
  231. bsz, seqlen, _ = x.shape
  232. kv_size = self.n_local_heads * self.head_dim
  233. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  234. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  235. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  236. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  237. q = apply_rotary_emb(q, freqs_cis)
  238. k = apply_rotary_emb(k, freqs_cis)
  239. if self.use_flash_attention is False:
  240. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  241. if self.kv_cache is not None:
  242. k, v = self.kv_cache.update(input_pos, k, v)
  243. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  244. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  245. y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
  246. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  247. else:
  248. assert (
  249. self.kv_cache is None
  250. ), "kv_cache is not supported for flash attention for now"
  251. # We don't need to transpose q, k, v here because flash_attn_varlen_func
  252. attn_output = self._flash_attention_forward(
  253. q, k, v, mask, seqlen, dropout=0.0
  254. )
  255. y = attn_output.reshape(bsz, seqlen, self.dim).contiguous()
  256. return self.wo(y)
  257. def _flash_attention_forward(
  258. self,
  259. query_states,
  260. key_states,
  261. value_states,
  262. attention_mask,
  263. query_length,
  264. dropout=0.0,
  265. softmax_scale=None,
  266. ):
  267. """
  268. Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
  269. first unpad the input, then computes the attention scores and pad the final attention scores.
  270. Args:
  271. query_states (`torch.Tensor`):
  272. Input query states to be passed to Flash Attention API
  273. key_states (`torch.Tensor`):
  274. Input key states to be passed to Flash Attention API
  275. value_states (`torch.Tensor`):
  276. Input value states to be passed to Flash Attention API
  277. attention_mask (`torch.Tensor`):
  278. The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
  279. position of padding tokens and 1 for the position of non-padding tokens.
  280. dropout (`int`, *optional*):
  281. Attention dropout
  282. softmax_scale (`float`, *optional*):
  283. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
  284. """
  285. # Contains at least one padding token in the sequence
  286. if attention_mask is not None:
  287. batch_size = query_states.shape[0]
  288. (
  289. query_states,
  290. key_states,
  291. value_states,
  292. indices_q,
  293. cu_seq_lens,
  294. max_seq_lens,
  295. ) = self._upad_input(
  296. query_states, key_states, value_states, attention_mask, query_length
  297. )
  298. cu_seqlens_q, cu_seqlens_k = cu_seq_lens
  299. max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
  300. attn_output_unpad = flash_attn_varlen_func(
  301. query_states,
  302. key_states,
  303. value_states,
  304. cu_seqlens_q=cu_seqlens_q,
  305. cu_seqlens_k=cu_seqlens_k,
  306. max_seqlen_q=max_seqlen_in_batch_q,
  307. max_seqlen_k=max_seqlen_in_batch_k,
  308. dropout_p=dropout,
  309. softmax_scale=softmax_scale,
  310. causal=True,
  311. )
  312. attn_output = pad_input(
  313. attn_output_unpad, indices_q, batch_size, query_length
  314. )
  315. else:
  316. attn_output = flash_attn_func(
  317. query_states,
  318. key_states,
  319. value_states,
  320. dropout,
  321. softmax_scale=softmax_scale,
  322. causal=True,
  323. )
  324. return attn_output
  325. def _get_unpad_data(self, attention_mask):
  326. seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  327. indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  328. max_seqlen_in_batch = seqlens_in_batch.max().item()
  329. cu_seqlens = F.pad(
  330. torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
  331. )
  332. return (
  333. indices,
  334. cu_seqlens,
  335. max_seqlen_in_batch,
  336. )
  337. def _upad_input(
  338. self, query_layer, key_layer, value_layer, attention_mask, query_length
  339. ):
  340. indices_k, cu_seqlens_k, max_seqlen_in_batch_k = self._get_unpad_data(
  341. attention_mask
  342. )
  343. batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
  344. key_layer = index_first_axis(
  345. key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
  346. indices_k,
  347. )
  348. value_layer = index_first_axis(
  349. value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
  350. indices_k,
  351. )
  352. if query_length == kv_seq_len:
  353. query_layer = index_first_axis(
  354. query_layer.reshape(batch_size * kv_seq_len, self.n_head, head_dim),
  355. indices_k,
  356. )
  357. cu_seqlens_q = cu_seqlens_k
  358. max_seqlen_in_batch_q = max_seqlen_in_batch_k
  359. indices_q = indices_k
  360. elif query_length == 1:
  361. max_seqlen_in_batch_q = 1
  362. cu_seqlens_q = torch.arange(
  363. batch_size + 1, dtype=torch.int32, device=query_layer.device
  364. ) # There is a memcpy here, that is very bad.
  365. indices_q = cu_seqlens_q[:-1]
  366. query_layer = query_layer.squeeze(1)
  367. else:
  368. # The -q_len: slice assumes left padding.
  369. attention_mask = attention_mask[:, -query_length:]
  370. query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
  371. query_layer, attention_mask
  372. )
  373. return (
  374. query_layer,
  375. key_layer,
  376. value_layer,
  377. indices_q,
  378. (cu_seqlens_q, cu_seqlens_k),
  379. (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
  380. )
  381. class FeedForward(nn.Module):
  382. def __init__(self, config: ModelArgs) -> None:
  383. super().__init__()
  384. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  385. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  386. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  387. def forward(self, x: Tensor) -> Tensor:
  388. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  389. class RMSNorm(nn.Module):
  390. def __init__(self, dim: int, eps: float = 1e-5):
  391. super().__init__()
  392. self.eps = eps
  393. self.weight = nn.Parameter(torch.ones(dim))
  394. def _norm(self, x):
  395. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  396. def forward(self, x: Tensor) -> Tensor:
  397. output = self._norm(x.float()).type_as(x)
  398. return output * self.weight
  399. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  400. freqs = 1.0 / (
  401. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  402. )
  403. t = torch.arange(seq_len, device=freqs.device)
  404. freqs = torch.outer(t, freqs)
  405. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  406. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  407. return cache.to(dtype=torch.bfloat16)
  408. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  409. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  410. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  411. x_out2 = torch.stack(
  412. [
  413. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  414. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  415. ],
  416. -1,
  417. )
  418. x_out2 = x_out2.flatten(3)
  419. return x_out2.type_as(x)
  420. if __name__ == "__main__":
  421. args = ModelArgs(
  422. max_seq_len=4096,
  423. vocab_size=32312,
  424. n_layer=12,
  425. n_head=12,
  426. dim=768,
  427. rope_base=10000,
  428. norm_eps=1e-5,
  429. codebook_size=0,
  430. num_codebooks=0,
  431. )
  432. model = Transformer(args)
  433. model = model.cuda().bfloat16()
  434. print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
  435. inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
  436. key_padding_mask = torch.zeros(2, 128).bool().cuda()
  437. key_padding_mask[0, 2:] = True
  438. x1 = model(inputs, key_padding_mask=key_padding_mask)
  439. print(x1.token_logits.shape)
  440. # print(x1.codebook_logits.shape)