lit_module.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. from dataclasses import dataclass
  2. from typing import Any, Optional
  3. import lightning as L
  4. import loralib as lora
  5. import torch
  6. import torch.nn.functional as F
  7. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  8. import fish_speech.utils as utils
  9. from fish_speech.models.text2semantic.llama import Transformer
  10. log = utils.RankedLogger(__name__, rank_zero_only=True)
  11. @dataclass
  12. class LoraConfig:
  13. r: int
  14. lora_alpha: float
  15. lora_dropout: float = 0.0
  16. class TextToSemantic(L.LightningModule):
  17. def __init__(
  18. self,
  19. model: Transformer,
  20. optimizer: Any,
  21. lr_scheduler: Any,
  22. lora_config: Optional[LoraConfig] = None,
  23. save_lora_only: bool = False,
  24. use_dpo: bool = False,
  25. dpo_beta: float = 0.2,
  26. ):
  27. super().__init__()
  28. self.model = model
  29. self.optimizer_builder = optimizer
  30. self.lr_scheduler_builder = lr_scheduler
  31. self.lora_config = lora_config
  32. self.save_lora_only = save_lora_only
  33. self.use_dpo = use_dpo # We don't support reference model yet
  34. self.dpo_beta = dpo_beta
  35. if self.lora_config is not None:
  36. self.setup_lora()
  37. def setup_lora(self):
  38. # Replace the embedding layer with a LoRA layer
  39. self.model.embeddings = lora.Embedding(
  40. num_embeddings=self.model.embeddings.num_embeddings,
  41. embedding_dim=self.model.embeddings.embedding_dim,
  42. padding_idx=self.model.embeddings.padding_idx,
  43. r=self.lora_config.r,
  44. lora_alpha=self.lora_config.lora_alpha,
  45. )
  46. # Replace output layer with a LoRA layer
  47. linears = [(self.model, "output")]
  48. # Replace all linear layers with LoRA layers
  49. for layer in self.model.layers:
  50. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  51. linears.extend(
  52. [
  53. (layer.feed_forward, "w1"),
  54. (layer.feed_forward, "w2"),
  55. (layer.feed_forward, "w3"),
  56. ]
  57. )
  58. for module, layer in linears:
  59. updated_linear = lora.Linear(
  60. in_features=getattr(module, layer).in_features,
  61. out_features=getattr(module, layer).out_features,
  62. bias=getattr(module, layer).bias,
  63. r=self.lora_config.r,
  64. lora_alpha=self.lora_config.lora_alpha,
  65. lora_dropout=self.lora_config.lora_dropout,
  66. )
  67. setattr(module, layer, updated_linear)
  68. # Mark only the LoRA layers as trainable
  69. lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
  70. def forward(self, x):
  71. return self.model(x)
  72. def on_save_checkpoint(self, checkpoint):
  73. if self.lora_config is None or self.save_lora_only is False:
  74. return
  75. # Save only LoRA parameters
  76. state_dict = checkpoint["state_dict"]
  77. for name in list(state_dict.keys()):
  78. if "lora" not in name:
  79. state_dict.pop(name)
  80. def configure_optimizers(self) -> OptimizerLRScheduler:
  81. # Get weight decay parameters
  82. weight_decay_parameters, other_parameters = [], []
  83. for name, param in self.named_parameters():
  84. if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
  85. other_parameters.append(param)
  86. else:
  87. weight_decay_parameters.append(param)
  88. optimizer = self.optimizer_builder(
  89. [
  90. {"params": weight_decay_parameters},
  91. {"params": other_parameters, "weight_decay": 0.0},
  92. ]
  93. )
  94. # Print the parameters and their weight decay
  95. for i in optimizer.param_groups:
  96. log.info(
  97. f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
  98. )
  99. lr_scheduler = self.lr_scheduler_builder(optimizer)
  100. return {
  101. "optimizer": optimizer,
  102. "lr_scheduler": {
  103. "scheduler": lr_scheduler,
  104. "interval": "step",
  105. },
  106. }
  107. # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
  108. def get_batch_logps(
  109. self,
  110. logits: torch.FloatTensor,
  111. labels: torch.LongTensor,
  112. average_log_prob: bool = False,
  113. ) -> torch.FloatTensor:
  114. """Compute the log probabilities of the given labels under the given logits.
  115. Args:
  116. logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
  117. labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
  118. average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
  119. Returns:
  120. A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
  121. """
  122. assert logits.shape[:-1] == labels.shape
  123. labels = labels.clone()
  124. loss_mask = labels != -100
  125. # dummy token; we'll ignore the losses on these tokens later
  126. labels[labels == -100] = 0
  127. per_token_logps = torch.gather(
  128. logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
  129. ).squeeze(-1)
  130. if average_log_prob:
  131. return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
  132. else:
  133. return (per_token_logps * loss_mask).sum(-1)
  134. def _step(self, batch, batch_idx, stage: str):
  135. # Do positive and negative samples in the same batch to speed up training
  136. outputs = self.model(
  137. x=batch["inputs"],
  138. key_padding_mask=batch["attention_masks"],
  139. )
  140. labels = batch["labels"]
  141. token_logits = outputs.token_logits
  142. codebook_logits = outputs.codebook_logits
  143. if self.use_dpo:
  144. # Firtst half is positive, second half is negative
  145. token_logits, negative_token_logits = token_logits.chunk(2)
  146. codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
  147. labels, negative_labels = labels.chunk(2)
  148. # Generate labels
  149. base_loss = F.cross_entropy(
  150. token_logits.reshape(-1, token_logits.size(-1)),
  151. labels[:, 0].reshape(-1),
  152. ignore_index=-100,
  153. )
  154. # If we have a codebook, add the loss
  155. if self.model.config.num_codebooks != 0:
  156. codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
  157. semantic_loss = F.cross_entropy(
  158. codebook_logits.reshape(-1, codebook_logits.size(-1)),
  159. codebook_labels.reshape(-1),
  160. ignore_index=-100,
  161. )
  162. loss = base_loss + semantic_loss
  163. else:
  164. loss = base_loss
  165. # If we use dpo
  166. if self.use_dpo:
  167. negative_codebook_labels = negative_labels[
  168. :, 1 : 1 + self.model.config.num_codebooks
  169. ].mT
  170. positive_codebook_logps = self.get_batch_logps(
  171. codebook_logits, codebook_labels
  172. )
  173. negative_codebook_logps = self.get_batch_logps(
  174. negative_codebook_logits, negative_codebook_labels
  175. )
  176. # TODO: implement the reference model, avoid screwing up the gradients
  177. dpo_loss = -F.logsigmoid(
  178. (positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
  179. ).mean()
  180. chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
  181. rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
  182. reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
  183. chosen_rewards, rejected_rewards = (
  184. chosen_rewards.mean(),
  185. rejected_rewards.mean(),
  186. )
  187. loss = loss + dpo_loss
  188. self.log(
  189. f"{stage}/dpo_loss",
  190. dpo_loss,
  191. on_step=True,
  192. on_epoch=False,
  193. prog_bar=False,
  194. logger=True,
  195. )
  196. self.log(
  197. f"{stage}/chosen_rewards",
  198. chosen_rewards,
  199. on_step=True,
  200. on_epoch=False,
  201. prog_bar=False,
  202. logger=True,
  203. )
  204. self.log(
  205. f"{stage}/rejected_rewards",
  206. rejected_rewards,
  207. on_step=True,
  208. on_epoch=False,
  209. prog_bar=False,
  210. logger=True,
  211. )
  212. self.log(
  213. f"{stage}/reward_accuracy",
  214. reward_accuracy,
  215. on_step=True,
  216. on_epoch=False,
  217. prog_bar=False,
  218. logger=True,
  219. )
  220. self.log(
  221. f"{stage}/loss",
  222. loss,
  223. on_step=True,
  224. on_epoch=False,
  225. prog_bar=True,
  226. logger=True,
  227. )
  228. if self.model.config.num_codebooks != 0:
  229. self.log(
  230. f"{stage}/base_loss",
  231. base_loss,
  232. on_step=True,
  233. on_epoch=False,
  234. prog_bar=False,
  235. logger=True,
  236. )
  237. self.log(
  238. f"{stage}/semantic_loss",
  239. semantic_loss,
  240. on_step=True,
  241. on_epoch=False,
  242. prog_bar=False,
  243. logger=True,
  244. )
  245. # Top-5 accuracy
  246. if self.model.config.num_codebooks == 0:
  247. _, indices = token_logits.topk(5, dim=-1)
  248. correct = indices.eq(labels[:, 0].unsqueeze(-1))
  249. correct[labels[:, 0] == -100] = 0
  250. correct = correct.sum()
  251. accuracy = correct / (labels[:, 0] != -100).sum()
  252. else:
  253. _, indices = codebook_logits.topk(5, dim=-1)
  254. correct = indices.eq(codebook_labels.unsqueeze(-1))
  255. correct[codebook_labels == -100] = 0
  256. correct = correct.sum()
  257. accuracy = correct / (codebook_labels != -100).sum()
  258. self.log(
  259. f"{stage}/top_5_accuracy",
  260. accuracy,
  261. on_step=True,
  262. on_epoch=False,
  263. prog_bar=True,
  264. logger=True,
  265. )
  266. if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
  267. _, indices = codebook_logits[
  268. :, :, : self.model.config.num_in_codebooks
  269. ].topk(5, dim=-1)
  270. codebook_labels = codebook_labels[
  271. :, :, : self.model.config.num_in_codebooks
  272. ]
  273. correct = indices.eq(codebook_labels.unsqueeze(-1))
  274. correct[codebook_labels == -100] = 0
  275. correct = correct.sum()
  276. accuracy = correct / (codebook_labels != -100).sum()
  277. self.log(
  278. f"{stage}/top_5_accuracy_in",
  279. accuracy,
  280. on_step=True,
  281. on_epoch=False,
  282. prog_bar=True,
  283. logger=True,
  284. )
  285. return loss
  286. def training_step(self, batch, batch_idx):
  287. return self._step(batch, batch_idx, "train")
  288. def validation_step(self, batch, batch_idx):
  289. return self._step(batch, batch_idx, "val")