lit_module.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from typing import Any
  2. import lightning as L
  3. import torch.nn.functional as F
  4. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  5. from transformers import LlamaForCausalLM
  6. class TextToSemantic(L.LightningModule):
  7. def __init__(self, model: LlamaForCausalLM, optimizer: Any, lr_scheduler: Any):
  8. super().__init__()
  9. self.model = model
  10. self.optimizer_builder = optimizer
  11. self.lr_scheduler_builder = lr_scheduler
  12. def forward(self, x):
  13. return self.model(x)
  14. def configure_optimizers(self) -> OptimizerLRScheduler:
  15. optimizer = self.optimizer_builder(self.parameters())
  16. lr_scheduler = self.lr_scheduler_builder(optimizer)
  17. return {
  18. "optimizer": optimizer,
  19. "lr_scheduler": {
  20. "scheduler": lr_scheduler,
  21. "interval": "step",
  22. },
  23. }
  24. def _step(self, batch, batch_idx, stage: str):
  25. logits = self.model(
  26. inputs=batch["inputs"],
  27. input_mask=batch["input_mask"],
  28. codes=batch["codes"][..., :-1],
  29. codes_mask=batch["codes_mask"][..., :-1],
  30. )
  31. # Generate labels
  32. labels = batch["codes"][..., 1:].contiguous()
  33. label_mask = batch["codes_mask"][..., 1:]
  34. label_mask = label_mask[:, None, :]
  35. label_mask = label_mask.expand(-1, labels.size(1), -1)
  36. labels = labels.masked_fill(label_mask, -100)
  37. loss = F.cross_entropy(
  38. logits.view(-1, logits.size(-1)),
  39. labels.view(-1),
  40. ignore_index=-100,
  41. )
  42. self.log(
  43. f"{stage}/loss",
  44. loss,
  45. on_step=True,
  46. on_epoch=False,
  47. prog_bar=True,
  48. logger=True,
  49. )
  50. # Top-5 accuracy
  51. _, indices = logits.topk(5, dim=-1)
  52. correct = indices.eq(labels.unsqueeze(-1)).sum()
  53. accuracy = correct / labels.numel()
  54. self.log(
  55. f"{stage}/top_5_accuracy",
  56. accuracy,
  57. on_step=True,
  58. on_epoch=False,
  59. prog_bar=True,
  60. logger=True,
  61. )
  62. return loss
  63. def training_step(self, batch, batch_idx):
  64. return self._step(batch, batch_idx, "train")
  65. def validation_step(self, batch, batch_idx):
  66. return self._step(batch, batch_idx, "val")