llama.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782
  1. import json
  2. import math
  3. from collections import OrderedDict
  4. from dataclasses import dataclass
  5. from pathlib import Path
  6. from typing import Optional
  7. import torch
  8. import torch.nn as nn
  9. from einops import rearrange
  10. from loguru import logger
  11. from torch import Tensor
  12. from torch.nn import functional as F
  13. from torch.nn.attention import SDPBackend, sdpa_kernel
  14. from torch.utils.checkpoint import checkpoint
  15. from transformers import AutoTokenizer
  16. from fish_speech.conversation import SEMANTIC_TOKEN
  17. from fish_speech.utils import RankedLogger
  18. from .lora import LoraConfig, setup_lora
  19. log = RankedLogger(__name__, rank_zero_only=True)
  20. def find_multiple(n: int, k: int) -> int:
  21. if n % k == 0:
  22. return n
  23. return n + k - (n % k)
  24. @dataclass
  25. class BaseModelArgs:
  26. model_type: str = "base"
  27. vocab_size: int = 32000
  28. n_layer: int = 32
  29. n_head: int = 32
  30. dim: int = 4096
  31. intermediate_size: int = None
  32. n_local_heads: int = -1
  33. head_dim: int = 64
  34. rope_base: float = 10000
  35. norm_eps: float = 1e-5
  36. max_seq_len: int = 2048
  37. dropout: float = 0.0
  38. tie_word_embeddings: bool = True
  39. attention_qkv_bias: bool = False
  40. # Codebook configs
  41. codebook_size: int = 160
  42. num_codebooks: int = 4
  43. # Gradient checkpointing
  44. use_gradient_checkpointing: bool = True
  45. # Initialize the model
  46. initializer_range: float = 0.02
  47. def __post_init__(self):
  48. if self.n_local_heads == -1:
  49. self.n_local_heads = self.n_head
  50. if self.intermediate_size is None:
  51. hidden_dim = 4 * self.dim
  52. n_hidden = int(2 * hidden_dim / 3)
  53. self.intermediate_size = find_multiple(n_hidden, 256)
  54. self.head_dim = self.dim // self.n_head
  55. @staticmethod
  56. def from_pretrained(path: str):
  57. path = Path(path)
  58. if path.is_dir():
  59. path = path / "config.json"
  60. with open(path, "r", encoding="utf-8") as f:
  61. data = json.load(f)
  62. match data["model_type"]:
  63. case "naive":
  64. cls = NaiveModelArgs
  65. case "dual_ar":
  66. cls = DualARModelArgs
  67. case _:
  68. raise ValueError(f"Unknown model type: {data['model_type']}")
  69. return cls(**data)
  70. def save(self, path: str):
  71. with open(path, "w") as f:
  72. json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
  73. @dataclass
  74. class NaiveModelArgs(BaseModelArgs):
  75. model_type: str = "naive"
  76. @dataclass
  77. class DualARModelArgs(BaseModelArgs):
  78. model_type: str = "dual_ar"
  79. n_fast_layer: int = 4
  80. class KVCache(nn.Module):
  81. def __init__(
  82. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  83. ):
  84. super().__init__()
  85. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  86. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  87. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  88. def update(self, input_pos, k_val, v_val):
  89. # input_pos: [S], k_val: [B, H, S, D]
  90. assert input_pos.shape[0] == k_val.shape[2]
  91. k_out = self.k_cache
  92. v_out = self.v_cache
  93. k_out[:, :, input_pos] = k_val
  94. v_out[:, :, input_pos] = v_val
  95. return k_out, v_out
  96. @dataclass
  97. class TransformerForwardResult:
  98. token_logits: Tensor
  99. codebook_logits: Tensor
  100. @dataclass
  101. class BaseTransformerForwardResult:
  102. logits: Tensor
  103. hidden_states: Tensor
  104. class BaseTransformer(nn.Module):
  105. def __init__(
  106. self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
  107. ) -> None:
  108. super().__init__()
  109. self.config = config
  110. self.tokenizer = tokenizer
  111. self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
  112. # Slow transformer
  113. self.embeddings = nn.Embedding(
  114. config.vocab_size,
  115. config.dim,
  116. )
  117. self.codebook_embeddings = nn.Embedding(
  118. config.codebook_size * config.num_codebooks,
  119. config.dim,
  120. )
  121. self.layers = nn.ModuleList(
  122. TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
  123. )
  124. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  125. if self.config.tie_word_embeddings is False:
  126. self.output = nn.Linear(
  127. config.dim,
  128. config.vocab_size,
  129. bias=False,
  130. )
  131. self.register_buffer(
  132. "freqs_cis",
  133. precompute_freqs_cis(
  134. config.max_seq_len,
  135. config.dim // config.n_head,
  136. config.rope_base,
  137. ),
  138. persistent=False,
  139. )
  140. self.register_buffer(
  141. "causal_mask",
  142. torch.tril(
  143. torch.ones(
  144. config.max_seq_len,
  145. config.max_seq_len,
  146. dtype=torch.bool,
  147. )
  148. ),
  149. persistent=False,
  150. )
  151. # For kv cache
  152. self.max_batch_size = -1
  153. self.max_seq_len = -1
  154. if init_weights:
  155. self.apply(self._init_weights)
  156. def setup_caches(
  157. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  158. ):
  159. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  160. return
  161. head_dim = self.config.dim // self.config.n_head
  162. max_seq_len = find_multiple(max_seq_len, 8)
  163. self.max_seq_len = max_seq_len
  164. self.max_batch_size = max_batch_size
  165. for b in self.layers:
  166. b.attention.kv_cache = KVCache(
  167. max_batch_size,
  168. max_seq_len,
  169. self.config.n_local_heads,
  170. head_dim,
  171. dtype=dtype,
  172. )
  173. def embed(self, x: Tensor) -> Tensor:
  174. vocab_embeds = [self.embeddings(x[:, 0])]
  175. for i in range(self.config.num_codebooks):
  176. emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
  177. emb[x[:, 0] != self.semantic_token_id] = 0
  178. vocab_embeds.append(emb)
  179. x = torch.stack(vocab_embeds, dim=3)
  180. x = x.sum(dim=3)
  181. return x
  182. def forward(
  183. self,
  184. inp: Tensor,
  185. key_padding_mask: Optional[Tensor] = None,
  186. ) -> BaseTransformerForwardResult:
  187. seq_len = inp.size(2)
  188. # Here we want to merge the embeddings of the codebooks
  189. x = self.embed(inp)
  190. freqs_cis = self.freqs_cis[:seq_len]
  191. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  192. # That is, FALSE means masked out
  193. # To maintain consistency, key_padding_mask use TRUE to mask out
  194. mask = None
  195. if key_padding_mask is not None:
  196. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  197. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  198. for layer in self.layers:
  199. if self.config.use_gradient_checkpointing and self.training:
  200. x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
  201. else:
  202. x = layer(x, freqs_cis, mask)
  203. # We got slow_out here
  204. slow_out = self.norm(x)
  205. if self.config.tie_word_embeddings:
  206. token_logits = F.linear(slow_out, self.embeddings.weight)
  207. else:
  208. token_logits = self.output(slow_out)
  209. return BaseTransformerForwardResult(
  210. logits=token_logits,
  211. hidden_states=x,
  212. )
  213. def forward_generate(
  214. self,
  215. x: Tensor,
  216. input_pos: Optional[Tensor] = None,
  217. return_all: bool = False,
  218. ) -> BaseTransformerForwardResult:
  219. # This is used for generation, optimized for torch compile
  220. assert (
  221. self.max_seq_len != -1 and self.max_batch_size != -1
  222. ), "Please call setup_caches before forward_generate"
  223. x = self.embed(x)
  224. mask = self.causal_mask[
  225. None, None, input_pos, : self.max_seq_len
  226. ] # (B, N, Q, K)
  227. freqs_cis = self.freqs_cis[input_pos]
  228. for layer in self.layers:
  229. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  230. # If prefill, we only calculate the logits of last token
  231. if x.size(1) > 1 and not return_all:
  232. x = x[:, -1:]
  233. # We got slow_out here
  234. slow_out = self.norm(x)
  235. if self.config.tie_word_embeddings:
  236. token_logits = F.linear(slow_out, self.embeddings.weight)
  237. else:
  238. token_logits = self.output(slow_out)
  239. return BaseTransformerForwardResult(
  240. logits=token_logits,
  241. hidden_states=x,
  242. )
  243. def _init_weights(self, module):
  244. std = self.config.initializer_range
  245. if isinstance(module, nn.Linear):
  246. module.weight.data.normal_(mean=0.0, std=std)
  247. if module.bias is not None:
  248. module.bias.data.zero_()
  249. elif isinstance(module, nn.Embedding):
  250. module.weight.data.normal_(mean=0.0, std=std)
  251. if module.padding_idx is not None:
  252. module.weight.data[module.padding_idx].zero_()
  253. @staticmethod
  254. def from_pretrained(
  255. path: str,
  256. load_weights: bool = False,
  257. max_length: int | None = None,
  258. lora_config: LoraConfig | None = None,
  259. rope_base: int | None = None,
  260. ) -> "BaseTransformer":
  261. config = BaseModelArgs.from_pretrained(str(path))
  262. if max_length is not None:
  263. config.max_seq_len = max_length
  264. log.info(f"Override max_seq_len to {max_length}")
  265. if rope_base is not None:
  266. config.rope_base = rope_base
  267. log.info(f"Override rope_base to {rope_base}")
  268. match config.model_type:
  269. case "naive":
  270. model_cls = NaiveTransformer
  271. case "dual_ar":
  272. model_cls = DualARTransformer
  273. case _:
  274. raise ValueError(f"Unknown model type: {config.model_type}")
  275. tokenizer = AutoTokenizer.from_pretrained(str(path))
  276. log.info(f"Loading model from {path}, config: {config}")
  277. model = model_cls(config, tokenizer=tokenizer)
  278. if lora_config is not None:
  279. setup_lora(model, lora_config)
  280. log.info(f"LoRA setup: {lora_config}")
  281. if load_weights is False:
  282. log.info("Randomly initialized model")
  283. else:
  284. if "int8" in str(Path(path)):
  285. logger.info("Using int8 weight-only quantization!")
  286. from tools.llama.quantize import WeightOnlyInt8QuantHandler
  287. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  288. model = simple_quantizer.convert_for_runtime()
  289. if "int4" in str(Path(path)):
  290. logger.info("Using int4 quantization!")
  291. path_comps = path.name.split("-")
  292. assert path_comps[-2].startswith("g")
  293. groupsize = int(path_comps[-2][1:])
  294. from tools.llama.quantize import WeightOnlyInt4QuantHandler
  295. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  296. model = simple_quantizer.convert_for_runtime()
  297. weights = torch.load(
  298. Path(path) / "model.pth",
  299. map_location="cpu",
  300. mmap=True,
  301. weights_only=True,
  302. )
  303. if "state_dict" in weights:
  304. logger.warning(
  305. "Using a TextToSemantic LightningModule checkpoint, "
  306. "please make sure it is a full model, not a LoRA model."
  307. )
  308. weights = weights["state_dict"]
  309. if next(iter(weights.keys())).startswith("model."):
  310. logger.info(
  311. f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
  312. )
  313. new_weights = OrderedDict()
  314. for k, v in weights.items():
  315. new_weights[k.replace("model.", "")] = v
  316. weights = new_weights
  317. # Verify the name and shape of parameters since strict=False in load_state_dict.
  318. for k, v in model.named_parameters():
  319. if k not in weights:
  320. logger.warning(f"No weight for {k}")
  321. elif v.shape != weights[k].shape:
  322. logger.warning(
  323. f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
  324. )
  325. err = model.load_state_dict(weights, strict=False, assign=True)
  326. log.info(f"Loaded weights with error: {err}")
  327. return model
  328. def save_pretrained(self, path: str, drop_lora: bool = False):
  329. path = Path(path)
  330. path.mkdir(parents=True, exist_ok=True)
  331. self.config.save(path / "config.json")
  332. state_dict = self.state_dict()
  333. if drop_lora:
  334. for key in list(state_dict.keys()):
  335. if "lora" not in key:
  336. continue
  337. state_dict.pop(key)
  338. log.info(f"Drop LoRA parameter: {key}")
  339. torch.save(state_dict, path / "model.pth")
  340. self.tokenizer.save_pretrained(path)
  341. class NaiveTransformer(BaseTransformer):
  342. def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
  343. super().__init__(config, init_weights=False, tokenizer=tokenizer)
  344. self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
  345. self.codebook_output = nn.Linear(
  346. config.dim,
  347. config.codebook_size * config.num_codebooks,
  348. bias=False,
  349. )
  350. self.apply(self._init_weights)
  351. def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
  352. token_logits = result.logits
  353. x = result.hidden_states
  354. # Codebook
  355. codebook_logits = self.codebook_output(self.codebook_norm(x))
  356. codebook_logits = rearrange(
  357. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  358. )
  359. return TransformerForwardResult(
  360. token_logits=token_logits,
  361. codebook_logits=codebook_logits,
  362. )
  363. def forward(
  364. self,
  365. inp: Tensor,
  366. key_padding_mask: Optional[Tensor] = None,
  367. ) -> TransformerForwardResult:
  368. result = super().forward(
  369. inp=inp,
  370. key_padding_mask=key_padding_mask,
  371. )
  372. return self.decode(result)
  373. def forward_generate(
  374. self, x: Tensor, input_pos: Optional[Tensor] = None
  375. ) -> TransformerForwardResult:
  376. result = super().forward_generate(x, input_pos)
  377. return self.decode(result)
  378. class DualARTransformer(BaseTransformer):
  379. def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
  380. super().__init__(config, init_weights=False, tokenizer=tokenizer)
  381. # Fast transformer
  382. self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
  383. # The equivalent bs is so large that sdpa doesn't work
  384. self.fast_layers = nn.ModuleList(
  385. TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
  386. )
  387. self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
  388. self.fast_output = nn.Linear(
  389. config.dim,
  390. config.codebook_size,
  391. bias=False,
  392. )
  393. self.apply(self._init_weights)
  394. def setup_caches(
  395. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  396. ):
  397. super().setup_caches(max_batch_size, max_seq_len, dtype)
  398. head_dim = self.config.dim // self.config.n_head
  399. # Fast transformer
  400. # The max seq len here is the number of codebooks
  401. for b in self.fast_layers:
  402. b.attention.kv_cache = KVCache(
  403. max_batch_size,
  404. self.config.num_codebooks,
  405. self.config.n_local_heads,
  406. head_dim,
  407. dtype=dtype,
  408. )
  409. def forward(
  410. self,
  411. inp: Tensor,
  412. key_padding_mask: Optional[Tensor] = None,
  413. ) -> TransformerForwardResult:
  414. parent_result = super().forward(inp, key_padding_mask)
  415. token_logits = parent_result.logits
  416. x = parent_result.hidden_states
  417. # Fast transformer
  418. fast_seq_len = self.config.num_codebooks
  419. fast_mask = self.causal_mask[
  420. None, None, :fast_seq_len, :fast_seq_len
  421. ] # (B, N, Q, K)
  422. fast_freqs_cis = self.freqs_cis[:fast_seq_len]
  423. # Drop the last token and rotate left
  424. codebooks = inp[:, 1:-1, 1:]
  425. codebooks = F.pad(codebooks, (0, 1), value=0)
  426. codebook_embeddings = self.fast_embeddings(codebooks)
  427. x = torch.cat([x[:, None], codebook_embeddings], dim=1)
  428. b, s = x.size(0), x.size(2)
  429. x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
  430. # Remove padded part
  431. codebooks = rearrange(codebooks, "b n s -> (b s) n")
  432. codebook_mask = (codebooks == 0).all(dim=-1)
  433. if torch.all(codebook_mask):
  434. # If all codebooks are padded, we keep first 8 to make sure the model runs
  435. codebook_mask[:8] = False
  436. x_bs, x_len = x.size(0), x.size(1)
  437. x = x[~codebook_mask]
  438. for layer in self.fast_layers:
  439. if self.config.use_gradient_checkpointing and self.training:
  440. x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
  441. else:
  442. x = layer(x, fast_freqs_cis, fast_mask)
  443. # unflatten the batch and num_codebooks
  444. fast_out = self.fast_norm(x)
  445. codebook_logits = self.fast_output(fast_out)
  446. # Re-pad the codebook_logits
  447. buffer = torch.zeros(
  448. x_bs,
  449. x_len,
  450. codebook_logits.size(-1),
  451. device=codebook_logits.device,
  452. dtype=codebook_logits.dtype,
  453. )
  454. buffer[~codebook_mask] = codebook_logits
  455. codebook_logits = buffer
  456. assert codebook_logits.shape[1] == self.config.num_codebooks
  457. codebook_logits = rearrange(
  458. codebook_logits,
  459. "(b s) n d -> b s n d",
  460. b=b,
  461. s=s,
  462. n=self.config.num_codebooks,
  463. )
  464. return TransformerForwardResult(
  465. token_logits=token_logits,
  466. codebook_logits=codebook_logits,
  467. )
  468. def forward_generate_fast(
  469. self, x: Tensor, input_pos: Optional[Tensor] = None
  470. ) -> Tensor:
  471. # Fast transformer
  472. x = x.view(1, 1, -1)
  473. fast_mask = self.causal_mask[
  474. None, None, input_pos, : self.config.num_codebooks
  475. ] # (B, N, Q, K)
  476. fast_freqs_cis = self.freqs_cis[input_pos]
  477. for layer in self.fast_layers:
  478. x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
  479. # unflatten the batch and num_codebooks
  480. fast_out = self.fast_norm(x) # only take the last token
  481. codebook_logits = self.fast_output(fast_out)
  482. return codebook_logits
  483. class TransformerBlock(nn.Module):
  484. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
  485. super().__init__()
  486. self.attention = Attention(config, use_sdpa=use_sdpa)
  487. self.feed_forward = FeedForward(config)
  488. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  489. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  490. def forward(
  491. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  492. ) -> Tensor:
  493. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  494. out = h + self.feed_forward(self.ffn_norm(h))
  495. return out
  496. class Attention(nn.Module):
  497. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
  498. super().__init__()
  499. assert config.dim % config.n_head == 0
  500. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  501. # key, query, value projections for all heads, but in a batch
  502. self.wqkv = nn.Linear(
  503. config.dim, total_head_dim, bias=config.attention_qkv_bias
  504. )
  505. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  506. self.kv_cache = None
  507. self.dropout = config.dropout
  508. self.n_head = config.n_head
  509. self.head_dim = config.head_dim
  510. self.n_local_heads = config.n_local_heads
  511. self.dim = config.dim
  512. self.use_sdpa = use_sdpa
  513. self._register_load_state_dict_pre_hook(self.load_hook)
  514. def load_hook(self, state_dict, prefix, *args):
  515. if prefix + "wq.weight" in state_dict:
  516. wq = state_dict.pop(prefix + "wq.weight")
  517. wk = state_dict.pop(prefix + "wk.weight")
  518. wv = state_dict.pop(prefix + "wv.weight")
  519. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  520. def forward(
  521. self,
  522. x: Tensor,
  523. freqs_cis: Tensor,
  524. mask: Tensor,
  525. input_pos: Optional[Tensor] = None,
  526. ) -> Tensor:
  527. bsz, seqlen, _ = x.shape
  528. kv_size = self.n_local_heads * self.head_dim
  529. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  530. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  531. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  532. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  533. q = apply_rotary_emb(q, freqs_cis)
  534. k = apply_rotary_emb(k, freqs_cis)
  535. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  536. if self.kv_cache is not None:
  537. k, v = self.kv_cache.update(input_pos, k, v)
  538. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  539. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  540. if self.use_sdpa:
  541. if mask is None:
  542. with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
  543. y = F.scaled_dot_product_attention(
  544. q,
  545. k,
  546. v,
  547. dropout_p=self.dropout if self.training else 0.0,
  548. is_causal=True,
  549. # No third party attn_mask here to use flash_attention
  550. )
  551. else:
  552. y = F.scaled_dot_product_attention(
  553. q,
  554. k,
  555. v,
  556. attn_mask=mask,
  557. dropout_p=self.dropout if self.training else 0.0,
  558. )
  559. else:
  560. y = self.eq_scaled_dot_product_attention(
  561. q,
  562. k,
  563. v,
  564. attn_mask=mask,
  565. dropout_p=self.dropout if self.training else 0.0,
  566. )
  567. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  568. return self.wo(y)
  569. def eq_scaled_dot_product_attention(
  570. self,
  571. query,
  572. key,
  573. value,
  574. attn_mask=None,
  575. dropout_p=0.0,
  576. ) -> torch.Tensor:
  577. # This is a standard scaled dot product attention
  578. # It's low efficient, but it doesn't raise cuda error
  579. L, S = query.size(-2), key.size(-2)
  580. scale_factor = 1 / math.sqrt(query.size(-1))
  581. attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
  582. if attn_mask is not None:
  583. if attn_mask.dtype == torch.bool:
  584. attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
  585. else:
  586. attn_bias += attn_mask
  587. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  588. attn_weight += attn_bias
  589. attn_weight = torch.softmax(attn_weight, dim=-1)
  590. attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
  591. return attn_weight @ value
  592. class FeedForward(nn.Module):
  593. def __init__(self, config: BaseModelArgs) -> None:
  594. super().__init__()
  595. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  596. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  597. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  598. def forward(self, x: Tensor) -> Tensor:
  599. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  600. class RMSNorm(nn.Module):
  601. def __init__(self, dim: int, eps: float = 1e-5):
  602. super().__init__()
  603. self.eps = eps
  604. self.weight = nn.Parameter(torch.ones(dim))
  605. def _norm(self, x):
  606. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  607. def forward(self, x: Tensor) -> Tensor:
  608. output = self._norm(x.float()).type_as(x)
  609. return output * self.weight
  610. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  611. freqs = 1.0 / (
  612. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  613. )
  614. t = torch.arange(seq_len, device=freqs.device)
  615. freqs = torch.outer(t, freqs)
  616. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  617. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  618. return cache.to(dtype=torch.bfloat16)
  619. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  620. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  621. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  622. x_out2 = torch.stack(
  623. [
  624. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  625. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  626. ],
  627. -1,
  628. )
  629. x_out2 = x_out2.flatten(3)
  630. return x_out2.type_as(x)