lit_module.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from typing import Any
  2. import lightning as L
  3. import torch.nn.functional as F
  4. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  5. class TextToSemantic(L.LightningModule):
  6. def __init__(self, model, optimizer: Any, lr_scheduler: Any):
  7. super().__init__()
  8. self.model = model
  9. self.optimizer_builder = optimizer
  10. self.lr_scheduler_builder = lr_scheduler
  11. def forward(self, x):
  12. return self.model(x)
  13. def configure_optimizers(self) -> OptimizerLRScheduler:
  14. optimizer = self.optimizer_builder(self.parameters())
  15. lr_scheduler = self.lr_scheduler_builder(optimizer)
  16. return {
  17. "optimizer": optimizer,
  18. "lr_scheduler": {
  19. "scheduler": lr_scheduler,
  20. "interval": "step",
  21. },
  22. }
  23. def _step(self, batch, batch_idx, stage: str):
  24. outputs = self.model(
  25. x=batch["inputs"],
  26. key_padding_mask=batch["attention_masks"],
  27. )
  28. # Generate labels
  29. labels = batch["labels"]
  30. token_loss = F.cross_entropy(
  31. outputs.token_logits.reshape(-1, outputs.token_logits.size(-1)),
  32. labels[:, 0].reshape(-1),
  33. ignore_index=-100,
  34. )
  35. codebook_labels = labels[:, 1:].mT
  36. semantic_loss = F.cross_entropy(
  37. outputs.codebook_logits.reshape(-1, outputs.codebook_logits.size(-1)),
  38. codebook_labels.reshape(-1),
  39. ignore_index=-100,
  40. )
  41. loss = token_loss + semantic_loss
  42. self.log(
  43. f"{stage}/loss",
  44. loss,
  45. on_step=True,
  46. on_epoch=False,
  47. prog_bar=True,
  48. logger=True,
  49. )
  50. # Top-5 accuracy
  51. _, indices = outputs.codebook_logits.topk(5, dim=-1)
  52. correct = indices.eq(codebook_labels.unsqueeze(-1))
  53. correct[codebook_labels == -100] = 0
  54. correct = correct.sum()
  55. accuracy = correct / (codebook_labels != -100).sum()
  56. self.log(
  57. f"{stage}/top_5_accuracy",
  58. accuracy,
  59. on_step=True,
  60. on_epoch=False,
  61. prog_bar=True,
  62. logger=True,
  63. )
  64. return loss
  65. def training_step(self, batch, batch_idx):
  66. return self._step(batch, batch_idx, "train")
  67. def validation_step(self, batch, batch_idx):
  68. return self._step(batch, batch_idx, "val")