llama.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733
  1. import json
  2. import math
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. from typing import Optional
  6. import torch
  7. import torch.nn as nn
  8. from einops import rearrange
  9. from torch import Tensor
  10. from torch.nn import functional as F
  11. from torch.nn.attention import SDPBackend, sdpa_kernel
  12. from torch.utils.checkpoint import checkpoint
  13. from transformers import AutoTokenizer
  14. from fish_speech.conversation import SEMANTIC_TOKEN
  15. from fish_speech.utils import RankedLogger
  16. from .lora import LoraConfig, setup_lora
  17. log = RankedLogger(__name__, rank_zero_only=True)
  18. def find_multiple(n: int, k: int) -> int:
  19. if n % k == 0:
  20. return n
  21. return n + k - (n % k)
  22. @dataclass
  23. class BaseModelArgs:
  24. model_type: str = "base"
  25. vocab_size: int = 32000
  26. n_layer: int = 32
  27. n_head: int = 32
  28. dim: int = 4096
  29. intermediate_size: int = None
  30. n_local_heads: int = -1
  31. head_dim: int = 64
  32. rope_base: float = 10000
  33. norm_eps: float = 1e-5
  34. max_seq_len: int = 2048
  35. dropout: float = 0.0
  36. tie_word_embeddings: bool = True
  37. attention_qkv_bias: bool = False
  38. # Codebook configs
  39. codebook_size: int = 160
  40. num_codebooks: int = 4
  41. # Gradient checkpointing
  42. use_gradient_checkpointing: bool = True
  43. # Initialize the model
  44. initializer_range: float = 0.02
  45. def __post_init__(self):
  46. if self.n_local_heads == -1:
  47. self.n_local_heads = self.n_head
  48. if self.intermediate_size is None:
  49. hidden_dim = 4 * self.dim
  50. n_hidden = int(2 * hidden_dim / 3)
  51. self.intermediate_size = find_multiple(n_hidden, 256)
  52. self.head_dim = self.dim // self.n_head
  53. @staticmethod
  54. def from_pretrained(path: str):
  55. path = Path(path)
  56. if path.is_dir():
  57. path = path / "config.json"
  58. with open(path, "r", encoding="utf-8") as f:
  59. data = json.load(f)
  60. match data["model_type"]:
  61. case "naive":
  62. cls = NaiveModelArgs
  63. case "dual_ar":
  64. cls = DualARModelArgs
  65. case _:
  66. raise ValueError(f"Unknown model type: {data['model_type']}")
  67. return cls(**data)
  68. def save(self, path: str):
  69. with open(path, "w") as f:
  70. json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
  71. @dataclass
  72. class NaiveModelArgs(BaseModelArgs):
  73. model_type: str = "naive"
  74. @dataclass
  75. class DualARModelArgs(BaseModelArgs):
  76. model_type: str = "dual_ar"
  77. n_fast_layer: int = 4
  78. class KVCache(nn.Module):
  79. def __init__(
  80. self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
  81. ):
  82. super().__init__()
  83. cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
  84. self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
  85. self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
  86. def update(self, input_pos, k_val, v_val):
  87. # input_pos: [S], k_val: [B, H, S, D]
  88. assert input_pos.shape[0] == k_val.shape[2]
  89. k_out = self.k_cache
  90. v_out = self.v_cache
  91. k_out[:, :, input_pos] = k_val
  92. v_out[:, :, input_pos] = v_val
  93. return k_out, v_out
  94. @dataclass
  95. class TransformerForwardResult:
  96. token_logits: Tensor
  97. codebook_logits: Tensor
  98. @dataclass
  99. class BaseTransformerForwardResult:
  100. logits: Tensor
  101. hidden_states: Tensor
  102. class BaseTransformer(nn.Module):
  103. def __init__(
  104. self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
  105. ) -> None:
  106. super().__init__()
  107. self.config = config
  108. self.tokenizer = tokenizer
  109. self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
  110. # Slow transformer
  111. self.embeddings = nn.Embedding(
  112. config.vocab_size,
  113. config.dim,
  114. )
  115. self.codebook_embeddings = nn.Embedding(
  116. config.codebook_size * config.num_codebooks,
  117. config.dim,
  118. )
  119. self.layers = nn.ModuleList(
  120. TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
  121. )
  122. self.norm = RMSNorm(config.dim, eps=config.norm_eps)
  123. if self.config.tie_word_embeddings is False:
  124. self.output = nn.Linear(
  125. config.dim,
  126. config.vocab_size,
  127. bias=False,
  128. )
  129. self.register_buffer(
  130. "freqs_cis",
  131. precompute_freqs_cis(
  132. config.max_seq_len,
  133. config.dim // config.n_head,
  134. config.rope_base,
  135. ),
  136. persistent=False,
  137. )
  138. self.register_buffer(
  139. "causal_mask",
  140. torch.tril(
  141. torch.ones(
  142. config.max_seq_len,
  143. config.max_seq_len,
  144. dtype=torch.bool,
  145. )
  146. ),
  147. persistent=False,
  148. )
  149. # For kv cache
  150. self.max_batch_size = -1
  151. self.max_seq_len = -1
  152. if init_weights:
  153. self.apply(self._init_weights)
  154. def setup_caches(
  155. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  156. ):
  157. if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
  158. return
  159. head_dim = self.config.dim // self.config.n_head
  160. max_seq_len = find_multiple(max_seq_len, 8)
  161. self.max_seq_len = max_seq_len
  162. self.max_batch_size = max_batch_size
  163. for b in self.layers:
  164. b.attention.kv_cache = KVCache(
  165. max_batch_size,
  166. max_seq_len,
  167. self.config.n_local_heads,
  168. head_dim,
  169. dtype=dtype,
  170. )
  171. def embed(self, x: Tensor) -> Tensor:
  172. vocab_embeds = [self.embeddings(x[:, 0])]
  173. for i in range(self.config.num_codebooks):
  174. emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
  175. emb[x[:, 0] != self.semantic_token_id] = 0
  176. vocab_embeds.append(emb)
  177. x = torch.stack(vocab_embeds, dim=3)
  178. x = x.sum(dim=3)
  179. return x
  180. def forward(
  181. self,
  182. inp: Tensor,
  183. key_padding_mask: Optional[Tensor] = None,
  184. ) -> BaseTransformerForwardResult:
  185. seq_len = inp.size(2)
  186. # Here we want to merge the embeddings of the codebooks
  187. x = self.embed(inp)
  188. freqs_cis = self.freqs_cis[:seq_len]
  189. # Not that the causal mask here follows the definition of scaled_dot_product_attention
  190. # That is, FALSE means masked out
  191. # To maintain consistency, key_padding_mask use TRUE to mask out
  192. mask = None
  193. if key_padding_mask is not None:
  194. mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
  195. mask = mask & key_padding_mask[:, None, None, :].logical_not()
  196. for layer in self.layers:
  197. if self.config.use_gradient_checkpointing and self.training:
  198. x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
  199. else:
  200. x = layer(x, freqs_cis, mask)
  201. # We got slow_out here
  202. slow_out = self.norm(x)
  203. if self.config.tie_word_embeddings:
  204. token_logits = F.linear(slow_out, self.embeddings.weight)
  205. else:
  206. token_logits = self.output(slow_out)
  207. return BaseTransformerForwardResult(
  208. logits=token_logits,
  209. hidden_states=x,
  210. )
  211. def forward_generate(
  212. self,
  213. x: Tensor,
  214. input_pos: Optional[Tensor] = None,
  215. return_all: bool = False,
  216. ) -> BaseTransformerForwardResult:
  217. # This is used for generation, optimized for torch compile
  218. assert (
  219. self.max_seq_len != -1 and self.max_batch_size != -1
  220. ), "Please call setup_caches before forward_generate"
  221. x = self.embed(x)
  222. mask = self.causal_mask[
  223. None, None, input_pos, : self.max_seq_len
  224. ] # (B, N, Q, K)
  225. freqs_cis = self.freqs_cis[input_pos]
  226. for layer in self.layers:
  227. x = layer(x, freqs_cis, mask, input_pos=input_pos)
  228. # If prefill, we only calculate the logits of last token
  229. if x.size(1) > 1 and not return_all:
  230. x = x[:, -1:]
  231. # We got slow_out here
  232. slow_out = self.norm(x)
  233. if self.config.tie_word_embeddings:
  234. token_logits = F.linear(slow_out, self.embeddings.weight)
  235. else:
  236. token_logits = self.output(slow_out)
  237. return BaseTransformerForwardResult(
  238. logits=token_logits,
  239. hidden_states=x,
  240. )
  241. def _init_weights(self, module):
  242. std = self.config.initializer_range
  243. if isinstance(module, nn.Linear):
  244. module.weight.data.normal_(mean=0.0, std=std)
  245. if module.bias is not None:
  246. module.bias.data.zero_()
  247. elif isinstance(module, nn.Embedding):
  248. module.weight.data.normal_(mean=0.0, std=std)
  249. if module.padding_idx is not None:
  250. module.weight.data[module.padding_idx].zero_()
  251. @staticmethod
  252. def from_pretrained(
  253. path: str,
  254. load_weights: bool = False,
  255. max_length: int | None = None,
  256. lora_config: LoraConfig | None = None,
  257. rope_base: int | None = None,
  258. ) -> "BaseTransformer":
  259. config = BaseModelArgs.from_pretrained(path)
  260. if max_length is not None:
  261. config.max_seq_len = max_length
  262. log.info(f"Override max_seq_len to {max_length}")
  263. if rope_base is not None:
  264. config.rope_base = rope_base
  265. log.info(f"Override rope_base to {rope_base}")
  266. match config.model_type:
  267. case "naive":
  268. model_cls = NaiveTransformer
  269. case "dual_ar":
  270. model_cls = DualARTransformer
  271. case _:
  272. raise ValueError(f"Unknown model type: {config.model_type}")
  273. tokenizer = AutoTokenizer.from_pretrained(str(path))
  274. log.info(f"Loading model from {path}, config: {config}")
  275. model = model_cls(config, tokenizer=tokenizer)
  276. if lora_config is not None:
  277. setup_lora(model, lora_config)
  278. log.info(f"LoRA setup: {lora_config}")
  279. if load_weights is False:
  280. log.info("Randomly initialized model")
  281. else:
  282. weights = torch.load(
  283. Path(path) / "model.pth", map_location="cpu", mmap=True
  284. )
  285. err = model.load_state_dict(weights, strict=False, assign=True)
  286. log.info(f"Loaded weights with error: {err}")
  287. return model
  288. def save_pretrained(self, path: str, drop_lora: bool = False):
  289. path = Path(path)
  290. path.mkdir(parents=True, exist_ok=True)
  291. self.config.save(path / "config.json")
  292. state_dict = self.state_dict()
  293. if drop_lora:
  294. for key in list(state_dict.keys()):
  295. if "lora" not in key:
  296. continue
  297. state_dict.pop(key)
  298. log.info(f"Drop LoRA parameter: {key}")
  299. torch.save(state_dict, path / "model.pth")
  300. self.tokenizer.save_pretrained(path)
  301. class NaiveTransformer(BaseTransformer):
  302. def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
  303. super().__init__(config, init_weights=False, tokenizer=tokenizer)
  304. self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
  305. self.codebook_output = nn.Linear(
  306. config.dim,
  307. config.codebook_size * config.num_codebooks,
  308. bias=False,
  309. )
  310. self.apply(self._init_weights)
  311. def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
  312. token_logits = result.logits
  313. x = result.hidden_states
  314. # Codebook
  315. codebook_logits = self.codebook_output(self.codebook_norm(x))
  316. codebook_logits = rearrange(
  317. codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
  318. )
  319. return TransformerForwardResult(
  320. token_logits=token_logits,
  321. codebook_logits=codebook_logits,
  322. )
  323. def forward(
  324. self,
  325. inp: Tensor,
  326. key_padding_mask: Optional[Tensor] = None,
  327. ) -> TransformerForwardResult:
  328. result = super().forward(
  329. inp=inp,
  330. key_padding_mask=key_padding_mask,
  331. )
  332. return self.decode(result)
  333. def forward_generate(
  334. self, x: Tensor, input_pos: Optional[Tensor] = None
  335. ) -> TransformerForwardResult:
  336. result = super().forward_generate(x, input_pos)
  337. return self.decode(result)
  338. class DualARTransformer(BaseTransformer):
  339. def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
  340. super().__init__(config, init_weights=False, tokenizer=tokenizer)
  341. # Fast transformer
  342. self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
  343. # The equivalent bs is so large that sdpa doesn't work
  344. self.fast_layers = nn.ModuleList(
  345. TransformerBlock(config, use_sdpa=False) for _ in range(config.n_fast_layer)
  346. )
  347. self.fast_norm = RMSNorm(config.dim, eps=config.norm_eps)
  348. self.fast_output = nn.Linear(
  349. config.dim,
  350. config.codebook_size,
  351. bias=False,
  352. )
  353. self.apply(self._init_weights)
  354. def setup_caches(
  355. self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
  356. ):
  357. super().setup_caches(max_batch_size, max_seq_len, dtype)
  358. head_dim = self.config.dim // self.config.n_head
  359. # Fast transformer
  360. # The max seq len here is the number of codebooks
  361. for b in self.fast_layers:
  362. b.attention.kv_cache = KVCache(
  363. max_batch_size,
  364. self.config.num_codebooks,
  365. self.config.n_local_heads,
  366. head_dim,
  367. dtype=dtype,
  368. )
  369. def forward(
  370. self,
  371. inp: Tensor,
  372. key_padding_mask: Optional[Tensor] = None,
  373. ) -> TransformerForwardResult:
  374. parent_result = super().forward(inp, key_padding_mask)
  375. token_logits = parent_result.logits
  376. x = parent_result.hidden_states
  377. # Fast transformer
  378. fast_seq_len = self.config.num_codebooks
  379. fast_mask = self.causal_mask[
  380. None, None, :fast_seq_len, :fast_seq_len
  381. ] # (B, N, Q, K)
  382. fast_freqs_cis = self.freqs_cis[:fast_seq_len]
  383. # Drop the last token and rotate left
  384. codebooks = inp[:, 1:-1, 1:]
  385. codebooks = F.pad(codebooks, (0, 1), value=0)
  386. codebook_embeddings = self.fast_embeddings(codebooks)
  387. x = torch.cat([x[:, None], codebook_embeddings], dim=1)
  388. b, s = x.size(0), x.size(2)
  389. x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
  390. # Remove padded part
  391. codebooks = rearrange(codebooks, "b n s -> (b s) n")
  392. codebook_mask = (codebooks == 0).all(dim=-1)
  393. if torch.all(codebook_mask):
  394. # If all codebooks are padded, we keep first 8 to make sure the model runs
  395. codebook_mask[:8] = False
  396. x_bs, x_len = x.size(0), x.size(1)
  397. x = x[~codebook_mask]
  398. for layer in self.fast_layers:
  399. if self.config.use_gradient_checkpointing and self.training:
  400. x = checkpoint(layer, x, fast_freqs_cis, fast_mask, use_reentrant=True)
  401. else:
  402. x = layer(x, fast_freqs_cis, fast_mask)
  403. # unflatten the batch and num_codebooks
  404. fast_out = self.fast_norm(x)
  405. codebook_logits = self.fast_output(fast_out)
  406. # Re-pad the codebook_logits
  407. buffer = torch.zeros(
  408. x_bs,
  409. x_len,
  410. codebook_logits.size(-1),
  411. device=codebook_logits.device,
  412. dtype=codebook_logits.dtype,
  413. )
  414. buffer[~codebook_mask] = codebook_logits
  415. codebook_logits = buffer
  416. assert codebook_logits.shape[1] == self.config.num_codebooks
  417. codebook_logits = rearrange(
  418. codebook_logits,
  419. "(b s) n d -> b s n d",
  420. b=b,
  421. s=s,
  422. n=self.config.num_codebooks,
  423. )
  424. return TransformerForwardResult(
  425. token_logits=token_logits,
  426. codebook_logits=codebook_logits,
  427. )
  428. def forward_generate_fast(
  429. self, x: Tensor, input_pos: Optional[Tensor] = None
  430. ) -> Tensor:
  431. # Fast transformer
  432. x = x.view(1, 1, -1)
  433. fast_mask = self.causal_mask[
  434. None, None, input_pos, : self.config.num_codebooks
  435. ] # (B, N, Q, K)
  436. fast_freqs_cis = self.freqs_cis[input_pos]
  437. for layer in self.fast_layers:
  438. x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
  439. # unflatten the batch and num_codebooks
  440. fast_out = self.fast_norm(x) # only take the last token
  441. codebook_logits = self.fast_output(fast_out)
  442. return codebook_logits
  443. class TransformerBlock(nn.Module):
  444. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
  445. super().__init__()
  446. self.attention = Attention(config, use_sdpa=use_sdpa)
  447. self.feed_forward = FeedForward(config)
  448. self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
  449. self.attention_norm = RMSNorm(config.dim, config.norm_eps)
  450. def forward(
  451. self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
  452. ) -> Tensor:
  453. h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
  454. out = h + self.feed_forward(self.ffn_norm(h))
  455. return out
  456. class Attention(nn.Module):
  457. def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
  458. super().__init__()
  459. assert config.dim % config.n_head == 0
  460. total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
  461. # key, query, value projections for all heads, but in a batch
  462. self.wqkv = nn.Linear(
  463. config.dim, total_head_dim, bias=config.attention_qkv_bias
  464. )
  465. self.wo = nn.Linear(config.dim, config.dim, bias=False)
  466. self.kv_cache = None
  467. self.dropout = config.dropout
  468. self.n_head = config.n_head
  469. self.head_dim = config.head_dim
  470. self.n_local_heads = config.n_local_heads
  471. self.dim = config.dim
  472. self.use_sdpa = use_sdpa
  473. self._register_load_state_dict_pre_hook(self.load_hook)
  474. def load_hook(self, state_dict, prefix, *args):
  475. if prefix + "wq.weight" in state_dict:
  476. wq = state_dict.pop(prefix + "wq.weight")
  477. wk = state_dict.pop(prefix + "wk.weight")
  478. wv = state_dict.pop(prefix + "wv.weight")
  479. state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
  480. def forward(
  481. self,
  482. x: Tensor,
  483. freqs_cis: Tensor,
  484. mask: Tensor,
  485. input_pos: Optional[Tensor] = None,
  486. ) -> Tensor:
  487. bsz, seqlen, _ = x.shape
  488. kv_size = self.n_local_heads * self.head_dim
  489. q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
  490. q = q.view(bsz, seqlen, self.n_head, self.head_dim)
  491. k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  492. v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
  493. q = apply_rotary_emb(q, freqs_cis)
  494. k = apply_rotary_emb(k, freqs_cis)
  495. q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
  496. if self.kv_cache is not None:
  497. k, v = self.kv_cache.update(input_pos, k, v)
  498. k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  499. v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
  500. if self.use_sdpa:
  501. if mask is None:
  502. with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
  503. y = F.scaled_dot_product_attention(
  504. q,
  505. k,
  506. v,
  507. dropout_p=self.dropout if self.training else 0.0,
  508. is_causal=True,
  509. # No third party attn_mask here to use flash_attention
  510. )
  511. else:
  512. y = F.scaled_dot_product_attention(
  513. q,
  514. k,
  515. v,
  516. attn_mask=mask,
  517. dropout_p=self.dropout if self.training else 0.0,
  518. )
  519. else:
  520. y = self.eq_scaled_dot_product_attention(
  521. q,
  522. k,
  523. v,
  524. attn_mask=mask,
  525. dropout_p=self.dropout if self.training else 0.0,
  526. )
  527. y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
  528. return self.wo(y)
  529. def eq_scaled_dot_product_attention(
  530. self,
  531. query,
  532. key,
  533. value,
  534. attn_mask=None,
  535. dropout_p=0.0,
  536. ) -> torch.Tensor:
  537. # This is a standard scaled dot product attention
  538. # It's low efficient, but it doesn't raise cuda error
  539. L, S = query.size(-2), key.size(-2)
  540. scale_factor = 1 / math.sqrt(query.size(-1))
  541. attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
  542. if attn_mask is not None:
  543. if attn_mask.dtype == torch.bool:
  544. attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
  545. else:
  546. attn_bias += attn_mask
  547. attn_weight = query @ key.transpose(-2, -1) * scale_factor
  548. attn_weight += attn_bias
  549. attn_weight = torch.softmax(attn_weight, dim=-1)
  550. attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
  551. return attn_weight @ value
  552. class FeedForward(nn.Module):
  553. def __init__(self, config: BaseModelArgs) -> None:
  554. super().__init__()
  555. self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  556. self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
  557. self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
  558. def forward(self, x: Tensor) -> Tensor:
  559. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  560. class RMSNorm(nn.Module):
  561. def __init__(self, dim: int, eps: float = 1e-5):
  562. super().__init__()
  563. self.eps = eps
  564. self.weight = nn.Parameter(torch.ones(dim))
  565. def _norm(self, x):
  566. return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
  567. def forward(self, x: Tensor) -> Tensor:
  568. output = self._norm(x.float()).type_as(x)
  569. return output * self.weight
  570. def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
  571. freqs = 1.0 / (
  572. base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
  573. )
  574. t = torch.arange(seq_len, device=freqs.device)
  575. freqs = torch.outer(t, freqs)
  576. freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
  577. cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
  578. return cache.to(dtype=torch.bfloat16)
  579. def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
  580. xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
  581. freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
  582. x_out2 = torch.stack(
  583. [
  584. xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
  585. xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
  586. ],
  587. -1,
  588. )
  589. x_out2 = x_out2.flatten(3)
  590. return x_out2.type_as(x)