llama.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026
  1. import dataclasses
  2. import json
  3. import math
  4. from collections import OrderedDict
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from typing import Optional
  8. import torch
  9. import torch.nn as nn
  10. from einops import rearrange
  11. from loguru import logger
  12. from torch import Tensor
  13. from torch.nn import functional as F
  14. from torch.nn.attention import SDPBackend, sdpa_kernel
  15. from torch.utils.checkpoint import checkpoint
  16. from fish_speech.models.text2semantic.lora import LoraConfig, setup_lora
  17. def find_multiple(n: int, k: int) -> int:
  18. if n % k == 0:
  19. return n
  20. return n + k - (n % k)
  21. @dataclass
  22. class BaseModelArgs:
  23. model_type: str = "base"
  24. vocab_size: int = 32000
  25. n_layer: int = 32
  26. n_head: int = 32
  27. dim: int = 4096
  28. intermediate_size: int = None
  29. n_local_heads: int = -1
  30. head_dim: int = 64
  31. rope_base: float = 10000
  32. norm_eps: float = 1e-5
  33. max_seq_len: int = 2048
  34. dropout: float = 0.0
  35. tie_word_embeddings: bool = True
  36. attention_qkv_bias: bool = False
  37. attention_o_bias: bool = False
  38. attention_qk_norm: bool = False
  39. # Codebook configs
  40. codebook_size: int = 160
  41. num_codebooks: int = 4
  42. semantic_begin_id: int = 0
  43. semantic_end_id: int = 0
  44. # Gradient checkpointing
  45. use_gradient_checkpointing: bool = True
  46. # Initialize the model
  47. initializer_range: float = 0.02
  48. # Dummy vars
  49. is_reward_model: bool = False
  50. scale_codebook_embeddings: bool = False
  51. audio_embed_dim: Optional[int] = None
  52. def __post_init__(self):
  53. if self.n_local_heads == -1:
  54. self.n_local_heads = self.n_head
  55. if self.intermediate_size is None:
  56. hidden_dim = 4 * self.dim
  57. n_hidden = int(2 * hidden_dim / 3)
  58. self.intermediate_size = find_multiple(n_hidden, 256)
  59. if self.head_dim is None:
  60. self.head_dim = self.dim // self.n_head
  61. @staticmethod
  62. def from_pretrained(path: str):
  63. path = Path(path)
  64. if path.is_dir():
  65. path = path / "config.json"
  66. with open(path, "r", encoding="utf-8") as f:
  67. data = json.load(f)
  68. match data["model_type"]:
  69. case "naive":
  70. cls = NaiveModelArgs
  71. case "dual_ar":
  72. cls = DualARModelArgs
  73. case "fish_qwen3_omni":
  74. return BaseModelArgs._from_fish_qwen3_omni(data)
  75. case _:
  76. raise ValueError(f"Unknown model type: {data['model_type']}")
  77. # Filter out unexpected keyword arguments
  78. valid_keys = {f.name for f in dataclasses.fields(cls)}
  79. data = {k: v for k, v in data.items() if k in valid_keys}
  80. return cls(**data)
  81. @staticmethod
  82. def _from_fish_qwen3_omni(data: dict) -> "DualARModelArgs":
  83. tc = data["text_config"]
  84. adc = data["audio_decoder_config"]
  85. flat = dict(
  86. model_type="dual_ar",
  87. vocab_size=tc["vocab_size"],
  88. n_layer=tc["n_layer"],
  89. n_head=tc["n_head"],
  90. n_local_heads=tc.get("n_local_heads", -1),
  91. head_dim=tc.get("head_dim"),
  92. dim=tc["dim"],
  93. intermediate_size=tc.get("intermediate_size"),
  94. rope_base=tc.get("rope_base", 10000),
  95. norm_eps=tc.get("norm_eps", 1e-5),
  96. max_seq_len=tc.get("max_seq_len", 2048),
  97. dropout=tc.get("dropout", 0.0),
  98. tie_word_embeddings=tc.get("tie_word_embeddings", True),
  99. attention_qkv_bias=tc.get("attention_qkv_bias", False),
  100. attention_o_bias=tc.get("attention_o_bias", False),
  101. attention_qk_norm=tc.get("attention_qk_norm", False),
  102. use_gradient_checkpointing=tc.get("use_gradient_checkpointing", True),
  103. initializer_range=tc.get("initializer_range", 0.02),
  104. semantic_begin_id=data.get("semantic_start_token_id", 0),
  105. semantic_end_id=data.get("semantic_end_token_id", 0),
  106. scale_codebook_embeddings=True,
  107. norm_fastlayer_input=True,
  108. audio_embed_dim=adc.get("text_dim", tc["dim"]),
  109. codebook_size=adc["vocab_size"],
  110. num_codebooks=adc["num_codebooks"],
  111. n_fast_layer=adc["n_layer"],
  112. fast_dim=adc.get("dim"),
  113. fast_n_head=adc.get("n_head"),
  114. fast_n_local_heads=adc.get("n_local_heads"),
  115. fast_head_dim=adc.get("head_dim"),
  116. fast_intermediate_size=adc.get("intermediate_size"),
  117. fast_attention_qkv_bias=adc.get("attention_qkv_bias"),
  118. fast_attention_qk_norm=adc.get("attention_qk_norm"),
  119. fast_attention_o_bias=adc.get("attention_o_bias"),
  120. )
  121. valid_keys = {f.name for f in dataclasses.fields(DualARModelArgs)}
  122. flat = {k: v for k, v in flat.items() if k in valid_keys and v is not None}
  123. return DualARModelArgs(**flat)
  124. def save(self, path: str):
  125. with open(path, "w") as f:
  126. json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
  127. @dataclass
  128. class NaiveModelArgs(BaseModelArgs):
  129. model_type: str = "naive"
  130. @dataclass
  131. class DualARModelArgs(BaseModelArgs):
  132. model_type: str = "dual_ar"
  133. n_fast_layer: int = 4
  134. fast_dim: int | None = None
  135. fast_n_head: int | None = None
  136. fast_n_local_heads: int | None = None
  137. fast_head_dim: int | None = None
  138. fast_intermediate_size: int | None = None
  139. fast_attention_qkv_bias: bool | None = None
  140. fast_attention_qk_norm: bool | None = None
  141. fast_attention_o_bias: bool | None = None
  142. norm_fastlayer_input: bool = False
  143. def __post_init__(self):
  144. super().__post_init__()
  145. self.fast_dim = self.fast_dim or self.dim
  146. self.fast_n_head = self.fast_n_head or self.n_head
  147. self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
  148. self.fast_head_dim = self.fast_head_dim or self.head_dim
  149. self.fast_intermediate_size = (
  150. self.fast_intermediate_size or self.intermediate_size
  151. )
  152. self.fast_attention_qkv_bias = (
  153. self.fast_attention_qkv_bias
  154. if self.fast_attention_qkv_bias is not None
  155. else self.attention_qkv_bias
  156. )
  157. self.fast_attention_qk_norm = (
  158. self.fast_attention_qk_norm
  159. if self.fast_attention_qk_norm is not None
  160. else self.attention_qk_norm
  161. )
  162. self.fast_attention_o_bias = (
  163. self.fast_attention_o_bias
  164. if self.fast_attention_o_bias is not None
  165. else self.attention_o_bias
  166. )
  167. class KVCache(nn.Module):
  168. def __init__(
  169. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  170. ):
  171. super().__init__()
  172. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  173. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  174. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  175. def update(self, input_pos, k_val, v_val):
  176. # input_pos: [S], k_val: [B, H, S, D]
  177. assert input_pos.shape[0] == k_val.shape[2]
  178. k_out = self.k_cache
  179. v_out = self.v_cache
  180. k_out[:, :, input_pos] = k_val
  181. v_out[:, :, input_pos] = v_val
  182. return k_out, v_out
  183. @dataclass
  184. class TransformerForwardResult:
  185. token_logits: Tensor
  186. codebook_logits: Tensor
  187. @dataclass
  188. class BaseTransformerForwardResult:
  189. logits: Tensor
  190. hidden_states: Tensor
  191. def _remap_fish_qwen3_omni_keys(weights: OrderedDict) -> OrderedDict:
  192. if not any(k.startswith(("text_model.", "audio_decoder.")) for k in weights):
  193. return weights
  194. new_weights = OrderedDict()
  195. for k, v in weights.items():
  196. if k.startswith("text_model.model."):
  197. new_key = k[len("text_model.model."):]
  198. elif k.startswith("audio_decoder."):
  199. suffix = k[len("audio_decoder."):]
  200. new_key = suffix if suffix.startswith("codebook_embeddings.") else "fast_" + suffix
  201. else:
  202. new_key = k
  203. new_weights[new_key] = v
  204. return new_weights
  205. class BaseTransformer(nn.Module):
  206. def __init__(
  207. self,
  208. config: BaseModelArgs,
  209. init_weights: bool = True,
  210. ) -> None:
  211. super().__init__()
  212. self.config = config
  213. # Slow transformer
  214. self.embeddings = nn.Embedding(
  215. config.vocab_size,
  216. config.dim,
  217. )
  218. self.codebook_embeddings = nn.Embedding(
  219. config.codebook_size * config.num_codebooks,
  220. config.dim,
  221. )
  222. self.layers = nn.ModuleList(
  223. TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
  224. )
  225. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  226. if self.config.tie_word_embeddings is False:
  227. self.output = nn.Linear(
  228. config.dim,
  229. config.vocab_size,
  230. bias=False,
  231. )
  232. self.register_buffer(
  233. "freqs_cis",
  234. precompute_freqs_cis(
  235. config.max_seq_len,
  236. config.head_dim,
  237. config.rope_base,
  238. ),
  239. persistent=False,
  240. )
  241. self.register_buffer(
  242. "causal_mask",
  243. torch.tril(
  244. torch.ones(
  245. config.max_seq_len,
  246. config.max_seq_len,
  247. dtype=torch.bool,
  248. )
  249. ),
  250. persistent=False,
  251. )
  252. # For kv cache
  253. self.max_batch_size = -1
  254. self.max_seq_len = -1
  255. if init_weights:
  256. self.apply(self._init_weights)
  257. def setup_caches(
  258. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  259. ):
  260. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  261. return
  262. max_seq_len = find_multiple(max_seq_len, 8)
  263. self.max_seq_len = max_seq_len
  264. self.max_batch_size = max_batch_size
  265. for b in self.layers:
  266. b.attention.kv_cache = KVCache(
  267. max_batch_size,
  268. max_seq_len,
  269. self.config.n_local_heads,
  270. self.config.head_dim,
  271. dtype=dtype,
  272. )
  273. def embed(self, inp: Tensor) -> Tensor:
  274. embeds = []
  275. for i in range(self.config.num_codebooks):
  276. emb = self.codebook_embeddings(
  277. inp[:, i + 1] + i * self.config.codebook_size
  278. )
  279. embeds.append(emb)
  280. vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
  281. is_semantic = (inp[:, 0] >= self.config.semantic_begin_id) & \
  282. (inp[:, 0] <= self.config.semantic_end_id)
  283. vq_embeds_sum[~is_semantic] = 0
  284. x = self.embeddings(inp[:, 0]) + vq_embeds_sum
  285. return x
  286. def forward(
  287. self,
  288. inp: Tensor,
  289. key_padding_mask: Optional[Tensor] = None,
  290. ) -> BaseTransformerForwardResult:
  291. seq_len = inp.size(2)
  292. # Here we want to merge the embeddings of the codebooks
  293. x = self.embed(inp)
  294. freqs_cis = self.freqs_cis[:seq_len]
  295. mask = None
  296. if key_padding_mask is not None:
  297. causal = self.causal_mask[:seq_len, :seq_len]
  298. causal = rearrange(causal, "q k -> 1 1 q k")
  299. atten_mask = rearrange(key_padding_mask, "b s -> b 1 1 s")
  300. atten_mask = atten_mask.logical_not()
  301. mask = causal & atten_mask
  302. for layer in self.layers:
  303. if self.config.use_gradient_checkpointing and self.training:
  304. x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
  305. else:
  306. x = layer(x, freqs_cis, mask)
  307. slow_out = self.norm(x)
  308. if self.config.tie_word_embeddings:
  309. token_logits = F.linear(slow_out, self.embeddings.weight)
  310. else:
  311. token_logits = self.output(slow_out)
  312. hidden_out = (
  313. slow_out
  314. if getattr(self.config, "norm_fastlayer_input", False)
  315. else x
  316. )
  317. return BaseTransformerForwardResult(
  318. logits=token_logits,
  319. hidden_states=hidden_out,
  320. )
  321. def forward_generate(
  322. self,
  323. inp: Tensor,
  324. input_pos: Optional[Tensor] = None,
  325. audio_masks: Optional[Tensor] = None,
  326. audio_parts: Optional[Tensor] = None,
  327. return_all: bool = False,
  328. ) -> BaseTransformerForwardResult:
  329. # Embedding logic replicated from embed() for compilation compatibility
  330. embeds = []
  331. for i in range(self.config.num_codebooks):
  332. emb = self.codebook_embeddings(
  333. inp[:, i + 1] + i * self.config.codebook_size
  334. )
  335. embeds.append(emb)
  336. vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
  337. vq_masks = (inp[:, 0] >= self.config.semantic_begin_id) & (
  338. inp[:, 0] <= self.config.semantic_end_id
  339. )
  340. vq_embeds_sum[~vq_masks] = 0
  341. x = self.embeddings(inp[:, 0]) + vq_embeds_sum
  342. if self.config.scale_codebook_embeddings:
  343. vq_masks_expanded = vq_masks.unsqueeze(-1).expand_as(x)
  344. x = torch.where(
  345. vq_masks_expanded, x / math.sqrt(self.config.num_codebooks + 1), x
  346. )
  347. # Audio embeddings
  348. if audio_parts is not None:
  349. # Note: This assumes self.audio_projector exists if audio_parts is used
  350. # It seems missing in init, but we keep existing logic
  351. if hasattr(self, "audio_projector"):
  352. audio_embeds = self.audio_projector(audio_parts)
  353. if self.config.scale_codebook_embeddings:
  354. x[audio_masks] = audio_embeds / math.sqrt(2)
  355. else:
  356. x[audio_masks] = audio_embeds
  357. else:
  358. logger.warning("audio_parts provided but model has no audio_projector")
  359. if input_pos is None:
  360. input_pos = torch.arange(inp.shape[-1], device=x.device)
  361. max_seq_len = inp.shape[-1]
  362. else:
  363. max_seq_len = self.max_seq_len
  364. mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
  365. freqs_cis = self.freqs_cis[input_pos]
  366. for layer in self.layers:
  367. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  368. if x.size(1) > 1 and not return_all:
  369. x = x[:, -1:]
  370. slow_out = self.norm(x)
  371. if self.config.is_reward_model:
  372. token_logits = self.score_output(slow_out)
  373. elif self.config.tie_word_embeddings:
  374. token_logits = F.linear(slow_out, self.embeddings.weight)
  375. else:
  376. token_logits = self.output(slow_out)
  377. hidden_out = (
  378. slow_out
  379. if getattr(self.config, "norm_fastlayer_input", False)
  380. else x
  381. )
  382. return BaseTransformerForwardResult(
  383. logits=token_logits,
  384. hidden_states=hidden_out,
  385. )
  386. def _init_weights(self, module):
  387. std = self.config.initializer_range
  388. if isinstance(module, nn.Linear):
  389. module.weight.data.normal_(mean=0.0, std=std)
  390. if module.bias is not None:
  391. module.bias.data.zero_()
  392. elif isinstance(module, nn.Embedding):
  393. module.weight.data.normal_(mean=0.0, std=std)
  394. if module.padding_idx is not None:
  395. module.weight.data[module.padding_idx].zero_()
  396. @staticmethod
  397. def from_pretrained(
  398. path: str,
  399. load_weights: bool = False,
  400. max_length: int | None = None,
  401. lora_config: LoraConfig | None = None,
  402. rope_base: int | None = None,
  403. ) -> "BaseTransformer":
  404. # Import wrapper locally to avoid circular dependency or global import issues
  405. from fish_speech.tokenizer import FishTokenizer
  406. config = BaseModelArgs.from_pretrained(str(path))
  407. if max_length is not None:
  408. config.max_seq_len = max_length
  409. logger.info(f"Override max_seq_len to {max_length}")
  410. if rope_base is not None:
  411. config.rope_base = rope_base
  412. logger.info(f"Override rope_base to {rope_base}")
  413. try:
  414. tokenizer = FishTokenizer.from_pretrained(path)
  415. config.semantic_begin_id = tokenizer.semantic_begin_id
  416. config.semantic_end_id = tokenizer.semantic_end_id
  417. logger.info(f"Injected Semantic IDs into Config: {config.semantic_begin_id}-{config.semantic_end_id}")
  418. except Exception as e:
  419. logger.warning(f"Failed to load tokenizer for config injection: {e}. Semantic IDs might be 0.")
  420. match config.model_type:
  421. case "naive":
  422. model_cls = NaiveTransformer
  423. case "dual_ar":
  424. model_cls = DualARTransformer
  425. case _:
  426. raise ValueError(f"Unknown model type: {config.model_type}")
  427. logger.info(f"Loading model from {path}, config: {config}")
  428. # Initialize model without passing tokenizer explicitly to __init__
  429. model = model_cls(config)
  430. # Attach tokenizer to model instance for inference convenience (optional, but good for user scripts)
  431. model.tokenizer = tokenizer
  432. if load_weights is False:
  433. logger.info("Randomly initialized model")
  434. else:
  435. if "int8" in str(Path(path)):
  436. logger.info("Using int8 weight-only quantization!")
  437. from tools.llama.quantize import WeightOnlyInt8QuantHandler
  438. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  439. model = simple_quantizer.convert_for_runtime()
  440. if "int4" in str(Path(path)):
  441. logger.info("Using int4 quantization!")
  442. path_comps = path.name.split("-")
  443. assert path_comps[-2].startswith("g")
  444. groupsize = int(path_comps[-2][1:])
  445. from tools.llama.quantize import WeightOnlyInt4QuantHandler
  446. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  447. model = simple_quantizer.convert_for_runtime()
  448. path_obj = Path(path)
  449. index_json = path_obj / "model.safetensors.index.json"
  450. single_st = path_obj / "model.safetensors"
  451. pth_file = path_obj / "model.pth"
  452. if index_json.exists():
  453. logger.info("Loading sharded safetensors weights")
  454. from safetensors.torch import load_file as st_load_file
  455. with open(index_json) as f:
  456. st_index = json.load(f)
  457. shard_files = sorted(set(st_index["weight_map"].values()))
  458. weights = OrderedDict()
  459. for shard in shard_files:
  460. weights.update(st_load_file(str(path_obj / shard), device="cpu"))
  461. weights = _remap_fish_qwen3_omni_keys(weights)
  462. elif single_st.exists():
  463. logger.info("Loading single safetensors weights")
  464. from safetensors.torch import load_file as st_load_file
  465. weights = OrderedDict(st_load_file(str(single_st), device="cpu"))
  466. weights = _remap_fish_qwen3_omni_keys(weights)
  467. elif pth_file.exists():
  468. weights = torch.load(
  469. pth_file,
  470. map_location="cpu",
  471. mmap=True,
  472. weights_only=True,
  473. )
  474. if "state_dict" in weights:
  475. weights = weights["state_dict"]
  476. if weights and next(iter(weights.keys())).startswith("model."):
  477. weights = OrderedDict(
  478. (k.replace("model.", ""), v) for k, v in weights.items()
  479. )
  480. for k in list(weights.keys()):
  481. if "audio_" in k:
  482. weights.pop(k)
  483. else:
  484. raise FileNotFoundError(f"No model weights found in {path_obj}")
  485. err = model.load_state_dict(weights, strict=False, assign=True)
  486. logger.info(f"Model weights loaded - Status: {err}")
  487. if lora_config is not None:
  488. setup_lora(model, lora_config)
  489. logger.info(f"LoRA setup: {lora_config}")
  490. return model
  491. def save_pretrained(self, path: str, drop_lora: bool = False):
  492. path = Path(path)
  493. path.mkdir(parents=True, exist_ok=True)
  494. self.config.save(path / "config.json")
  495. state_dict = self.state_dict()
  496. if drop_lora:
  497. for key in list(state_dict.keys()):
  498. if "lora" not in key:
  499. continue
  500. state_dict.pop(key)
  501. torch.save(state_dict, path / "model.pth")
  502. if hasattr(self, "tokenizer"):
  503. self.tokenizer.save_pretrained(path)
  504. class NaiveTransformer(BaseTransformer):
  505. def __init__(self, config: NaiveModelArgs) -> None:
  506. super().__init__(config, init_weights=False)
  507. self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
  508. self.codebook_output = nn.Linear(
  509. config.dim,
  510. config.codebook_size * config.num_codebooks,
  511. bias=False,
  512. )
  513. self.apply(self._init_weights)
  514. def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
  515. token_logits = result.logits
  516. x = result.hidden_states
  517. # Codebook
  518. codebook_logits = self.codebook_output(self.codebook_norm(x))
  519. codebook_logits = rearrange(
  520. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  521. )
  522. return TransformerForwardResult(
  523. token_logits=token_logits,
  524. codebook_logits=codebook_logits,
  525. )
  526. def forward(
  527. self,
  528. inp: Tensor,
  529. key_padding_mask: Optional[Tensor] = None,
  530. ) -> TransformerForwardResult:
  531. result = super().forward(
  532. inp=inp,
  533. key_padding_mask=key_padding_mask,
  534. )
  535. return self.decode(result)
  536. def forward_generate(
  537. self, x: Tensor, input_pos: Optional[Tensor] = None
  538. ) -> TransformerForwardResult:
  539. result = super().forward_generate(x, input_pos)
  540. return self.decode(result)
  541. class DualARTransformer(BaseTransformer):
  542. def __init__(self, config: NaiveModelArgs) -> None:
  543. super().__init__(config, init_weights=False)
  544. # Project to fast dim if needed
  545. if config.fast_dim is not None and config.fast_dim != config.dim:
  546. self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
  547. else:
  548. self.fast_project_in = nn.Identity()
  549. # Fast transformer
  550. self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
  551. # The equivalent bs is so large that sdpa doesn't work
  552. override_config = dataclasses.replace(
  553. config,
  554. dim=config.fast_dim,
  555. n_head=config.fast_n_head,
  556. n_local_heads=config.fast_n_local_heads,
  557. head_dim=config.fast_head_dim,
  558. intermediate_size=config.fast_intermediate_size,
  559. attention_qkv_bias=config.fast_attention_qkv_bias,
  560. attention_qk_norm=config.fast_attention_qk_norm,
  561. attention_o_bias=config.fast_attention_o_bias,
  562. )
  563. self.fast_layers = nn.ModuleList(
  564. TransformerBlock(override_config, use_sdpa=False)
  565. for _ in range(config.n_fast_layer)
  566. )
  567. self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
  568. self.fast_output = nn.Linear(
  569. config.fast_dim,
  570. config.codebook_size,
  571. bias=False,
  572. )
  573. self.register_buffer(
  574. "fast_freqs_cis",
  575. precompute_freqs_cis(
  576. config.num_codebooks,
  577. config.fast_head_dim,
  578. config.rope_base,
  579. ),
  580. persistent=False,
  581. )
  582. self.apply(self._init_weights)
  583. def setup_caches(
  584. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  585. ):
  586. super().setup_caches(max_batch_size, max_seq_len, dtype)
  587. # Fast transformer
  588. # The max seq len here is the number of codebooks
  589. for b in self.fast_layers:
  590. b.attention.kv_cache = KVCache(
  591. max_batch_size,
  592. self.config.num_codebooks,
  593. self.config.fast_n_local_heads,
  594. self.config.fast_head_dim,
  595. dtype=dtype,
  596. )
  597. def forward(
  598. self,
  599. inp: Tensor,
  600. labels: Optional[Tensor] = None,
  601. key_padding_mask: Optional[Tensor] = None,
  602. vq_parts: Optional[Tensor] = None,
  603. vq_masks: Optional[Tensor] = None,
  604. vq_require_losses: Optional[Tensor] = None,
  605. mel_parts: Optional[Tensor] = None,
  606. mel_masks: Optional[Tensor] = None,
  607. ) -> TransformerForwardResult:
  608. parent_result = super().forward(
  609. inp=inp,
  610. key_padding_mask=key_padding_mask,
  611. )
  612. token_logits = parent_result.logits
  613. x = parent_result.hidden_states
  614. # Fast transformer
  615. fast_seq_len = self.config.num_codebooks
  616. fast_mask = self.causal_mask[
  617. None, None, :fast_seq_len, :fast_seq_len
  618. ] # (B, N, Q, K)
  619. fast_freqs_cis = self.fast_freqs_cis[:fast_seq_len]
  620. # Extract corresponding parts with labels
  621. token_labels = labels[:, 0]
  622. # [MODIFIED] Use config instead of tokenizer
  623. codebook_mask = (token_labels >= self.config.semantic_begin_id) & (
  624. token_labels <= self.config.semantic_end_id
  625. )
  626. # This gives where input token is <|semantic|>
  627. x = x[codebook_mask]
  628. if x.shape[0] == 0:
  629. # Use dummy input when no vq is required
  630. x = torch.zeros(
  631. (4, self.config.dim),
  632. device=x.device,
  633. dtype=x.dtype,
  634. )
  635. codebooks = torch.zeros(
  636. (x.shape[0], self.config.num_codebooks - 1),
  637. device=x.device,
  638. dtype=torch.int,
  639. )
  640. else:
  641. all_codebooks = labels[:, 1:, :]
  642. all_codebooks_permuted = all_codebooks.permute(0, 2, 1)
  643. semantic_codebooks = all_codebooks_permuted[codebook_mask]
  644. codebooks = semantic_codebooks[:, :-1]
  645. x = self.fast_project_in(x)
  646. codebook_embeddings = self.fast_embeddings(codebooks)
  647. x = torch.cat([x[:, None], codebook_embeddings], dim=1)
  648. for layer in self.fast_layers:
  649. if self.config.use_gradient_checkpointing and self.training:
  650. x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
  651. else:
  652. x = layer(x, fast_freqs_cis, fast_mask)
  653. # unflatten the batch and num_codebooks
  654. fast_out = self.fast_norm(x)
  655. codebook_logits = self.fast_output(fast_out)
  656. assert codebook_logits.shape[1] == self.config.num_codebooks
  657. return TransformerForwardResult(
  658. token_logits=token_logits,
  659. codebook_logits=codebook_logits,
  660. )
  661. def forward_generate_fast(
  662. self, x: Tensor, input_pos: Optional[Tensor] = None
  663. ) -> Tensor:
  664. # Fast transformer
  665. x = x.view(x.shape[0], 1, -1)
  666. fast_mask = self.causal_mask[
  667. None, None, input_pos, : self.config.num_codebooks
  668. ] # (B, N, Q, K)
  669. fast_freqs_cis = self.fast_freqs_cis[input_pos]
  670. for layer in self.fast_layers:
  671. x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
  672. # unflatten the batch and num_codebooks
  673. fast_out = self.fast_norm(x) # only take the last token
  674. codebook_logits = self.fast_output(fast_out)
  675. return codebook_logits
  676. def forward_generate(
  677. self,
  678. x: Tensor,
  679. input_pos: Optional[Tensor] = None,
  680. audio_masks: Optional[Tensor] = None,
  681. audio_parts: Optional[Tensor] = None,
  682. ) -> TransformerForwardResult:
  683. x = super().forward_generate(x, input_pos, audio_masks, audio_parts)
  684. x.hidden_states = self.fast_project_in(x.hidden_states)
  685. return x
  686. class TransformerBlock(nn.Module):
  687. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
  688. super().__init__()
  689. self.attention = Attention(config, use_sdpa=use_sdpa)
  690. self.feed_forward = FeedForward(config)
  691. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  692. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  693. def forward(
  694. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  695. ) -> Tensor:
  696. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  697. out = h + self.feed_forward(self.ffn_norm(h))
  698. return out
  699. class Attention(nn.Module):
  700. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
  701. super().__init__()
  702. assert config.dim % config.n_head == 0
  703. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  704. # key, query, value projections for all heads, but in a batch
  705. self.wqkv = nn.Linear(
  706. config.dim, total_head_dim, bias=config.attention_qkv_bias
  707. )
  708. self.wo = nn.Linear(
  709. config.n_head * config.head_dim, config.dim, bias=config.attention_o_bias
  710. )
  711. self.kv_cache = None
  712. if config.attention_qk_norm:
  713. self.q_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
  714. self.k_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
  715. self.dropout = config.dropout
  716. self.n_head = config.n_head
  717. self.head_dim = config.head_dim
  718. self.n_local_heads = config.n_local_heads
  719. self.dim = config.dim
  720. self.use_sdpa = use_sdpa
  721. self.attention_qk_norm = config.attention_qk_norm
  722. self.config = config
  723. self._register_load_state_dict_pre_hook(self.load_hook)
  724. def load_hook(self, state_dict, prefix, *args):
  725. if prefix + "wq.weight" in state_dict:
  726. wq = state_dict.pop(prefix + "wq.weight")
  727. wk = state_dict.pop(prefix + "wk.weight")
  728. wv = state_dict.pop(prefix + "wv.weight")
  729. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  730. def forward(
  731. self,
  732. x: Tensor,
  733. freqs_cis: Tensor,
  734. mask: Tensor,
  735. input_pos: Optional[Tensor] = None,
  736. ) -> Tensor:
  737. bsz, seqlen, _ = x.shape
  738. q_size = self.n_head * self.head_dim
  739. kv_size = self.n_local_heads * self.head_dim
  740. q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
  741. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  742. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  743. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  744. if self.attention_qk_norm:
  745. q = self.q_norm(q)
  746. k = self.k_norm(k)
  747. q = apply_rotary_emb(q, freqs_cis)
  748. k = apply_rotary_emb(k, freqs_cis)
  749. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  750. if self.kv_cache is not None:
  751. k, v = self.kv_cache.update(input_pos, k, v)
  752. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  753. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  754. if self.use_sdpa:
  755. if mask is None:
  756. with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
  757. y = F.scaled_dot_product_attention(
  758. q,
  759. k,
  760. v,
  761. dropout_p=self.dropout if self.training else 0.0,
  762. is_causal=True,
  763. # No third party attn_mask here to use flash_attention
  764. )
  765. else:
  766. y = F.scaled_dot_product_attention(
  767. q,
  768. k,
  769. v,
  770. attn_mask=mask,
  771. dropout_p=self.dropout if self.training else 0.0,
  772. )
  773. else:
  774. y = self.eq_scaled_dot_product_attention(
  775. q,
  776. k,
  777. v,
  778. attn_mask=mask,
  779. dropout_p=self.dropout if self.training else 0.0,
  780. )
  781. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
  782. return self.wo(y)
  783. def eq_scaled_dot_product_attention(
  784. self,
  785. query,
  786. key,
  787. value,
  788. attn_mask=None,
  789. dropout_p=0.0,
  790. ) -> torch.Tensor:
  791. # This is a standard scaled dot product attention
  792. # It's low efficient, but it doesn't raise cuda error
  793. L, S = query.size(-2), key.size(-2)
  794. scale_factor = 1 / math.sqrt(query.size(-1))
  795. attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
  796. if attn_mask is not None:
  797. if attn_mask.dtype == torch.bool:
  798. attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
  799. else:
  800. attn_bias += attn_mask
  801. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  802. attn_weight += attn_bias
  803. attn_weight = torch.softmax(attn_weight, dim=-1)
  804. attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
  805. return attn_weight @ value
  806. class FeedForward(nn.Module):
  807. def __init__(self, config: BaseModelArgs) -> None:
  808. super().__init__()
  809. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  810. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  811. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  812. def forward(self, x: Tensor) -> Tensor:
  813. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  814. class RMSNorm(nn.Module):
  815. def __init__(self, dim: int, eps: float = 1e-5):
  816. super().__init__()
  817. self.eps = eps
  818. self.weight = nn.Parameter(torch.ones(dim))
  819. def _norm(self, x):
  820. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  821. def forward(self, x: Tensor) -> Tensor:
  822. output = self._norm(x.float()).type_as(x)
  823. return output * self.weight
  824. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  825. """
  826. Precomputes frequency tensors for complex exponentials (cis)
  827. Args:
  828. seq_len: Length of the sequence for which positional embeddings are needed.
  829. n_elem: Number of elements in the frequency tensor.
  830. base: Base value for the frequency scaling (default: 10000).
  831. Returns:
  832. A tensor containing the precomputed frequencies in real and imaginary parts (bfloat16).
  833. """
  834. freqs = 1.0 / (
  835. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  836. )
  837. t = torch.arange(seq_len, device=freqs.device)
  838. freqs = torch.outer(t, freqs)
  839. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  840. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  841. return cache.to(dtype=torch.bfloat16)
  842. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  843. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  844. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  845. x_out2 = torch.stack(
  846. [
  847. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  848. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  849. ],
  850. -1,
  851. )
  852. x_out2 = x_out2.flatten(3)
  853. return x_out2.type_as(x)