lit_module.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from typing import Any
  2. import lightning as L
  3. import torch.nn.functional as F
  4. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  5. class TextToSemantic(L.LightningModule):
  6. def __init__(self, model, optimizer: Any, lr_scheduler: Any):
  7. super().__init__()
  8. self.model = model
  9. self.optimizer_builder = optimizer
  10. self.lr_scheduler_builder = lr_scheduler
  11. def forward(self, x):
  12. return self.model(x)
  13. def configure_optimizers(self) -> OptimizerLRScheduler:
  14. optimizer = self.optimizer_builder(self.parameters())
  15. lr_scheduler = self.lr_scheduler_builder(optimizer)
  16. return {
  17. "optimizer": optimizer,
  18. "lr_scheduler": {
  19. "scheduler": lr_scheduler,
  20. "interval": "step",
  21. },
  22. }
  23. def _step(self, batch, batch_idx, stage: str):
  24. outputs = self.model(
  25. x=batch["inputs"],
  26. key_padding_mask=batch["attention_masks"],
  27. )
  28. # Generate labels
  29. labels = batch["labels"]
  30. loss = F.cross_entropy(
  31. outputs.token_logits.reshape(-1, outputs.token_logits.size(-1)),
  32. labels[:, 0].reshape(-1),
  33. ignore_index=-100,
  34. )
  35. # If we have a codebook, add the loss
  36. if self.model.config.num_codebooks != 0:
  37. codebook_labels = labels[:, 1:].mT
  38. semantic_loss = F.cross_entropy(
  39. outputs.codebook_logits.reshape(-1, outputs.codebook_logits.size(-1)),
  40. codebook_labels.reshape(-1),
  41. ignore_index=-100,
  42. )
  43. loss = loss + semantic_loss
  44. self.log(
  45. f"{stage}/loss",
  46. loss,
  47. on_step=True,
  48. on_epoch=False,
  49. prog_bar=True,
  50. logger=True,
  51. )
  52. # Top-5 accuracy
  53. if self.model.config.num_codebooks == 0:
  54. _, indices = outputs.token_logits.topk(5, dim=-1)
  55. correct = indices.eq(labels[:, 0].unsqueeze(-1))
  56. correct[labels[:, 0] == -100] = 0
  57. correct = correct.sum()
  58. accuracy = correct / (labels[:, 0] != -100).sum()
  59. else:
  60. _, indices = outputs.codebook_logits.topk(5, dim=-1)
  61. correct = indices.eq(codebook_labels.unsqueeze(-1))
  62. correct[codebook_labels == -100] = 0
  63. correct = correct.sum()
  64. accuracy = correct / (codebook_labels != -100).sum()
  65. self.log(
  66. f"{stage}/top_5_accuracy",
  67. accuracy,
  68. on_step=True,
  69. on_epoch=False,
  70. prog_bar=True,
  71. logger=True,
  72. )
  73. return loss
  74. def training_step(self, batch, batch_idx):
  75. return self._step(batch, batch_idx, "train")
  76. def validation_step(self, batch, batch_idx):
  77. return self._step(batch, batch_idx, "val")