lit_module.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. from typing import Any, Optional
  2. import lightning as L
  3. import torch
  4. import torch.nn.functional as F
  5. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  6. import fish_speech.utils as utils
  7. from fish_speech.models.text2semantic.llama import NaiveTransformer
  8. from fish_speech.models.text2semantic.lora_utils import LoraConfig, setup_lora
  9. log = utils.RankedLogger(__name__, rank_zero_only=True)
  10. class TextToSemantic(L.LightningModule):
  11. def __init__(
  12. self,
  13. model: NaiveTransformer,
  14. optimizer: Any,
  15. lr_scheduler: Any,
  16. lora_config: Optional[LoraConfig] = None,
  17. use_dpo: bool = False,
  18. dpo_beta: float = 0.2,
  19. ):
  20. super().__init__()
  21. self.model = model
  22. self.optimizer_builder = optimizer
  23. self.lr_scheduler_builder = lr_scheduler
  24. self.lora_config = lora_config
  25. self.use_dpo = use_dpo # We don't support reference model yet
  26. self.dpo_beta = dpo_beta
  27. if self.lora_config is not None:
  28. setup_lora(self.model, self.lora_config)
  29. def forward(self, x):
  30. return self.model(x)
  31. def on_save_checkpoint(self, checkpoint):
  32. if self.lora_config is None:
  33. return
  34. # Save only LoRA parameters
  35. state_dict = checkpoint["state_dict"]
  36. for name in list(state_dict.keys()):
  37. if "lora" not in name:
  38. state_dict.pop(name)
  39. def configure_optimizers(self) -> OptimizerLRScheduler:
  40. # Get weight decay parameters
  41. weight_decay_parameters, other_parameters = [], []
  42. for name, param in self.named_parameters():
  43. if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
  44. other_parameters.append(param)
  45. else:
  46. weight_decay_parameters.append(param)
  47. optimizer = self.optimizer_builder(
  48. [
  49. {"params": weight_decay_parameters},
  50. {"params": other_parameters, "weight_decay": 0.0},
  51. ]
  52. )
  53. # Print the parameters and their weight decay
  54. for i in optimizer.param_groups:
  55. log.info(
  56. f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
  57. )
  58. lr_scheduler = self.lr_scheduler_builder(optimizer)
  59. return {
  60. "optimizer": optimizer,
  61. "lr_scheduler": {
  62. "scheduler": lr_scheduler,
  63. "interval": "step",
  64. },
  65. }
  66. # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
  67. def get_batch_logps(
  68. self,
  69. logits: torch.FloatTensor,
  70. labels: torch.LongTensor,
  71. average_log_prob: bool = False,
  72. ) -> torch.FloatTensor:
  73. """Compute the log probabilities of the given labels under the given logits.
  74. Args:
  75. logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
  76. labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
  77. average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
  78. Returns:
  79. A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
  80. """
  81. assert logits.shape[:-1] == labels.shape
  82. labels = labels.clone()
  83. loss_mask = labels != -100
  84. # dummy token; we'll ignore the losses on these tokens later
  85. labels[labels == -100] = 0
  86. per_token_logps = torch.gather(
  87. logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
  88. ).squeeze(-1)
  89. if average_log_prob:
  90. return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
  91. else:
  92. return (per_token_logps * loss_mask).sum(-1)
  93. def _step(self, batch, batch_idx, stage: str):
  94. is_train = stage == "train"
  95. if is_train:
  96. # Key part to make lora work
  97. # Otherwise the parameters are merged, which lead to incorrect gradients
  98. self.model.train()
  99. # Do positive and negative samples in the same batch to speed up training
  100. labels = batch["labels"]
  101. outputs = self.model(
  102. inp=batch["inputs"],
  103. key_padding_mask=batch["attention_masks"],
  104. )
  105. token_logits = outputs.token_logits
  106. codebook_logits = outputs.codebook_logits
  107. if self.use_dpo:
  108. # Firtst half is positive, second half is negative
  109. token_logits, negative_token_logits = token_logits.chunk(2)
  110. codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
  111. labels, negative_labels = labels.chunk(2)
  112. # Generate labels
  113. base_loss = F.cross_entropy(
  114. token_logits.reshape(-1, token_logits.size(-1)),
  115. labels[:, 0].reshape(-1),
  116. ignore_index=-100,
  117. )
  118. codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
  119. semantic_loss = F.cross_entropy(
  120. codebook_logits.reshape(-1, codebook_logits.size(-1)),
  121. codebook_labels.reshape(-1),
  122. ignore_index=-100,
  123. )
  124. loss = base_loss + semantic_loss
  125. # If we use dpo
  126. if self.use_dpo:
  127. negative_codebook_labels = negative_labels[
  128. :, 1 : 1 + self.model.config.num_codebooks
  129. ].mT
  130. positive_codebook_logps = self.get_batch_logps(
  131. codebook_logits, codebook_labels
  132. )
  133. negative_codebook_logps = self.get_batch_logps(
  134. negative_codebook_logits, negative_codebook_labels
  135. )
  136. # TODO: implement the reference model, avoid screwing up the gradients
  137. dpo_loss = -F.logsigmoid(
  138. (positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
  139. ).mean()
  140. chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
  141. rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
  142. reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
  143. chosen_rewards, rejected_rewards = (
  144. chosen_rewards.mean(),
  145. rejected_rewards.mean(),
  146. )
  147. loss = loss + dpo_loss
  148. self.log(
  149. f"{stage}/dpo_loss",
  150. dpo_loss,
  151. on_step=is_train,
  152. on_epoch=not is_train,
  153. prog_bar=False,
  154. logger=True,
  155. sync_dist=not is_train,
  156. )
  157. self.log(
  158. f"{stage}/chosen_rewards",
  159. chosen_rewards,
  160. on_step=is_train,
  161. on_epoch=not is_train,
  162. prog_bar=False,
  163. logger=True,
  164. sync_dist=not is_train,
  165. )
  166. self.log(
  167. f"{stage}/rejected_rewards",
  168. rejected_rewards,
  169. on_step=is_train,
  170. on_epoch=not is_train,
  171. prog_bar=False,
  172. logger=True,
  173. sync_dist=not is_train,
  174. )
  175. self.log(
  176. f"{stage}/reward_accuracy",
  177. reward_accuracy,
  178. on_step=is_train,
  179. on_epoch=not is_train,
  180. prog_bar=False,
  181. logger=True,
  182. sync_dist=not is_train,
  183. )
  184. self.log(
  185. f"{stage}/loss",
  186. loss,
  187. on_step=is_train,
  188. on_epoch=not is_train,
  189. prog_bar=True,
  190. logger=True,
  191. sync_dist=not is_train,
  192. )
  193. self.log(
  194. f"{stage}/base_loss",
  195. base_loss,
  196. on_step=is_train,
  197. on_epoch=not is_train,
  198. prog_bar=False,
  199. logger=True,
  200. sync_dist=not is_train,
  201. )
  202. self.log(
  203. f"{stage}/semantic_loss",
  204. semantic_loss,
  205. on_step=is_train,
  206. on_epoch=not is_train,
  207. prog_bar=False,
  208. logger=True,
  209. sync_dist=not is_train,
  210. )
  211. # Top-5 accuracy
  212. accuracy = self.get_accuracy(codebook_logits, codebook_labels)
  213. self.log(
  214. f"{stage}/top_5_accuracy",
  215. accuracy,
  216. on_step=is_train,
  217. on_epoch=not is_train,
  218. prog_bar=True,
  219. logger=True,
  220. sync_dist=not is_train,
  221. )
  222. if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
  223. accuracy = self.get_accuracy(
  224. codebook_logits[:, :, : self.model.config.num_in_codebooks],
  225. codebook_labels[:, :, : self.model.config.num_in_codebooks],
  226. )
  227. self.log(
  228. f"{stage}/top_5_accuracy_in",
  229. accuracy,
  230. on_step=is_train,
  231. on_epoch=not is_train,
  232. prog_bar=True,
  233. logger=True,
  234. sync_dist=not is_train,
  235. )
  236. return loss
  237. def get_accuracy(self, logits, labels):
  238. _, indices = logits.topk(5, dim=-1)
  239. correct = indices.eq(labels.unsqueeze(-1))
  240. correct[labels == -100] = 0
  241. correct = correct.sum()
  242. accuracy = correct / (labels != -100).sum()
  243. return accuracy
  244. def training_step(self, batch, batch_idx):
  245. return self._step(batch, batch_idx, "train")
  246. def validation_step(self, batch, batch_idx):
  247. return self._step(batch, batch_idx, "val")