llama.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752
  1. import json
  2. import math
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. from typing import Optional
  6. import torch
  7. import torch.nn as nn
  8. from einops import rearrange
  9. from loguru import logger
  10. from torch import Tensor
  11. from torch.nn import functional as F
  12. from torch.nn.attention import SDPBackend, sdpa_kernel
  13. from torch.utils.checkpoint import checkpoint
  14. from transformers import AutoTokenizer
  15. from fish_speech.conversation import SEMANTIC_TOKEN
  16. from fish_speech.utils import RankedLogger
  17. from .lora import LoraConfig, setup_lora
  18. log = RankedLogger(__name__, rank_zero_only=True)
  19. def find_multiple(n: int, k: int) -> int:
  20. if n % k == 0:
  21. return n
  22. return n + k - (n % k)
  23. @dataclass
  24. class BaseModelArgs:
  25. model_type: str = "base"
  26. vocab_size: int = 32000
  27. n_layer: int = 32
  28. n_head: int = 32
  29. dim: int = 4096
  30. intermediate_size: int = None
  31. n_local_heads: int = -1
  32. head_dim: int = 64
  33. rope_base: float = 10000
  34. norm_eps: float = 1e-5
  35. max_seq_len: int = 2048
  36. dropout: float = 0.0
  37. tie_word_embeddings: bool = True
  38. attention_qkv_bias: bool = False
  39. # Codebook configs
  40. codebook_size: int = 160
  41. num_codebooks: int = 4
  42. # Gradient checkpointing
  43. use_gradient_checkpointing: bool = True
  44. # Initialize the model
  45. initializer_range: float = 0.02
  46. def __post_init__(self):
  47. if self.n_local_heads == -1:
  48. self.n_local_heads = self.n_head
  49. if self.intermediate_size is None:
  50. hidden_dim = 4 * self.dim
  51. n_hidden = int(2 * hidden_dim / 3)
  52. self.intermediate_size = find_multiple(n_hidden, 256)
  53. self.head_dim = self.dim // self.n_head
  54. @staticmethod
  55. def from_pretrained(path: str):
  56. path = Path(path)
  57. if path.is_dir():
  58. path = path / "config.json"
  59. with open(path, "r", encoding="utf-8") as f:
  60. data = json.load(f)
  61. match data["model_type"]:
  62. case "naive":
  63. cls = NaiveModelArgs
  64. case "dual_ar":
  65. cls = DualARModelArgs
  66. case _:
  67. raise ValueError(f"Unknown model type: {data['model_type']}")
  68. return cls(**data)
  69. def save(self, path: str):
  70. with open(path, "w") as f:
  71. json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
  72. @dataclass
  73. class NaiveModelArgs(BaseModelArgs):
  74. model_type: str = "naive"
  75. @dataclass
  76. class DualARModelArgs(BaseModelArgs):
  77. model_type: str = "dual_ar"
  78. n_fast_layer: int = 4
  79. class KVCache(nn.Module):
  80. def __init__(
  81. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  82. ):
  83. super().__init__()
  84. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  85. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  86. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  87. def update(self, input_pos, k_val, v_val):
  88. # input_pos: [S], k_val: [B, H, S, D]
  89. assert input_pos.shape[0] == k_val.shape[2]
  90. k_out = self.k_cache
  91. v_out = self.v_cache
  92. k_out[:, :, input_pos] = k_val
  93. v_out[:, :, input_pos] = v_val
  94. return k_out, v_out
  95. @dataclass
  96. class TransformerForwardResult:
  97. token_logits: Tensor
  98. codebook_logits: Tensor
  99. @dataclass
  100. class BaseTransformerForwardResult:
  101. logits: Tensor
  102. hidden_states: Tensor
  103. class BaseTransformer(nn.Module):
  104. def __init__(
  105. self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
  106. ) -> None:
  107. super().__init__()
  108. self.config = config
  109. self.tokenizer = tokenizer
  110. self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
  111. # Slow transformer
  112. self.embeddings = nn.Embedding(
  113. config.vocab_size,
  114. config.dim,
  115. )
  116. self.codebook_embeddings = nn.Embedding(
  117. config.codebook_size * config.num_codebooks,
  118. config.dim,
  119. )
  120. self.layers = nn.ModuleList(
  121. TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
  122. )
  123. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  124. if self.config.tie_word_embeddings is False:
  125. self.output = nn.Linear(
  126. config.dim,
  127. config.vocab_size,
  128. bias=False,
  129. )
  130. self.register_buffer(
  131. "freqs_cis",
  132. precompute_freqs_cis(
  133. config.max_seq_len,
  134. config.dim // config.n_head,
  135. config.rope_base,
  136. ),
  137. persistent=False,
  138. )
  139. self.register_buffer(
  140. "causal_mask",
  141. torch.tril(
  142. torch.ones(
  143. config.max_seq_len,
  144. config.max_seq_len,
  145. dtype=torch.bool,
  146. )
  147. ),
  148. persistent=False,
  149. )
  150. # For kv cache
  151. self.max_batch_size = -1
  152. self.max_seq_len = -1
  153. if init_weights:
  154. self.apply(self._init_weights)
  155. def setup_caches(
  156. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  157. ):
  158. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  159. return
  160. head_dim = self.config.dim // self.config.n_head
  161. max_seq_len = find_multiple(max_seq_len, 8)
  162. self.max_seq_len = max_seq_len
  163. self.max_batch_size = max_batch_size
  164. for b in self.layers:
  165. b.attention.kv_cache = KVCache(
  166. max_batch_size,
  167. max_seq_len,
  168. self.config.n_local_heads,
  169. head_dim,
  170. dtype=dtype,
  171. )
  172. def embed(self, x: Tensor) -> Tensor:
  173. vocab_embeds = [self.embeddings(x[:, 0])]
  174. for i in range(self.config.num_codebooks):
  175. emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
  176. emb[x[:, 0] != self.semantic_token_id] = 0
  177. vocab_embeds.append(emb)
  178. x = torch.stack(vocab_embeds, dim=3)
  179. x = x.sum(dim=3)
  180. return x
  181. def forward(
  182. self,
  183. inp: Tensor,
  184. key_padding_mask: Optional[Tensor] = None,
  185. ) -> BaseTransformerForwardResult:
  186. seq_len = inp.size(2)
  187. # Here we want to merge the embeddings of the codebooks
  188. x = self.embed(inp)
  189. freqs_cis = self.freqs_cis[:seq_len]
  190. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  191. # That is, FALSE means masked out
  192. # To maintain consistency, key_padding_mask use TRUE to mask out
  193. mask = None
  194. if key_padding_mask is not None:
  195. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  196. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  197. for layer in self.layers:
  198. if self.config.use_gradient_checkpointing and self.training:
  199. x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
  200. else:
  201. x = layer(x, freqs_cis, mask)
  202. # We got slow_out here
  203. slow_out = self.norm(x)
  204. if self.config.tie_word_embeddings:
  205. token_logits = F.linear(slow_out, self.embeddings.weight)
  206. else:
  207. token_logits = self.output(slow_out)
  208. return BaseTransformerForwardResult(
  209. logits=token_logits,
  210. hidden_states=x,
  211. )
  212. def forward_generate(
  213. self,
  214. x: Tensor,
  215. input_pos: Optional[Tensor] = None,
  216. return_all: bool = False,
  217. ) -> BaseTransformerForwardResult:
  218. # This is used for generation, optimized for torch compile
  219. assert (
  220. self.max_seq_len != -1 and self.max_batch_size != -1
  221. ), "Please call setup_caches before forward_generate"
  222. x = self.embed(x)
  223. mask = self.causal_mask[
  224. None, None, input_pos, : self.max_seq_len
  225. ] # (B, N, Q, K)
  226. freqs_cis = self.freqs_cis[input_pos]
  227. for layer in self.layers:
  228. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  229. # If prefill, we only calculate the logits of last token
  230. if x.size(1) > 1 and not return_all:
  231. x = x[:, -1:]
  232. # We got slow_out here
  233. slow_out = self.norm(x)
  234. if self.config.tie_word_embeddings:
  235. token_logits = F.linear(slow_out, self.embeddings.weight)
  236. else:
  237. token_logits = self.output(slow_out)
  238. return BaseTransformerForwardResult(
  239. logits=token_logits,
  240. hidden_states=x,
  241. )
  242. def _init_weights(self, module):
  243. std = self.config.initializer_range
  244. if isinstance(module, nn.Linear):
  245. module.weight.data.normal_(mean=0.0, std=std)
  246. if module.bias is not None:
  247. module.bias.data.zero_()
  248. elif isinstance(module, nn.Embedding):
  249. module.weight.data.normal_(mean=0.0, std=std)
  250. if module.padding_idx is not None:
  251. module.weight.data[module.padding_idx].zero_()
  252. @staticmethod
  253. def from_pretrained(
  254. path: str,
  255. load_weights: bool = False,
  256. max_length: int | None = None,
  257. lora_config: LoraConfig | None = None,
  258. rope_base: int | None = None,
  259. ) -> "BaseTransformer":
  260. config = BaseModelArgs.from_pretrained(str(path))
  261. if max_length is not None:
  262. config.max_seq_len = max_length
  263. log.info(f"Override max_seq_len to {max_length}")
  264. if rope_base is not None:
  265. config.rope_base = rope_base
  266. log.info(f"Override rope_base to {rope_base}")
  267. match config.model_type:
  268. case "naive":
  269. model_cls = NaiveTransformer
  270. case "dual_ar":
  271. model_cls = DualARTransformer
  272. case _:
  273. raise ValueError(f"Unknown model type: {config.model_type}")
  274. tokenizer = AutoTokenizer.from_pretrained(str(path))
  275. log.info(f"Loading model from {path}, config: {config}")
  276. model = model_cls(config, tokenizer=tokenizer)
  277. if lora_config is not None:
  278. setup_lora(model, lora_config)
  279. log.info(f"LoRA setup: {lora_config}")
  280. if load_weights is False:
  281. log.info("Randomly initialized model")
  282. else:
  283. if "int8" in str(Path(path)):
  284. logger.info("Using int8 weight-only quantization!")
  285. from tools.llama.quantize import WeightOnlyInt8QuantHandler
  286. simple_quantizer = WeightOnlyInt8QuantHandler(model)
  287. model = simple_quantizer.convert_for_runtime()
  288. if "int4" in str(Path(path)):
  289. logger.info("Using int4 quantization!")
  290. path_comps = path.name.split("-")
  291. assert path_comps[-2].startswith("g")
  292. groupsize = int(path_comps[-2][1:])
  293. from tools.llama.quantize import WeightOnlyInt4QuantHandler
  294. simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
  295. model = simple_quantizer.convert_for_runtime()
  296. weights = torch.load(
  297. Path(path) / "model.pth", map_location="cpu", mmap=True
  298. )
  299. err = model.load_state_dict(weights, strict=False, assign=True)
  300. log.info(f"Loaded weights with error: {err}")
  301. return model
  302. def save_pretrained(self, path: str, drop_lora: bool = False):
  303. path = Path(path)
  304. path.mkdir(parents=True, exist_ok=True)
  305. self.config.save(path / "config.json")
  306. state_dict = self.state_dict()
  307. if drop_lora:
  308. for key in list(state_dict.keys()):
  309. if "lora" not in key:
  310. continue
  311. state_dict.pop(key)
  312. log.info(f"Drop LoRA parameter: {key}")
  313. torch.save(state_dict, path / "model.pth")
  314. self.tokenizer.save_pretrained(path)
  315. class NaiveTransformer(BaseTransformer):
  316. def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
  317. super().__init__(config, init_weights=False, tokenizer=tokenizer)
  318. self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
  319. self.codebook_output = nn.Linear(
  320. config.dim,
  321. config.codebook_size * config.num_codebooks,
  322. bias=False,
  323. )
  324. self.apply(self._init_weights)
  325. def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
  326. token_logits = result.logits
  327. x = result.hidden_states
  328. # Codebook
  329. codebook_logits = self.codebook_output(self.codebook_norm(x))
  330. codebook_logits = rearrange(
  331. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  332. )
  333. return TransformerForwardResult(
  334. token_logits=token_logits,
  335. codebook_logits=codebook_logits,
  336. )
  337. def forward(
  338. self,
  339. inp: Tensor,
  340. key_padding_mask: Optional[Tensor] = None,
  341. ) -> TransformerForwardResult:
  342. result = super().forward(
  343. inp=inp,
  344. key_padding_mask=key_padding_mask,
  345. )
  346. return self.decode(result)
  347. def forward_generate(
  348. self, x: Tensor, input_pos: Optional[Tensor] = None
  349. ) -> TransformerForwardResult:
  350. result = super().forward_generate(x, input_pos)
  351. return self.decode(result)
  352. class DualARTransformer(BaseTransformer):
  353. def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
  354. super().__init__(config, init_weights=False, tokenizer=tokenizer)
  355. # Fast transformer
  356. self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
  357. # The equivalent bs is so large that sdpa doesn't work
  358. self.fast_layers = nn.ModuleList(
  359. TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
  360. )
  361. self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
  362. self.fast_output = nn.Linear(
  363. config.dim,
  364. config.codebook_size,
  365. bias=False,
  366. )
  367. self.apply(self._init_weights)
  368. def setup_caches(
  369. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  370. ):
  371. super().setup_caches(max_batch_size, max_seq_len, dtype)
  372. head_dim = self.config.dim // self.config.n_head
  373. # Fast transformer
  374. # The max seq len here is the number of codebooks
  375. for b in self.fast_layers:
  376. b.attention.kv_cache = KVCache(
  377. max_batch_size,
  378. self.config.num_codebooks,
  379. self.config.n_local_heads,
  380. head_dim,
  381. dtype=dtype,
  382. )
  383. def forward(
  384. self,
  385. inp: Tensor,
  386. key_padding_mask: Optional[Tensor] = None,
  387. ) -> TransformerForwardResult:
  388. parent_result = super().forward(inp, key_padding_mask)
  389. token_logits = parent_result.logits
  390. x = parent_result.hidden_states
  391. # Fast transformer
  392. fast_seq_len = self.config.num_codebooks
  393. fast_mask = self.causal_mask[
  394. None, None, :fast_seq_len, :fast_seq_len
  395. ] # (B, N, Q, K)
  396. fast_freqs_cis = self.freqs_cis[:fast_seq_len]
  397. # Drop the last token and rotate left
  398. codebooks = inp[:, 1:-1, 1:]
  399. codebooks = F.pad(codebooks, (0, 1), value=0)
  400. codebook_embeddings = self.fast_embeddings(codebooks)
  401. x = torch.cat([x[:, None], codebook_embeddings], dim=1)
  402. b, s = x.size(0), x.size(2)
  403. x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
  404. # Remove padded part
  405. codebooks = rearrange(codebooks, "b n s -> (b s) n")
  406. codebook_mask = (codebooks == 0).all(dim=-1)
  407. if torch.all(codebook_mask):
  408. # If all codebooks are padded, we keep first 8 to make sure the model runs
  409. codebook_mask[:8] = False
  410. x_bs, x_len = x.size(0), x.size(1)
  411. x = x[~codebook_mask]
  412. for layer in self.fast_layers:
  413. if self.config.use_gradient_checkpointing and self.training:
  414. x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
  415. else:
  416. x = layer(x, fast_freqs_cis, fast_mask)
  417. # unflatten the batch and num_codebooks
  418. fast_out = self.fast_norm(x)
  419. codebook_logits = self.fast_output(fast_out)
  420. # Re-pad the codebook_logits
  421. buffer = torch.zeros(
  422. x_bs,
  423. x_len,
  424. codebook_logits.size(-1),
  425. device=codebook_logits.device,
  426. dtype=codebook_logits.dtype,
  427. )
  428. buffer[~codebook_mask] = codebook_logits
  429. codebook_logits = buffer
  430. assert codebook_logits.shape[1] == self.config.num_codebooks
  431. codebook_logits = rearrange(
  432. codebook_logits,
  433. "(b s) n d -> b s n d",
  434. b=b,
  435. s=s,
  436. n=self.config.num_codebooks,
  437. )
  438. return TransformerForwardResult(
  439. token_logits=token_logits,
  440. codebook_logits=codebook_logits,
  441. )
  442. def forward_generate_fast(
  443. self, x: Tensor, input_pos: Optional[Tensor] = None
  444. ) -> Tensor:
  445. # Fast transformer
  446. x = x.view(1, 1, -1)
  447. fast_mask = self.causal_mask[
  448. None, None, input_pos, : self.config.num_codebooks
  449. ] # (B, N, Q, K)
  450. fast_freqs_cis = self.freqs_cis[input_pos]
  451. for layer in self.fast_layers:
  452. x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
  453. # unflatten the batch and num_codebooks
  454. fast_out = self.fast_norm(x) # only take the last token
  455. codebook_logits = self.fast_output(fast_out)
  456. return codebook_logits
  457. class TransformerBlock(nn.Module):
  458. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
  459. super().__init__()
  460. self.attention = Attention(config, use_sdpa=use_sdpa)
  461. self.feed_forward = FeedForward(config)
  462. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  463. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  464. def forward(
  465. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  466. ) -> Tensor:
  467. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  468. out = h + self.feed_forward(self.ffn_norm(h))
  469. return out
  470. class Attention(nn.Module):
  471. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
  472. super().__init__()
  473. assert config.dim % config.n_head == 0
  474. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  475. # key, query, value projections for all heads, but in a batch
  476. self.wqkv = nn.Linear(
  477. config.dim, total_head_dim, bias=config.attention_qkv_bias
  478. )
  479. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  480. self.kv_cache = None
  481. self.dropout = config.dropout
  482. self.n_head = config.n_head
  483. self.head_dim = config.head_dim
  484. self.n_local_heads = config.n_local_heads
  485. self.dim = config.dim
  486. self.use_sdpa = use_sdpa
  487. self._register_load_state_dict_pre_hook(self.load_hook)
  488. def load_hook(self, state_dict, prefix, *args):
  489. if prefix + "wq.weight" in state_dict:
  490. wq = state_dict.pop(prefix + "wq.weight")
  491. wk = state_dict.pop(prefix + "wk.weight")
  492. wv = state_dict.pop(prefix + "wv.weight")
  493. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  494. def forward(
  495. self,
  496. x: Tensor,
  497. freqs_cis: Tensor,
  498. mask: Tensor,
  499. input_pos: Optional[Tensor] = None,
  500. ) -> Tensor:
  501. bsz, seqlen, _ = x.shape
  502. kv_size = self.n_local_heads * self.head_dim
  503. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  504. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  505. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  506. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  507. q = apply_rotary_emb(q, freqs_cis)
  508. k = apply_rotary_emb(k, freqs_cis)
  509. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  510. if self.kv_cache is not None:
  511. k, v = self.kv_cache.update(input_pos, k, v)
  512. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  513. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  514. if self.use_sdpa:
  515. if mask is None:
  516. with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
  517. y = F.scaled_dot_product_attention(
  518. q,
  519. k,
  520. v,
  521. dropout_p=self.dropout if self.training else 0.0,
  522. is_causal=True,
  523. # No third party attn_mask here to use flash_attention
  524. )
  525. else:
  526. y = F.scaled_dot_product_attention(
  527. q,
  528. k,
  529. v,
  530. attn_mask=mask,
  531. dropout_p=self.dropout if self.training else 0.0,
  532. )
  533. else:
  534. y = self.eq_scaled_dot_product_attention(
  535. q,
  536. k,
  537. v,
  538. attn_mask=mask,
  539. dropout_p=self.dropout if self.training else 0.0,
  540. )
  541. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  542. return self.wo(y)
  543. def eq_scaled_dot_product_attention(
  544. self,
  545. query,
  546. key,
  547. value,
  548. attn_mask=None,
  549. dropout_p=0.0,
  550. ) -> torch.Tensor:
  551. # This is a standard scaled dot product attention
  552. # It's low efficient, but it doesn't raise cuda error
  553. L, S = query.size(-2), key.size(-2)
  554. scale_factor = 1 / math.sqrt(query.size(-1))
  555. attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
  556. if attn_mask is not None:
  557. if attn_mask.dtype == torch.bool:
  558. attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
  559. else:
  560. attn_bias += attn_mask
  561. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  562. attn_weight += attn_bias
  563. attn_weight = torch.softmax(attn_weight, dim=-1)
  564. attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
  565. return attn_weight @ value
  566. class FeedForward(nn.Module):
  567. def __init__(self, config: BaseModelArgs) -> None:
  568. super().__init__()
  569. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  570. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  571. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  572. def forward(self, x: Tensor) -> Tensor:
  573. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  574. class RMSNorm(nn.Module):
  575. def __init__(self, dim: int, eps: float = 1e-5):
  576. super().__init__()
  577. self.eps = eps
  578. self.weight = nn.Parameter(torch.ones(dim))
  579. def _norm(self, x):
  580. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  581. def forward(self, x: Tensor) -> Tensor:
  582. output = self._norm(x.float()).type_as(x)
  583. return output * self.weight
  584. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  585. freqs = 1.0 / (
  586. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  587. )
  588. t = torch.arange(seq_len, device=freqs.device)
  589. freqs = torch.outer(t, freqs)
  590. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  591. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  592. return cache.to(dtype=torch.bfloat16)
  593. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  594. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  595. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  596. x_out2 = torch.stack(
  597. [
  598. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  599. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  600. ],
  601. -1,
  602. )
  603. x_out2 = x_out2.flatten(3)
  604. return x_out2.type_as(x)