lit_module.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import platform
  2. from typing import Any, Optional
  3. import lightning as L
  4. import torch
  5. import torch.nn.functional as F
  6. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  7. import fish_speech.utils as utils
  8. log = utils.RankedLogger(__name__, rank_zero_only=True)
  9. class TextToSemantic(L.LightningModule):
  10. def __init__(self, model, optimizer: Any, lr_scheduler: Any):
  11. super().__init__()
  12. self.model = model
  13. self.optimizer_builder = optimizer
  14. self.lr_scheduler_builder = lr_scheduler
  15. def forward(self, x):
  16. return self.model(x)
  17. def configure_optimizers(self) -> OptimizerLRScheduler:
  18. optimizer = self.optimizer_builder(self.parameters())
  19. lr_scheduler = self.lr_scheduler_builder(optimizer)
  20. return {
  21. "optimizer": optimizer,
  22. "lr_scheduler": {
  23. "scheduler": lr_scheduler,
  24. "interval": "step",
  25. },
  26. }
  27. def _step(self, batch, batch_idx, stage: str):
  28. outputs = self.model(
  29. x=batch["inputs"],
  30. key_padding_mask=batch["attention_masks"],
  31. )
  32. # Generate labels
  33. labels = batch["labels"]
  34. loss = F.cross_entropy(
  35. outputs.token_logits.reshape(-1, outputs.token_logits.size(-1)),
  36. labels[:, 0].reshape(-1),
  37. ignore_index=-100,
  38. )
  39. # If we have a codebook, add the loss
  40. if self.model.config.num_codebooks != 0:
  41. codebook_labels = labels[:, 1:].mT
  42. semantic_loss = F.cross_entropy(
  43. outputs.codebook_logits.reshape(-1, outputs.codebook_logits.size(-1)),
  44. codebook_labels.reshape(-1),
  45. ignore_index=-100,
  46. )
  47. loss = loss + semantic_loss
  48. self.log(
  49. f"{stage}/loss",
  50. loss,
  51. on_step=True,
  52. on_epoch=False,
  53. prog_bar=True,
  54. logger=True,
  55. )
  56. # Top-5 accuracy
  57. if self.model.config.num_codebooks == 0:
  58. _, indices = outputs.token_logits.topk(5, dim=-1)
  59. correct = indices.eq(labels[:, 0].unsqueeze(-1))
  60. correct[labels[:, 0] == -100] = 0
  61. correct = correct.sum()
  62. accuracy = correct / (labels[:, 0] != -100).sum()
  63. else:
  64. _, indices = outputs.codebook_logits.topk(5, dim=-1)
  65. correct = indices.eq(codebook_labels.unsqueeze(-1))
  66. correct[codebook_labels == -100] = 0
  67. correct = correct.sum()
  68. accuracy = correct / (codebook_labels != -100).sum()
  69. self.log(
  70. f"{stage}/top_5_accuracy",
  71. accuracy,
  72. on_step=True,
  73. on_epoch=False,
  74. prog_bar=True,
  75. logger=True,
  76. )
  77. return loss
  78. def training_step(self, batch, batch_idx):
  79. return self._step(batch, batch_idx, "train")
  80. def validation_step(self, batch, batch_idx):
  81. return self._step(batch, batch_idx, "val")