lit_module.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from typing import Any, Optional
  2. import lightning as L
  3. import torch
  4. import torch.nn.functional as F
  5. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  6. import fish_speech.utils as utils
  7. from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
  8. from fish_speech.models.text2semantic.llama import NaiveTransformer
  9. log = utils.RankedLogger(__name__, rank_zero_only=True)
  10. class TextToSemantic(L.LightningModule):
  11. def __init__(
  12. self,
  13. model: NaiveTransformer,
  14. optimizer: Any,
  15. lr_scheduler: Any,
  16. ):
  17. super().__init__()
  18. self.model = model
  19. self.optimizer_builder = optimizer
  20. self.lr_scheduler_builder = lr_scheduler
  21. def forward(self, x):
  22. return self.model(x)
  23. def on_save_checkpoint(self, checkpoint):
  24. # Save only LoRA parameters
  25. state_dict = checkpoint["state_dict"]
  26. use_lora = any("lora" in name for name in state_dict.keys())
  27. if not use_lora:
  28. return
  29. for name in list(state_dict.keys()):
  30. if "lora" not in name:
  31. state_dict.pop(name)
  32. def configure_optimizers(self) -> OptimizerLRScheduler:
  33. # Get weight decay parameters
  34. weight_decay_parameters, other_parameters = [], []
  35. for name, param in self.named_parameters():
  36. if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
  37. other_parameters.append(param)
  38. else:
  39. weight_decay_parameters.append(param)
  40. optimizer = self.optimizer_builder(
  41. [
  42. {"params": weight_decay_parameters},
  43. {"params": other_parameters, "weight_decay": 0.0},
  44. ]
  45. )
  46. # Print the parameters and their weight decay
  47. for i in optimizer.param_groups:
  48. log.info(
  49. f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
  50. )
  51. lr_scheduler = self.lr_scheduler_builder(optimizer)
  52. return {
  53. "optimizer": optimizer,
  54. "lr_scheduler": {
  55. "scheduler": lr_scheduler,
  56. "interval": "step",
  57. },
  58. }
  59. # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
  60. def get_batch_logps(
  61. self,
  62. logits: torch.FloatTensor,
  63. labels: torch.LongTensor,
  64. average_log_prob: bool = False,
  65. ) -> torch.FloatTensor:
  66. """Compute the log probabilities of the given labels under the given logits.
  67. Args:
  68. logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
  69. labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
  70. average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
  71. Returns:
  72. A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
  73. """
  74. assert logits.shape[:-1] == labels.shape
  75. labels = labels.clone()
  76. loss_mask = labels != -100
  77. # dummy token; we'll ignore the losses on these tokens later
  78. labels[labels == -100] = 0
  79. per_token_logps = torch.gather(
  80. logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
  81. ).squeeze(-1)
  82. if average_log_prob:
  83. return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
  84. else:
  85. return (per_token_logps * loss_mask).sum(-1)
  86. def _step(self, batch, batch_idx, stage: str):
  87. is_train = stage == "train"
  88. if is_train:
  89. # Key part to make lora work
  90. # Otherwise the parameters are merged, which lead to incorrect gradients
  91. self.model.train()
  92. # Do positive and negative samples in the same batch to speed up training
  93. labels = batch["labels"]
  94. outputs = self.model(
  95. inp=batch["inputs"],
  96. key_padding_mask=batch["attention_masks"],
  97. )
  98. token_logits = outputs.token_logits
  99. codebook_logits = outputs.codebook_logits
  100. # Generate labels
  101. base_loss = F.cross_entropy(
  102. token_logits.view(-1, token_logits.size(-1)),
  103. labels[:, 0].reshape(-1),
  104. ignore_index=-100,
  105. )
  106. codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
  107. semantic_loss = F.cross_entropy(
  108. codebook_logits.view(-1, codebook_logits.size(-1)),
  109. codebook_labels.reshape(-1),
  110. ignore_index=-100,
  111. )
  112. loss = base_loss + semantic_loss
  113. self.log(
  114. f"{stage}/loss",
  115. loss,
  116. on_step=is_train,
  117. on_epoch=not is_train,
  118. prog_bar=True,
  119. logger=True,
  120. sync_dist=not is_train,
  121. )
  122. self.log(
  123. f"{stage}/base_loss",
  124. base_loss,
  125. on_step=is_train,
  126. on_epoch=not is_train,
  127. prog_bar=False,
  128. logger=True,
  129. sync_dist=not is_train,
  130. )
  131. self.log(
  132. f"{stage}/semantic_loss",
  133. semantic_loss,
  134. on_step=is_train,
  135. on_epoch=not is_train,
  136. prog_bar=False,
  137. logger=True,
  138. sync_dist=not is_train,
  139. )
  140. # Top-5 accuracy
  141. accuracy = self.get_accuracy(codebook_logits, codebook_labels)
  142. self.log(
  143. f"{stage}/top_5_accuracy",
  144. accuracy,
  145. on_step=is_train,
  146. on_epoch=not is_train,
  147. prog_bar=True,
  148. logger=True,
  149. sync_dist=not is_train,
  150. )
  151. return loss
  152. def get_accuracy(self, logits, labels):
  153. mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
  154. if mask.sum() == 0:
  155. return torch.tensor(0.0, device=logits.device)
  156. _, indices = logits.topk(5, dim=-1)
  157. correct = indices.eq(labels.unsqueeze(-1))
  158. correct[~mask] = 0
  159. correct = correct.sum()
  160. accuracy = correct / mask.sum()
  161. return accuracy
  162. def training_step(self, batch, batch_idx):
  163. return self._step(batch, batch_idx, "train")
  164. def validation_step(self, batch, batch_idx):
  165. return self._step(batch, batch_idx, "val")