lit_module.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. from dataclasses import dataclass
  2. from typing import Any, Optional
  3. import lightning as L
  4. import loralib as lora
  5. import torch
  6. import torch.nn.functional as F
  7. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  8. import fish_speech.utils as utils
  9. from fish_speech.models.text2semantic.llama import Transformer
  10. log = utils.RankedLogger(__name__, rank_zero_only=True)
  11. @dataclass
  12. class LoraConfig:
  13. r: int
  14. lora_alpha: float
  15. lora_dropout: float = 0.0
  16. class TextToSemantic(L.LightningModule):
  17. def __init__(
  18. self,
  19. model: Transformer,
  20. optimizer: Any,
  21. lr_scheduler: Any,
  22. lora_config: Optional[LoraConfig] = None,
  23. save_lora_only: bool = False,
  24. use_dpo: bool = False,
  25. dpo_beta: float = 0.2,
  26. ):
  27. super().__init__()
  28. self.model = model
  29. self.optimizer_builder = optimizer
  30. self.lr_scheduler_builder = lr_scheduler
  31. self.lora_config = lora_config
  32. self.save_lora_only = save_lora_only
  33. self.use_dpo = use_dpo # We don't support reference model yet
  34. self.dpo_beta = dpo_beta
  35. if self.lora_config is not None:
  36. self.setup_lora()
  37. def setup_lora(self):
  38. # Replace the embedding layer with a LoRA layer
  39. self.model.embeddings = lora.Embedding(
  40. num_embeddings=self.model.embeddings.num_embeddings,
  41. embedding_dim=self.model.embeddings.embedding_dim,
  42. padding_idx=self.model.embeddings.padding_idx,
  43. r=self.lora_config.r,
  44. lora_alpha=self.lora_config.lora_alpha,
  45. )
  46. # Replace output layer with a LoRA layer
  47. linears = [(self.model, "output")]
  48. # Replace all linear layers with LoRA layers
  49. for layer in self.model.layers:
  50. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  51. linears.extend(
  52. [
  53. (layer.feed_forward, "w1"),
  54. (layer.feed_forward, "w2"),
  55. (layer.feed_forward, "w3"),
  56. ]
  57. )
  58. for module, layer in linears:
  59. updated_linear = lora.Linear(
  60. in_features=getattr(module, layer).in_features,
  61. out_features=getattr(module, layer).out_features,
  62. bias=getattr(module, layer).bias,
  63. r=self.lora_config.r,
  64. lora_alpha=self.lora_config.lora_alpha,
  65. lora_dropout=self.lora_config.lora_dropout,
  66. )
  67. setattr(module, layer, updated_linear)
  68. # Mark only the LoRA layers as trainable
  69. lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
  70. def forward(self, x):
  71. return self.model(x)
  72. def on_save_checkpoint(self, checkpoint):
  73. if self.lora_config is None or self.save_lora_only is False:
  74. return
  75. # Save only LoRA parameters
  76. state_dict = checkpoint["state_dict"]
  77. for name in list(state_dict.keys()):
  78. if "lora" not in name:
  79. state_dict.pop(name)
  80. def configure_optimizers(self) -> OptimizerLRScheduler:
  81. optimizer = self.optimizer_builder(self.parameters())
  82. lr_scheduler = self.lr_scheduler_builder(optimizer)
  83. return {
  84. "optimizer": optimizer,
  85. "lr_scheduler": {
  86. "scheduler": lr_scheduler,
  87. "interval": "step",
  88. },
  89. }
  90. # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
  91. def get_batch_logps(
  92. self,
  93. logits: torch.FloatTensor,
  94. labels: torch.LongTensor,
  95. average_log_prob: bool = False,
  96. ) -> torch.FloatTensor:
  97. """Compute the log probabilities of the given labels under the given logits.
  98. Args:
  99. logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
  100. labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
  101. average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
  102. Returns:
  103. A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
  104. """
  105. assert logits.shape[:-1] == labels.shape
  106. labels = labels.clone()
  107. loss_mask = labels != -100
  108. # dummy token; we'll ignore the losses on these tokens later
  109. labels[labels == -100] = 0
  110. per_token_logps = torch.gather(
  111. logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
  112. ).squeeze(-1)
  113. if average_log_prob:
  114. return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
  115. else:
  116. return (per_token_logps * loss_mask).sum(-1)
  117. def _step(self, batch, batch_idx, stage: str):
  118. # Do positive and negative samples in the same batch to speed up training
  119. outputs = self.model(
  120. x=batch["inputs"],
  121. key_padding_mask=batch["attention_masks"],
  122. )
  123. labels = batch["labels"]
  124. token_logits = outputs.token_logits
  125. codebook_logits = outputs.codebook_logits
  126. if self.use_dpo:
  127. # Firtst half is positive, second half is negative
  128. token_logits, negative_token_logits = token_logits.chunk(2)
  129. codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
  130. labels, negative_labels = labels.chunk(2)
  131. # Generate labels
  132. base_loss = F.cross_entropy(
  133. token_logits.reshape(-1, token_logits.size(-1)),
  134. labels[:, 0].reshape(-1),
  135. ignore_index=-100,
  136. )
  137. # If we have a codebook, add the loss
  138. if self.model.config.num_codebooks != 0:
  139. codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
  140. semantic_loss = F.cross_entropy(
  141. codebook_logits.reshape(-1, codebook_logits.size(-1)),
  142. codebook_labels.reshape(-1),
  143. ignore_index=-100,
  144. )
  145. loss = base_loss + semantic_loss
  146. else:
  147. loss = base_loss
  148. # If we use dpo
  149. if self.use_dpo:
  150. negative_codebook_labels = negative_labels[
  151. :, 1 : 1 + self.model.config.num_codebooks
  152. ].mT
  153. positive_codebook_logps = self.get_batch_logps(
  154. codebook_logits, codebook_labels
  155. )
  156. negative_codebook_logps = self.get_batch_logps(
  157. negative_codebook_logits, negative_codebook_labels
  158. )
  159. # TODO: implement the reference model, avoid screwing up the gradients
  160. dpo_loss = -F.logsigmoid(
  161. (positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
  162. ).mean()
  163. chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
  164. rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
  165. reward_accuracy = (
  166. (positive_codebook_logps > negative_codebook_logps).float().mean()
  167. )
  168. chosen_rewards, rejected_rewards = (
  169. chosen_rewards.mean(),
  170. rejected_rewards.mean(),
  171. )
  172. loss = loss + dpo_loss
  173. self.log(
  174. f"{stage}/dpo_loss",
  175. dpo_loss,
  176. on_step=True,
  177. on_epoch=False,
  178. prog_bar=False,
  179. logger=True,
  180. )
  181. self.log(
  182. f"{stage}/chosen_rewards",
  183. chosen_rewards,
  184. on_step=True,
  185. on_epoch=False,
  186. prog_bar=False,
  187. logger=True,
  188. )
  189. self.log(
  190. f"{stage}/rejected_rewards",
  191. rejected_rewards,
  192. on_step=True,
  193. on_epoch=False,
  194. prog_bar=False,
  195. logger=True,
  196. )
  197. self.log(
  198. f"{stage}/reward_accuracy",
  199. reward_accuracy,
  200. on_step=True,
  201. on_epoch=False,
  202. prog_bar=False,
  203. logger=True,
  204. )
  205. self.log(
  206. f"{stage}/loss",
  207. loss,
  208. on_step=True,
  209. on_epoch=False,
  210. prog_bar=True,
  211. logger=True,
  212. )
  213. if self.model.config.num_codebooks != 0:
  214. self.log(
  215. f"{stage}/base_loss",
  216. base_loss,
  217. on_step=True,
  218. on_epoch=False,
  219. prog_bar=False,
  220. logger=True,
  221. )
  222. self.log(
  223. f"{stage}/semantic_loss",
  224. semantic_loss,
  225. on_step=True,
  226. on_epoch=False,
  227. prog_bar=False,
  228. logger=True,
  229. )
  230. # Top-5 accuracy
  231. if self.model.config.num_codebooks == 0:
  232. _, indices = token_logits.topk(5, dim=-1)
  233. correct = indices.eq(labels[:, 0].unsqueeze(-1))
  234. correct[labels[:, 0] == -100] = 0
  235. correct = correct.sum()
  236. accuracy = correct / (labels[:, 0] != -100).sum()
  237. else:
  238. _, indices = codebook_logits.topk(5, dim=-1)
  239. correct = indices.eq(codebook_labels.unsqueeze(-1))
  240. correct[codebook_labels == -100] = 0
  241. correct = correct.sum()
  242. accuracy = correct / (codebook_labels != -100).sum()
  243. self.log(
  244. f"{stage}/top_5_accuracy",
  245. accuracy,
  246. on_step=True,
  247. on_epoch=False,
  248. prog_bar=True,
  249. logger=True,
  250. )
  251. return loss
  252. def training_step(self, batch, batch_idx):
  253. return self._step(batch, batch_idx, "train")
  254. def validation_step(self, batch, batch_idx):
  255. return self._step(batch, batch_idx, "val")