lit_module.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import platform
  2. from dataclasses import dataclass
  3. from typing import Any, Dict, Optional
  4. import lightning as L
  5. import loralib as lora
  6. import torch
  7. import torch.nn.functional as F
  8. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  9. import fish_speech.utils as utils
  10. from fish_speech.models.text2semantic.llama import Transformer
  11. log = utils.RankedLogger(__name__, rank_zero_only=True)
  12. @dataclass
  13. class LoraConfig:
  14. r: int
  15. lora_alpha: float
  16. lora_dropout: float = 0.0
  17. class TextToSemantic(L.LightningModule):
  18. def __init__(
  19. self,
  20. model: Transformer,
  21. optimizer: Any,
  22. lr_scheduler: Any,
  23. lora_config: Optional[LoraConfig] = None,
  24. ):
  25. super().__init__()
  26. self.model = model
  27. self.optimizer_builder = optimizer
  28. self.lr_scheduler_builder = lr_scheduler
  29. self.lora_config = lora_config
  30. if self.lora_config is not None:
  31. self.setup_lora()
  32. def setup_lora(self):
  33. # Replace the embedding layer with a LoRA layer
  34. self.model.embeddings = lora.Embedding(
  35. num_embeddings=self.model.embeddings.num_embeddings,
  36. embedding_dim=self.model.embeddings.embedding_dim,
  37. padding_idx=self.model.embeddings.padding_idx,
  38. r=self.lora_config.r,
  39. lora_alpha=self.lora_config.lora_alpha,
  40. )
  41. # Replace output layer with a LoRA layer
  42. linears = [(self.model, "output")]
  43. # Replace all linear layers with LoRA layers
  44. for layer in self.model.layers:
  45. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  46. linears.extend(
  47. [
  48. (layer.feed_forward, "w1"),
  49. (layer.feed_forward, "w2"),
  50. (layer.feed_forward, "w3"),
  51. ]
  52. )
  53. for module, layer in linears:
  54. updated_linear = lora.Linear(
  55. in_features=getattr(module, layer).in_features,
  56. out_features=getattr(module, layer).out_features,
  57. bias=getattr(module, layer).bias,
  58. r=self.lora_config.r,
  59. lora_alpha=self.lora_config.lora_alpha,
  60. lora_dropout=self.lora_config.lora_dropout,
  61. )
  62. setattr(module, layer, updated_linear)
  63. # Mark only the LoRA layers as trainable
  64. lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
  65. def forward(self, x):
  66. return self.model(x)
  67. def on_save_checkpoint(self, checkpoint):
  68. if self.lora_config is None:
  69. return
  70. # Save the LoRA parameters
  71. state_dict = checkpoint["state_dict"]
  72. for name in list(state_dict.keys()):
  73. if "lora" not in name:
  74. state_dict.pop(name)
  75. def configure_optimizers(self) -> OptimizerLRScheduler:
  76. optimizer = self.optimizer_builder(self.parameters())
  77. lr_scheduler = self.lr_scheduler_builder(optimizer)
  78. return {
  79. "optimizer": optimizer,
  80. "lr_scheduler": {
  81. "scheduler": lr_scheduler,
  82. "interval": "step",
  83. },
  84. }
  85. def _step(self, batch, batch_idx, stage: str):
  86. outputs = self.model(
  87. x=batch["inputs"],
  88. key_padding_mask=batch["attention_masks"],
  89. )
  90. # Generate labels
  91. labels = batch["labels"]
  92. base_loss = F.cross_entropy(
  93. outputs.token_logits.reshape(-1, outputs.token_logits.size(-1)),
  94. labels[:, 0].reshape(-1),
  95. ignore_index=-100,
  96. )
  97. # If we have a codebook, add the loss
  98. if self.model.config.num_codebooks != 0:
  99. codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
  100. semantic_loss = F.cross_entropy(
  101. outputs.codebook_logits.reshape(-1, outputs.codebook_logits.size(-1)),
  102. codebook_labels.reshape(-1),
  103. ignore_index=-100,
  104. )
  105. loss = base_loss + semantic_loss
  106. else:
  107. loss = base_loss
  108. self.log(
  109. f"{stage}/loss",
  110. loss,
  111. on_step=True,
  112. on_epoch=False,
  113. prog_bar=True,
  114. logger=True,
  115. )
  116. if self.model.config.num_codebooks != 0:
  117. self.log(
  118. f"{stage}/base_loss",
  119. base_loss,
  120. on_step=True,
  121. on_epoch=False,
  122. prog_bar=False,
  123. logger=True,
  124. )
  125. self.log(
  126. f"{stage}/semantic_loss",
  127. semantic_loss,
  128. on_step=True,
  129. on_epoch=False,
  130. prog_bar=False,
  131. logger=True,
  132. )
  133. # Top-5 accuracy
  134. if self.model.config.num_codebooks == 0:
  135. _, indices = outputs.token_logits.topk(5, dim=-1)
  136. correct = indices.eq(labels[:, 0].unsqueeze(-1))
  137. correct[labels[:, 0] == -100] = 0
  138. correct = correct.sum()
  139. accuracy = correct / (labels[:, 0] != -100).sum()
  140. else:
  141. _, indices = outputs.codebook_logits.topk(5, dim=-1)
  142. # print(codebook_labels[0, :10], torch.argmax(outputs.codebook_logits[0, :10], dim=-1))
  143. # print(codebook_labels[codebook_labels != -100][:10], indices[codebook_labels != -100][:10])
  144. correct = indices.eq(codebook_labels.unsqueeze(-1))
  145. correct[codebook_labels == -100] = 0
  146. correct = correct.sum()
  147. accuracy = correct / (codebook_labels != -100).sum()
  148. self.log(
  149. f"{stage}/top_5_accuracy",
  150. accuracy,
  151. on_step=True,
  152. on_epoch=False,
  153. prog_bar=True,
  154. logger=True,
  155. )
  156. return loss
  157. def training_step(self, batch, batch_idx):
  158. return self._step(batch, batch_idx, "train")
  159. def validation_step(self, batch, batch_idx):
  160. return self._step(batch, batch_idx, "val")