llama.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. from dataclasses import dataclass
  2. from typing import Optional
  3. import torch
  4. import torch.nn as nn
  5. from einops import rearrange
  6. from torch import Tensor
  7. from torch.nn import functional as F
  8. def find_multiple(n: int, k: int) -> int:
  9. if n % k == 0:
  10. return n
  11. return n + k - (n % k)
  12. @dataclass
  13. class ModelArgs:
  14. vocab_size: int = 32000
  15. n_layer: int = 32
  16. n_head: int = 32
  17. dim: int = 4096
  18. intermediate_size: int = None
  19. n_local_heads: int = -1
  20. head_dim: int = 64
  21. rope_base: float = 10000
  22. norm_eps: float = 1e-5
  23. max_seq_len: int = 2048
  24. # Additional decoding heads
  25. codebook_size: int = 160
  26. num_codebooks: int = 4
  27. codebook_padding_idx: int = 0
  28. def __post_init__(self):
  29. if self.n_local_heads == -1:
  30. self.n_local_heads = self.n_head
  31. if self.intermediate_size is None:
  32. hidden_dim = 4 * self.dim
  33. n_hidden = int(2 * hidden_dim / 3)
  34. self.intermediate_size = find_multiple(n_hidden, 256)
  35. self.head_dim = self.dim // self.n_head
  36. class KVCache(nn.Module):
  37. def __init__(
  38. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  39. ):
  40. super().__init__()
  41. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  42. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  43. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  44. def update(self, input_pos, k_val, v_val):
  45. # input_pos: [S], k_val: [B, H, S, D]
  46. assert input_pos.shape[0] == k_val.shape[2]
  47. k_out = self.k_cache
  48. v_out = self.v_cache
  49. k_out[:, :, input_pos] = k_val
  50. v_out[:, :, input_pos] = v_val
  51. return k_out, v_out
  52. @dataclass
  53. class TransformerForwardResult:
  54. token_logits: Tensor
  55. codebook_logits: Tensor
  56. class Transformer(nn.Module):
  57. def __init__(self, config: ModelArgs) -> None:
  58. super().__init__()
  59. self.config = config
  60. self.embeddings = nn.Embedding(
  61. config.vocab_size + config.codebook_size * config.num_codebooks, config.dim
  62. )
  63. self.layers = nn.ModuleList(
  64. TransformerBlock(config) for _ in range(config.n_layer)
  65. )
  66. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  67. self.output = nn.Linear(
  68. config.dim,
  69. config.vocab_size + config.codebook_size * config.num_codebooks,
  70. bias=False,
  71. )
  72. self.register_buffer(
  73. "freqs_cis",
  74. precompute_freqs_cis(
  75. config.max_seq_len,
  76. config.dim // config.n_head,
  77. config.rope_base,
  78. ),
  79. )
  80. self.register_buffer(
  81. "causal_mask",
  82. torch.tril(
  83. torch.ones(
  84. config.max_seq_len,
  85. config.max_seq_len,
  86. dtype=torch.bool,
  87. )
  88. ),
  89. )
  90. # For kv cache
  91. self.max_batch_size = -1
  92. self.max_seq_len = -1
  93. def setup_caches(
  94. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  95. ):
  96. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  97. return
  98. head_dim = self.config.dim // self.config.n_head
  99. max_seq_len = find_multiple(max_seq_len, 8)
  100. self.max_seq_len = max_seq_len
  101. self.max_batch_size = max_batch_size
  102. for b in self.layers:
  103. b.attention.kv_cache = KVCache(
  104. max_batch_size,
  105. max_seq_len,
  106. self.config.n_local_heads,
  107. head_dim,
  108. dtype=dtype,
  109. )
  110. def embed(self, x: Tensor) -> Tensor:
  111. # Here we want to merge the embeddings of the codebooks
  112. if self.config.num_codebooks == 0:
  113. return self.embeddings(x[:, 0])
  114. vocab_embeds = [self.embeddings(x[:, 0])]
  115. for i in range(self.config.num_codebooks):
  116. emb = self.embeddings(
  117. x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
  118. )
  119. emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
  120. vocab_embeds.append(emb)
  121. x = torch.stack(vocab_embeds, dim=3)
  122. return x.sum(dim=3)
  123. def compute(
  124. self,
  125. x: Tensor,
  126. freqs_cis: Tensor,
  127. mask: Tensor,
  128. input_pos: Optional[Tensor] = None,
  129. ) -> TransformerForwardResult:
  130. for layer in self.layers:
  131. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  132. x = self.norm(x)
  133. logits = self.output(x)
  134. token_logits = logits[:, :, : self.config.vocab_size]
  135. if self.config.num_codebooks == 0:
  136. return TransformerForwardResult(
  137. token_logits=token_logits,
  138. codebook_logits=None,
  139. )
  140. codebook_logits = logits[:, :, self.config.vocab_size :]
  141. codebook_logits = rearrange(
  142. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  143. )
  144. return TransformerForwardResult(
  145. token_logits=token_logits,
  146. codebook_logits=codebook_logits,
  147. )
  148. def forward(
  149. self, x: Tensor, key_padding_mask: Optional[Tensor] = None
  150. ) -> TransformerForwardResult:
  151. # x: (batch, num_codebooks + 1, seq_len)
  152. seq_len = x.size(2)
  153. # Here we want to merge the embeddings of the codebooks
  154. x = self.embed(x)
  155. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  156. freqs_cis = self.freqs_cis[:seq_len]
  157. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  158. # That is, FALSE means masked out
  159. # To maintain consistency, key_padding_mask use TRUE to mask out
  160. if key_padding_mask is not None:
  161. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  162. return self.compute(x, freqs_cis, mask)
  163. def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
  164. # x: (batch, num_codebooks + 1, 1)
  165. assert (
  166. self.max_seq_len != -1 and self.max_batch_size != -1
  167. ), "Please call setup_caches before forward_generate"
  168. x = self.embed(x)
  169. mask = self.causal_mask[
  170. None, None, input_pos, : self.max_seq_len
  171. ] # (B, N, Q, K)
  172. freqs_cis = self.freqs_cis[input_pos]
  173. # TODO: support key padding mask for generation
  174. return self.compute(x, freqs_cis, mask, input_pos=input_pos)
  175. class TransformerBlock(nn.Module):
  176. def __init__(self, config: ModelArgs) -> None:
  177. super().__init__()
  178. self.attention = Attention(config)
  179. self.feed_forward = FeedForward(config)
  180. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  181. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  182. def forward(
  183. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  184. ) -> Tensor:
  185. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  186. out = h + self.feed_forward(self.ffn_norm(h))
  187. return out
  188. class Attention(nn.Module):
  189. def __init__(self, config: ModelArgs):
  190. super().__init__()
  191. assert config.dim % config.n_head == 0
  192. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  193. # key, query, value projections for all heads, but in a batch
  194. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  195. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  196. self.kv_cache = None
  197. self.n_head = config.n_head
  198. self.head_dim = config.head_dim
  199. self.n_local_heads = config.n_local_heads
  200. self.dim = config.dim
  201. self._register_load_state_dict_pre_hook(self.load_hook)
  202. def load_hook(self, state_dict, prefix, *args):
  203. if prefix + "wq.weight" in state_dict:
  204. wq = state_dict.pop(prefix + "wq.weight")
  205. wk = state_dict.pop(prefix + "wk.weight")
  206. wv = state_dict.pop(prefix + "wv.weight")
  207. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  208. def forward(
  209. self,
  210. x: Tensor,
  211. freqs_cis: Tensor,
  212. mask: Tensor,
  213. input_pos: Optional[Tensor] = None,
  214. ) -> Tensor:
  215. bsz, seqlen, _ = x.shape
  216. kv_size = self.n_local_heads * self.head_dim
  217. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  218. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  219. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  220. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  221. q = apply_rotary_emb(q, freqs_cis)
  222. k = apply_rotary_emb(k, freqs_cis)
  223. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  224. if self.kv_cache is not None:
  225. k, v = self.kv_cache.update(input_pos, k, v)
  226. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  227. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  228. y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
  229. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  230. y = self.wo(y)
  231. return y
  232. class FeedForward(nn.Module):
  233. def __init__(self, config: ModelArgs) -> None:
  234. super().__init__()
  235. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  236. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  237. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  238. def forward(self, x: Tensor) -> Tensor:
  239. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  240. class RMSNorm(nn.Module):
  241. def __init__(self, dim: int, eps: float = 1e-5):
  242. super().__init__()
  243. self.eps = eps
  244. self.weight = nn.Parameter(torch.ones(dim))
  245. def _norm(self, x):
  246. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  247. def forward(self, x: Tensor) -> Tensor:
  248. output = self._norm(x.float()).type_as(x)
  249. return output * self.weight
  250. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  251. freqs = 1.0 / (
  252. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  253. )
  254. t = torch.arange(seq_len, device=freqs.device)
  255. freqs = torch.outer(t, freqs)
  256. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  257. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  258. return cache.to(dtype=torch.bfloat16)
  259. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  260. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  261. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  262. x_out2 = torch.stack(
  263. [
  264. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  265. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  266. ],
  267. -1,
  268. )
  269. x_out2 = x_out2.flatten(3)
  270. return x_out2.type_as(x)
  271. if __name__ == "__main__":
  272. args = ModelArgs(
  273. max_seq_len=4096,
  274. vocab_size=32312,
  275. n_layer=12,
  276. n_head=12,
  277. dim=768,
  278. rope_base=10000,
  279. norm_eps=1e-5,
  280. codebook_size=0,
  281. num_codebooks=0,
  282. )
  283. model = Transformer(args)
  284. model = model.cuda().bfloat16()
  285. print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
  286. inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
  287. key_padding_mask = torch.zeros(2, 128).bool().cuda()
  288. key_padding_mask[0, 2:] = True
  289. x1 = model(inputs, key_padding_mask=key_padding_mask)
  290. print(x1.token_logits.shape)
  291. # print(x1.codebook_logits.shape)