llama.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110
  1. import dataclasses
  2. import json
  3. import math
  4. from collections import OrderedDict
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from typing import Optional
  8. import torch
  9. import torch.nn as nn
  10. from einops import rearrange
  11. from loguru import logger
  12. from torch import Tensor
  13. from torch.nn import functional as F
  14. from torch.nn.attention import SDPBackend, sdpa_kernel
  15. from torch.utils.checkpoint import checkpoint
  16. from fish_speech.models.text2semantic.lora import LoraConfig, setup_lora
  17. def find_multiple(n: int, k: int) -> int:
  18. if n % k == 0:
  19. return n
  20. return n + k - (n % k)
  21. @dataclass
  22. class BaseModelArgs:
  23. model_type: str = "base"
  24. vocab_size: int = 32000
  25. n_layer: int = 32
  26. n_head: int = 32
  27. dim: int = 4096
  28. intermediate_size: int = None
  29. n_local_heads: int = -1
  30. head_dim: int = 64
  31. rope_base: float = 10000
  32. norm_eps: float = 1e-5
  33. max_seq_len: int = 2048
  34. dropout: float = 0.0
  35. tie_word_embeddings: bool = True
  36. attention_qkv_bias: bool = False
  37. attention_o_bias: bool = False
  38. attention_qk_norm: bool = False
  39. # Codebook configs
  40. codebook_size: int = 160
  41. num_codebooks: int = 4
  42. semantic_begin_id: int = 0
  43. semantic_end_id: int = 0
  44. # Gradient checkpointing
  45. use_gradient_checkpointing: bool = True
  46. # Initialize the model
  47. initializer_range: float = 0.02
  48. # Dummy vars
  49. is_reward_model: bool = False
  50. scale_codebook_embeddings: bool = False
  51. audio_embed_dim: Optional[int] = None
  52. def __post_init__(self):
  53. if self.n_local_heads == -1:
  54. self.n_local_heads = self.n_head
  55. if self.intermediate_size is None:
  56. hidden_dim = 4 * self.dim
  57. n_hidden = int(2 * hidden_dim / 3)
  58. self.intermediate_size = find_multiple(n_hidden, 256)
  59. if self.head_dim is None:
  60. self.head_dim = self.dim // self.n_head
  61. @staticmethod
  62. def from_pretrained(path: str):
  63. path = Path(path)
  64. if path.is_dir():
  65. path = path / "config.json"
  66. with open(path, "r", encoding="utf-8") as f:
  67. data = json.load(f)
  68. match data["model_type"]:
  69. case "naive":
  70. cls = NaiveModelArgs
  71. case "dual_ar":
  72. cls = DualARModelArgs
  73. case "fish_qwen3_omni":
  74. return BaseModelArgs._from_fish_qwen3_omni(data)
  75. case _:
  76. raise ValueError(f"Unknown model type: {data['model_type']}")
  77. # Filter out unexpected keyword arguments
  78. valid_keys = {f.name for f in dataclasses.fields(cls)}
  79. data = {k: v for k, v in data.items() if k in valid_keys}
  80. return cls(**data)
  81. @staticmethod
  82. def _from_fish_qwen3_omni(data: dict) -> "DualARModelArgs":
  83. tc = data["text_config"]
  84. adc = data["audio_decoder_config"]
  85. flat = dict(
  86. model_type="dual_ar",
  87. vocab_size=tc["vocab_size"],
  88. n_layer=tc["n_layer"],
  89. n_head=tc["n_head"],
  90. n_local_heads=tc.get("n_local_heads", -1),
  91. head_dim=tc.get("head_dim"),
  92. dim=tc["dim"],
  93. intermediate_size=tc.get("intermediate_size"),
  94. rope_base=tc.get("rope_base", 10000),
  95. norm_eps=tc.get("norm_eps", 1e-5),
  96. max_seq_len=tc.get("max_seq_len", 2048),
  97. dropout=tc.get("dropout", 0.0),
  98. tie_word_embeddings=tc.get("tie_word_embeddings", True),
  99. attention_qkv_bias=tc.get("attention_qkv_bias", False),
  100. attention_o_bias=tc.get("attention_o_bias", False),
  101. attention_qk_norm=tc.get("attention_qk_norm", False),
  102. use_gradient_checkpointing=tc.get("use_gradient_checkpointing", True),
  103. initializer_range=tc.get("initializer_range", 0.02),
  104. semantic_begin_id=data.get("semantic_start_token_id", 0),
  105. semantic_end_id=data.get("semantic_end_token_id", 0),
  106. scale_codebook_embeddings=True,
  107. norm_fastlayer_input=True,
  108. audio_embed_dim=adc.get("text_dim", tc["dim"]),
  109. codebook_size=adc["vocab_size"],
  110. num_codebooks=adc["num_codebooks"],
  111. n_fast_layer=adc["n_layer"],
  112. fast_dim=adc.get("dim"),
  113. fast_n_head=adc.get("n_head"),
  114. fast_n_local_heads=adc.get("n_local_heads"),
  115. fast_head_dim=adc.get("head_dim"),
  116. fast_intermediate_size=adc.get("intermediate_size"),
  117. fast_attention_qkv_bias=adc.get("attention_qkv_bias"),
  118. fast_attention_qk_norm=adc.get("attention_qk_norm"),
  119. fast_attention_o_bias=adc.get("attention_o_bias"),
  120. )
  121. valid_keys = {f.name for f in dataclasses.fields(DualARModelArgs)}
  122. flat = {k: v for k, v in flat.items() if k in valid_keys and v is not None}
  123. return DualARModelArgs(**flat)
  124. def save(self, path: str):
  125. with open(path, "w") as f:
  126. json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
  127. @dataclass
  128. class NaiveModelArgs(BaseModelArgs):
  129. model_type: str = "naive"
  130. @dataclass
  131. class DualARModelArgs(BaseModelArgs):
  132. model_type: str = "dual_ar"
  133. n_fast_layer: int = 4
  134. fast_dim: int | None = None
  135. fast_n_head: int | None = None
  136. fast_n_local_heads: int | None = None
  137. fast_head_dim: int | None = None
  138. fast_intermediate_size: int | None = None
  139. fast_attention_qkv_bias: bool | None = None
  140. fast_attention_qk_norm: bool | None = None
  141. fast_attention_o_bias: bool | None = None
  142. norm_fastlayer_input: bool = False
  143. def __post_init__(self):
  144. super().__post_init__()
  145. self.fast_dim = self.fast_dim or self.dim
  146. self.fast_n_head = self.fast_n_head or self.n_head
  147. self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
  148. self.fast_head_dim = self.fast_head_dim or self.head_dim
  149. self.fast_intermediate_size = (
  150. self.fast_intermediate_size or self.intermediate_size
  151. )
  152. self.fast_attention_qkv_bias = (
  153. self.fast_attention_qkv_bias
  154. if self.fast_attention_qkv_bias is not None
  155. else self.attention_qkv_bias
  156. )
  157. self.fast_attention_qk_norm = (
  158. self.fast_attention_qk_norm
  159. if self.fast_attention_qk_norm is not None
  160. else self.attention_qk_norm
  161. )
  162. self.fast_attention_o_bias = (
  163. self.fast_attention_o_bias
  164. if self.fast_attention_o_bias is not None
  165. else self.attention_o_bias
  166. )
  167. class KVCache(nn.Module):
  168. def __init__(
  169. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  170. ):
  171. super().__init__()
  172. logger.info(f"Initializing KVCache max_batch_size={max_batch_size}, max_seq_len={max_seq_len}, n_heads={n_heads}, head_dim={head_dim}")
  173. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  174. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  175. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  176. def update(self, input_pos, k_val, v_val):
  177. # input_pos: [S], k_val: [B, H, S, D]
  178. assert input_pos.shape[0] == k_val.shape[2]
  179. k_out = self.k_cache
  180. v_out = self.v_cache
  181. k_out[:, :, input_pos] = k_val
  182. v_out[:, :, input_pos] = v_val
  183. return k_out, v_out
  184. @dataclass
  185. class TransformerForwardResult:
  186. token_logits: Tensor
  187. codebook_logits: Tensor
  188. @dataclass
  189. class BaseTransformerForwardResult:
  190. logits: Tensor
  191. hidden_states: Tensor
  192. def _remap_fish_qwen3_omni_keys(weights: OrderedDict) -> OrderedDict:
  193. if not any(k.startswith(("text_model.", "audio_decoder.")) for k in weights):
  194. return weights
  195. new_weights = OrderedDict()
  196. for k, v in weights.items():
  197. if k.startswith("text_model.model."):
  198. new_key = k[len("text_model.model.") :]
  199. elif k.startswith("audio_decoder."):
  200. suffix = k[len("audio_decoder.") :]
  201. new_key = (
  202. suffix
  203. if suffix.startswith("codebook_embeddings.")
  204. else "fast_" + suffix
  205. )
  206. else:
  207. new_key = k
  208. new_weights[new_key] = v
  209. return new_weights
  210. class BaseTransformer(nn.Module):
  211. def __init__(
  212. self,
  213. config: BaseModelArgs,
  214. init_weights: bool = True,
  215. ) -> None:
  216. super().__init__()
  217. self.config = config
  218. # Slow transformer
  219. self.embeddings = nn.Embedding(
  220. config.vocab_size,
  221. config.dim,
  222. )
  223. self.codebook_embeddings = nn.Embedding(
  224. config.codebook_size * config.num_codebooks,
  225. config.dim,
  226. )
  227. self.layers = nn.ModuleList(
  228. TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
  229. )
  230. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  231. if self.config.tie_word_embeddings is False:
  232. self.output = nn.Linear(
  233. config.dim,
  234. config.vocab_size,
  235. bias=False,
  236. )
  237. self.register_buffer(
  238. "freqs_cis",
  239. precompute_freqs_cis(
  240. config.max_seq_len,
  241. config.head_dim,
  242. config.rope_base,
  243. ),
  244. persistent=False,
  245. )
  246. self.register_buffer(
  247. "causal_mask",
  248. torch.tril(
  249. torch.ones(
  250. config.max_seq_len,
  251. config.max_seq_len,
  252. dtype=torch.bool,
  253. )
  254. ),
  255. persistent=False,
  256. )
  257. # For kv cache
  258. self.max_batch_size = -1
  259. self.max_seq_len = -1
  260. if init_weights:
  261. self.apply(self._init_weights)
  262. def setup_caches(
  263. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  264. ):
  265. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  266. return
  267. max_seq_len = find_multiple(max_seq_len, 8)
  268. self.max_seq_len = max_seq_len
  269. self.max_batch_size = max_batch_size
  270. logger.info(f"max_batch_size: {max_batch_size}, max_seq_len: {max_seq_len}")
  271. logger.info(f"self.layers: {len(self.layers)}")
  272. for b in self.layers:
  273. b.attention.kv_cache = KVCache(
  274. max_batch_size,
  275. max_seq_len,
  276. self.config.n_local_heads,
  277. self.config.head_dim,
  278. dtype=dtype,
  279. )
  280. def embed(self, inp: Tensor) -> Tensor:
  281. embeds = []
  282. for i in range(self.config.num_codebooks):
  283. emb = self.codebook_embeddings(
  284. inp[:, i + 1] + i * self.config.codebook_size
  285. )
  286. embeds.append(emb)
  287. vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
  288. is_semantic = (inp[:, 0] >= self.config.semantic_begin_id) & (
  289. inp[:, 0] <= self.config.semantic_end_id
  290. )
  291. vq_embeds_sum[~is_semantic] = 0
  292. x = self.embeddings(inp[:, 0]) + vq_embeds_sum
  293. return x
  294. def forward(
  295. self,
  296. inp: Tensor,
  297. key_padding_mask: Optional[Tensor] = None,
  298. ) -> BaseTransformerForwardResult:
  299. seq_len = inp.size(2)
  300. # Here we want to merge the embeddings of the codebooks
  301. x = self.embed(inp)
  302. freqs_cis = self.freqs_cis[:seq_len]
  303. mask = None
  304. if key_padding_mask is not None:
  305. causal = self.causal_mask[:seq_len, :seq_len]
  306. causal = rearrange(causal, "q k -> 1 1 q k")
  307. atten_mask = rearrange(key_padding_mask, "b s -> b 1 1 s")
  308. atten_mask = atten_mask.logical_not()
  309. mask = causal & atten_mask
  310. for layer in self.layers:
  311. if self.config.use_gradient_checkpointing and self.training:
  312. x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
  313. else:
  314. x = layer(x, freqs_cis, mask)
  315. slow_out = self.norm(x)
  316. if self.config.tie_word_embeddings:
  317. token_logits = F.linear(slow_out, self.embeddings.weight)
  318. else:
  319. token_logits = self.output(slow_out)
  320. hidden_out = (
  321. slow_out if getattr(self.config, "norm_fastlayer_input", False) else x
  322. )
  323. return BaseTransformerForwardResult(
  324. logits=token_logits,
  325. hidden_states=hidden_out,
  326. )
  327. def forward_generate(
  328. self,
  329. inp: Tensor,
  330. input_pos: Optional[Tensor] = None,
  331. audio_masks: Optional[Tensor] = None,
  332. audio_parts: Optional[Tensor] = None,
  333. return_all: bool = False,
  334. ) -> BaseTransformerForwardResult:
  335. # Embedding logic replicated from embed() for compilation compatibility
  336. embeds = []
  337. for i in range(self.config.num_codebooks):
  338. emb = self.codebook_embeddings(
  339. inp[:, i + 1] + i * self.config.codebook_size
  340. )
  341. embeds.append(emb)
  342. vq_embeds_sum = torch.stack(embeds, dim=1).sum(dim=1)
  343. vq_masks = (inp[:, 0] >= self.config.semantic_begin_id) & (
  344. inp[:, 0] <= self.config.semantic_end_id
  345. )
  346. vq_embeds_sum[~vq_masks] = 0
  347. x = self.embeddings(inp[:, 0]) + vq_embeds_sum
  348. if self.config.scale_codebook_embeddings:
  349. vq_masks_expanded = vq_masks.unsqueeze(-1).expand_as(x)
  350. x = torch.where(
  351. vq_masks_expanded, x / math.sqrt(self.config.num_codebooks + 1), x
  352. )
  353. # Audio embeddings
  354. if audio_parts is not None:
  355. # Note: This assumes self.audio_projector exists if audio_parts is used
  356. # It seems missing in init, but we keep existing logic
  357. if hasattr(self, "audio_projector"):
  358. audio_embeds = self.audio_projector(audio_parts)
  359. if self.config.scale_codebook_embeddings:
  360. x[audio_masks] = audio_embeds / math.sqrt(2)
  361. else:
  362. x[audio_masks] = audio_embeds
  363. else:
  364. logger.warning("audio_parts provided but model has no audio_projector")
  365. if input_pos is None:
  366. input_pos = torch.arange(inp.shape[-1], device=x.device)
  367. max_seq_len = inp.shape[-1]
  368. else:
  369. max_seq_len = self.max_seq_len
  370. mask = self.causal_mask[None, None, input_pos, :max_seq_len] # (B, N, Q, K)
  371. freqs_cis = self.freqs_cis[input_pos]
  372. for layer in self.layers:
  373. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  374. if x.size(1) > 1 and not return_all:
  375. x = x[:, -1:]
  376. slow_out = self.norm(x)
  377. if self.config.is_reward_model:
  378. token_logits = self.score_output(slow_out)
  379. elif self.config.tie_word_embeddings:
  380. token_logits = F.linear(slow_out, self.embeddings.weight)
  381. else:
  382. token_logits = self.output(slow_out)
  383. hidden_out = (
  384. slow_out if getattr(self.config, "norm_fastlayer_input", False) else x
  385. )
  386. return BaseTransformerForwardResult(
  387. logits=token_logits,
  388. hidden_states=hidden_out,
  389. )
  390. def _init_weights(self, module):
  391. std = self.config.initializer_range
  392. if isinstance(module, nn.Linear):
  393. module.weight.data.normal_(mean=0.0, std=std)
  394. if module.bias is not None:
  395. module.bias.data.zero_()
  396. elif isinstance(module, nn.Embedding):
  397. module.weight.data.normal_(mean=0.0, std=std)
  398. if module.padding_idx is not None:
  399. module.weight.data[module.padding_idx].zero_()
  400. @staticmethod
  401. def from_pretrained(
  402. path: str,
  403. load_weights: bool = False,
  404. max_length: int | None = None,
  405. lora_config: LoraConfig | None = None,
  406. rope_base: int | None = None,
  407. ) -> "BaseTransformer":
  408. # Import wrapper locally to avoid circular dependency or global import issues
  409. from fish_speech.tokenizer import FishTokenizer
  410. config = BaseModelArgs.from_pretrained(str(path))
  411. if max_length is not None:
  412. config.max_seq_len = max_length
  413. logger.info(f"Override max_seq_len to {max_length}")
  414. if rope_base is not None:
  415. config.rope_base = rope_base
  416. logger.info(f"Override rope_base to {rope_base}")
  417. tokenizer = None
  418. try:
  419. tokenizer = FishTokenizer.from_pretrained(path)
  420. config.semantic_begin_id = tokenizer.semantic_begin_id
  421. config.semantic_end_id = tokenizer.semantic_end_id
  422. logger.info(
  423. f"Injected Semantic IDs into Config: {config.semantic_begin_id}-{config.semantic_end_id}"
  424. )
  425. except Exception as e:
  426. logger.warning(
  427. f"Failed to load tokenizer for config injection: {e}. Semantic IDs might be 0."
  428. )
  429. match config.model_type:
  430. case "naive":
  431. model_cls = NaiveTransformer
  432. case "dual_ar":
  433. model_cls = DualARTransformer
  434. case _:
  435. raise ValueError(f"Unknown model type: {config.model_type}")
  436. logger.info(f"Loading model from {path}, config: {config}")
  437. # Initialize model without passing tokenizer explicitly to __init__
  438. model = model_cls(config)
  439. # Attach tokenizer to model instance for inference convenience (optional, but good for user scripts)
  440. model.tokenizer = tokenizer
  441. if load_weights is False:
  442. logger.info("Randomly initialized model")
  443. else:
  444. if "int8" in str(Path(path)):
  445. logger.info("Using int8 weight-only quantization!")
  446. from tools.llama.quantize import WeightOnlyInt8QuantHandler
  447. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  448. model = simple_quantizer.convert_for_runtime()
  449. if "int4" in str(Path(path)):
  450. logger.info("Using int4 quantization!")
  451. path_comps = path.name.split("-")
  452. assert path_comps[-2].startswith("g")
  453. groupsize = int(path_comps[-2][1:])
  454. from tools.llama.quantize import WeightOnlyInt4QuantHandler
  455. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  456. model = simple_quantizer.convert_for_runtime()
  457. path_obj = Path(path)
  458. index_json = path_obj / "model.safetensors.index.json"
  459. single_st = path_obj / "model.safetensors"
  460. pth_file = path_obj / "model.pth"
  461. if index_json.exists():
  462. logger.info("Loading sharded safetensors weights")
  463. from safetensors.torch import load_file as st_load_file
  464. with open(index_json) as f:
  465. st_index = json.load(f)
  466. shard_files = sorted(set(st_index["weight_map"].values()))
  467. weights = OrderedDict()
  468. for shard in shard_files:
  469. weights.update(st_load_file(str(path_obj / shard), device="cpu"))
  470. weights = _remap_fish_qwen3_omni_keys(weights)
  471. elif single_st.exists():
  472. logger.info("Loading single safetensors weights")
  473. from safetensors.torch import load_file as st_load_file
  474. weights = OrderedDict(st_load_file(str(single_st), device="cpu"))
  475. weights = _remap_fish_qwen3_omni_keys(weights)
  476. elif pth_file.exists():
  477. weights = torch.load(
  478. pth_file,
  479. map_location="cpu",
  480. mmap=True,
  481. weights_only=True,
  482. )
  483. if "state_dict" in weights:
  484. weights = weights["state_dict"]
  485. if weights and next(iter(weights.keys())).startswith("model."):
  486. weights = OrderedDict(
  487. (k.replace("model.", ""), v) for k, v in weights.items()
  488. )
  489. for k in list(weights.keys()):
  490. if "audio_" in k:
  491. weights.pop(k)
  492. else:
  493. raise FileNotFoundError(f"No model weights found in {path_obj}")
  494. err = model.load_state_dict(weights, strict=False, assign=True)
  495. logger.info(f"Model weights loaded - Status: {err}")
  496. if lora_config is not None:
  497. setup_lora(model, lora_config)
  498. logger.info(f"LoRA setup: {lora_config}")
  499. return model
  500. def save_pretrained(self, path: str, drop_lora: bool = False):
  501. path = Path(path)
  502. path.mkdir(parents=True, exist_ok=True)
  503. self.config.save(path / "config.json")
  504. state_dict = self.state_dict()
  505. if drop_lora:
  506. for key in list(state_dict.keys()):
  507. if "lora" not in key:
  508. continue
  509. state_dict.pop(key)
  510. torch.save(state_dict, path / "model.pth")
  511. if hasattr(self, "tokenizer"):
  512. self.tokenizer.save_pretrained(path)
  513. class NaiveTransformer(BaseTransformer):
  514. def __init__(self, config: NaiveModelArgs) -> None:
  515. super().__init__(config, init_weights=False)
  516. self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
  517. self.codebook_output = nn.Linear(
  518. config.dim,
  519. config.codebook_size * config.num_codebooks,
  520. bias=False,
  521. )
  522. self.apply(self._init_weights)
  523. def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
  524. token_logits = result.logits
  525. x = result.hidden_states
  526. # Codebook
  527. codebook_logits = self.codebook_output(self.codebook_norm(x))
  528. codebook_logits = rearrange(
  529. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  530. )
  531. return TransformerForwardResult(
  532. token_logits=token_logits,
  533. codebook_logits=codebook_logits,
  534. )
  535. def forward(
  536. self,
  537. inp: Tensor,
  538. key_padding_mask: Optional[Tensor] = None,
  539. ) -> TransformerForwardResult:
  540. result = super().forward(
  541. inp=inp,
  542. key_padding_mask=key_padding_mask,
  543. )
  544. return self.decode(result)
  545. def forward_generate(
  546. self, x: Tensor, input_pos: Optional[Tensor] = None
  547. ) -> TransformerForwardResult:
  548. result = super().forward_generate(x, input_pos)
  549. return self.decode(result)
  550. class DualARTransformer(BaseTransformer):
  551. def __init__(self, config: NaiveModelArgs) -> None:
  552. super().__init__(config, init_weights=False)
  553. # Project to fast dim if needed
  554. if config.fast_dim is not None and config.fast_dim != config.dim:
  555. self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
  556. else:
  557. self.fast_project_in = nn.Identity()
  558. # Fast transformer
  559. self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
  560. # The equivalent bs is so large that sdpa doesn't work
  561. override_config = dataclasses.replace(
  562. config,
  563. dim=config.fast_dim,
  564. n_head=config.fast_n_head,
  565. n_local_heads=config.fast_n_local_heads,
  566. head_dim=config.fast_head_dim,
  567. intermediate_size=config.fast_intermediate_size,
  568. attention_qkv_bias=config.fast_attention_qkv_bias,
  569. attention_qk_norm=config.fast_attention_qk_norm,
  570. attention_o_bias=config.fast_attention_o_bias,
  571. )
  572. self.fast_layers = nn.ModuleList(
  573. TransformerBlock(override_config, use_sdpa=False)
  574. for _ in range(config.n_fast_layer)
  575. )
  576. self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
  577. self.fast_output = nn.Linear(
  578. config.fast_dim,
  579. config.codebook_size,
  580. bias=False,
  581. )
  582. self.register_buffer(
  583. "fast_freqs_cis",
  584. precompute_freqs_cis(
  585. config.num_codebooks,
  586. config.fast_head_dim,
  587. config.rope_base,
  588. ),
  589. persistent=False,
  590. )
  591. self.apply(self._init_weights)
  592. def setup_caches(
  593. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  594. ):
  595. super().setup_caches(max_batch_size, max_seq_len, dtype)
  596. logger.info(f"max_batch_size: {max_batch_size}, max_seq_len: {max_seq_len}")
  597. logger.info(f"self.fast_layers: {len(self.fast_layers)}")
  598. # Fast transformer
  599. # The max seq len here is the number of codebooks
  600. for b in self.fast_layers:
  601. b.attention.kv_cache = KVCache(
  602. max_batch_size,
  603. self.config.num_codebooks,
  604. self.config.fast_n_local_heads,
  605. self.config.fast_head_dim,
  606. dtype=dtype,
  607. )
  608. def forward(
  609. self,
  610. inp: Tensor,
  611. labels: Optional[Tensor] = None,
  612. key_padding_mask: Optional[Tensor] = None,
  613. vq_parts: Optional[Tensor] = None,
  614. vq_masks: Optional[Tensor] = None,
  615. vq_require_losses: Optional[Tensor] = None,
  616. mel_parts: Optional[Tensor] = None,
  617. mel_masks: Optional[Tensor] = None,
  618. ) -> TransformerForwardResult:
  619. parent_result = super().forward(
  620. inp=inp,
  621. key_padding_mask=key_padding_mask,
  622. )
  623. token_logits = parent_result.logits
  624. x = parent_result.hidden_states
  625. # Fast transformer
  626. fast_seq_len = self.config.num_codebooks
  627. fast_mask = self.causal_mask[
  628. None, None, :fast_seq_len, :fast_seq_len
  629. ] # (B, N, Q, K)
  630. fast_freqs_cis = self.fast_freqs_cis[:fast_seq_len]
  631. # Extract corresponding parts with labels
  632. token_labels = labels[:, 0]
  633. # [MODIFIED] Use config instead of tokenizer
  634. codebook_mask = (token_labels >= self.config.semantic_begin_id) & (
  635. token_labels <= self.config.semantic_end_id
  636. )
  637. # This gives where input token is <|semantic|>
  638. x = x[codebook_mask]
  639. if x.shape[0] == 0:
  640. # Use dummy input when no vq is required
  641. x = torch.zeros(
  642. (4, self.config.dim),
  643. device=x.device,
  644. dtype=x.dtype,
  645. )
  646. codebooks = torch.zeros(
  647. (x.shape[0], self.config.num_codebooks - 1),
  648. device=x.device,
  649. dtype=torch.int,
  650. )
  651. else:
  652. all_codebooks = labels[:, 1:, :]
  653. all_codebooks_permuted = all_codebooks.permute(0, 2, 1)
  654. semantic_codebooks = all_codebooks_permuted[codebook_mask]
  655. codebooks = semantic_codebooks[:, :-1]
  656. x = self.fast_project_in(x)
  657. codebook_embeddings = self.fast_embeddings(codebooks)
  658. x = torch.cat([x[:, None], codebook_embeddings], dim=1)
  659. for layer in self.fast_layers:
  660. if self.config.use_gradient_checkpointing and self.training:
  661. x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
  662. else:
  663. x = layer(x, fast_freqs_cis, fast_mask)
  664. # unflatten the batch and num_codebooks
  665. fast_out = self.fast_norm(x)
  666. codebook_logits = self.fast_output(fast_out)
  667. assert codebook_logits.shape[1] == self.config.num_codebooks
  668. return TransformerForwardResult(
  669. token_logits=token_logits,
  670. codebook_logits=codebook_logits,
  671. )
  672. def forward_generate_fast(
  673. self, x: Tensor, input_pos: Optional[Tensor] = None
  674. ) -> Tensor:
  675. # Fast transformer
  676. x = x.view(x.shape[0], 1, -1)
  677. fast_mask = self.causal_mask[
  678. None, None, input_pos, : self.config.num_codebooks
  679. ] # (B, N, Q, K)
  680. fast_freqs_cis = self.fast_freqs_cis[input_pos]
  681. for layer in self.fast_layers:
  682. x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
  683. # unflatten the batch and num_codebooks
  684. fast_out = self.fast_norm(x) # only take the last token
  685. codebook_logits = self.fast_output(fast_out)
  686. return codebook_logits
  687. def forward_generate(
  688. self,
  689. x: Tensor,
  690. input_pos: Optional[Tensor] = None,
  691. audio_masks: Optional[Tensor] = None,
  692. audio_parts: Optional[Tensor] = None,
  693. ) -> TransformerForwardResult:
  694. x = super().forward_generate(x, input_pos, audio_masks, audio_parts)
  695. x.hidden_states = self.fast_project_in(x.hidden_states)
  696. return x
  697. class TransformerBlock(nn.Module):
  698. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
  699. super().__init__()
  700. self.attention = Attention(config, use_sdpa=use_sdpa)
  701. self.feed_forward = FeedForward(config)
  702. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  703. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  704. def forward(
  705. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  706. ) -> Tensor:
  707. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  708. out = h + self.feed_forward(self.ffn_norm(h))
  709. return out
  710. class Attention(nn.Module):
  711. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
  712. super().__init__()
  713. assert config.dim % config.n_head == 0
  714. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  715. # key, query, value projections for all heads, but in a batch
  716. self.wqkv = nn.Linear(
  717. config.dim, total_head_dim, bias=config.attention_qkv_bias
  718. )
  719. self.wo = nn.Linear(
  720. config.n_head * config.head_dim, config.dim, bias=config.attention_o_bias
  721. )
  722. self.kv_cache = None
  723. if config.attention_qk_norm:
  724. self.q_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
  725. self.k_norm = nn.RMSNorm(config.head_dim, config.norm_eps)
  726. self.dropout = config.dropout
  727. self.n_head = config.n_head
  728. self.head_dim = config.head_dim
  729. self.n_local_heads = config.n_local_heads
  730. self.dim = config.dim
  731. self.use_sdpa = use_sdpa
  732. self.attention_qk_norm = config.attention_qk_norm
  733. self.config = config
  734. self._register_load_state_dict_pre_hook(self.load_hook)
  735. def load_hook(self, state_dict, prefix, *args):
  736. if prefix + "wq.weight" in state_dict:
  737. wq = state_dict.pop(prefix + "wq.weight")
  738. wk = state_dict.pop(prefix + "wk.weight")
  739. wv = state_dict.pop(prefix + "wv.weight")
  740. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  741. def forward(
  742. self,
  743. x: torch.Tensor,
  744. freqs_cis: torch.Tensor,
  745. mask: Optional[torch.Tensor],
  746. input_pos: Optional[torch.Tensor] = None,
  747. ):
  748. bsz, seqlen, _ = x.shape
  749. q_size = self.n_head * self.head_dim
  750. kv_size = self.n_local_heads * self.head_dim
  751. # =========================
  752. # QKV projection
  753. # =========================
  754. q, k, v = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
  755. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  756. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  757. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  758. if self.attention_qk_norm:
  759. q = self.q_norm(q)
  760. k = self.k_norm(k)
  761. # [B, H, T, D]
  762. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  763. # =========================
  764. # KV Cache + Sliding Window
  765. # =========================
  766. start = 0
  767. if self.kv_cache is not None:
  768. # update cache
  769. k, v = self.kv_cache.update(input_pos, k, v)
  770. max_context = 4096 # 可调
  771. if input_pos is not None and seqlen == 1:
  772. # decode
  773. seq_len = int(input_pos.item()) + 1
  774. else:
  775. # prefill
  776. seq_len = k.size(2)
  777. start = max(0, seq_len - max_context)
  778. # window 裁剪
  779. k = k[:, :, start:seq_len, :]
  780. v = v[:, :, start:seq_len, :]
  781. if mask is not None:
  782. mask = mask[:, :, :, start:seq_len]
  783. else:
  784. seq_len = seqlen
  785. # =========================
  786. # RoPE(核心:以 K 为基准)
  787. # =========================
  788. T_k = k.size(2)
  789. if self.kv_cache is not None:
  790. freqs = freqs_cis[start: start + T_k]
  791. else:
  792. freqs = freqs_cis[:T_k]
  793. # =========================
  794. # Apply RoPE(Q/K 分开对齐)
  795. # =========================
  796. T_q = q.size(2)
  797. print("Q:", q.shape)
  798. print("K:", k.shape)
  799. print("freqs:", freqs.shape)
  800. print("start:", start)
  801. assert k.size(2) == freqs.size(0), f"K vs freqs mismatch"
  802. assert q.size(2) <= k.size(2), f"Q longer than K??"
  803. # Q 用最后 T_q 个位置
  804. q = apply_rotary_emb(q, freqs[-T_q:])
  805. # K 用完整 window
  806. k = apply_rotary_emb(k, freqs)
  807. # =========================
  808. # GQA expand
  809. # =========================
  810. if self.n_head != self.n_local_heads:
  811. repeat = self.n_head // self.n_local_heads
  812. k = k.repeat_interleave(repeat, dim=1)
  813. v = v.repeat_interleave(repeat, dim=1)
  814. # =========================
  815. # Attention
  816. # =========================
  817. if self.use_sdpa:
  818. if mask is None:
  819. y = torch.nn.functional.scaled_dot_product_attention(
  820. q,
  821. k,
  822. v,
  823. dropout_p=self.dropout if self.training else 0.0,
  824. is_causal=True,
  825. )
  826. else:
  827. y = torch.nn.functional.scaled_dot_product_attention(
  828. q,
  829. k,
  830. v,
  831. attn_mask=mask,
  832. dropout_p=self.dropout if self.training else 0.0,
  833. )
  834. else:
  835. y = self.eq_scaled_dot_product_attention(
  836. q,
  837. k,
  838. v,
  839. attn_mask=mask,
  840. dropout_p=self.dropout if self.training else 0.0,
  841. )
  842. # =========================
  843. # Output
  844. # =========================
  845. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, q_size)
  846. return self.wo(y)
  847. def eq_scaled_dot_product_attention(
  848. self,
  849. query,
  850. key,
  851. value,
  852. attn_mask=None,
  853. dropout_p=0.0,
  854. ) -> torch.Tensor:
  855. # This is a standard scaled dot product attention
  856. # It's low efficient, but it doesn't raise cuda error
  857. L, S = query.size(-2), key.size(-2)
  858. scale_factor = 1 / math.sqrt(query.size(-1))
  859. attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
  860. if attn_mask is not None:
  861. if attn_mask.dtype == torch.bool:
  862. attn_bias = torch.where(
  863. attn_mask.logical_not(), float("-inf"), attn_bias
  864. )
  865. else:
  866. attn_bias = attn_bias + attn_mask
  867. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  868. attn_weight += attn_bias
  869. attn_weight = torch.softmax(attn_weight, dim=-1)
  870. attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
  871. return attn_weight @ value
  872. class FeedForward(nn.Module):
  873. def __init__(self, config: BaseModelArgs) -> None:
  874. super().__init__()
  875. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  876. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  877. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  878. def forward(self, x: Tensor) -> Tensor:
  879. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  880. class RMSNorm(nn.Module):
  881. def __init__(self, dim: int, eps: float = 1e-5):
  882. super().__init__()
  883. self.eps = eps
  884. self.weight = nn.Parameter(torch.ones(dim))
  885. def _norm(self, x):
  886. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  887. def forward(self, x: Tensor) -> Tensor:
  888. output = self._norm(x.float()).type_as(x)
  889. return output * self.weight
  890. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  891. """
  892. Precomputes frequency tensors for complex exponentials (cis)
  893. Args:
  894. seq_len: Length of the sequence for which positional embeddings are needed.
  895. n_elem: Number of elements in the frequency tensor.
  896. base: Base value for the frequency scaling (default: 10000).
  897. Returns:
  898. A tensor containing the precomputed frequencies in real and imaginary parts (bfloat16).
  899. """
  900. freqs = 1.0 / (
  901. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  902. )
  903. t = torch.arange(seq_len, device=freqs.device)
  904. freqs = torch.outer(t, freqs)
  905. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  906. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  907. return cache.to(dtype=torch.bfloat16)
  908. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  909. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  910. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  911. x_out2 = torch.stack(
  912. [
  913. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  914. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  915. ],
  916. -1,
  917. )
  918. x_out2 = x_out2.flatten(3)
  919. return x_out2.type_as(x)