llama.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. from dataclasses import dataclass
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. from einops import rearrange
  6. from torch import Tensor
  7. from torch.nn import functional as F
  8. def find_multiple(n: int, k: int) -> int:
  9. if n % k == 0:
  10. return n
  11. return n + k - (n % k)
  12. @dataclass
  13. class ModelArgs:
  14. vocab_size: int = 32000
  15. n_layer: int = 32
  16. n_head: int = 32
  17. dim: int = 4096
  18. intermediate_size: int = None
  19. n_local_heads: int = -1
  20. head_dim: int = 64
  21. rope_base: float = 10000
  22. norm_eps: float = 1e-5
  23. max_seq_len: int = 2048
  24. # Additional decoding heads
  25. codebook_size: int = 160
  26. num_codebooks: int = 4
  27. def __post_init__(self):
  28. if self.n_local_heads == -1:
  29. self.n_local_heads = self.n_head
  30. if self.intermediate_size is None:
  31. hidden_dim = 4 * self.dim
  32. n_hidden = int(2 * hidden_dim / 3)
  33. self.intermediate_size = find_multiple(n_hidden, 256)
  34. self.head_dim = self.dim // self.n_head
  35. class KVCache(nn.Module):
  36. def __init__(
  37. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  38. ):
  39. super().__init__()
  40. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  41. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  42. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  43. def update(self, input_pos, k_val, v_val):
  44. # input_pos: [S], k_val: [B, H, S, D]
  45. assert input_pos.shape[0] == k_val.shape[2]
  46. k_out = self.k_cache
  47. v_out = self.v_cache
  48. k_out[:, :, input_pos] = k_val
  49. v_out[:, :, input_pos] = v_val
  50. return k_out, v_out
  51. @dataclass
  52. class TransformerForwardResult:
  53. token_logits: Tensor
  54. codebook_logits: Tensor
  55. class Transformer(nn.Module):
  56. def __init__(self, config: ModelArgs) -> None:
  57. super().__init__()
  58. self.config = config
  59. self.embeddings = nn.Embedding(
  60. config.vocab_size + config.codebook_size * config.num_codebooks, config.dim
  61. )
  62. self.layers = nn.ModuleList(
  63. TransformerBlock(config) for _ in range(config.n_layer)
  64. )
  65. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  66. self.output = nn.Linear(
  67. config.dim,
  68. config.vocab_size + config.codebook_size * config.num_codebooks,
  69. bias=False,
  70. )
  71. self.register_buffer(
  72. "freqs_cis",
  73. precompute_freqs_cis(
  74. config.max_seq_len,
  75. config.dim // config.n_head,
  76. config.rope_base,
  77. ),
  78. )
  79. self.register_buffer(
  80. "causal_mask",
  81. torch.tril(
  82. torch.ones(
  83. config.max_seq_len,
  84. config.max_seq_len,
  85. dtype=torch.bool,
  86. )
  87. ),
  88. )
  89. # For kv cache
  90. self.max_batch_size = -1
  91. self.max_seq_len = -1
  92. def setup_caches(self, max_batch_size, max_seq_len):
  93. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  94. return
  95. head_dim = self.config.dim // self.config.n_head
  96. max_seq_len = find_multiple(max_seq_len, 8)
  97. self.max_seq_len = max_seq_len
  98. self.max_batch_size = max_batch_size
  99. for b in self.layers:
  100. b.attention.kv_cache = KVCache(
  101. max_batch_size, max_seq_len, self.config.n_local_heads, head_dim
  102. )
  103. def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None) -> Tensor:
  104. # x: (batch, num_codebooks + 1, seq_len)
  105. seq_len = x.size(2)
  106. # Here we want to merge the embeddings of the codebooks
  107. vocab_embeds = [self.embeddings(x[:, 0])]
  108. for i in range(self.config.num_codebooks):
  109. emb = self.embeddings(
  110. x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
  111. )
  112. vocab_embeds.append(emb)
  113. x = torch.stack(vocab_embeds, dim=3)
  114. x = x.mean(dim=3)
  115. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  116. freqs_cis = self.freqs_cis[:seq_len]
  117. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  118. # That is, FALSE means masked out
  119. # To maintain consistency, key_padding_mask use TRUE to mask out
  120. if key_padding_mask is not None:
  121. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  122. for layer in self.layers:
  123. x = layer(x, freqs_cis, mask)
  124. x = self.norm(x)
  125. logits = self.output(x)
  126. token_logits = logits[:, :, : self.config.vocab_size]
  127. codebook_logits = logits[:, :, self.config.vocab_size :]
  128. codebook_logits = rearrange(
  129. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  130. )
  131. return TransformerForwardResult(
  132. token_logits=token_logits,
  133. codebook_logits=codebook_logits,
  134. )
  135. def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
  136. # x: (batch, num_codebooks + 1, 1)
  137. assert (
  138. self.max_seq_len != -1 and self.max_batch_size != -1
  139. ), "Please call setup_caches before forward_generate"
  140. # Here we want to merge the embeddings of the codebooks
  141. vocab_embeds = [self.embeddings(x[:, 0])]
  142. for i in range(self.config.num_codebooks):
  143. emb = self.embeddings(
  144. x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
  145. )
  146. vocab_embeds.append(emb)
  147. x = torch.stack(vocab_embeds, dim=3)
  148. x = x.mean(dim=3)
  149. mask = self.causal_mask[
  150. None, None, input_pos, : self.max_seq_len
  151. ] # (B, N, Q, K)
  152. freqs_cis = self.freqs_cis[input_pos]
  153. for layer in self.layers:
  154. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  155. x = self.norm(x)
  156. logits = self.output(x)
  157. token_logits = logits[:, :, : self.config.vocab_size]
  158. codebook_logits = logits[:, :, self.config.vocab_size :]
  159. codebook_logits = rearrange(
  160. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  161. )
  162. return TransformerForwardResult(
  163. token_logits=token_logits,
  164. codebook_logits=codebook_logits,
  165. )
  166. class TransformerBlock(nn.Module):
  167. def __init__(self, config: ModelArgs) -> None:
  168. super().__init__()
  169. self.attention = Attention(config)
  170. self.feed_forward = FeedForward(config)
  171. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  172. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  173. def forward(
  174. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  175. ) -> Tensor:
  176. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  177. out = h + self.feed_forward(self.ffn_norm(h))
  178. return out
  179. class Attention(nn.Module):
  180. def __init__(self, config: ModelArgs):
  181. super().__init__()
  182. assert config.dim % config.n_head == 0
  183. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  184. # key, query, value projections for all heads, but in a batch
  185. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  186. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  187. self.kv_cache = None
  188. self.n_head = config.n_head
  189. self.head_dim = config.head_dim
  190. self.n_local_heads = config.n_local_heads
  191. self.dim = config.dim
  192. self._register_load_state_dict_pre_hook(self.load_hook)
  193. def load_hook(self, state_dict, prefix, *args):
  194. if prefix + "wq.weight" in state_dict:
  195. wq = state_dict.pop(prefix + "wq.weight")
  196. wk = state_dict.pop(prefix + "wk.weight")
  197. wv = state_dict.pop(prefix + "wv.weight")
  198. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  199. def forward(
  200. self,
  201. x: Tensor,
  202. freqs_cis: Tensor,
  203. mask: Tensor,
  204. input_pos: Optional[Tensor] = None,
  205. ) -> Tensor:
  206. bsz, seqlen, _ = x.shape
  207. kv_size = self.n_local_heads * self.head_dim
  208. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  209. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  210. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  211. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  212. q = apply_rotary_emb(q, freqs_cis)
  213. k = apply_rotary_emb(k, freqs_cis)
  214. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  215. if self.kv_cache is not None:
  216. k, v = self.kv_cache.update(input_pos, k, v)
  217. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  218. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  219. y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
  220. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  221. y = self.wo(y)
  222. return y
  223. class FeedForward(nn.Module):
  224. def __init__(self, config: ModelArgs) -> None:
  225. super().__init__()
  226. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  227. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  228. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  229. def forward(self, x: Tensor) -> Tensor:
  230. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  231. class RMSNorm(nn.Module):
  232. def __init__(self, dim: int, eps: float = 1e-5):
  233. super().__init__()
  234. self.eps = eps
  235. self.weight = nn.Parameter(torch.ones(dim))
  236. def _norm(self, x):
  237. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  238. def forward(self, x: Tensor) -> Tensor:
  239. output = self._norm(x.float()).type_as(x)
  240. return output * self.weight
  241. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  242. freqs = 1.0 / (
  243. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  244. )
  245. t = torch.arange(seq_len, device=freqs.device)
  246. freqs = torch.outer(t, freqs)
  247. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  248. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  249. return cache.to(dtype=torch.bfloat16)
  250. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  251. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  252. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  253. x_out2 = torch.stack(
  254. [
  255. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  256. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  257. ],
  258. -1,
  259. )
  260. x_out2 = x_out2.flatten(3)
  261. return x_out2.type_as(x)
  262. if __name__ == "__main__":
  263. args = ModelArgs(
  264. max_seq_len=4096,
  265. vocab_size=32312,
  266. n_layer=12,
  267. n_head=12,
  268. dim=768,
  269. rope_base=10000,
  270. norm_eps=1e-5,
  271. codebook_size=168,
  272. num_codebooks=4,
  273. )
  274. model = Transformer(args)
  275. model = model.cuda().bfloat16()
  276. print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
  277. inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
  278. key_padding_mask = torch.zeros(2, 128).bool().cuda()
  279. key_padding_mask[0, 2:] = True
  280. x1 = model(inputs, key_padding_mask=key_padding_mask)
  281. print(x1.token_logits.shape)
  282. print(x1.codebook_logits.shape)