lit_module.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. from typing import Any, Optional
  2. import lightning as L
  3. import torch
  4. import torch.nn.functional as F
  5. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  6. import fish_speech.utils as utils
  7. CODEBOOK_PAD_TOKEN_ID = 0
  8. from fish_speech.models.text2semantic.llama import NaiveTransformer
  9. log = utils.RankedLogger(__name__, rank_zero_only=True)
  10. class TextToSemantic(L.LightningModule):
  11. def __init__(
  12. self,
  13. model: NaiveTransformer,
  14. optimizer: Any,
  15. lr_scheduler: Any,
  16. ):
  17. super().__init__()
  18. self.model = model
  19. self.optimizer_builder = optimizer
  20. self.lr_scheduler_builder = lr_scheduler
  21. def forward(self, x):
  22. return self.model(x)
  23. def on_save_checkpoint(self, checkpoint):
  24. # Save only LoRA parameters
  25. state_dict = checkpoint["state_dict"]
  26. use_lora = any("lora" in name for name in state_dict.keys())
  27. if not use_lora:
  28. return
  29. for name in list(state_dict.keys()):
  30. if "lora" not in name:
  31. state_dict.pop(name)
  32. def configure_optimizers(self) -> OptimizerLRScheduler:
  33. # Get weight decay parameters
  34. weight_decay_parameters, other_parameters = [], []
  35. for name, param in self.named_parameters():
  36. if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
  37. other_parameters.append(param)
  38. else:
  39. weight_decay_parameters.append(param)
  40. optimizer = self.optimizer_builder(
  41. [
  42. {"params": weight_decay_parameters},
  43. {"params": other_parameters, "weight_decay": 0.0},
  44. ]
  45. )
  46. # Print the parameters and their weight decay
  47. for i in optimizer.param_groups:
  48. log.info(
  49. f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
  50. )
  51. lr_scheduler = self.lr_scheduler_builder(optimizer)
  52. return {
  53. "optimizer": optimizer,
  54. "lr_scheduler": {
  55. "scheduler": lr_scheduler,
  56. "interval": "step",
  57. },
  58. }
  59. # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
  60. def get_batch_logps(
  61. self,
  62. logits: torch.FloatTensor,
  63. labels: torch.LongTensor,
  64. average_log_prob: bool = False,
  65. ) -> torch.FloatTensor:
  66. """Compute the log probabilities of the given labels under the given logits.
  67. Args:
  68. logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
  69. labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
  70. average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
  71. Returns:
  72. A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
  73. """
  74. assert logits.shape[:-1] == labels.shape
  75. labels = labels.clone()
  76. loss_mask = labels != -100
  77. # dummy token; we'll ignore the losses on these tokens later
  78. labels[labels == -100] = 0
  79. per_token_logps = torch.gather(
  80. logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
  81. ).squeeze(-1)
  82. if average_log_prob:
  83. return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
  84. else:
  85. return (per_token_logps * loss_mask).sum(-1)
  86. def _step(self, batch, batch_idx, stage: str):
  87. is_train = stage == "train"
  88. if is_train:
  89. # Key part to make lora work
  90. # Otherwise the parameters are merged, which lead to incorrect gradients
  91. self.model.train()
  92. # Do positive and negative samples in the same batch to speed up training
  93. labels = batch["labels"]
  94. outputs = self.model(
  95. inp=batch["inputs"],
  96. key_padding_mask=batch["attention_masks"],
  97. labels=batch["labels"],
  98. )
  99. token_logits = outputs.token_logits
  100. codebook_logits = outputs.codebook_logits
  101. # Generate labels
  102. base_loss = F.cross_entropy(
  103. token_logits.view(-1, token_logits.size(-1)),
  104. labels[:, 0].reshape(-1),
  105. ignore_index=-100,
  106. )
  107. token_ids = labels[:, 0]
  108. semantic_mask = (token_ids >= self.model.tokenizer.semantic_begin_id) & (
  109. token_ids <= self.model.tokenizer.semantic_end_id
  110. )
  111. all_codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks]
  112. all_codebook_labels_permuted = all_codebook_labels.permute(0, 2, 1)
  113. filtered_codebook_labels = all_codebook_labels_permuted[semantic_mask]
  114. semantic_loss = F.cross_entropy(
  115. codebook_logits.reshape(-1, codebook_logits.size(-1)),
  116. filtered_codebook_labels.reshape(-1),
  117. ignore_index=-100,
  118. )
  119. loss = base_loss + semantic_loss
  120. self.log(
  121. f"{stage}/loss",
  122. loss,
  123. on_step=is_train,
  124. on_epoch=not is_train,
  125. prog_bar=True,
  126. logger=True,
  127. sync_dist=not is_train,
  128. )
  129. self.log(
  130. f"{stage}/base_loss",
  131. base_loss,
  132. on_step=is_train,
  133. on_epoch=not is_train,
  134. prog_bar=False,
  135. logger=True,
  136. sync_dist=not is_train,
  137. )
  138. self.log(
  139. f"{stage}/semantic_loss",
  140. semantic_loss,
  141. on_step=is_train,
  142. on_epoch=not is_train,
  143. prog_bar=False,
  144. logger=True,
  145. sync_dist=not is_train,
  146. )
  147. # Top-5 accuracy
  148. accuracy = self.get_accuracy(codebook_logits, filtered_codebook_labels)
  149. self.log(
  150. f"{stage}/top_5_accuracy",
  151. accuracy,
  152. on_step=is_train,
  153. on_epoch=not is_train,
  154. prog_bar=True,
  155. logger=True,
  156. sync_dist=not is_train,
  157. )
  158. return loss
  159. def get_accuracy(self, logits, labels):
  160. mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
  161. if mask.sum() == 0:
  162. return torch.tensor(0.0, device=logits.device)
  163. _, indices = logits.topk(5, dim=-1)
  164. correct = indices.eq(labels.unsqueeze(-1))
  165. correct[~mask] = 0
  166. correct = correct.sum()
  167. accuracy = correct / mask.sum()
  168. return accuracy
  169. def training_step(self, batch, batch_idx):
  170. return self._step(batch, batch_idx, "train")
  171. def validation_step(self, batch, batch_idx):
  172. return self._step(batch, batch_idx, "val")