llama.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import math
  2. from dataclasses import dataclass
  3. from typing import Optional
  4. import torch
  5. import torch.nn as nn
  6. from einops import rearrange
  7. from torch import Tensor
  8. from torch.nn import functional as F
  9. def find_multiple(n: int, k: int) -> int:
  10. if n % k == 0:
  11. return n
  12. return n + k - (n % k)
  13. @dataclass
  14. class ModelArgs:
  15. vocab_size: int = 32000
  16. n_layer: int = 32
  17. n_head: int = 32
  18. dim: int = 4096
  19. intermediate_size: int = None
  20. n_local_heads: int = -1
  21. head_dim: int = 64
  22. rope_base: float = 10000
  23. norm_eps: float = 1e-5
  24. max_seq_len: int = 2048
  25. dropout: float = 0.0
  26. # Additional decoding heads
  27. codebook_size: int = 160
  28. num_codebooks: int = 4
  29. num_in_codebooks: Optional[int] = None
  30. codebook_padding_idx: int = 0
  31. # Gradient checkpointing
  32. use_gradient_checkpointing: bool = True
  33. # NEFT
  34. neft_alpha: float = 0
  35. def __post_init__(self):
  36. if self.n_local_heads == -1:
  37. self.n_local_heads = self.n_head
  38. if self.intermediate_size is None:
  39. hidden_dim = 4 * self.dim
  40. n_hidden = int(2 * hidden_dim / 3)
  41. self.intermediate_size = find_multiple(n_hidden, 256)
  42. if self.num_in_codebooks is None:
  43. self.num_in_codebooks = self.num_codebooks
  44. self.head_dim = self.dim // self.n_head
  45. class KVCache(nn.Module):
  46. def __init__(
  47. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  48. ):
  49. super().__init__()
  50. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  51. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  52. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  53. def update(self, input_pos, k_val, v_val):
  54. # input_pos: [S], k_val: [B, H, S, D]
  55. assert input_pos.shape[0] == k_val.shape[2]
  56. k_out = self.k_cache
  57. v_out = self.v_cache
  58. k_out[:, :, input_pos] = k_val
  59. v_out[:, :, input_pos] = v_val
  60. return k_out, v_out
  61. @dataclass
  62. class TransformerForwardResult:
  63. token_logits: Tensor
  64. codebook_logits: Tensor
  65. class Transformer(nn.Module):
  66. def __init__(self, config: ModelArgs) -> None:
  67. super().__init__()
  68. self.config = config
  69. self.embeddings = nn.Embedding(
  70. config.vocab_size + config.codebook_size * config.num_in_codebooks,
  71. config.dim,
  72. )
  73. self.layers = nn.ModuleList(
  74. TransformerBlock(config) for _ in range(config.n_layer)
  75. )
  76. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  77. self.output = nn.Linear(
  78. config.dim,
  79. config.vocab_size + config.codebook_size * config.num_codebooks,
  80. bias=False,
  81. )
  82. self.register_buffer(
  83. "freqs_cis",
  84. precompute_freqs_cis(
  85. config.max_seq_len,
  86. config.dim // config.n_head,
  87. config.rope_base,
  88. ),
  89. )
  90. self.register_buffer(
  91. "causal_mask",
  92. torch.tril(
  93. torch.ones(
  94. config.max_seq_len,
  95. config.max_seq_len,
  96. dtype=torch.bool,
  97. )
  98. ),
  99. )
  100. # For kv cache
  101. self.max_batch_size = -1
  102. self.max_seq_len = -1
  103. def setup_caches(
  104. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  105. ):
  106. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  107. return
  108. head_dim = self.config.dim // self.config.n_head
  109. max_seq_len = find_multiple(max_seq_len, 8)
  110. self.max_seq_len = max_seq_len
  111. self.max_batch_size = max_batch_size
  112. for b in self.layers:
  113. b.attention.kv_cache = KVCache(
  114. max_batch_size,
  115. max_seq_len,
  116. self.config.n_local_heads,
  117. head_dim,
  118. dtype=dtype,
  119. )
  120. def embed(self, x: Tensor) -> Tensor:
  121. # Here we want to merge the embeddings of the codebooks
  122. if self.config.num_in_codebooks == 0:
  123. return self.embeddings(x[:, 0])
  124. vocab_embeds = [self.embeddings(x[:, 0])]
  125. for i in range(self.config.num_in_codebooks):
  126. emb = self.embeddings(
  127. x[:, i + 1] + i * self.config.codebook_size + self.config.vocab_size
  128. )
  129. emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
  130. vocab_embeds.append(emb)
  131. x = torch.stack(vocab_embeds, dim=3)
  132. x = x.sum(dim=3)
  133. if self.config.neft_alpha > 0 and self.training:
  134. # alpha / sqrt(L * D)
  135. scaled_alpha = self.config.neft_alpha / math.sqrt(
  136. self.config.dim * x.shape[2]
  137. )
  138. x += torch.rand_like(x) * scaled_alpha
  139. print("NEFT alpha:", scaled_alpha)
  140. return x
  141. def compute(
  142. self,
  143. x: Tensor,
  144. freqs_cis: Tensor,
  145. mask: Tensor,
  146. input_pos: Optional[Tensor] = None,
  147. ) -> TransformerForwardResult:
  148. for layer in self.layers:
  149. if self.config.use_gradient_checkpointing and self.training:
  150. x = torch.utils.checkpoint.checkpoint(
  151. layer, x, freqs_cis, mask, input_pos, use_reentrant=True
  152. )
  153. else:
  154. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  155. x = self.norm(x)
  156. logits = self.output(x)
  157. token_logits = logits[:, :, : self.config.vocab_size]
  158. if self.config.num_codebooks == 0:
  159. return TransformerForwardResult(
  160. token_logits=token_logits,
  161. codebook_logits=None,
  162. )
  163. codebook_logits = logits[:, :, self.config.vocab_size :]
  164. codebook_logits = rearrange(
  165. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  166. )
  167. return TransformerForwardResult(
  168. token_logits=token_logits,
  169. codebook_logits=codebook_logits,
  170. )
  171. def forward(
  172. self, x: Tensor, key_padding_mask: Optional[Tensor] = None
  173. ) -> TransformerForwardResult:
  174. # x: (batch, num_codebooks + 1, seq_len)
  175. seq_len = x.size(2)
  176. # Here we want to merge the embeddings of the codebooks
  177. x = self.embed(x)
  178. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  179. freqs_cis = self.freqs_cis[:seq_len]
  180. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  181. # That is, FALSE means masked out
  182. # To maintain consistency, key_padding_mask use TRUE to mask out
  183. if key_padding_mask is not None:
  184. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  185. return self.compute(x, freqs_cis, mask)
  186. def forward_generate(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
  187. # x: (batch, num_codebooks + 1, 1)
  188. assert (
  189. self.max_seq_len != -1 and self.max_batch_size != -1
  190. ), "Please call setup_caches before forward_generate"
  191. x = self.embed(x)
  192. mask = self.causal_mask[
  193. None, None, input_pos, : self.max_seq_len
  194. ] # (B, N, Q, K)
  195. freqs_cis = self.freqs_cis[input_pos]
  196. # TODO: support key padding mask for generation
  197. return self.compute(x, freqs_cis, mask, input_pos=input_pos)
  198. class TransformerBlock(nn.Module):
  199. def __init__(self, config: ModelArgs) -> None:
  200. super().__init__()
  201. self.attention = Attention(config)
  202. self.feed_forward = FeedForward(config)
  203. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  204. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  205. def forward(
  206. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  207. ) -> Tensor:
  208. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  209. out = h + self.feed_forward(self.ffn_norm(h))
  210. return out
  211. class Attention(nn.Module):
  212. def __init__(self, config: ModelArgs):
  213. super().__init__()
  214. assert config.dim % config.n_head == 0
  215. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  216. # key, query, value projections for all heads, but in a batch
  217. self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
  218. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  219. self.kv_cache = None
  220. self.dropout = config.dropout
  221. self.n_head = config.n_head
  222. self.head_dim = config.head_dim
  223. self.n_local_heads = config.n_local_heads
  224. self.dim = config.dim
  225. self._register_load_state_dict_pre_hook(self.load_hook)
  226. def load_hook(self, state_dict, prefix, *args):
  227. if prefix + "wq.weight" in state_dict:
  228. wq = state_dict.pop(prefix + "wq.weight")
  229. wk = state_dict.pop(prefix + "wk.weight")
  230. wv = state_dict.pop(prefix + "wv.weight")
  231. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  232. def forward(
  233. self,
  234. x: Tensor,
  235. freqs_cis: Tensor,
  236. mask: Tensor,
  237. input_pos: Optional[Tensor] = None,
  238. ) -> Tensor:
  239. bsz, seqlen, _ = x.shape
  240. kv_size = self.n_local_heads * self.head_dim
  241. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  242. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  243. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  244. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  245. q = apply_rotary_emb(q, freqs_cis)
  246. k = apply_rotary_emb(k, freqs_cis)
  247. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  248. if self.kv_cache is not None:
  249. k, v = self.kv_cache.update(input_pos, k, v)
  250. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  251. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  252. y = F.scaled_dot_product_attention(
  253. q,
  254. k,
  255. v,
  256. attn_mask=mask,
  257. dropout_p=self.dropout if self.training else 0.0,
  258. )
  259. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  260. return self.wo(y)
  261. class FeedForward(nn.Module):
  262. def __init__(self, config: ModelArgs) -> None:
  263. super().__init__()
  264. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  265. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  266. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  267. def forward(self, x: Tensor) -> Tensor:
  268. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  269. class RMSNorm(nn.Module):
  270. def __init__(self, dim: int, eps: float = 1e-5):
  271. super().__init__()
  272. self.eps = eps
  273. self.weight = nn.Parameter(torch.ones(dim))
  274. def _norm(self, x):
  275. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  276. def forward(self, x: Tensor) -> Tensor:
  277. output = self._norm(x.float()).type_as(x)
  278. return output * self.weight
  279. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  280. freqs = 1.0 / (
  281. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  282. )
  283. t = torch.arange(seq_len, device=freqs.device)
  284. freqs = torch.outer(t, freqs)
  285. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  286. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  287. return cache.to(dtype=torch.bfloat16)
  288. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  289. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  290. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  291. x_out2 = torch.stack(
  292. [
  293. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  294. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  295. ],
  296. -1,
  297. )
  298. x_out2 = x_out2.flatten(3)
  299. return x_out2.type_as(x)
  300. if __name__ == "__main__":
  301. args = ModelArgs(
  302. max_seq_len=4096,
  303. vocab_size=32312,
  304. n_layer=12,
  305. n_head=12,
  306. dim=768,
  307. rope_base=10000,
  308. norm_eps=1e-5,
  309. codebook_size=0,
  310. num_codebooks=0,
  311. )
  312. model = Transformer(args)
  313. model = model.cuda().bfloat16()
  314. print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
  315. inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
  316. key_padding_mask = torch.zeros(2, 128).bool().cuda()
  317. key_padding_mask[0, 2:] = True
  318. x1 = model(inputs, key_padding_mask=key_padding_mask)
  319. print(x1.token_logits.shape)
  320. # print(x1.codebook_logits.shape)