llama.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884
  1. import dataclasses
  2. import json
  3. import math
  4. from collections import OrderedDict
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from typing import Optional
  8. import torch
  9. import torch.nn as nn
  10. from einops import rearrange
  11. from loguru import logger
  12. from torch import Tensor
  13. from torch.nn import functional as F
  14. from torch.nn.attention import SDPBackend, sdpa_kernel
  15. from torch.utils.checkpoint import checkpoint
  16. from transformers import AutoTokenizer
  17. from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
  18. from fish_speech.utils import RankedLogger
  19. from .lora import LoraConfig, setup_lora
  20. log = RankedLogger(__name__, rank_zero_only=True)
  21. def find_multiple(n: int, k: int) -> int:
  22. if n % k == 0:
  23. return n
  24. return n + k - (n % k)
  25. @dataclass
  26. class BaseModelArgs:
  27. model_type: str = "base"
  28. vocab_size: int = 32000
  29. n_layer: int = 32
  30. n_head: int = 32
  31. dim: int = 4096
  32. intermediate_size: int = None
  33. n_local_heads: int = -1
  34. head_dim: int = 64
  35. rope_base: float = 10000
  36. norm_eps: float = 1e-5
  37. max_seq_len: int = 2048
  38. dropout: float = 0.0
  39. tie_word_embeddings: bool = True
  40. attention_qkv_bias: bool = False
  41. # Codebook configs
  42. codebook_size: int = 160
  43. num_codebooks: int = 4
  44. # Gradient checkpointing
  45. use_gradient_checkpointing: bool = True
  46. # Initialize the model
  47. initializer_range: float = 0.02
  48. # Dummy vars
  49. is_reward_model: bool = False
  50. share_codebook_embeddings: bool = True
  51. scale_codebook_embeddings: bool = False
  52. def __post_init__(self):
  53. if self.n_local_heads == -1:
  54. self.n_local_heads = self.n_head
  55. if self.intermediate_size is None:
  56. hidden_dim = 4 * self.dim
  57. n_hidden = int(2 * hidden_dim / 3)
  58. self.intermediate_size = find_multiple(n_hidden, 256)
  59. self.head_dim = self.dim // self.n_head
  60. @staticmethod
  61. def from_pretrained(path: str):
  62. path = Path(path)
  63. if path.is_dir():
  64. path = path / "config.json"
  65. with open(path, "r", encoding="utf-8") as f:
  66. data = json.load(f)
  67. match data["model_type"]:
  68. case "naive":
  69. cls = NaiveModelArgs
  70. case "dual_ar":
  71. cls = DualARModelArgs
  72. case _:
  73. raise ValueError(f"Unknown model type: {data['model_type']}")
  74. return cls(**data)
  75. def save(self, path: str):
  76. with open(path, "w") as f:
  77. json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
  78. @dataclass
  79. class NaiveModelArgs(BaseModelArgs):
  80. model_type: str = "naive"
  81. @dataclass
  82. class DualARModelArgs(BaseModelArgs):
  83. model_type: str = "dual_ar"
  84. n_fast_layer: int = 4
  85. fast_dim: int | None = None
  86. fast_n_head: int | None = None
  87. fast_n_local_heads: int | None = None
  88. fast_head_dim: int | None = None
  89. fast_intermediate_size: int | None = None
  90. fast_attention_qkv_bias: bool | None = None
  91. def __post_init__(self):
  92. super().__post_init__()
  93. self.fast_dim = self.fast_dim or self.dim
  94. self.fast_n_head = self.fast_n_head or self.n_head
  95. self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
  96. self.fast_head_dim = self.fast_head_dim or self.head_dim
  97. self.fast_intermediate_size = (
  98. self.fast_intermediate_size or self.intermediate_size
  99. )
  100. self.fast_attention_qkv_bias = (
  101. self.fast_attention_qkv_bias
  102. if self.fast_attention_qkv_bias is not None
  103. else self.attention_qkv_bias
  104. )
  105. class KVCache(nn.Module):
  106. def __init__(
  107. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  108. ):
  109. super().__init__()
  110. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  111. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  112. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  113. def update(self, input_pos, k_val, v_val):
  114. # input_pos: [S], k_val: [B, H, S, D]
  115. assert input_pos.shape[0] == k_val.shape[2]
  116. k_out = self.k_cache
  117. v_out = self.v_cache
  118. k_out[:, :, input_pos] = k_val
  119. v_out[:, :, input_pos] = v_val
  120. return k_out, v_out
  121. @dataclass
  122. class TransformerForwardResult:
  123. token_logits: Tensor
  124. codebook_logits: Tensor
  125. @dataclass
  126. class BaseTransformerForwardResult:
  127. logits: Tensor
  128. hidden_states: Tensor
  129. class BaseTransformer(nn.Module):
  130. def __init__(
  131. self,
  132. config: BaseModelArgs,
  133. tokenizer: FishTokenizer,
  134. init_weights: bool = True,
  135. ) -> None:
  136. super().__init__()
  137. self.config = config
  138. self.tokenizer = tokenizer
  139. self.semantic_token_ids = [
  140. tokenizer.get_token_id(SEMANTIC_TOKEN) for SEMANTIC_TOKEN in SEMANTIC_TOKENS
  141. ]
  142. # Slow transformer
  143. self.embeddings = nn.Embedding(
  144. config.vocab_size,
  145. config.dim,
  146. )
  147. self.codebook_embeddings = nn.Embedding(
  148. config.codebook_size * config.num_codebooks,
  149. config.dim,
  150. )
  151. self.layers = nn.ModuleList(
  152. TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
  153. )
  154. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  155. if self.config.tie_word_embeddings is False:
  156. self.output = nn.Linear(
  157. config.dim,
  158. config.vocab_size,
  159. bias=False,
  160. )
  161. self.register_buffer(
  162. "freqs_cis",
  163. precompute_freqs_cis(
  164. config.max_seq_len,
  165. config.dim // config.n_head,
  166. config.rope_base,
  167. ),
  168. persistent=False,
  169. )
  170. self.register_buffer(
  171. "causal_mask",
  172. torch.tril(
  173. torch.ones(
  174. config.max_seq_len,
  175. config.max_seq_len,
  176. dtype=torch.bool,
  177. )
  178. ),
  179. persistent=False,
  180. )
  181. # For kv cache
  182. self.max_batch_size = -1
  183. self.max_seq_len = -1
  184. if init_weights:
  185. self.apply(self._init_weights)
  186. def setup_caches(
  187. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  188. ):
  189. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  190. return
  191. head_dim = self.config.dim // self.config.n_head
  192. max_seq_len = find_multiple(max_seq_len, 8)
  193. self.max_seq_len = max_seq_len
  194. self.max_batch_size = max_batch_size
  195. for b in self.layers:
  196. b.attention.kv_cache = KVCache(
  197. max_batch_size,
  198. max_seq_len,
  199. self.config.n_local_heads,
  200. head_dim,
  201. dtype=dtype,
  202. )
  203. def embed(self, inp: Tensor, share_codebook_embeddings=True) -> Tensor:
  204. embeds = []
  205. semantic_token_ids_tensor = torch.tensor(
  206. self.semantic_token_ids, device=inp.device, dtype=inp.dtype
  207. )
  208. for i in range(self.config.num_codebooks):
  209. if share_codebook_embeddings:
  210. emb = self.codebook_embeddings(
  211. inp[:, i + 1] + i * self.config.codebook_size
  212. )
  213. else:
  214. emb = self.codebook_embeddings(inp[:, i + 1])
  215. embeds.append(emb)
  216. vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
  217. vq_embeds_sum[~torch.isin(inp[:, 0], semantic_token_ids_tensor)] = 0
  218. x = self.embeddings(inp[:, 0]) + vq_embeds_sum
  219. return x
  220. def forward(
  221. self,
  222. inp: Tensor,
  223. key_padding_mask: Optional[Tensor] = None,
  224. ) -> BaseTransformerForwardResult:
  225. seq_len = inp.size(2)
  226. # Here we want to merge the embeddings of the codebooks
  227. x = self.embed(inp)
  228. freqs_cis = self.freqs_cis[:seq_len]
  229. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  230. # That is, FALSE means masked out
  231. # To maintain consistency, key_padding_mask use TRUE to mask out
  232. mask = None
  233. if key_padding_mask is not None:
  234. causal = self.causal_mask[:seq_len, :seq_len]
  235. causal = rearrange(causal, "q k -> 1 1 q k")
  236. atten_mask = rearrange(key_padding_mask, "b s -> b 1 1 s")
  237. atten_mask = atten_mask.logical_not()
  238. mask = causal & atten_mask
  239. # return freqs_cis, mask
  240. for layer in self.layers:
  241. if self.config.use_gradient_checkpointing and self.training:
  242. x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
  243. else:
  244. x = layer(x, freqs_cis, mask)
  245. # We got slow_out here
  246. slow_out = self.norm(x)
  247. if self.config.tie_word_embeddings:
  248. token_logits = F.linear(slow_out, self.embeddings.weight)
  249. else:
  250. token_logits = self.output(slow_out)
  251. return BaseTransformerForwardResult(
  252. logits=token_logits,
  253. hidden_states=x,
  254. )
  255. def forward_generate(
  256. self,
  257. inp: Tensor,
  258. input_pos: Optional[Tensor] = None,
  259. return_all: bool = False,
  260. ) -> BaseTransformerForwardResult:
  261. x = self.embed(
  262. inp, share_codebook_embeddings=self.config.share_codebook_embeddings
  263. )
  264. if input_pos is None:
  265. input_pos = torch.arange(inp.shape[-1], device=x.device)
  266. max_seq_len = inp.shape[-1]
  267. else:
  268. max_seq_len = self.max_seq_len
  269. mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
  270. freqs_cis = self.freqs_cis[input_pos]
  271. for layer in self.layers:
  272. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  273. # If prefill, we only calculate the logits of last token
  274. if x.size(1) > 1 and not return_all:
  275. x = x[:, -1:]
  276. # We got slow_out here
  277. slow_out = self.norm(x)
  278. if self.config.is_reward_model:
  279. token_logits = self.score_output(slow_out)
  280. elif self.config.tie_word_embeddings:
  281. token_logits = F.linear(slow_out, self.embeddings.weight)
  282. else:
  283. token_logits = self.output(slow_out)
  284. return BaseTransformerForwardResult(
  285. logits=token_logits,
  286. hidden_states=x,
  287. )
  288. def _init_weights(self, module):
  289. std = self.config.initializer_range
  290. if isinstance(module, nn.Linear):
  291. module.weight.data.normal_(mean=0.0, std=std)
  292. if module.bias is not None:
  293. module.bias.data.zero_()
  294. elif isinstance(module, nn.Embedding):
  295. module.weight.data.normal_(mean=0.0, std=std)
  296. if module.padding_idx is not None:
  297. module.weight.data[module.padding_idx].zero_()
  298. @staticmethod
  299. def from_pretrained(
  300. path: str,
  301. load_weights: bool = False,
  302. max_length: int | None = None,
  303. lora_config: LoraConfig | None = None,
  304. rope_base: int | None = None,
  305. is_agent: bool = False,
  306. ) -> "BaseTransformer":
  307. config = BaseModelArgs.from_pretrained(str(path))
  308. if max_length is not None:
  309. config.max_seq_len = max_length
  310. log.info(f"Override max_seq_len to {max_length}")
  311. if rope_base is not None:
  312. config.rope_base = rope_base
  313. log.info(f"Override rope_base to {rope_base}")
  314. match config.model_type:
  315. case "naive":
  316. model_cls = NaiveTransformer
  317. case "dual_ar":
  318. model_cls = DualARTransformer
  319. case _:
  320. raise ValueError(f"Unknown model type: {config.model_type}")
  321. tokenizer_path = str(path) + "/tokenizer.tiktoken"
  322. tokenizer = FishTokenizer(tokenizer_path)
  323. log.info(f"Loading model from {path}, config: {config}")
  324. model = model_cls(config, tokenizer=tokenizer)
  325. if lora_config is not None:
  326. setup_lora(model, lora_config)
  327. log.info(f"LoRA setup: {lora_config}")
  328. if load_weights is False:
  329. log.info("Randomly initialized model")
  330. else:
  331. if "int8" in str(Path(path)):
  332. logger.info("Using int8 weight-only quantization!")
  333. from tools.llama.quantize import WeightOnlyInt8QuantHandler
  334. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  335. model = simple_quantizer.convert_for_runtime()
  336. if "int4" in str(Path(path)):
  337. logger.info("Using int4 quantization!")
  338. path_comps = path.name.split("-")
  339. assert path_comps[-2].startswith("g")
  340. groupsize = int(path_comps[-2][1:])
  341. from tools.llama.quantize import WeightOnlyInt4QuantHandler
  342. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  343. model = simple_quantizer.convert_for_runtime()
  344. weights = torch.load(
  345. Path(path) / "model.pth",
  346. map_location="cpu",
  347. mmap=True,
  348. weights_only=True,
  349. )
  350. if "state_dict" in weights:
  351. logger.warning(
  352. "Using a TextToSemantic LightningModule checkpoint, "
  353. "please make sure it is a full model, not a LoRA model."
  354. )
  355. weights = weights["state_dict"]
  356. if next(iter(weights.keys())).startswith("model."):
  357. logger.info(
  358. f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
  359. )
  360. new_weights = OrderedDict()
  361. for k, v in weights.items():
  362. new_weights[k.replace("model.", "")] = v
  363. weights = new_weights
  364. # Verify the name and shape of parameters since strict=False in load_state_dict.
  365. for k, v in model.named_parameters():
  366. if k not in weights:
  367. logger.warning(f"No weight for {k}")
  368. elif v.shape != weights[k].shape:
  369. logger.warning(
  370. f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
  371. )
  372. err = model.load_state_dict(weights, strict=False, assign=True)
  373. log.info(f"Loaded weights with error: {err}")
  374. return model
  375. def save_pretrained(self, path: str, drop_lora: bool = False):
  376. path = Path(path)
  377. path.mkdir(parents=True, exist_ok=True)
  378. self.config.save(path / "config.json")
  379. state_dict = self.state_dict()
  380. if drop_lora:
  381. for key in list(state_dict.keys()):
  382. if "lora" not in key:
  383. continue
  384. state_dict.pop(key)
  385. log.info(f"Drop LoRA parameter: {key}")
  386. torch.save(state_dict, path / "model.pth")
  387. self.tokenizer.save_pretrained(path)
  388. class NaiveTransformer(BaseTransformer):
  389. def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
  390. super().__init__(config, init_weights=False, tokenizer=tokenizer)
  391. self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
  392. self.codebook_output = nn.Linear(
  393. config.dim,
  394. config.codebook_size * config.num_codebooks,
  395. bias=False,
  396. )
  397. self.apply(self._init_weights)
  398. def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
  399. token_logits = result.logits
  400. x = result.hidden_states
  401. # Codebook
  402. codebook_logits = self.codebook_output(self.codebook_norm(x))
  403. codebook_logits = rearrange(
  404. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  405. )
  406. return TransformerForwardResult(
  407. token_logits=token_logits,
  408. codebook_logits=codebook_logits,
  409. )
  410. def forward(
  411. self,
  412. inp: Tensor,
  413. key_padding_mask: Optional[Tensor] = None,
  414. ) -> TransformerForwardResult:
  415. result = super().forward(
  416. inp=inp,
  417. key_padding_mask=key_padding_mask,
  418. )
  419. return self.decode(result)
  420. def forward_generate(
  421. self, x: Tensor, input_pos: Optional[Tensor] = None
  422. ) -> TransformerForwardResult:
  423. result = super().forward_generate(x, input_pos)
  424. return self.decode(result)
  425. class DualARTransformer(BaseTransformer):
  426. def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
  427. super().__init__(config, init_weights=False, tokenizer=tokenizer)
  428. # Project to fast dim if needed
  429. if config.fast_dim is not None and config.fast_dim != config.dim:
  430. self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
  431. else:
  432. self.fast_project_in = nn.Identity()
  433. # Fast transformer
  434. self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
  435. # The equivalent bs is so large that sdpa doesn't work
  436. override_config = dataclasses.replace(
  437. config,
  438. dim=config.fast_dim,
  439. n_head=config.fast_n_head,
  440. n_local_heads=config.fast_n_local_heads,
  441. head_dim=config.fast_head_dim,
  442. intermediate_size=config.fast_intermediate_size,
  443. attention_qkv_bias=config.fast_attention_qkv_bias,
  444. )
  445. self.fast_layers = nn.ModuleList(
  446. TransformerBlock(override_config, use_sdpa=False)
  447. for _ in range(config.n_fast_layer)
  448. )
  449. self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
  450. self.fast_output = nn.Linear(
  451. config.fast_dim,
  452. config.codebook_size,
  453. bias=False,
  454. )
  455. self.register_buffer(
  456. "fast_freqs_cis",
  457. precompute_freqs_cis(
  458. config.num_codebooks,
  459. config.fast_dim // config.fast_n_head,
  460. config.rope_base,
  461. ),
  462. persistent=False,
  463. )
  464. self.apply(self._init_weights)
  465. def setup_caches(
  466. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  467. ):
  468. super().setup_caches(max_batch_size, max_seq_len, dtype)
  469. head_dim = self.config.fast_dim // self.config.fast_n_head
  470. # Fast transformer
  471. # The max seq len here is the number of codebooks
  472. for b in self.fast_layers:
  473. b.attention.kv_cache = KVCache(
  474. max_batch_size,
  475. self.config.num_codebooks,
  476. self.config.fast_n_local_heads,
  477. head_dim,
  478. dtype=dtype,
  479. )
  480. def forward(
  481. self,
  482. inp: Tensor,
  483. key_padding_mask: Optional[Tensor] = None,
  484. ) -> TransformerForwardResult:
  485. parent_result = super().forward(inp, key_padding_mask)
  486. token_logits = parent_result.logits
  487. x = parent_result.hidden_states
  488. x = self.fast_project_in(x)
  489. # Fast transformer
  490. fast_seq_len = self.config.num_codebooks
  491. fast_mask = self.causal_mask[
  492. None, None, :fast_seq_len, :fast_seq_len
  493. ] # (B, N, Q, K)
  494. # Drop the last token and rotate left
  495. codebooks = inp[:, 1:-1, 1:]
  496. codebooks = F.pad(codebooks, (0, 1), value=0)
  497. codebook_embeddings = self.fast_embeddings(codebooks)
  498. x = torch.cat([x[:, None], codebook_embeddings], dim=1)
  499. b, s = x.size(0), x.size(2)
  500. x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
  501. # Remove padded part
  502. codebooks = rearrange(codebooks, "b n s -> (b s) n")
  503. codebook_mask = (codebooks == 0).all(dim=-1)
  504. if torch.all(codebook_mask):
  505. # If all codebooks are padded, we keep first 8 to make sure the model runs
  506. codebook_mask[:8] = False
  507. x_bs, x_len = x.size(0), x.size(1)
  508. x = x[~codebook_mask]
  509. for layer in self.fast_layers:
  510. if self.config.use_gradient_checkpointing and self.training:
  511. x = checkpoint(
  512. layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
  513. )
  514. else:
  515. x = layer(x, self.fast_freqs_cis, fast_mask)
  516. # unflatten the batch and num_codebooks
  517. fast_out = self.fast_norm(x)
  518. codebook_logits = self.fast_output(fast_out)
  519. # Re-pad the codebook_logits
  520. buffer = torch.zeros(
  521. x_bs,
  522. x_len,
  523. codebook_logits.size(-1),
  524. device=codebook_logits.device,
  525. dtype=codebook_logits.dtype,
  526. )
  527. buffer[~codebook_mask] = codebook_logits
  528. codebook_logits = buffer
  529. assert codebook_logits.shape[1] == self.config.num_codebooks
  530. codebook_logits = rearrange(
  531. codebook_logits,
  532. "(b s) n d -> b s n d",
  533. b=b,
  534. s=s,
  535. n=self.config.num_codebooks,
  536. )
  537. return TransformerForwardResult(
  538. token_logits=token_logits,
  539. codebook_logits=codebook_logits,
  540. )
  541. def forward_generate_fast(
  542. self, x: Tensor, input_pos: Optional[Tensor] = None
  543. ) -> Tensor:
  544. # Fast transformer
  545. x = x.view(1, 1, -1)
  546. fast_mask = self.causal_mask[
  547. None, None, input_pos, : self.config.num_codebooks
  548. ] # (B, N, Q, K)
  549. fast_freqs_cis = self.fast_freqs_cis[input_pos]
  550. for layer in self.fast_layers:
  551. x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
  552. # unflatten the batch and num_codebooks
  553. fast_out = self.fast_norm(x) # only take the last token
  554. codebook_logits = self.fast_output(fast_out)
  555. return codebook_logits
  556. def forward_generate(
  557. self,
  558. x: Tensor,
  559. input_pos: Optional[Tensor] = None,
  560. vq_masks: Optional[Tensor] = None,
  561. ) -> TransformerForwardResult:
  562. x = super().forward_generate(x, input_pos, vq_masks)
  563. x.hidden_states = self.fast_project_in(x.hidden_states)
  564. return x
  565. class TransformerBlock(nn.Module):
  566. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
  567. super().__init__()
  568. self.attention = Attention(config, use_sdpa=use_sdpa)
  569. self.feed_forward = FeedForward(config)
  570. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  571. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  572. def forward(
  573. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  574. ) -> Tensor:
  575. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  576. out = h + self.feed_forward(self.ffn_norm(h))
  577. return out
  578. class Attention(nn.Module):
  579. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
  580. super().__init__()
  581. assert config.dim % config.n_head == 0
  582. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  583. # key, query, value projections for all heads, but in a batch
  584. self.wqkv = nn.Linear(
  585. config.dim, total_head_dim, bias=config.attention_qkv_bias
  586. )
  587. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  588. self.kv_cache = None
  589. self.dropout = config.dropout
  590. self.n_head = config.n_head
  591. self.head_dim = config.head_dim
  592. self.n_local_heads = config.n_local_heads
  593. self.dim = config.dim
  594. self.use_sdpa = use_sdpa
  595. self._register_load_state_dict_pre_hook(self.load_hook)
  596. def load_hook(self, state_dict, prefix, *args):
  597. if prefix + "wq.weight" in state_dict:
  598. wq = state_dict.pop(prefix + "wq.weight")
  599. wk = state_dict.pop(prefix + "wk.weight")
  600. wv = state_dict.pop(prefix + "wv.weight")
  601. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  602. def forward(
  603. self,
  604. x: Tensor,
  605. freqs_cis: Tensor,
  606. mask: Tensor,
  607. input_pos: Optional[Tensor] = None,
  608. ) -> Tensor:
  609. bsz, seqlen, _ = x.shape
  610. kv_size = self.n_local_heads * self.head_dim
  611. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  612. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  613. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  614. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  615. q = apply_rotary_emb(q, freqs_cis)
  616. k = apply_rotary_emb(k, freqs_cis)
  617. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  618. if self.kv_cache is not None:
  619. k, v = self.kv_cache.update(input_pos, k, v)
  620. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  621. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  622. if self.use_sdpa:
  623. if mask is None:
  624. with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
  625. y = F.scaled_dot_product_attention(
  626. q,
  627. k,
  628. v,
  629. dropout_p=self.dropout if self.training else 0.0,
  630. is_causal=True,
  631. # No third party attn_mask here to use flash_attention
  632. )
  633. else:
  634. y = F.scaled_dot_product_attention(
  635. q,
  636. k,
  637. v,
  638. attn_mask=mask,
  639. dropout_p=self.dropout if self.training else 0.0,
  640. )
  641. else:
  642. y = self.eq_scaled_dot_product_attention(
  643. q,
  644. k,
  645. v,
  646. attn_mask=mask,
  647. dropout_p=self.dropout if self.training else 0.0,
  648. )
  649. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  650. return self.wo(y)
  651. def eq_scaled_dot_product_attention(
  652. self,
  653. query,
  654. key,
  655. value,
  656. attn_mask=None,
  657. dropout_p=0.0,
  658. ) -> torch.Tensor:
  659. # This is a standard scaled dot product attention
  660. # It's low efficient, but it doesn't raise cuda error
  661. L, S = query.size(-2), key.size(-2)
  662. scale_factor = 1 / math.sqrt(query.size(-1))
  663. attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
  664. if attn_mask is not None:
  665. if attn_mask.dtype == torch.bool:
  666. attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
  667. else:
  668. attn_bias += attn_mask
  669. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  670. attn_weight += attn_bias
  671. attn_weight = torch.softmax(attn_weight, dim=-1)
  672. attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
  673. return attn_weight @ value
  674. class FeedForward(nn.Module):
  675. def __init__(self, config: BaseModelArgs) -> None:
  676. super().__init__()
  677. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  678. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  679. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  680. def forward(self, x: Tensor) -> Tensor:
  681. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  682. class RMSNorm(nn.Module):
  683. def __init__(self, dim: int, eps: float = 1e-5):
  684. super().__init__()
  685. self.eps = eps
  686. self.weight = nn.Parameter(torch.ones(dim))
  687. def _norm(self, x):
  688. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  689. def forward(self, x: Tensor) -> Tensor:
  690. output = self._norm(x.float()).type_as(x)
  691. return output * self.weight
  692. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  693. """
  694. Precomputes frequency tensors for complex exponentials (cis)
  695. Args:
  696. seq_len: Length of the sequence for which positional embeddings are needed.
  697. n_elem: Number of elements in the frequency tensor.
  698. base: Base value for the frequency scaling (default: 10000).
  699. Returns:
  700. A tensor containing the precomputed frequencies in real and imaginary parts (bfloat16).
  701. """
  702. freqs = 1.0 / (
  703. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  704. )
  705. t = torch.arange(seq_len, device=freqs.device)
  706. freqs = torch.outer(t, freqs)
  707. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  708. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  709. return cache.to(dtype=torch.bfloat16)
  710. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  711. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  712. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  713. x_out2 = torch.stack(
  714. [
  715. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  716. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  717. ],
  718. -1,
  719. )
  720. x_out2 = x_out2.flatten(3)
  721. return x_out2.type_as(x)