llama.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. from dataclasses import dataclass
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. from einops import rearrange
  6. from torch import Tensor
  7. from torch.nn import functional as F
  8. def find_multiple(n: int, k: int) -> int:
  9. if n % k == 0:
  10. return n
  11. return n + k - (n % k)
  12. @dataclass
  13. class ModelArgs:
  14. vocab_size: int = 32000
  15. n_layer: int = 32
  16. n_head: int = 32
  17. dim: int = 4096
  18. intermediate_size: int = None
  19. n_local_heads: int = -1
  20. head_dim: int = 64
  21. rope_base: float = 10000
  22. norm_eps: float = 1e-5
  23. max_seq_len: int = 2048
  24. # Additional decoding heads
  25. codebook_size: int = 160
  26. num_codebooks: int = 4
  27. codebook_padding_idx: int = 0
  28. def __post_init__(self):
  29. if self.n_local_heads == -1:
  30. self.n_local_heads = self.n_head
  31. if self.intermediate_size is None:
  32. hidden_dim = 4 * self.dim
  33. n_hidden = int(2 * hidden_dim / 3)
  34. self.intermediate_size = find_multiple(n_hidden, 256)
  35. self.head_dim = self.dim // self.n_head
  36. class KVCache(nn.Module):
  37. def __init__(
  38. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  39. ):
  40. super().__init__()
  41. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  42. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  43. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  44. def update(self, input_pos, k_val, v_val):
  45. # input_pos: [S], k_val: [B, H, S, D]
  46. assert input_pos.shape[0] == k_val.shape[2]
  47. k_out = self.k_cache
  48. v_out = self.v_cache
  49. k_out[:, :, input_pos] = k_val
  50. v_out[:, :, input_pos] = v_val
  51. return k_out, v_out
  52. @dataclass
  53. class TransformerForwardResult:
  54. token_logits: Tensor
  55. codebook_logits: Tensor
  56. class Transformer(nn.Module):
  57. def __init__(self, config: ModelArgs) -> None:
  58. super().__init__()
  59. self.config = config
  60. self.embeddings = nn.Embedding(
  61. config.vocab_size + config.codebook_size * config.num_codebooks, config.dim
  62. )
  63. self.layers = nn.ModuleList(
  64. TransformerBlock(config) for _ in range(config.n_layer)
  65. )
  66. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  67. self.output = nn.Linear(
  68. config.dim,
  69. config.vocab_size + config.codebook_size * config.num_codebooks,
  70. bias=False,
  71. )
  72. self.register_buffer(
  73. "freqs_cis",
  74. precompute_freqs_cis(
  75. config.max_seq_len,
  76. config.dim // config.n_head,
  77. config.rope_base,
  78. ),
  79. )
  80. self.register_buffer(
  81. "causal_mask",
  82. torch.tril(
  83. torch.ones(
  84. config.max_seq_len,
  85. config.max_seq_len,
  86. dtype=torch.bool,
  87. )
  88. ),
  89. )
  90. # For kv cache
  91. self.max_batch_size = -1
  92. self.max_seq_len = -1
  93. def setup_caches(self, max_batch_size, max_seq_len):
  94. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  95. return
  96. head_dim = self.config.dim // self.config.n_head
  97. max_seq_len = find_multiple(max_seq_len, 8)
  98. self.max_seq_len = max_seq_len
  99. self.max_batch_size = max_batch_size
  100. for b in self.layers:
  101. b.attention.kv_cache = KVCache(
  102. max_batch_size, max_seq_len, self.config.n_local_heads, head_dim
  103. )
  104. def embed(self, x: Tensor) -> Tensor:
  105. # Here we want to merge the embeddings of the codebooks
  106. if self.config.num_codebooks == 0:
  107. return self.embeddings(x[:, 0])
  108. vocab_embeds = [self.embeddings(x[:, 0])]
  109. for i in range(self.config.num_codebooks):
  110. emb = self.embeddings(
  111. x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
  112. )
  113. emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
  114. vocab_embeds.append(emb)
  115. x = torch.stack(vocab_embeds, dim=3)
  116. return x.sum(dim=3)
  117. def compute(
  118. self,
  119. x: Tensor,
  120. freqs_cis: Tensor,
  121. mask: Tensor,
  122. input_pos: Optional[Tensor] = None,
  123. ) -> TransformerForwardResult:
  124. for layer in self.layers:
  125. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  126. x = self.norm(x)
  127. logits = self.output(x)
  128. token_logits = logits[:, :, : self.config.vocab_size]
  129. if self.config.num_codebooks == 0:
  130. return TransformerForwardResult(
  131. token_logits=token_logits,
  132. codebook_logits=None,
  133. )
  134. codebook_logits = logits[:, :, self.config.vocab_size :]
  135. codebook_logits = rearrange(
  136. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  137. )
  138. return TransformerForwardResult(
  139. token_logits=token_logits,
  140. codebook_logits=codebook_logits,
  141. )
  142. def forward(
  143. self, x: Tensor, key_padding_mask: Optional[Tensor] = None
  144. ) -> TransformerForwardResult:
  145. # x: (batch, num_codebooks + 1, seq_len)
  146. seq_len = x.size(2)
  147. # Here we want to merge the embeddings of the codebooks
  148. x = self.embed(x)
  149. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  150. freqs_cis = self.freqs_cis[:seq_len]
  151. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  152. # That is, FALSE means masked out
  153. # To maintain consistency, key_padding_mask use TRUE to mask out
  154. if key_padding_mask is not None:
  155. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  156. return self.compute(x, freqs_cis, mask)
  157. def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
  158. # x: (batch, num_codebooks + 1, 1)
  159. assert (
  160. self.max_seq_len != -1 and self.max_batch_size != -1
  161. ), "Please call setup_caches before forward_generate"
  162. x = self.embed(x)
  163. mask = self.causal_mask[
  164. None, None, input_pos, : self.max_seq_len
  165. ] # (B, N, Q, K)
  166. freqs_cis = self.freqs_cis[input_pos]
  167. # TODO: support key padding mask for generation
  168. return self.compute(x, freqs_cis, mask, input_pos=input_pos)
  169. class TransformerBlock(nn.Module):
  170. def __init__(self, config: ModelArgs) -> None:
  171. super().__init__()
  172. self.attention = Attention(config)
  173. self.feed_forward = FeedForward(config)
  174. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  175. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  176. def forward(
  177. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  178. ) -> Tensor:
  179. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  180. out = h + self.feed_forward(self.ffn_norm(h))
  181. return out
  182. class Attention(nn.Module):
  183. def __init__(self, config: ModelArgs):
  184. super().__init__()
  185. assert config.dim % config.n_head == 0
  186. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  187. # key, query, value projections for all heads, but in a batch
  188. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  189. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  190. self.kv_cache = None
  191. self.n_head = config.n_head
  192. self.head_dim = config.head_dim
  193. self.n_local_heads = config.n_local_heads
  194. self.dim = config.dim
  195. self._register_load_state_dict_pre_hook(self.load_hook)
  196. def load_hook(self, state_dict, prefix, *args):
  197. if prefix + "wq.weight" in state_dict:
  198. wq = state_dict.pop(prefix + "wq.weight")
  199. wk = state_dict.pop(prefix + "wk.weight")
  200. wv = state_dict.pop(prefix + "wv.weight")
  201. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  202. def forward(
  203. self,
  204. x: Tensor,
  205. freqs_cis: Tensor,
  206. mask: Tensor,
  207. input_pos: Optional[Tensor] = None,
  208. ) -> Tensor:
  209. bsz, seqlen, _ = x.shape
  210. kv_size = self.n_local_heads * self.head_dim
  211. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  212. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  213. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  214. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  215. q = apply_rotary_emb(q, freqs_cis)
  216. k = apply_rotary_emb(k, freqs_cis)
  217. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  218. if self.kv_cache is not None:
  219. k, v = self.kv_cache.update(input_pos, k, v)
  220. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  221. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  222. y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
  223. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  224. y = self.wo(y)
  225. return y
  226. class FeedForward(nn.Module):
  227. def __init__(self, config: ModelArgs) -> None:
  228. super().__init__()
  229. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  230. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  231. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  232. def forward(self, x: Tensor) -> Tensor:
  233. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  234. class RMSNorm(nn.Module):
  235. def __init__(self, dim: int, eps: float = 1e-5):
  236. super().__init__()
  237. self.eps = eps
  238. self.weight = nn.Parameter(torch.ones(dim))
  239. def _norm(self, x):
  240. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  241. def forward(self, x: Tensor) -> Tensor:
  242. output = self._norm(x.float()).type_as(x)
  243. return output * self.weight
  244. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  245. freqs = 1.0 / (
  246. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  247. )
  248. t = torch.arange(seq_len, device=freqs.device)
  249. freqs = torch.outer(t, freqs)
  250. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  251. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  252. return cache.to(dtype=torch.bfloat16)
  253. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  254. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  255. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  256. x_out2 = torch.stack(
  257. [
  258. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  259. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  260. ],
  261. -1,
  262. )
  263. x_out2 = x_out2.flatten(3)
  264. return x_out2.type_as(x)
  265. if __name__ == "__main__":
  266. args = ModelArgs(
  267. max_seq_len=4096,
  268. vocab_size=32312,
  269. n_layer=12,
  270. n_head=12,
  271. dim=768,
  272. rope_base=10000,
  273. norm_eps=1e-5,
  274. codebook_size=0,
  275. num_codebooks=0,
  276. )
  277. model = Transformer(args)
  278. model = model.cuda().bfloat16()
  279. print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
  280. inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
  281. key_padding_mask = torch.zeros(2, 128).bool().cuda()
  282. key_padding_mask[0, 2:] = True
  283. x1 = model(inputs, key_padding_mask=key_padding_mask)
  284. print(x1.token_logits.shape)
  285. # print(x1.codebook_logits.shape)