lit_module.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. from dataclasses import dataclass
  2. from typing import Any, Optional
  3. import lightning as L
  4. import loralib as lora
  5. import torch
  6. import torch.nn.functional as F
  7. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  8. import fish_speech.utils as utils
  9. from fish_speech.models.text2semantic.llama import NaiveTransformer
  10. log = utils.RankedLogger(__name__, rank_zero_only=True)
  11. @dataclass
  12. class LoraConfig:
  13. r: int
  14. lora_alpha: float
  15. lora_dropout: float = 0.0
  16. class TextToSemantic(L.LightningModule):
  17. def __init__(
  18. self,
  19. model: NaiveTransformer,
  20. optimizer: Any,
  21. lr_scheduler: Any,
  22. lora_config: Optional[LoraConfig] = None,
  23. save_lora_only: bool = False,
  24. use_dpo: bool = False,
  25. dpo_beta: float = 0.2,
  26. ):
  27. super().__init__()
  28. self.model = model
  29. self.optimizer_builder = optimizer
  30. self.lr_scheduler_builder = lr_scheduler
  31. self.lora_config = lora_config
  32. self.save_lora_only = save_lora_only
  33. self.use_dpo = use_dpo # We don't support reference model yet
  34. self.dpo_beta = dpo_beta
  35. if self.lora_config is not None:
  36. self.setup_lora()
  37. def setup_lora(self):
  38. # Replace the embedding layer with a LoRA layer
  39. self.model.embeddings = lora.Embedding(
  40. num_embeddings=self.model.embeddings.num_embeddings,
  41. embedding_dim=self.model.embeddings.embedding_dim,
  42. padding_idx=self.model.embeddings.padding_idx,
  43. r=self.lora_config.r,
  44. lora_alpha=self.lora_config.lora_alpha,
  45. )
  46. # Replace output layer with a LoRA layer
  47. linears = [(self.model, "output")]
  48. # Replace all linear layers with LoRA layers
  49. for layer in self.model.layers:
  50. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  51. linears.extend(
  52. [
  53. (layer.feed_forward, "w1"),
  54. (layer.feed_forward, "w2"),
  55. (layer.feed_forward, "w3"),
  56. ]
  57. )
  58. if hasattr(self.model, "fast_layers"):
  59. # Dual-AR model
  60. linears.extend([(self.model, "fast_output")])
  61. for layer in self.model.fast_layers:
  62. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  63. linears.extend(
  64. [
  65. (layer.feed_forward, "w1"),
  66. (layer.feed_forward, "w2"),
  67. (layer.feed_forward, "w3"),
  68. ]
  69. )
  70. for module, layer in linears:
  71. updated_linear = lora.Linear(
  72. in_features=getattr(module, layer).in_features,
  73. out_features=getattr(module, layer).out_features,
  74. bias=getattr(module, layer).bias,
  75. r=self.lora_config.r,
  76. lora_alpha=self.lora_config.lora_alpha,
  77. lora_dropout=self.lora_config.lora_dropout,
  78. )
  79. setattr(module, layer, updated_linear)
  80. # Mark only the LoRA layers as trainable
  81. lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
  82. def forward(self, x):
  83. return self.model(x)
  84. def on_save_checkpoint(self, checkpoint):
  85. if self.lora_config is None or self.save_lora_only is False:
  86. return
  87. # Save only LoRA parameters
  88. state_dict = checkpoint["state_dict"]
  89. for name in list(state_dict.keys()):
  90. if "lora" not in name:
  91. state_dict.pop(name)
  92. def configure_optimizers(self) -> OptimizerLRScheduler:
  93. # Get weight decay parameters
  94. weight_decay_parameters, other_parameters = [], []
  95. for name, param in self.named_parameters():
  96. if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
  97. other_parameters.append(param)
  98. else:
  99. weight_decay_parameters.append(param)
  100. optimizer = self.optimizer_builder(
  101. [
  102. {"params": weight_decay_parameters},
  103. {"params": other_parameters, "weight_decay": 0.0},
  104. ]
  105. )
  106. # Print the parameters and their weight decay
  107. for i in optimizer.param_groups:
  108. log.info(
  109. f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
  110. )
  111. lr_scheduler = self.lr_scheduler_builder(optimizer)
  112. return {
  113. "optimizer": optimizer,
  114. "lr_scheduler": {
  115. "scheduler": lr_scheduler,
  116. "interval": "step",
  117. },
  118. }
  119. # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
  120. def get_batch_logps(
  121. self,
  122. logits: torch.FloatTensor,
  123. labels: torch.LongTensor,
  124. average_log_prob: bool = False,
  125. ) -> torch.FloatTensor:
  126. """Compute the log probabilities of the given labels under the given logits.
  127. Args:
  128. logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
  129. labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
  130. average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
  131. Returns:
  132. A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
  133. """
  134. assert logits.shape[:-1] == labels.shape
  135. labels = labels.clone()
  136. loss_mask = labels != -100
  137. # dummy token; we'll ignore the losses on these tokens later
  138. labels[labels == -100] = 0
  139. per_token_logps = torch.gather(
  140. logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
  141. ).squeeze(-1)
  142. if average_log_prob:
  143. return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
  144. else:
  145. return (per_token_logps * loss_mask).sum(-1)
  146. def _step(self, batch, batch_idx, stage: str):
  147. is_train = stage == "train"
  148. # Do positive and negative samples in the same batch to speed up training
  149. labels = batch["labels"]
  150. outputs = self.model(
  151. inp=batch["inputs"],
  152. key_padding_mask=batch["attention_masks"],
  153. )
  154. token_logits = outputs.token_logits
  155. codebook_logits = outputs.codebook_logits
  156. if self.use_dpo:
  157. # Firtst half is positive, second half is negative
  158. token_logits, negative_token_logits = token_logits.chunk(2)
  159. codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
  160. labels, negative_labels = labels.chunk(2)
  161. # Generate labels
  162. base_loss = F.cross_entropy(
  163. token_logits.reshape(-1, token_logits.size(-1)),
  164. labels[:, 0].reshape(-1),
  165. ignore_index=-100,
  166. )
  167. codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
  168. semantic_loss = F.cross_entropy(
  169. codebook_logits.reshape(-1, codebook_logits.size(-1)),
  170. codebook_labels.reshape(-1),
  171. ignore_index=-100,
  172. )
  173. loss = base_loss + semantic_loss
  174. # If we use dpo
  175. if self.use_dpo:
  176. negative_codebook_labels = negative_labels[
  177. :, 1 : 1 + self.model.config.num_codebooks
  178. ].mT
  179. positive_codebook_logps = self.get_batch_logps(
  180. codebook_logits, codebook_labels
  181. )
  182. negative_codebook_logps = self.get_batch_logps(
  183. negative_codebook_logits, negative_codebook_labels
  184. )
  185. # TODO: implement the reference model, avoid screwing up the gradients
  186. dpo_loss = -F.logsigmoid(
  187. (positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
  188. ).mean()
  189. chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
  190. rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
  191. reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
  192. chosen_rewards, rejected_rewards = (
  193. chosen_rewards.mean(),
  194. rejected_rewards.mean(),
  195. )
  196. loss = loss + dpo_loss
  197. self.log(
  198. f"{stage}/dpo_loss",
  199. dpo_loss,
  200. on_step=is_train,
  201. on_epoch=not is_train,
  202. prog_bar=False,
  203. logger=True,
  204. )
  205. self.log(
  206. f"{stage}/chosen_rewards",
  207. chosen_rewards,
  208. on_step=is_train,
  209. on_epoch=not is_train,
  210. prog_bar=False,
  211. logger=True,
  212. )
  213. self.log(
  214. f"{stage}/rejected_rewards",
  215. rejected_rewards,
  216. on_step=is_train,
  217. on_epoch=not is_train,
  218. prog_bar=False,
  219. logger=True,
  220. )
  221. self.log(
  222. f"{stage}/reward_accuracy",
  223. reward_accuracy,
  224. on_step=is_train,
  225. on_epoch=not is_train,
  226. prog_bar=False,
  227. logger=True,
  228. )
  229. self.log(
  230. f"{stage}/loss",
  231. loss,
  232. on_step=is_train,
  233. on_epoch=not is_train,
  234. prog_bar=True,
  235. logger=True,
  236. )
  237. self.log(
  238. f"{stage}/base_loss",
  239. base_loss,
  240. on_step=is_train,
  241. on_epoch=not is_train,
  242. prog_bar=False,
  243. logger=True,
  244. )
  245. self.log(
  246. f"{stage}/semantic_loss",
  247. semantic_loss,
  248. on_step=is_train,
  249. on_epoch=not is_train,
  250. prog_bar=False,
  251. logger=True,
  252. )
  253. # Top-5 accuracy
  254. accuracy = self.get_accuracy(codebook_logits, codebook_labels)
  255. self.log(
  256. f"{stage}/top_5_accuracy",
  257. accuracy,
  258. on_step=is_train,
  259. on_epoch=not is_train,
  260. prog_bar=True,
  261. logger=True,
  262. )
  263. if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
  264. accuracy = self.get_accuracy(
  265. codebook_logits[:, :, : self.model.config.num_in_codebooks],
  266. codebook_labels[:, :, : self.model.config.num_in_codebooks],
  267. )
  268. self.log(
  269. f"{stage}/top_5_accuracy_in",
  270. accuracy,
  271. on_step=is_train,
  272. on_epoch=not is_train,
  273. prog_bar=True,
  274. logger=True,
  275. )
  276. return loss
  277. def get_accuracy(self, logits, labels):
  278. _, indices = logits.topk(5, dim=-1)
  279. correct = indices.eq(labels.unsqueeze(-1))
  280. correct[labels == -100] = 0
  281. correct = correct.sum()
  282. accuracy = correct / (labels != -100).sum()
  283. return accuracy
  284. def training_step(self, batch, batch_idx):
  285. return self._step(batch, batch_idx, "train")
  286. def validation_step(self, batch, batch_idx):
  287. return self._step(batch, batch_idx, "val")