llama.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import math
  2. from dataclasses import dataclass
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from einops import rearrange
  7. from torch import Tensor
  8. from torch.nn import functional as F
  9. def find_multiple(n: int, k: int) -> int:
  10. if n % k == 0:
  11. return n
  12. return n + k - (n % k)
  13. @dataclass
  14. class ModelArgs:
  15. vocab_size: int = 32000
  16. n_layer: int = 32
  17. n_head: int = 32
  18. dim: int = 4096
  19. intermediate_size: int = None
  20. n_local_heads: int = -1
  21. head_dim: int = 64
  22. rope_base: float = 10000
  23. norm_eps: float = 1e-5
  24. max_seq_len: int = 2048
  25. dropout: float = 0.0
  26. # Additional decoding heads
  27. codebook_size: int = 160
  28. num_codebooks: int = 4
  29. num_in_codebooks: Optional[int] = None
  30. codebook_padding_idx: int = 0
  31. # Gradient checkpointing
  32. use_gradient_checkpointing: bool = True
  33. # NEFT
  34. neft_alpha: float = 0
  35. def __post_init__(self):
  36. if self.n_local_heads == -1:
  37. self.n_local_heads = self.n_head
  38. if self.intermediate_size is None:
  39. hidden_dim = 4 * self.dim
  40. n_hidden = int(2 * hidden_dim / 3)
  41. self.intermediate_size = find_multiple(n_hidden, 256)
  42. if self.num_in_codebooks is None:
  43. self.num_in_codebooks = self.num_codebooks
  44. self.head_dim = self.dim // self.n_head
  45. class KVCache(nn.Module):
  46. def __init__(
  47. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  48. ):
  49. super().__init__()
  50. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  51. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  52. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  53. def update(self, input_pos, k_val, v_val):
  54. # input_pos: [S], k_val: [B, H, S, D]
  55. assert input_pos.shape[0] == k_val.shape[2]
  56. k_out = self.k_cache
  57. v_out = self.v_cache
  58. k_out[:, :, input_pos] = k_val
  59. v_out[:, :, input_pos] = v_val
  60. return k_out, v_out
  61. @dataclass
  62. class TransformerForwardResult:
  63. token_logits: Tensor
  64. codebook_logits: Tensor
  65. class Transformer(nn.Module):
  66. def __init__(self, config: ModelArgs) -> None:
  67. super().__init__()
  68. self.config = config
  69. self.embeddings = nn.Embedding(
  70. config.vocab_size + config.codebook_size * config.num_in_codebooks,
  71. config.dim,
  72. )
  73. self.layers = nn.ModuleList(
  74. TransformerBlock(config) for _ in range(config.n_layer)
  75. )
  76. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  77. self.output = nn.Linear(
  78. config.dim,
  79. config.vocab_size + config.codebook_size * config.num_codebooks,
  80. bias=False,
  81. )
  82. self.register_buffer(
  83. "freqs_cis",
  84. precompute_freqs_cis(
  85. config.max_seq_len,
  86. config.dim // config.n_head,
  87. config.rope_base,
  88. ),
  89. )
  90. self.register_buffer(
  91. "causal_mask",
  92. torch.tril(
  93. torch.ones(
  94. config.max_seq_len,
  95. config.max_seq_len,
  96. dtype=torch.bool,
  97. )
  98. ),
  99. )
  100. # For kv cache
  101. self.max_batch_size = -1
  102. self.max_seq_len = -1
  103. def setup_caches(
  104. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  105. ):
  106. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  107. return
  108. head_dim = self.config.dim // self.config.n_head
  109. max_seq_len = find_multiple(max_seq_len, 8)
  110. self.max_seq_len = max_seq_len
  111. self.max_batch_size = max_batch_size
  112. for b in self.layers:
  113. b.attention.kv_cache = KVCache(
  114. max_batch_size,
  115. max_seq_len,
  116. self.config.n_local_heads,
  117. head_dim,
  118. dtype=dtype,
  119. )
  120. def embed(self, x: Tensor) -> Tensor:
  121. # Here we want to merge the embeddings of the codebooks
  122. if self.config.num_in_codebooks == 0:
  123. return self.embeddings(x[:, 0])
  124. vocab_embeds = [self.embeddings(x[:, 0])]
  125. for i in range(self.config.num_in_codebooks):
  126. emb = self.embeddings(
  127. x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
  128. )
  129. emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
  130. vocab_embeds.append(emb)
  131. x = torch.stack(vocab_embeds, dim=3)
  132. x = x.sum(dim=3)
  133. if self.config.neft_alpha > 0 and self.training:
  134. # alpha / sqrt(L * D)
  135. scaled_alpha = self.config.neft_alpha / math.sqrt(
  136. self.config.dim * x.shape[2]
  137. )
  138. x += torch.rand_like(x) * scaled_alpha
  139. return x
  140. def compute(
  141. self,
  142. x: Tensor,
  143. freqs_cis: Tensor,
  144. mask: Tensor,
  145. input_pos: Optional[Tensor] = None,
  146. ) -> TransformerForwardResult:
  147. for layer in self.layers:
  148. if self.config.use_gradient_checkpointing and self.training:
  149. x = torch.utils.checkpoint.checkpoint(
  150. layer, x, freqs_cis, mask, input_pos, use_reentrant=True
  151. )
  152. else:
  153. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  154. x = self.norm(x)
  155. logits = self.output(x)
  156. token_logits = logits[:, :, : self.config.vocab_size]
  157. if self.config.num_codebooks == 0:
  158. return TransformerForwardResult(
  159. token_logits=token_logits,
  160. codebook_logits=None,
  161. )
  162. codebook_logits = logits[:, :, self.config.vocab_size :]
  163. codebook_logits = rearrange(
  164. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  165. )
  166. return TransformerForwardResult(
  167. token_logits=token_logits,
  168. codebook_logits=codebook_logits,
  169. )
  170. def forward(
  171. self, x: Tensor, key_padding_mask: Optional[Tensor] = None
  172. ) -> TransformerForwardResult:
  173. # x: (batch, num_codebooks + 1, seq_len)
  174. seq_len = x.size(2)
  175. # Here we want to merge the embeddings of the codebooks
  176. x = self.embed(x)
  177. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  178. freqs_cis = self.freqs_cis[:seq_len]
  179. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  180. # That is, FALSE means masked out
  181. # To maintain consistency, key_padding_mask use TRUE to mask out
  182. if key_padding_mask is not None:
  183. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  184. return self.compute(x, freqs_cis, mask)
  185. def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
  186. # x: (batch, num_codebooks + 1, 1)
  187. assert (
  188. self.max_seq_len != -1 and self.max_batch_size != -1
  189. ), "Please call setup_caches before forward_generate"
  190. x = self.embed(x)
  191. mask = self.causal_mask[
  192. None, None, input_pos, : self.max_seq_len
  193. ] # (B, N, Q, K)
  194. freqs_cis = self.freqs_cis[input_pos]
  195. # TODO: support key padding mask for generation
  196. return self.compute(x, freqs_cis, mask, input_pos=input_pos)
  197. class TransformerBlock(nn.Module):
  198. def __init__(self, config: ModelArgs) -> None:
  199. super().__init__()
  200. self.attention = Attention(config)
  201. self.feed_forward = FeedForward(config)
  202. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  203. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  204. def forward(
  205. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  206. ) -> Tensor:
  207. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  208. out = h + self.feed_forward(self.ffn_norm(h))
  209. return out
  210. class Attention(nn.Module):
  211. def __init__(self, config: ModelArgs):
  212. super().__init__()
  213. assert config.dim % config.n_head == 0
  214. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  215. # key, query, value projections for all heads, but in a batch
  216. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  217. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  218. self.kv_cache = None
  219. self.dropout = config.dropout
  220. self.n_head = config.n_head
  221. self.head_dim = config.head_dim
  222. self.n_local_heads = config.n_local_heads
  223. self.dim = config.dim
  224. self._register_load_state_dict_pre_hook(self.load_hook)
  225. def load_hook(self, state_dict, prefix, *args):
  226. if prefix + "wq.weight" in state_dict:
  227. wq = state_dict.pop(prefix + "wq.weight")
  228. wk = state_dict.pop(prefix + "wk.weight")
  229. wv = state_dict.pop(prefix + "wv.weight")
  230. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  231. def forward(
  232. self,
  233. x: Tensor,
  234. freqs_cis: Tensor,
  235. mask: Tensor,
  236. input_pos: Optional[Tensor] = None,
  237. ) -> Tensor:
  238. bsz, seqlen, _ = x.shape
  239. kv_size = self.n_local_heads * self.head_dim
  240. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  241. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  242. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  243. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  244. q = apply_rotary_emb(q, freqs_cis)
  245. k = apply_rotary_emb(k, freqs_cis)
  246. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  247. if self.kv_cache is not None:
  248. k, v = self.kv_cache.update(input_pos, k, v)
  249. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  250. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  251. y = F.scaled_dot_product_attention(
  252. q,
  253. k,
  254. v,
  255. attn_mask=mask,
  256. dropout_p=self.dropout if self.training else 0.0,
  257. )
  258. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  259. return self.wo(y)
  260. class FeedForward(nn.Module):
  261. def __init__(self, config: ModelArgs) -> None:
  262. super().__init__()
  263. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  264. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  265. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  266. def forward(self, x: Tensor) -> Tensor:
  267. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  268. class RMSNorm(nn.Module):
  269. def __init__(self, dim: int, eps: float = 1e-5):
  270. super().__init__()
  271. self.eps = eps
  272. self.weight = nn.Parameter(torch.ones(dim))
  273. def _norm(self, x):
  274. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  275. def forward(self, x: Tensor) -> Tensor:
  276. output = self._norm(x.float()).type_as(x)
  277. return output * self.weight
  278. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  279. freqs = 1.0 / (
  280. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  281. )
  282. t = torch.arange(seq_len, device=freqs.device)
  283. freqs = torch.outer(t, freqs)
  284. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  285. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  286. return cache.to(dtype=torch.bfloat16)
  287. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  288. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  289. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  290. x_out2 = torch.stack(
  291. [
  292. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  293. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  294. ],
  295. -1,
  296. )
  297. x_out2 = x_out2.flatten(3)
  298. return x_out2.type_as(x)
  299. if __name__ == "__main__":
  300. args = ModelArgs(
  301. max_seq_len=4096,
  302. vocab_size=32312,
  303. n_layer=12,
  304. n_head=12,
  305. dim=768,
  306. rope_base=10000,
  307. norm_eps=1e-5,
  308. codebook_size=0,
  309. num_codebooks=0,
  310. )
  311. model = Transformer(args)
  312. model = model.cuda().bfloat16()
  313. print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
  314. inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
  315. key_padding_mask = torch.zeros(2, 128).bool().cuda()
  316. key_padding_mask[0, 2:] = True
  317. x1 = model(inputs, key_padding_mask=key_padding_mask)
  318. print(x1.token_logits.shape)
  319. # print(x1.codebook_logits.shape)