lit_module.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import platform
  2. from dataclasses import dataclass
  3. from typing import Any, Dict, Optional
  4. import lightning as L
  5. import loralib as lora
  6. import torch
  7. import torch.nn.functional as F
  8. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  9. import fish_speech.utils as utils
  10. from fish_speech.models.text2semantic.llama import Transformer
  11. log = utils.RankedLogger(__name__, rank_zero_only=True)
  12. @dataclass
  13. class LoraConfig:
  14. r: int
  15. lora_alpha: float
  16. lora_dropout: float = 0.0
  17. class TextToSemantic(L.LightningModule):
  18. def __init__(
  19. self,
  20. model: Transformer,
  21. optimizer: Any,
  22. lr_scheduler: Any,
  23. lora_config: Optional[LoraConfig] = None,
  24. ):
  25. super().__init__()
  26. self.model = model
  27. self.optimizer_builder = optimizer
  28. self.lr_scheduler_builder = lr_scheduler
  29. self.lora_config = lora_config
  30. if self.lora_config is not None:
  31. self.setup_lora()
  32. def setup_lora(self):
  33. # Replace the embedding layer with a LoRA layer
  34. self.model.embeddings = lora.Embedding(
  35. num_embeddings=self.model.embeddings.num_embeddings,
  36. embedding_dim=self.model.embeddings.embedding_dim,
  37. padding_idx=self.model.embeddings.padding_idx,
  38. r=self.lora_config.r,
  39. lora_alpha=self.lora_config.lora_alpha,
  40. )
  41. # Replace output layer with a LoRA layer
  42. linears = [(self.model, "output")]
  43. # Replace all linear layers with LoRA layers
  44. for layer in self.model.layers:
  45. linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
  46. linears.extend(
  47. [
  48. (layer.feed_forward, "w1"),
  49. (layer.feed_forward, "w2"),
  50. (layer.feed_forward, "w3"),
  51. ]
  52. )
  53. for module, layer in linears:
  54. updated_linear = lora.Linear(
  55. in_features=getattr(module, layer).in_features,
  56. out_features=getattr(module, layer).out_features,
  57. bias=getattr(module, layer).bias,
  58. r=self.lora_config.r,
  59. lora_alpha=self.lora_config.lora_alpha,
  60. lora_dropout=self.lora_config.lora_dropout,
  61. )
  62. setattr(module, layer, updated_linear)
  63. # Mark only the LoRA layers as trainable
  64. lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
  65. def forward(self, x):
  66. return self.model(x)
  67. def configure_optimizers(self) -> OptimizerLRScheduler:
  68. optimizer = self.optimizer_builder(self.parameters())
  69. lr_scheduler = self.lr_scheduler_builder(optimizer)
  70. return {
  71. "optimizer": optimizer,
  72. "lr_scheduler": {
  73. "scheduler": lr_scheduler,
  74. "interval": "step",
  75. },
  76. }
  77. def _step(self, batch, batch_idx, stage: str):
  78. outputs = self.model(
  79. x=batch["inputs"],
  80. key_padding_mask=batch["attention_masks"],
  81. )
  82. # Generate labels
  83. labels = batch["labels"]
  84. base_loss = F.cross_entropy(
  85. outputs.token_logits.reshape(-1, outputs.token_logits.size(-1)),
  86. labels[:, 0].reshape(-1),
  87. ignore_index=-100,
  88. )
  89. # If we have a codebook, add the loss
  90. if self.model.config.num_codebooks != 0:
  91. codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
  92. semantic_loss = F.cross_entropy(
  93. outputs.codebook_logits.reshape(-1, outputs.codebook_logits.size(-1)),
  94. codebook_labels.reshape(-1),
  95. ignore_index=-100,
  96. )
  97. loss = base_loss + semantic_loss
  98. else:
  99. loss = base_loss
  100. self.log(
  101. f"{stage}/loss",
  102. loss,
  103. on_step=True,
  104. on_epoch=False,
  105. prog_bar=True,
  106. logger=True,
  107. )
  108. if self.model.config.num_codebooks != 0:
  109. self.log(
  110. f"{stage}/base_loss",
  111. base_loss,
  112. on_step=True,
  113. on_epoch=False,
  114. prog_bar=False,
  115. logger=True,
  116. )
  117. self.log(
  118. f"{stage}/semantic_loss",
  119. semantic_loss,
  120. on_step=True,
  121. on_epoch=False,
  122. prog_bar=False,
  123. logger=True,
  124. )
  125. # Top-5 accuracy
  126. if self.model.config.num_codebooks == 0:
  127. _, indices = outputs.token_logits.topk(5, dim=-1)
  128. correct = indices.eq(labels[:, 0].unsqueeze(-1))
  129. correct[labels[:, 0] == -100] = 0
  130. correct = correct.sum()
  131. accuracy = correct / (labels[:, 0] != -100).sum()
  132. else:
  133. _, indices = outputs.codebook_logits.topk(5, dim=-1)
  134. # print(codebook_labels[0, :10], torch.argmax(outputs.codebook_logits[0, :10], dim=-1))
  135. # print(codebook_labels[codebook_labels != -100][:10], indices[codebook_labels != -100][:10])
  136. correct = indices.eq(codebook_labels.unsqueeze(-1))
  137. correct[codebook_labels == -100] = 0
  138. correct = correct.sum()
  139. accuracy = correct / (codebook_labels != -100).sum()
  140. self.log(
  141. f"{stage}/top_5_accuracy",
  142. accuracy,
  143. on_step=True,
  144. on_epoch=False,
  145. prog_bar=True,
  146. logger=True,
  147. )
  148. return loss
  149. def training_step(self, batch, batch_idx):
  150. return self._step(batch, batch_idx, "train")
  151. def validation_step(self, batch, batch_idx):
  152. return self._step(batch, batch_idx, "val")