llama.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844
  1. import dataclasses
  2. import json
  3. import math
  4. from collections import OrderedDict
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from typing import Optional
  8. import torch
  9. import torch.nn as nn
  10. from einops import rearrange
  11. from loguru import logger
  12. from torch import Tensor
  13. from torch.nn import functional as F
  14. from torch.nn.attention import SDPBackend, sdpa_kernel
  15. from torch.utils.checkpoint import checkpoint
  16. from transformers import AutoTokenizer
  17. from fish_speech.conversation import SEMANTIC_TOKEN
  18. from fish_speech.utils import RankedLogger
  19. from .lora import LoraConfig, setup_lora
  20. log = RankedLogger(__name__, rank_zero_only=True)
  21. def find_multiple(n: int, k: int) -> int:
  22. if n % k == 0:
  23. return n
  24. return n + k - (n % k)
  25. @dataclass
  26. class BaseModelArgs:
  27. model_type: str = "base"
  28. vocab_size: int = 32000
  29. n_layer: int = 32
  30. n_head: int = 32
  31. dim: int = 4096
  32. intermediate_size: int = None
  33. n_local_heads: int = -1
  34. head_dim: int = 64
  35. rope_base: float = 10000
  36. norm_eps: float = 1e-5
  37. max_seq_len: int = 2048
  38. dropout: float = 0.0
  39. tie_word_embeddings: bool = True
  40. attention_qkv_bias: bool = False
  41. # Codebook configs
  42. codebook_size: int = 160
  43. num_codebooks: int = 4
  44. # Gradient checkpointing
  45. use_gradient_checkpointing: bool = True
  46. # Initialize the model
  47. initializer_range: float = 0.02
  48. # Dummy vars
  49. is_reward_model: bool = False
  50. share_codebook_embeddings: bool = True
  51. def __post_init__(self):
  52. if self.n_local_heads == -1:
  53. self.n_local_heads = self.n_head
  54. if self.intermediate_size is None:
  55. hidden_dim = 4 * self.dim
  56. n_hidden = int(2 * hidden_dim / 3)
  57. self.intermediate_size = find_multiple(n_hidden, 256)
  58. self.head_dim = self.dim // self.n_head
  59. @staticmethod
  60. def from_pretrained(path: str):
  61. path = Path(path)
  62. if path.is_dir():
  63. path = path / "config.json"
  64. with open(path, "r", encoding="utf-8") as f:
  65. data = json.load(f)
  66. match data["model_type"]:
  67. case "naive":
  68. cls = NaiveModelArgs
  69. case "dual_ar":
  70. cls = DualARModelArgs
  71. case _:
  72. raise ValueError(f"Unknown model type: {data['model_type']}")
  73. return cls(**data)
  74. def save(self, path: str):
  75. with open(path, "w") as f:
  76. json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
  77. @dataclass
  78. class NaiveModelArgs(BaseModelArgs):
  79. model_type: str = "naive"
  80. @dataclass
  81. class DualARModelArgs(BaseModelArgs):
  82. model_type: str = "dual_ar"
  83. n_fast_layer: int = 4
  84. fast_dim: int | None = None
  85. fast_n_head: int | None = None
  86. fast_n_local_heads: int | None = None
  87. fast_head_dim: int | None = None
  88. fast_intermediate_size: int | None = None
  89. fast_attention_qkv_bias: bool | None = None
  90. def __post_init__(self):
  91. super().__post_init__()
  92. self.fast_dim = self.fast_dim or self.dim
  93. self.fast_n_head = self.fast_n_head or self.n_head
  94. self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
  95. self.fast_head_dim = self.fast_head_dim or self.head_dim
  96. self.fast_intermediate_size = (
  97. self.fast_intermediate_size or self.intermediate_size
  98. )
  99. self.fast_attention_qkv_bias = (
  100. self.fast_attention_qkv_bias
  101. if self.fast_attention_qkv_bias is not None
  102. else self.attention_qkv_bias
  103. )
  104. class KVCache(nn.Module):
  105. def __init__(
  106. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  107. ):
  108. super().__init__()
  109. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  110. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  111. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  112. def update(self, input_pos, k_val, v_val):
  113. # input_pos: [S], k_val: [B, H, S, D]
  114. assert input_pos.shape[0] == k_val.shape[2]
  115. k_out = self.k_cache
  116. v_out = self.v_cache
  117. k_out[:, :, input_pos] = k_val
  118. v_out[:, :, input_pos] = v_val
  119. return k_out, v_out
  120. @dataclass
  121. class TransformerForwardResult:
  122. token_logits: Tensor
  123. codebook_logits: Tensor
  124. @dataclass
  125. class BaseTransformerForwardResult:
  126. logits: Tensor
  127. hidden_states: Tensor
  128. class BaseTransformer(nn.Module):
  129. def __init__(
  130. self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
  131. ) -> None:
  132. super().__init__()
  133. self.config = config
  134. self.tokenizer = tokenizer
  135. self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
  136. # Slow transformer
  137. self.embeddings = nn.Embedding(
  138. config.vocab_size,
  139. config.dim,
  140. )
  141. self.codebook_embeddings = nn.Embedding(
  142. config.codebook_size * config.num_codebooks,
  143. config.dim,
  144. )
  145. self.layers = nn.ModuleList(
  146. TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
  147. )
  148. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  149. if self.config.tie_word_embeddings is False:
  150. self.output = nn.Linear(
  151. config.dim,
  152. config.vocab_size,
  153. bias=False,
  154. )
  155. self.register_buffer(
  156. "freqs_cis",
  157. precompute_freqs_cis(
  158. config.max_seq_len,
  159. config.dim // config.n_head,
  160. config.rope_base,
  161. ),
  162. persistent=False,
  163. )
  164. self.register_buffer(
  165. "causal_mask",
  166. torch.tril(
  167. torch.ones(
  168. config.max_seq_len,
  169. config.max_seq_len,
  170. dtype=torch.bool,
  171. )
  172. ),
  173. persistent=False,
  174. )
  175. # For kv cache
  176. self.max_batch_size = -1
  177. self.max_seq_len = -1
  178. if init_weights:
  179. self.apply(self._init_weights)
  180. def setup_caches(
  181. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  182. ):
  183. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  184. return
  185. head_dim = self.config.dim // self.config.n_head
  186. max_seq_len = find_multiple(max_seq_len, 8)
  187. self.max_seq_len = max_seq_len
  188. self.max_batch_size = max_batch_size
  189. for b in self.layers:
  190. b.attention.kv_cache = KVCache(
  191. max_batch_size,
  192. max_seq_len,
  193. self.config.n_local_heads,
  194. head_dim,
  195. dtype=dtype,
  196. )
  197. def embed(self, x: Tensor) -> Tensor:
  198. vocab_embeds = [self.embeddings(x[:, 0])]
  199. for i in range(self.config.num_codebooks):
  200. emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
  201. emb[x[:, 0] != self.semantic_token_id] = 0
  202. vocab_embeds.append(emb)
  203. x = torch.stack(vocab_embeds, dim=3)
  204. x = x.sum(dim=3)
  205. return x
  206. def forward(
  207. self,
  208. inp: Tensor,
  209. key_padding_mask: Optional[Tensor] = None,
  210. ) -> BaseTransformerForwardResult:
  211. seq_len = inp.size(2)
  212. # Here we want to merge the embeddings of the codebooks
  213. x = self.embed(inp)
  214. freqs_cis = self.freqs_cis[:seq_len]
  215. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  216. # That is, FALSE means masked out
  217. # To maintain consistency, key_padding_mask use TRUE to mask out
  218. mask = None
  219. if key_padding_mask is not None:
  220. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  221. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  222. for layer in self.layers:
  223. if self.config.use_gradient_checkpointing and self.training:
  224. x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
  225. else:
  226. x = layer(x, freqs_cis, mask)
  227. # We got slow_out here
  228. slow_out = self.norm(x)
  229. if self.config.tie_word_embeddings:
  230. token_logits = F.linear(slow_out, self.embeddings.weight)
  231. else:
  232. token_logits = self.output(slow_out)
  233. return BaseTransformerForwardResult(
  234. logits=token_logits,
  235. hidden_states=x,
  236. )
  237. def forward_generate(
  238. self,
  239. x: Tensor,
  240. input_pos: Optional[Tensor] = None,
  241. return_all: bool = False,
  242. ) -> BaseTransformerForwardResult:
  243. # This is used for generation, optimized for torch compile
  244. assert (
  245. self.max_seq_len != -1 and self.max_batch_size != -1
  246. ), "Please call setup_caches before forward_generate"
  247. x = self.embed(x)
  248. mask = self.causal_mask[
  249. None, None, input_pos, : self.max_seq_len
  250. ] # (B, N, Q, K)
  251. freqs_cis = self.freqs_cis[input_pos]
  252. for layer in self.layers:
  253. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  254. # If prefill, we only calculate the logits of last token
  255. if x.size(1) > 1 and not return_all:
  256. x = x[:, -1:]
  257. # We got slow_out here
  258. slow_out = self.norm(x)
  259. if self.config.tie_word_embeddings:
  260. token_logits = F.linear(slow_out, self.embeddings.weight)
  261. else:
  262. token_logits = self.output(slow_out)
  263. return BaseTransformerForwardResult(
  264. logits=token_logits,
  265. hidden_states=x,
  266. )
  267. def _init_weights(self, module):
  268. std = self.config.initializer_range
  269. if isinstance(module, nn.Linear):
  270. module.weight.data.normal_(mean=0.0, std=std)
  271. if module.bias is not None:
  272. module.bias.data.zero_()
  273. elif isinstance(module, nn.Embedding):
  274. module.weight.data.normal_(mean=0.0, std=std)
  275. if module.padding_idx is not None:
  276. module.weight.data[module.padding_idx].zero_()
  277. @staticmethod
  278. def from_pretrained(
  279. path: str,
  280. load_weights: bool = False,
  281. max_length: int | None = None,
  282. lora_config: LoraConfig | None = None,
  283. rope_base: int | None = None,
  284. ) -> "BaseTransformer":
  285. config = BaseModelArgs.from_pretrained(str(path))
  286. if max_length is not None:
  287. config.max_seq_len = max_length
  288. log.info(f"Override max_seq_len to {max_length}")
  289. if rope_base is not None:
  290. config.rope_base = rope_base
  291. log.info(f"Override rope_base to {rope_base}")
  292. match config.model_type:
  293. case "naive":
  294. model_cls = NaiveTransformer
  295. case "dual_ar":
  296. model_cls = DualARTransformer
  297. case _:
  298. raise ValueError(f"Unknown model type: {config.model_type}")
  299. tokenizer = AutoTokenizer.from_pretrained(str(path))
  300. log.info(f"Loading model from {path}, config: {config}")
  301. model = model_cls(config, tokenizer=tokenizer)
  302. if lora_config is not None:
  303. setup_lora(model, lora_config)
  304. log.info(f"LoRA setup: {lora_config}")
  305. if load_weights is False:
  306. log.info("Randomly initialized model")
  307. else:
  308. if "int8" in str(Path(path)):
  309. logger.info("Using int8 weight-only quantization!")
  310. from tools.llama.quantize import WeightOnlyInt8QuantHandler
  311. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  312. model = simple_quantizer.convert_for_runtime()
  313. if "int4" in str(Path(path)):
  314. logger.info("Using int4 quantization!")
  315. path_comps = path.name.split("-")
  316. assert path_comps[-2].startswith("g")
  317. groupsize = int(path_comps[-2][1:])
  318. from tools.llama.quantize import WeightOnlyInt4QuantHandler
  319. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  320. model = simple_quantizer.convert_for_runtime()
  321. weights = torch.load(
  322. Path(path) / "model.pth",
  323. map_location="cpu",
  324. mmap=True,
  325. weights_only=True,
  326. )
  327. if "state_dict" in weights:
  328. logger.warning(
  329. "Using a TextToSemantic LightningModule checkpoint, "
  330. "please make sure it is a full model, not a LoRA model."
  331. )
  332. weights = weights["state_dict"]
  333. if next(iter(weights.keys())).startswith("model."):
  334. logger.info(
  335. f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
  336. )
  337. new_weights = OrderedDict()
  338. for k, v in weights.items():
  339. new_weights[k.replace("model.", "")] = v
  340. weights = new_weights
  341. # Verify the name and shape of parameters since strict=False in load_state_dict.
  342. for k, v in model.named_parameters():
  343. if k not in weights:
  344. logger.warning(f"No weight for {k}")
  345. elif v.shape != weights[k].shape:
  346. logger.warning(
  347. f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
  348. )
  349. err = model.load_state_dict(weights, strict=False, assign=True)
  350. log.info(f"Loaded weights with error: {err}")
  351. return model
  352. def save_pretrained(self, path: str, drop_lora: bool = False):
  353. path = Path(path)
  354. path.mkdir(parents=True, exist_ok=True)
  355. self.config.save(path / "config.json")
  356. state_dict = self.state_dict()
  357. if drop_lora:
  358. for key in list(state_dict.keys()):
  359. if "lora" not in key:
  360. continue
  361. state_dict.pop(key)
  362. log.info(f"Drop LoRA parameter: {key}")
  363. torch.save(state_dict, path / "model.pth")
  364. self.tokenizer.save_pretrained(path)
  365. class NaiveTransformer(BaseTransformer):
  366. def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
  367. super().__init__(config, init_weights=False, tokenizer=tokenizer)
  368. self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
  369. self.codebook_output = nn.Linear(
  370. config.dim,
  371. config.codebook_size * config.num_codebooks,
  372. bias=False,
  373. )
  374. self.apply(self._init_weights)
  375. def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
  376. token_logits = result.logits
  377. x = result.hidden_states
  378. # Codebook
  379. codebook_logits = self.codebook_output(self.codebook_norm(x))
  380. codebook_logits = rearrange(
  381. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  382. )
  383. return TransformerForwardResult(
  384. token_logits=token_logits,
  385. codebook_logits=codebook_logits,
  386. )
  387. def forward(
  388. self,
  389. inp: Tensor,
  390. key_padding_mask: Optional[Tensor] = None,
  391. ) -> TransformerForwardResult:
  392. result = super().forward(
  393. inp=inp,
  394. key_padding_mask=key_padding_mask,
  395. )
  396. return self.decode(result)
  397. def forward_generate(
  398. self, x: Tensor, input_pos: Optional[Tensor] = None
  399. ) -> TransformerForwardResult:
  400. result = super().forward_generate(x, input_pos)
  401. return self.decode(result)
  402. class DualARTransformer(BaseTransformer):
  403. def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
  404. super().__init__(config, init_weights=False, tokenizer=tokenizer)
  405. # Project to fast dim if needed
  406. if config.fast_dim is not None and config.fast_dim != config.dim:
  407. self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
  408. else:
  409. self.fast_project_in = nn.Identity()
  410. # Fast transformer
  411. self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
  412. # The equivalent bs is so large that sdpa doesn't work
  413. override_config = dataclasses.replace(
  414. config,
  415. dim=config.fast_dim,
  416. n_head=config.fast_n_head,
  417. n_local_heads=config.fast_n_local_heads,
  418. head_dim=config.fast_head_dim,
  419. intermediate_size=config.fast_intermediate_size,
  420. attention_qkv_bias=config.fast_attention_qkv_bias,
  421. )
  422. self.fast_layers = nn.ModuleList(
  423. TransformerBlock(override_config, use_sdpa=False)
  424. for _ in range(config.n_fast_layer)
  425. )
  426. self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
  427. self.fast_output = nn.Linear(
  428. config.fast_dim,
  429. config.codebook_size,
  430. bias=False,
  431. )
  432. self.register_buffer(
  433. "fast_freqs_cis",
  434. precompute_freqs_cis(
  435. config.num_codebooks,
  436. config.fast_dim // config.fast_n_head,
  437. config.rope_base,
  438. ),
  439. persistent=False,
  440. )
  441. self.apply(self._init_weights)
  442. def setup_caches(
  443. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  444. ):
  445. super().setup_caches(max_batch_size, max_seq_len, dtype)
  446. head_dim = self.config.fast_dim // self.config.fast_n_head
  447. # Fast transformer
  448. # The max seq len here is the number of codebooks
  449. for b in self.fast_layers:
  450. b.attention.kv_cache = KVCache(
  451. max_batch_size,
  452. self.config.num_codebooks,
  453. self.config.fast_n_local_heads,
  454. head_dim,
  455. dtype=dtype,
  456. )
  457. def forward(
  458. self,
  459. inp: Tensor,
  460. key_padding_mask: Optional[Tensor] = None,
  461. ) -> TransformerForwardResult:
  462. parent_result = super().forward(inp, key_padding_mask)
  463. token_logits = parent_result.logits
  464. x = parent_result.hidden_states
  465. x = self.fast_project_in(x)
  466. # Fast transformer
  467. fast_seq_len = self.config.num_codebooks
  468. fast_mask = self.causal_mask[
  469. None, None, :fast_seq_len, :fast_seq_len
  470. ] # (B, N, Q, K)
  471. # Drop the last token and rotate left
  472. codebooks = inp[:, 1:-1, 1:]
  473. codebooks = F.pad(codebooks, (0, 1), value=0)
  474. codebook_embeddings = self.fast_embeddings(codebooks)
  475. x = torch.cat([x[:, None], codebook_embeddings], dim=1)
  476. b, s = x.size(0), x.size(2)
  477. x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
  478. # Remove padded part
  479. codebooks = rearrange(codebooks, "b n s -> (b s) n")
  480. codebook_mask = (codebooks == 0).all(dim=-1)
  481. if torch.all(codebook_mask):
  482. # If all codebooks are padded, we keep first 8 to make sure the model runs
  483. codebook_mask[:8] = False
  484. x_bs, x_len = x.size(0), x.size(1)
  485. x = x[~codebook_mask]
  486. for layer in self.fast_layers:
  487. if self.config.use_gradient_checkpointing and self.training:
  488. x = checkpoint(
  489. layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
  490. )
  491. else:
  492. x = layer(x, self.fast_freqs_cis, fast_mask)
  493. # unflatten the batch and num_codebooks
  494. fast_out = self.fast_norm(x)
  495. codebook_logits = self.fast_output(fast_out)
  496. # Re-pad the codebook_logits
  497. buffer = torch.zeros(
  498. x_bs,
  499. x_len,
  500. codebook_logits.size(-1),
  501. device=codebook_logits.device,
  502. dtype=codebook_logits.dtype,
  503. )
  504. buffer[~codebook_mask] = codebook_logits
  505. codebook_logits = buffer
  506. assert codebook_logits.shape[1] == self.config.num_codebooks
  507. codebook_logits = rearrange(
  508. codebook_logits,
  509. "(b s) n d -> b s n d",
  510. b=b,
  511. s=s,
  512. n=self.config.num_codebooks,
  513. )
  514. return TransformerForwardResult(
  515. token_logits=token_logits,
  516. codebook_logits=codebook_logits,
  517. )
  518. def forward_generate_fast(
  519. self, x: Tensor, input_pos: Optional[Tensor] = None
  520. ) -> Tensor:
  521. # Fast transformer
  522. x = x.view(1, 1, -1)
  523. fast_mask = self.causal_mask[
  524. None, None, input_pos, : self.config.num_codebooks
  525. ] # (B, N, Q, K)
  526. fast_freqs_cis = self.fast_freqs_cis[input_pos]
  527. for layer in self.fast_layers:
  528. x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
  529. # unflatten the batch and num_codebooks
  530. fast_out = self.fast_norm(x) # only take the last token
  531. codebook_logits = self.fast_output(fast_out)
  532. return codebook_logits
  533. def forward_generate(
  534. self, x: Tensor, input_pos: Optional[Tensor] = None
  535. ) -> TransformerForwardResult:
  536. x = super().forward_generate(x, input_pos)
  537. x.hidden_states = self.fast_project_in(x.hidden_states)
  538. return x
  539. class TransformerBlock(nn.Module):
  540. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
  541. super().__init__()
  542. self.attention = Attention(config, use_sdpa=use_sdpa)
  543. self.feed_forward = FeedForward(config)
  544. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  545. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  546. def forward(
  547. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  548. ) -> Tensor:
  549. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  550. out = h + self.feed_forward(self.ffn_norm(h))
  551. return out
  552. class Attention(nn.Module):
  553. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
  554. super().__init__()
  555. assert config.dim % config.n_head == 0
  556. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  557. # key, query, value projections for all heads, but in a batch
  558. self.wqkv = nn.Linear(
  559. config.dim, total_head_dim, bias=config.attention_qkv_bias
  560. )
  561. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  562. self.kv_cache = None
  563. self.dropout = config.dropout
  564. self.n_head = config.n_head
  565. self.head_dim = config.head_dim
  566. self.n_local_heads = config.n_local_heads
  567. self.dim = config.dim
  568. self.use_sdpa = use_sdpa
  569. self._register_load_state_dict_pre_hook(self.load_hook)
  570. def load_hook(self, state_dict, prefix, *args):
  571. if prefix + "wq.weight" in state_dict:
  572. wq = state_dict.pop(prefix + "wq.weight")
  573. wk = state_dict.pop(prefix + "wk.weight")
  574. wv = state_dict.pop(prefix + "wv.weight")
  575. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  576. def forward(
  577. self,
  578. x: Tensor,
  579. freqs_cis: Tensor,
  580. mask: Tensor,
  581. input_pos: Optional[Tensor] = None,
  582. ) -> Tensor:
  583. bsz, seqlen, _ = x.shape
  584. kv_size = self.n_local_heads * self.head_dim
  585. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  586. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  587. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  588. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  589. q = apply_rotary_emb(q, freqs_cis)
  590. k = apply_rotary_emb(k, freqs_cis)
  591. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  592. if self.kv_cache is not None:
  593. k, v = self.kv_cache.update(input_pos, k, v)
  594. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  595. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  596. if self.use_sdpa:
  597. if mask is None:
  598. with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
  599. y = F.scaled_dot_product_attention(
  600. q,
  601. k,
  602. v,
  603. dropout_p=self.dropout if self.training else 0.0,
  604. is_causal=True,
  605. # No third party attn_mask here to use flash_attention
  606. )
  607. else:
  608. y = F.scaled_dot_product_attention(
  609. q,
  610. k,
  611. v,
  612. attn_mask=mask,
  613. dropout_p=self.dropout if self.training else 0.0,
  614. )
  615. else:
  616. y = self.eq_scaled_dot_product_attention(
  617. q,
  618. k,
  619. v,
  620. attn_mask=mask,
  621. dropout_p=self.dropout if self.training else 0.0,
  622. )
  623. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  624. return self.wo(y)
  625. def eq_scaled_dot_product_attention(
  626. self,
  627. query,
  628. key,
  629. value,
  630. attn_mask=None,
  631. dropout_p=0.0,
  632. ) -> torch.Tensor:
  633. # This is a standard scaled dot product attention
  634. # It's low efficient, but it doesn't raise cuda error
  635. L, S = query.size(-2), key.size(-2)
  636. scale_factor = 1 / math.sqrt(query.size(-1))
  637. attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
  638. if attn_mask is not None:
  639. if attn_mask.dtype == torch.bool:
  640. attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
  641. else:
  642. attn_bias += attn_mask
  643. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  644. attn_weight += attn_bias
  645. attn_weight = torch.softmax(attn_weight, dim=-1)
  646. attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
  647. return attn_weight @ value
  648. class FeedForward(nn.Module):
  649. def __init__(self, config: BaseModelArgs) -> None:
  650. super().__init__()
  651. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  652. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  653. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  654. def forward(self, x: Tensor) -> Tensor:
  655. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  656. class RMSNorm(nn.Module):
  657. def __init__(self, dim: int, eps: float = 1e-5):
  658. super().__init__()
  659. self.eps = eps
  660. self.weight = nn.Parameter(torch.ones(dim))
  661. def _norm(self, x):
  662. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  663. def forward(self, x: Tensor) -> Tensor:
  664. output = self._norm(x.float()).type_as(x)
  665. return output * self.weight
  666. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  667. freqs = 1.0 / (
  668. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  669. )
  670. t = torch.arange(seq_len, device=freqs.device)
  671. freqs = torch.outer(t, freqs)
  672. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  673. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  674. return cache.to(dtype=torch.bfloat16)
  675. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  676. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  677. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  678. x_out2 = torch.stack(
  679. [
  680. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  681. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  682. ],
  683. -1,
  684. )
  685. x_out2 = x_out2.flatten(3)
  686. return x_out2.type_as(x)