llama.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940
  1. import dataclasses
  2. import json
  3. import math
  4. from collections import OrderedDict
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from typing import Optional
  8. import torch
  9. import torch.nn as nn
  10. from einops import rearrange
  11. from loguru import logger
  12. from torch import Tensor
  13. from torch.nn import functional as F
  14. from torch.nn.attention import SDPBackend, sdpa_kernel
  15. from torch.utils.checkpoint import checkpoint
  16. from transformers import AutoTokenizer
  17. from fish_speech.models.text2semantic.lora import LoraConfig, setup_lora
  18. from fish_speech.tokenizer import SEMANTIC_TOKENS, FishTokenizer
  19. def find_multiple(n: int, k: int) -> int:
  20. if n % k == 0:
  21. return n
  22. return n + k - (n % k)
  23. @dataclass
  24. class BaseModelArgs:
  25. model_type: str = "base"
  26. vocab_size: int = 32000
  27. n_layer: int = 32
  28. n_head: int = 32
  29. dim: int = 4096
  30. intermediate_size: int = None
  31. n_local_heads: int = -1
  32. head_dim: int = 64
  33. rope_base: float = 10000
  34. norm_eps: float = 1e-5
  35. max_seq_len: int = 2048
  36. dropout: float = 0.0
  37. tie_word_embeddings: bool = True
  38. attention_qkv_bias: bool = False
  39. attention_o_bias: bool = False
  40. attention_qk_norm: bool = False
  41. # Codebook configs
  42. codebook_size: int = 160
  43. num_codebooks: int = 4
  44. # Gradient checkpointing
  45. use_gradient_checkpointing: bool = True
  46. # Initialize the model
  47. initializer_range: float = 0.02
  48. # Dummy vars
  49. is_reward_model: bool = False
  50. scale_codebook_embeddings: bool = False
  51. def __post_init__(self):
  52. if self.n_local_heads == -1:
  53. self.n_local_heads = self.n_head
  54. if self.intermediate_size is None:
  55. hidden_dim = 4 * self.dim
  56. n_hidden = int(2 * hidden_dim / 3)
  57. self.intermediate_size = find_multiple(n_hidden, 256)
  58. if self.head_dim is None:
  59. self.head_dim = self.dim // self.n_head
  60. @staticmethod
  61. def from_pretrained(path: str):
  62. path = Path(path)
  63. if path.is_dir():
  64. path = path / "config.json"
  65. with open(path, "r", encoding="utf-8") as f:
  66. data = json.load(f)
  67. match data["model_type"]:
  68. case "naive":
  69. cls = NaiveModelArgs
  70. case "dual_ar":
  71. cls = DualARModelArgs
  72. case _:
  73. raise ValueError(f"Unknown model type: {data['model_type']}")
  74. return cls(**data)
  75. def save(self, path: str):
  76. with open(path, "w") as f:
  77. json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
  78. @dataclass
  79. class NaiveModelArgs(BaseModelArgs):
  80. model_type: str = "naive"
  81. @dataclass
  82. class DualARModelArgs(BaseModelArgs):
  83. model_type: str = "dual_ar"
  84. n_fast_layer: int = 4
  85. fast_dim: int | None = None
  86. fast_n_head: int | None = None
  87. fast_n_local_heads: int | None = None
  88. fast_head_dim: int | None = None
  89. fast_intermediate_size: int | None = None
  90. fast_attention_qkv_bias: bool | None = None
  91. fast_attention_qk_norm: bool | None = None
  92. fast_attention_o_bias: bool | None = None
  93. def __post_init__(self):
  94. super().__post_init__()
  95. self.fast_dim = self.fast_dim or self.dim
  96. self.fast_n_head = self.fast_n_head or self.n_head
  97. self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
  98. self.fast_head_dim = self.fast_head_dim or self.head_dim
  99. self.fast_intermediate_size = (
  100. self.fast_intermediate_size or self.intermediate_size
  101. )
  102. self.fast_attention_qkv_bias = (
  103. self.fast_attention_qkv_bias
  104. if self.fast_attention_qkv_bias is not None
  105. else self.attention_qkv_bias
  106. )
  107. self.fast_attention_qk_norm = (
  108. self.fast_attention_qk_norm
  109. if self.fast_attention_qk_norm is not None
  110. else self.attention_qk_norm
  111. )
  112. self.fast_attention_o_bias = (
  113. self.fast_attention_o_bias
  114. if self.fast_attention_o_bias is not None
  115. else self.attention_o_bias
  116. )
  117. class KVCache(nn.Module):
  118. def __init__(
  119. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  120. ):
  121. super().__init__()
  122. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  123. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  124. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  125. def update(self, input_pos, k_val, v_val):
  126. # input_pos: [S], k_val: [B, H, S, D]
  127. assert input_pos.shape[0] == k_val.shape[2]
  128. k_out = self.k_cache
  129. v_out = self.v_cache
  130. k_out[:, :, input_pos] = k_val
  131. v_out[:, :, input_pos] = v_val
  132. return k_out, v_out
  133. @dataclass
  134. class TransformerForwardResult:
  135. token_logits: Tensor
  136. codebook_logits: Tensor
  137. @dataclass
  138. class BaseTransformerForwardResult:
  139. logits: Tensor
  140. hidden_states: Tensor
  141. class BaseTransformer(nn.Module):
  142. def __init__(
  143. self,
  144. config: BaseModelArgs,
  145. tokenizer: FishTokenizer,
  146. init_weights: bool = True,
  147. ) -> None:
  148. super().__init__()
  149. self.config = config
  150. self.tokenizer = tokenizer
  151. self.semantic_token_ids = list(tokenizer.semantic_id_to_token_id.values())
  152. # Slow transformer
  153. self.embeddings = nn.Embedding(
  154. config.vocab_size,
  155. config.dim,
  156. )
  157. self.codebook_embeddings = nn.Embedding(
  158. config.codebook_size * config.num_codebooks,
  159. config.dim,
  160. )
  161. self.layers = nn.ModuleList(
  162. TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
  163. )
  164. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  165. if self.config.tie_word_embeddings is False:
  166. self.output = nn.Linear(
  167. config.dim,
  168. config.vocab_size,
  169. bias=False,
  170. )
  171. self.register_buffer(
  172. "freqs_cis",
  173. precompute_freqs_cis(
  174. config.max_seq_len,
  175. config.head_dim,
  176. config.rope_base,
  177. ),
  178. persistent=False,
  179. )
  180. self.register_buffer(
  181. "causal_mask",
  182. torch.tril(
  183. torch.ones(
  184. config.max_seq_len,
  185. config.max_seq_len,
  186. dtype=torch.bool,
  187. )
  188. ),
  189. persistent=False,
  190. )
  191. # For kv cache
  192. self.max_batch_size = -1
  193. self.max_seq_len = -1
  194. if init_weights:
  195. self.apply(self._init_weights)
  196. def setup_caches(
  197. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  198. ):
  199. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  200. return
  201. max_seq_len = find_multiple(max_seq_len, 8)
  202. self.max_seq_len = max_seq_len
  203. self.max_batch_size = max_batch_size
  204. for b in self.layers:
  205. b.attention.kv_cache = KVCache(
  206. max_batch_size,
  207. max_seq_len,
  208. self.config.n_local_heads,
  209. self.config.head_dim,
  210. dtype=dtype,
  211. )
  212. def embed(self, inp: Tensor) -> Tensor:
  213. embeds = []
  214. semantic_token_ids_tensor = torch.tensor(
  215. self.semantic_token_ids, device=inp.device, dtype=inp.dtype
  216. )
  217. for i in range(self.config.num_codebooks):
  218. emb = self.codebook_embeddings(
  219. inp[:, i + 1] + i * self.config.codebook_size
  220. )
  221. embeds.append(emb)
  222. vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
  223. vq_embeds_sum[~torch.isin(inp[:, 0], semantic_token_ids_tensor)] = 0
  224. x = self.embeddings(inp[:, 0]) + vq_embeds_sum
  225. return x
  226. def forward(
  227. self,
  228. inp: Tensor,
  229. key_padding_mask: Optional[Tensor] = None,
  230. ) -> BaseTransformerForwardResult:
  231. seq_len = inp.size(2)
  232. # Here we want to merge the embeddings of the codebooks
  233. x = self.embed(inp)
  234. freqs_cis = self.freqs_cis[:seq_len]
  235. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  236. # That is, FALSE means masked out
  237. # To maintain consistency, key_padding_mask use TRUE to mask out
  238. mask = None
  239. if key_padding_mask is not None:
  240. causal = self.causal_mask[:seq_len, :seq_len]
  241. causal = rearrange(causal, "q k -> 1 1 q k")
  242. atten_mask = rearrange(key_padding_mask, "b s -> b 1 1 s")
  243. atten_mask = atten_mask.logical_not()
  244. mask = causal & atten_mask
  245. # return freqs_cis, mask
  246. for layer in self.layers:
  247. if self.config.use_gradient_checkpointing and self.training:
  248. x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
  249. else:
  250. x = layer(x, freqs_cis, mask)
  251. # We got slow_out here
  252. slow_out = self.norm(x)
  253. if self.config.tie_word_embeddings:
  254. token_logits = F.linear(slow_out, self.embeddings.weight)
  255. else:
  256. token_logits = self.output(slow_out)
  257. return BaseTransformerForwardResult(
  258. logits=token_logits,
  259. hidden_states=x,
  260. )
  261. def forward_generate(
  262. self,
  263. inp: Tensor,
  264. input_pos: Optional[Tensor] = None,
  265. audio_masks: Optional[Tensor] = None,
  266. audio_parts: Optional[Tensor] = None,
  267. return_all: bool = False,
  268. ) -> BaseTransformerForwardResult:
  269. # This is used for generation, optimized for torch compile
  270. # assert (
  271. # self.max_seq_len != -1 and self.max_batch_size != -1
  272. # ), "Please call setup_caches before forward_generate"
  273. embeds = []
  274. for i in range(self.config.num_codebooks):
  275. emb = self.codebook_embeddings(
  276. inp[:, i + 1] + i * self.config.codebook_size
  277. )
  278. embeds.append(emb)
  279. vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
  280. vq_masks = (inp[:, 0] >= self.tokenizer.semantic_begin_id) & (
  281. inp[:, 0] <= self.tokenizer.semantic_end_id
  282. )
  283. vq_embeds_sum[~vq_masks] = 0
  284. x = self.embeddings(inp[:, 0]) + vq_embeds_sum
  285. if self.config.scale_codebook_embeddings:
  286. # Expand vq_masks to match x's shape
  287. vq_masks_expanded = vq_masks.unsqueeze(-1).expand_as(x)
  288. x = torch.where(
  289. vq_masks_expanded, x / math.sqrt(self.config.num_codebooks + 1), x
  290. )
  291. # Audio embeddings
  292. if audio_parts is not None:
  293. audio_embeds = self.audio_projector(audio_parts)
  294. if self.config.scale_codebook_embeddings:
  295. x[audio_masks] = audio_embeds / math.sqrt(2)
  296. else:
  297. x[audio_masks] = audio_embeds
  298. if input_pos is None:
  299. input_pos = torch.arange(inp.shape[-1], device=x.device)
  300. max_seq_len = inp.shape[-1]
  301. else:
  302. max_seq_len = self.max_seq_len
  303. mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
  304. freqs_cis = self.freqs_cis[input_pos]
  305. for layer in self.layers:
  306. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  307. # If prefill, we only calculate the logits of last token
  308. if x.size(1) > 1 and not return_all:
  309. x = x[:, -1:]
  310. # We got slow_out here
  311. slow_out = self.norm(x)
  312. if self.config.is_reward_model:
  313. token_logits = self.score_output(slow_out)
  314. elif self.config.tie_word_embeddings:
  315. token_logits = F.linear(slow_out, self.embeddings.weight)
  316. else:
  317. token_logits = self.output(slow_out)
  318. return BaseTransformerForwardResult(
  319. logits=token_logits,
  320. hidden_states=x,
  321. )
  322. def _init_weights(self, module):
  323. std = self.config.initializer_range
  324. if isinstance(module, nn.Linear):
  325. module.weight.data.normal_(mean=0.0, std=std)
  326. if module.bias is not None:
  327. module.bias.data.zero_()
  328. elif isinstance(module, nn.Embedding):
  329. module.weight.data.normal_(mean=0.0, std=std)
  330. if module.padding_idx is not None:
  331. module.weight.data[module.padding_idx].zero_()
  332. @staticmethod
  333. def from_pretrained(
  334. path: str,
  335. load_weights: bool = False,
  336. max_length: int | None = None,
  337. lora_config: LoraConfig | None = None,
  338. rope_base: int | None = None,
  339. ) -> "BaseTransformer":
  340. config = BaseModelArgs.from_pretrained(str(path))
  341. if max_length is not None:
  342. config.max_seq_len = max_length
  343. logger.info(f"Override max_seq_len to {max_length}")
  344. if rope_base is not None:
  345. config.rope_base = rope_base
  346. logger.info(f"Override rope_base to {rope_base}")
  347. match config.model_type:
  348. case "naive":
  349. model_cls = NaiveTransformer
  350. case "dual_ar":
  351. model_cls = DualARTransformer
  352. case _:
  353. raise ValueError(f"Unknown model type: {config.model_type}")
  354. tokenizer = FishTokenizer.from_pretrained(path)
  355. logger.info(f"Loading model from {path}, config: {config}")
  356. model = model_cls(config, tokenizer=tokenizer)
  357. if load_weights is False:
  358. logger.info("Randomly initialized model")
  359. else:
  360. if "int8" in str(Path(path)):
  361. logger.info("Using int8 weight-only quantization!")
  362. from tools.llama.quantize import WeightOnlyInt8QuantHandler
  363. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  364. model = simple_quantizer.convert_for_runtime()
  365. if "int4" in str(Path(path)):
  366. logger.info("Using int4 quantization!")
  367. path_comps = path.name.split("-")
  368. assert path_comps[-2].startswith("g")
  369. groupsize = int(path_comps[-2][1:])
  370. from tools.llama.quantize import WeightOnlyInt4QuantHandler
  371. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  372. model = simple_quantizer.convert_for_runtime()
  373. weights = torch.load(
  374. Path(path) / "model.pth",
  375. map_location="cpu",
  376. mmap=True,
  377. weights_only=True,
  378. )
  379. if "state_dict" in weights:
  380. logger.warning(
  381. "Using a TextToSemantic LightningModule checkpoint, "
  382. "please make sure it is a full model, not a LoRA model."
  383. )
  384. weights = weights["state_dict"]
  385. if next(iter(weights.keys())).startswith("model."):
  386. logger.info(
  387. f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
  388. )
  389. new_weights = OrderedDict()
  390. for k, v in weights.items():
  391. new_weights[k.replace("model.", "")] = v
  392. weights = new_weights
  393. # Remove audio related weights
  394. for k in list(weights.keys()):
  395. if "audio_" in k:
  396. weights.pop(k)
  397. # Verify the name and shape of parameters since strict=False in load_state_dict.
  398. for k, v in model.named_parameters():
  399. if k not in weights:
  400. logger.warning(f"No weight for {k}")
  401. elif v.shape != weights[k].shape:
  402. logger.warning(
  403. f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
  404. )
  405. err = model.load_state_dict(weights, strict=False, assign=True)
  406. logger.info(f"Model weights loaded - Status: {err}")
  407. if lora_config is not None:
  408. setup_lora(model, lora_config)
  409. logger.info(f"LoRA setup: {lora_config}")
  410. return model
  411. def save_pretrained(self, path: str, drop_lora: bool = False):
  412. path = Path(path)
  413. path.mkdir(parents=True, exist_ok=True)
  414. self.config.save(path / "config.json")
  415. state_dict = self.state_dict()
  416. if drop_lora:
  417. for key in list(state_dict.keys()):
  418. if "lora" not in key:
  419. continue
  420. state_dict.pop(key)
  421. logger.info(f"Drop LoRA parameter: {key}")
  422. torch.save(state_dict, path / "model.pth")
  423. self.tokenizer.save_pretrained(path)
  424. class NaiveTransformer(BaseTransformer):
  425. def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
  426. super().__init__(config, init_weights=False, tokenizer=tokenizer)
  427. self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
  428. self.codebook_output = nn.Linear(
  429. config.dim,
  430. config.codebook_size * config.num_codebooks,
  431. bias=False,
  432. )
  433. self.apply(self._init_weights)
  434. def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
  435. token_logits = result.logits
  436. x = result.hidden_states
  437. # Codebook
  438. codebook_logits = self.codebook_output(self.codebook_norm(x))
  439. codebook_logits = rearrange(
  440. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  441. )
  442. return TransformerForwardResult(
  443. token_logits=token_logits,
  444. codebook_logits=codebook_logits,
  445. )
  446. def forward(
  447. self,
  448. inp: Tensor,
  449. key_padding_mask: Optional[Tensor] = None,
  450. ) -> TransformerForwardResult:
  451. result = super().forward(
  452. inp=inp,
  453. key_padding_mask=key_padding_mask,
  454. )
  455. return self.decode(result)
  456. def forward_generate(
  457. self, x: Tensor, input_pos: Optional[Tensor] = None
  458. ) -> TransformerForwardResult:
  459. result = super().forward_generate(x, input_pos)
  460. return self.decode(result)
  461. class DualARTransformer(BaseTransformer):
  462. def __init__(self, config: NaiveModelArgs, tokenizer: FishTokenizer) -> None:
  463. super().__init__(config, init_weights=False, tokenizer=tokenizer)
  464. # Project to fast dim if needed
  465. if config.fast_dim is not None and config.fast_dim != config.dim:
  466. self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
  467. else:
  468. self.fast_project_in = nn.Identity()
  469. # Fast transformer
  470. self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
  471. # The equivalent bs is so large that sdpa doesn't work
  472. override_config = dataclasses.replace(
  473. config,
  474. dim=config.fast_dim,
  475. n_head=config.fast_n_head,
  476. n_local_heads=config.fast_n_local_heads,
  477. head_dim=config.fast_head_dim,
  478. intermediate_size=config.fast_intermediate_size,
  479. attention_qkv_bias=config.fast_attention_qkv_bias,
  480. attention_qk_norm=config.fast_attention_qk_norm,
  481. attention_o_bias=config.fast_attention_o_bias,
  482. )
  483. self.fast_layers = nn.ModuleList(
  484. TransformerBlock(override_config, use_sdpa=False)
  485. for _ in range(config.n_fast_layer)
  486. )
  487. self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
  488. self.fast_output = nn.Linear(
  489. config.fast_dim,
  490. config.codebook_size,
  491. bias=False,
  492. )
  493. self.register_buffer(
  494. "fast_freqs_cis",
  495. precompute_freqs_cis(
  496. config.num_codebooks,
  497. config.fast_head_dim,
  498. config.rope_base,
  499. ),
  500. persistent=False,
  501. )
  502. self.apply(self._init_weights)
  503. def setup_caches(
  504. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  505. ):
  506. super().setup_caches(max_batch_size, max_seq_len, dtype)
  507. # Fast transformer
  508. # The max seq len here is the number of codebooks
  509. for b in self.fast_layers:
  510. b.attention.kv_cache = KVCache(
  511. max_batch_size,
  512. self.config.num_codebooks,
  513. self.config.fast_n_local_heads,
  514. self.config.fast_head_dim,
  515. dtype=dtype,
  516. )
  517. def forward(
  518. self,
  519. inp: Tensor,
  520. labels: Optional[Tensor] = None,
  521. key_padding_mask: Optional[Tensor] = None,
  522. vq_parts: Optional[Tensor] = None,
  523. vq_masks: Optional[Tensor] = None,
  524. vq_require_losses: Optional[Tensor] = None,
  525. mel_parts: Optional[Tensor] = None,
  526. mel_masks: Optional[Tensor] = None,
  527. ) -> TransformerForwardResult:
  528. parent_result = super().forward(
  529. inp=inp,
  530. key_padding_mask=key_padding_mask,
  531. )
  532. token_logits = parent_result.logits
  533. x = parent_result.hidden_states
  534. # Fast transformer
  535. fast_seq_len = self.config.num_codebooks
  536. fast_mask = self.causal_mask[
  537. None, None, :fast_seq_len, :fast_seq_len
  538. ] # (B, N, Q, K)
  539. fast_freqs_cis = self.fast_freqs_cis[:fast_seq_len]
  540. # Extract corresponding parts with labels
  541. token_labels = labels[:, 0]
  542. codebook_mask = (token_labels >= self.tokenizer.semantic_begin_id) & (
  543. token_labels <= self.tokenizer.semantic_end_id
  544. )
  545. # This gives where input token is <|semantic|>
  546. x = x[codebook_mask]
  547. if x.shape[0] == 0:
  548. # Use dummy input when no vq is required
  549. x = torch.zeros(
  550. (4, self.config.dim),
  551. device=x.device,
  552. dtype=x.dtype,
  553. )
  554. codebooks = torch.zeros(
  555. (x.shape[0], self.config.num_codebooks - 1),
  556. device=x.device,
  557. dtype=torch.int,
  558. )
  559. else:
  560. all_codebooks = labels[:, 1:, :]
  561. all_codebooks_permuted = all_codebooks.permute(0, 2, 1)
  562. semantic_codebooks = all_codebooks_permuted[codebook_mask]
  563. codebooks = semantic_codebooks[:, :-1]
  564. x = self.fast_project_in(x)
  565. codebook_embeddings = self.fast_embeddings(codebooks)
  566. x = torch.cat([x[:, None], codebook_embeddings], dim=1)
  567. for layer in self.fast_layers:
  568. if self.config.use_gradient_checkpointing and self.training:
  569. x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
  570. else:
  571. x = layer(x, fast_freqs_cis, fast_mask)
  572. # unflatten the batch and num_codebooks
  573. fast_out = self.fast_norm(x)
  574. codebook_logits = self.fast_output(fast_out)
  575. assert codebook_logits.shape[1] == self.config.num_codebooks
  576. return TransformerForwardResult(
  577. token_logits=token_logits,
  578. codebook_logits=codebook_logits,
  579. )
  580. def forward_generate_fast(
  581. self, x: Tensor, input_pos: Optional[Tensor] = None
  582. ) -> Tensor:
  583. # Fast transformer
  584. x = x.view(x.shape[0], 1, -1)
  585. fast_mask = self.causal_mask[
  586. None, None, input_pos, : self.config.num_codebooks
  587. ] # (B, N, Q, K)
  588. fast_freqs_cis = self.fast_freqs_cis[input_pos]
  589. for layer in self.fast_layers:
  590. x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
  591. # unflatten the batch and num_codebooks
  592. fast_out = self.fast_norm(x) # only take the last token
  593. codebook_logits = self.fast_output(fast_out)
  594. return codebook_logits
  595. def forward_generate(
  596. self,
  597. x: Tensor,
  598. input_pos: Optional[Tensor] = None,
  599. audio_masks: Optional[Tensor] = None,
  600. audio_parts: Optional[Tensor] = None,
  601. ) -> TransformerForwardResult:
  602. x = super().forward_generate(x, input_pos, audio_masks, audio_parts)
  603. x.hidden_states = self.fast_project_in(x.hidden_states)
  604. return x
  605. class TransformerBlock(nn.Module):
  606. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
  607. super().__init__()
  608. self.attention = Attention(config, use_sdpa=use_sdpa)
  609. self.feed_forward = FeedForward(config)
  610. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  611. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  612. def forward(
  613. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  614. ) -> Tensor:
  615. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  616. out = h + self.feed_forward(self.ffn_norm(h))
  617. return out
  618. class Attention(nn.Module):
  619. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
  620. super().__init__()
  621. assert config.dim % config.n_head == 0
  622. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  623. # key, query, value projections for all heads, but in a batch
  624. self.wqkv = nn.Linear(
  625. config.dim, total_head_dim, bias=config.attention_qkv_bias
  626. )
  627. self.wo = nn.Linear(
  628. config.n_head * config.head_dim, config.dim, bias=config.attention_o_bias
  629. )
  630. self.kv_cache = None
  631. if config.attention_qk_norm:
  632. self.q_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
  633. self.k_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
  634. self.dropout = config.dropout
  635. self.n_head = config.n_head
  636. self.head_dim = config.head_dim
  637. self.n_local_heads = config.n_local_heads
  638. self.dim = config.dim
  639. self.use_sdpa = use_sdpa
  640. self.attention_qk_norm = config.attention_qk_norm
  641. self.config = config
  642. self._register_load_state_dict_pre_hook(self.load_hook)
  643. def load_hook(self, state_dict, prefix, *args):
  644. if prefix + "wq.weight" in state_dict:
  645. wq = state_dict.pop(prefix + "wq.weight")
  646. wk = state_dict.pop(prefix + "wk.weight")
  647. wv = state_dict.pop(prefix + "wv.weight")
  648. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  649. def forward(
  650. self,
  651. x: Tensor,
  652. freqs_cis: Tensor,
  653. mask: Tensor,
  654. input_pos: Optional[Tensor] = None,
  655. ) -> Tensor:
  656. bsz, seqlen, _ = x.shape
  657. q_size = self.n_head * self.head_dim
  658. kv_size = self.n_local_heads * self.head_dim
  659. q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
  660. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  661. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  662. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  663. if self.attention_qk_norm:
  664. q = self.q_norm(q)
  665. k = self.k_norm(k)
  666. q = apply_rotary_emb(q, freqs_cis)
  667. k = apply_rotary_emb(k, freqs_cis)
  668. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  669. if self.kv_cache is not None:
  670. k, v = self.kv_cache.update(input_pos, k, v)
  671. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  672. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  673. if self.use_sdpa:
  674. if mask is None:
  675. with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
  676. y = F.scaled_dot_product_attention(
  677. q,
  678. k,
  679. v,
  680. dropout_p=self.dropout if self.training else 0.0,
  681. is_causal=True,
  682. # No third party attn_mask here to use flash_attention
  683. )
  684. else:
  685. y = F.scaled_dot_product_attention(
  686. q,
  687. k,
  688. v,
  689. attn_mask=mask,
  690. dropout_p=self.dropout if self.training else 0.0,
  691. )
  692. else:
  693. y = self.eq_scaled_dot_product_attention(
  694. q,
  695. k,
  696. v,
  697. attn_mask=mask,
  698. dropout_p=self.dropout if self.training else 0.0,
  699. )
  700. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
  701. return self.wo(y)
  702. def eq_scaled_dot_product_attention(
  703. self,
  704. query,
  705. key,
  706. value,
  707. attn_mask=None,
  708. dropout_p=0.0,
  709. ) -> torch.Tensor:
  710. # This is a standard scaled dot product attention
  711. # It's low efficient, but it doesn't raise cuda error
  712. L, S = query.size(-2), key.size(-2)
  713. scale_factor = 1 / math.sqrt(query.size(-1))
  714. attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
  715. if attn_mask is not None:
  716. if attn_mask.dtype == torch.bool:
  717. attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
  718. else:
  719. attn_bias += attn_mask
  720. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  721. attn_weight += attn_bias
  722. attn_weight = torch.softmax(attn_weight, dim=-1)
  723. attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
  724. return attn_weight @ value
  725. class FeedForward(nn.Module):
  726. def __init__(self, config: BaseModelArgs) -> None:
  727. super().__init__()
  728. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  729. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  730. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  731. def forward(self, x: Tensor) -> Tensor:
  732. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  733. class RMSNorm(nn.Module):
  734. def __init__(self, dim: int, eps: float = 1e-5):
  735. super().__init__()
  736. self.eps = eps
  737. self.weight = nn.Parameter(torch.ones(dim))
  738. def _norm(self, x):
  739. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  740. def forward(self, x: Tensor) -> Tensor:
  741. output = self._norm(x.float()).type_as(x)
  742. return output * self.weight
  743. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  744. """
  745. Precomputes frequency tensors for complex exponentials (cis)
  746. Args:
  747. seq_len: Length of the sequence for which positional embeddings are needed.
  748. n_elem: Number of elements in the frequency tensor.
  749. base: Base value for the frequency scaling (default: 10000).
  750. Returns:
  751. A tensor containing the precomputed frequencies in real and imaginary parts (bfloat16).
  752. """
  753. freqs = 1.0 / (
  754. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  755. )
  756. t = torch.arange(seq_len, device=freqs.device)
  757. freqs = torch.outer(t, freqs)
  758. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  759. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  760. return cache.to(dtype=torch.bfloat16)
  761. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  762. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  763. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  764. x_out2 = torch.stack(
  765. [
  766. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  767. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  768. ],
  769. -1,
  770. )
  771. x_out2 = x_out2.flatten(3)
  772. return x_out2.type_as(x)