lit_module.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from typing import Any
  2. import lightning as L
  3. from lightning.pytorch.utilities.types import OptimizerLRScheduler
  4. from transformers import LlamaForCausalLM
  5. class TextToSemantic(L.LightningModule):
  6. def __init__(self, model: LlamaForCausalLM, optimizer: Any, lr_scheduler: Any):
  7. super().__init__()
  8. self.model = model
  9. self.optimizer_builder = optimizer
  10. self.lr_scheduler_builder = lr_scheduler
  11. def forward(self, x):
  12. return self.model(x)
  13. def configure_optimizers(self) -> OptimizerLRScheduler:
  14. optimizer = self.optimizer_builder(self.parameters())
  15. lr_scheduler = self.lr_scheduler_builder(optimizer)
  16. return {
  17. "optimizer": optimizer,
  18. "lr_scheduler": {
  19. "scheduler": lr_scheduler,
  20. "interval": "step",
  21. },
  22. }
  23. def _step(self, batch, batch_idx, stage: str):
  24. result = self.model(**batch)
  25. loss = result.loss
  26. logits = result.logits
  27. self.log(
  28. f"{stage}/loss",
  29. loss,
  30. on_step=True,
  31. on_epoch=False,
  32. prog_bar=True,
  33. logger=True,
  34. )
  35. # Top-5 accuracy
  36. _, indices = logits.topk(5, dim=-1)
  37. correct = indices.eq(batch["labels"].unsqueeze(-1)).sum()
  38. accuracy = correct / batch["labels"].numel()
  39. self.log(
  40. f"{stage}/accuracy",
  41. accuracy,
  42. on_step=True,
  43. on_epoch=False,
  44. prog_bar=True,
  45. logger=True,
  46. )
  47. return loss
  48. def training_step(self, batch, batch_idx):
  49. return self._step(batch, batch_idx, "train")
  50. def validation_step(self, batch, batch_idx):
  51. return self._step(batch, batch_idx, "val")