llama.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. import math
  2. from dataclasses import dataclass
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from einops import rearrange
  7. from torch import Tensor
  8. from torch.nn import functional as F
  9. from transformers.utils import is_flash_attn_2_available
  10. if is_flash_attn_2_available():
  11. from flash_attn import flash_attn_func, flash_attn_varlen_func
  12. from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
  13. def find_multiple(n: int, k: int) -> int:
  14. if n % k == 0:
  15. return n
  16. return n + k - (n % k)
  17. @dataclass
  18. class ModelArgs:
  19. vocab_size: int = 32000
  20. n_layer: int = 32
  21. n_head: int = 32
  22. dim: int = 4096
  23. intermediate_size: int = None
  24. n_local_heads: int = -1
  25. head_dim: int = 64
  26. rope_base: float = 10000
  27. norm_eps: float = 1e-5
  28. max_seq_len: int = 2048
  29. dropout: float = 0.0
  30. # Additional decoding heads
  31. codebook_size: int = 160
  32. num_codebooks: int = 4
  33. codebook_padding_idx: int = 0
  34. # Use flash attention
  35. use_flash_attention: bool = False
  36. # Gradient checkpointing
  37. use_gradient_checkpointing: bool = True
  38. # NEFT
  39. neft_alpha: float = 0
  40. def __post_init__(self):
  41. if self.n_local_heads == -1:
  42. self.n_local_heads = self.n_head
  43. if self.intermediate_size is None:
  44. hidden_dim = 4 * self.dim
  45. n_hidden = int(2 * hidden_dim / 3)
  46. self.intermediate_size = find_multiple(n_hidden, 256)
  47. self.head_dim = self.dim // self.n_head
  48. class KVCache(nn.Module):
  49. def __init__(
  50. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  51. ):
  52. super().__init__()
  53. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  54. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  55. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  56. def update(self, input_pos, k_val, v_val):
  57. # input_pos: [S], k_val: [B, H, S, D]
  58. assert input_pos.shape[0] == k_val.shape[2]
  59. k_out = self.k_cache
  60. v_out = self.v_cache
  61. k_out[:, :, input_pos] = k_val
  62. v_out[:, :, input_pos] = v_val
  63. return k_out, v_out
  64. @dataclass
  65. class TransformerForwardResult:
  66. token_logits: Tensor
  67. codebook_logits: Tensor
  68. class Transformer(nn.Module):
  69. def __init__(self, config: ModelArgs) -> None:
  70. super().__init__()
  71. self.config = config
  72. self.embeddings = nn.Embedding(
  73. config.vocab_size + config.codebook_size * config.num_codebooks, config.dim
  74. )
  75. self.layers = nn.ModuleList(
  76. TransformerBlock(config) for _ in range(config.n_layer)
  77. )
  78. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  79. self.output = nn.Linear(
  80. config.dim,
  81. config.vocab_size + config.codebook_size * config.num_codebooks,
  82. bias=False,
  83. )
  84. self.register_buffer(
  85. "freqs_cis",
  86. precompute_freqs_cis(
  87. config.max_seq_len,
  88. config.dim // config.n_head,
  89. config.rope_base,
  90. ),
  91. )
  92. self.register_buffer(
  93. "causal_mask",
  94. torch.tril(
  95. torch.ones(
  96. config.max_seq_len,
  97. config.max_seq_len,
  98. dtype=torch.bool,
  99. )
  100. ),
  101. )
  102. # For kv cache
  103. self.max_batch_size = -1
  104. self.max_seq_len = -1
  105. def setup_caches(
  106. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  107. ):
  108. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  109. return
  110. head_dim = self.config.dim // self.config.n_head
  111. max_seq_len = find_multiple(max_seq_len, 8)
  112. self.max_seq_len = max_seq_len
  113. self.max_batch_size = max_batch_size
  114. for b in self.layers:
  115. b.attention.kv_cache = KVCache(
  116. max_batch_size,
  117. max_seq_len,
  118. self.config.n_local_heads,
  119. head_dim,
  120. dtype=dtype,
  121. )
  122. def embed(self, x: Tensor) -> Tensor:
  123. # Here we want to merge the embeddings of the codebooks
  124. if self.config.num_codebooks == 0:
  125. return self.embeddings(x[:, 0])
  126. vocab_embeds = [self.embeddings(x[:, 0])]
  127. for i in range(self.config.num_codebooks):
  128. emb = self.embeddings(
  129. x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
  130. )
  131. emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
  132. vocab_embeds.append(emb)
  133. x = torch.stack(vocab_embeds, dim=3)
  134. x = x.sum(dim=3)
  135. if self.config.neft_alpha > 0:
  136. # alpha / sqrt(L * D)
  137. scaled_alpha = self.config.neft_alpha / math.sqrt(
  138. self.config.dim * x.shape[2]
  139. )
  140. x += torch.rand_like(x) * scaled_alpha
  141. return x
  142. def compute(
  143. self,
  144. x: Tensor,
  145. freqs_cis: Tensor,
  146. mask: Tensor,
  147. input_pos: Optional[Tensor] = None,
  148. ) -> TransformerForwardResult:
  149. for layer in self.layers:
  150. if self.config.use_gradient_checkpointing and self.training:
  151. x = torch.utils.checkpoint.checkpoint(
  152. layer, x, freqs_cis, mask, input_pos, use_reentrant=True
  153. )
  154. else:
  155. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  156. x = self.norm(x)
  157. logits = self.output(x)
  158. token_logits = logits[:, :, : self.config.vocab_size]
  159. if self.config.num_codebooks == 0:
  160. return TransformerForwardResult(
  161. token_logits=token_logits,
  162. codebook_logits=None,
  163. )
  164. codebook_logits = logits[:, :, self.config.vocab_size :]
  165. codebook_logits = rearrange(
  166. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  167. )
  168. return TransformerForwardResult(
  169. token_logits=token_logits,
  170. codebook_logits=codebook_logits,
  171. )
  172. def forward(
  173. self, x: Tensor, key_padding_mask: Optional[Tensor] = None
  174. ) -> TransformerForwardResult:
  175. # x: (batch, num_codebooks + 1, seq_len)
  176. seq_len = x.size(2)
  177. # Here we want to merge the embeddings of the codebooks
  178. x = self.embed(x)
  179. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  180. freqs_cis = self.freqs_cis[:seq_len]
  181. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  182. # That is, FALSE means masked out
  183. # To maintain consistency, key_padding_mask use TRUE to mask out
  184. if self.config.use_flash_attention is False and key_padding_mask is not None:
  185. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  186. elif self.config.use_flash_attention is True and key_padding_mask is not None:
  187. mask = key_padding_mask.logical_not()
  188. return self.compute(x, freqs_cis, mask)
  189. def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
  190. # x: (batch, num_codebooks + 1, 1)
  191. assert (
  192. self.max_seq_len != -1 and self.max_batch_size != -1
  193. ), "Please call setup_caches before forward_generate"
  194. x = self.embed(x)
  195. mask = self.causal_mask[
  196. None, None, input_pos, : self.max_seq_len
  197. ] # (B, N, Q, K)
  198. freqs_cis = self.freqs_cis[input_pos]
  199. # TODO: support key padding mask for generation
  200. return self.compute(x, freqs_cis, mask, input_pos=input_pos)
  201. class TransformerBlock(nn.Module):
  202. def __init__(self, config: ModelArgs) -> None:
  203. super().__init__()
  204. self.attention = Attention(config)
  205. self.feed_forward = FeedForward(config)
  206. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  207. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  208. def forward(
  209. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  210. ) -> Tensor:
  211. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  212. out = h + self.feed_forward(self.ffn_norm(h))
  213. return out
  214. class Attention(nn.Module):
  215. def __init__(self, config: ModelArgs):
  216. super().__init__()
  217. assert config.dim % config.n_head == 0
  218. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  219. # key, query, value projections for all heads, but in a batch
  220. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  221. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  222. self.kv_cache = None
  223. self.dropout = config.dropout
  224. self.n_head = config.n_head
  225. self.head_dim = config.head_dim
  226. self.n_local_heads = config.n_local_heads
  227. self.dim = config.dim
  228. self.use_flash_attention = config.use_flash_attention
  229. self._register_load_state_dict_pre_hook(self.load_hook)
  230. def load_hook(self, state_dict, prefix, *args):
  231. if prefix + "wq.weight" in state_dict:
  232. wq = state_dict.pop(prefix + "wq.weight")
  233. wk = state_dict.pop(prefix + "wk.weight")
  234. wv = state_dict.pop(prefix + "wv.weight")
  235. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  236. def forward(
  237. self,
  238. x: Tensor,
  239. freqs_cis: Tensor,
  240. mask: Tensor,
  241. input_pos: Optional[Tensor] = None,
  242. ) -> Tensor:
  243. bsz, seqlen, _ = x.shape
  244. kv_size = self.n_local_heads * self.head_dim
  245. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  246. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  247. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  248. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  249. q = apply_rotary_emb(q, freqs_cis)
  250. k = apply_rotary_emb(k, freqs_cis)
  251. if self.use_flash_attention is False:
  252. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  253. if self.kv_cache is not None:
  254. k, v = self.kv_cache.update(input_pos, k, v)
  255. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  256. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  257. y = F.scaled_dot_product_attention(
  258. q,
  259. k,
  260. v,
  261. attn_mask=mask,
  262. dropout_p=self.dropout if self.training else 0.0,
  263. )
  264. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  265. else:
  266. assert (
  267. self.kv_cache is None
  268. ), "kv_cache is not supported for flash attention for now"
  269. # We don't need to transpose q, k, v here because flash_attn_varlen_func
  270. attn_output = self._flash_attention_forward(
  271. q, k, v, mask, seqlen, dropout=self.dropout if self.training else 0.0
  272. )
  273. y = attn_output.reshape(bsz, seqlen, self.dim).contiguous()
  274. return self.wo(y)
  275. def _flash_attention_forward(
  276. self,
  277. query_states,
  278. key_states,
  279. value_states,
  280. attention_mask,
  281. query_length,
  282. dropout=0.0,
  283. softmax_scale=None,
  284. ):
  285. """
  286. Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
  287. first unpad the input, then computes the attention scores and pad the final attention scores.
  288. Args:
  289. query_states (`torch.Tensor`):
  290. Input query states to be passed to Flash Attention API
  291. key_states (`torch.Tensor`):
  292. Input key states to be passed to Flash Attention API
  293. value_states (`torch.Tensor`):
  294. Input value states to be passed to Flash Attention API
  295. attention_mask (`torch.Tensor`):
  296. The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
  297. position of padding tokens and 1 for the position of non-padding tokens.
  298. dropout (`int`, *optional*):
  299. Attention dropout
  300. softmax_scale (`float`, *optional*):
  301. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
  302. """
  303. # Contains at least one padding token in the sequence
  304. if attention_mask is not None:
  305. batch_size = query_states.shape[0]
  306. (
  307. query_states,
  308. key_states,
  309. value_states,
  310. indices_q,
  311. cu_seq_lens,
  312. max_seq_lens,
  313. ) = self._upad_input(
  314. query_states, key_states, value_states, attention_mask, query_length
  315. )
  316. cu_seqlens_q, cu_seqlens_k = cu_seq_lens
  317. max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
  318. attn_output_unpad = flash_attn_varlen_func(
  319. query_states,
  320. key_states,
  321. value_states,
  322. cu_seqlens_q=cu_seqlens_q,
  323. cu_seqlens_k=cu_seqlens_k,
  324. max_seqlen_q=max_seqlen_in_batch_q,
  325. max_seqlen_k=max_seqlen_in_batch_k,
  326. dropout_p=dropout,
  327. softmax_scale=softmax_scale,
  328. causal=True,
  329. )
  330. attn_output = pad_input(
  331. attn_output_unpad, indices_q, batch_size, query_length
  332. )
  333. else:
  334. attn_output = flash_attn_func(
  335. query_states,
  336. key_states,
  337. value_states,
  338. dropout,
  339. softmax_scale=softmax_scale,
  340. causal=True,
  341. )
  342. return attn_output
  343. def _get_unpad_data(self, attention_mask):
  344. seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  345. indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  346. max_seqlen_in_batch = seqlens_in_batch.max().item()
  347. cu_seqlens = F.pad(
  348. torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
  349. )
  350. return (
  351. indices,
  352. cu_seqlens,
  353. max_seqlen_in_batch,
  354. )
  355. def _upad_input(
  356. self, query_layer, key_layer, value_layer, attention_mask, query_length
  357. ):
  358. indices_k, cu_seqlens_k, max_seqlen_in_batch_k = self._get_unpad_data(
  359. attention_mask
  360. )
  361. batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
  362. key_layer = index_first_axis(
  363. key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
  364. indices_k,
  365. )
  366. value_layer = index_first_axis(
  367. value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
  368. indices_k,
  369. )
  370. if query_length == kv_seq_len:
  371. query_layer = index_first_axis(
  372. query_layer.reshape(batch_size * kv_seq_len, self.n_head, head_dim),
  373. indices_k,
  374. )
  375. cu_seqlens_q = cu_seqlens_k
  376. max_seqlen_in_batch_q = max_seqlen_in_batch_k
  377. indices_q = indices_k
  378. elif query_length == 1:
  379. max_seqlen_in_batch_q = 1
  380. cu_seqlens_q = torch.arange(
  381. batch_size + 1, dtype=torch.int32, device=query_layer.device
  382. ) # There is a memcpy here, that is very bad.
  383. indices_q = cu_seqlens_q[:-1]
  384. query_layer = query_layer.squeeze(1)
  385. else:
  386. # The -q_len: slice assumes left padding.
  387. attention_mask = attention_mask[:, -query_length:]
  388. query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
  389. query_layer, attention_mask
  390. )
  391. return (
  392. query_layer,
  393. key_layer,
  394. value_layer,
  395. indices_q,
  396. (cu_seqlens_q, cu_seqlens_k),
  397. (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
  398. )
  399. class FeedForward(nn.Module):
  400. def __init__(self, config: ModelArgs) -> None:
  401. super().__init__()
  402. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  403. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  404. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  405. def forward(self, x: Tensor) -> Tensor:
  406. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  407. class RMSNorm(nn.Module):
  408. def __init__(self, dim: int, eps: float = 1e-5):
  409. super().__init__()
  410. self.eps = eps
  411. self.weight = nn.Parameter(torch.ones(dim))
  412. def _norm(self, x):
  413. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  414. def forward(self, x: Tensor) -> Tensor:
  415. output = self._norm(x.float()).type_as(x)
  416. return output * self.weight
  417. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  418. freqs = 1.0 / (
  419. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  420. )
  421. t = torch.arange(seq_len, device=freqs.device)
  422. freqs = torch.outer(t, freqs)
  423. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  424. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  425. return cache.to(dtype=torch.bfloat16)
  426. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  427. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  428. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  429. x_out2 = torch.stack(
  430. [
  431. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  432. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  433. ],
  434. -1,
  435. )
  436. x_out2 = x_out2.flatten(3)
  437. return x_out2.type_as(x)
  438. if __name__ == "__main__":
  439. args = ModelArgs(
  440. max_seq_len=4096,
  441. vocab_size=32312,
  442. n_layer=12,
  443. n_head=12,
  444. dim=768,
  445. rope_base=10000,
  446. norm_eps=1e-5,
  447. codebook_size=0,
  448. num_codebooks=0,
  449. )
  450. model = Transformer(args)
  451. model = model.cuda().bfloat16()
  452. print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
  453. inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
  454. key_padding_mask = torch.zeros(2, 128).bool().cuda()
  455. key_padding_mask[0, 2:] = True
  456. x1 = model(inputs, key_padding_mask=key_padding_mask)
  457. print(x1.token_logits.shape)
  458. # print(x1.codebook_logits.shape)