train.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from pathlib import Path
  2. import hydra
  3. import pyrootutils
  4. import torch
  5. from lightning.fabric import Fabric
  6. from omegaconf import DictConfig, OmegaConf
  7. from tqdm import tqdm
  8. from transformers import LlamaForCausalLM
  9. from transformers.utils import is_flash_attn_available
  10. # Allow TF32 on Ampere GPUs
  11. torch.set_float32_matmul_precision("high")
  12. torch.backends.cudnn.allow_tf32 = True
  13. # register eval resolver and root
  14. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  15. OmegaConf.register_new_resolver("eval", eval)
  16. # flake8: noqa: E402
  17. from speech_lm.logger import RankedLogger
  18. log = RankedLogger(__name__, rank_zero_only=True)
  19. def train(
  20. model: LlamaForCausalLM,
  21. optimizer: torch.optim.Optimizer,
  22. scheduler: torch.optim.lr_scheduler._LRScheduler,
  23. dataloader: torch.utils.data.DataLoader,
  24. global_step: int,
  25. fabric: Fabric,
  26. cfg: DictConfig,
  27. ):
  28. model.train()
  29. bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
  30. bar.update(global_step)
  31. accumulate_steps = 0
  32. optimizer.zero_grad()
  33. while global_step < cfg.schedule.max_steps:
  34. for batch in dataloader:
  35. # Accumulate gradients
  36. is_accumulating = (
  37. accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
  38. )
  39. accumulate_steps += 1
  40. # Train one step
  41. with fabric.no_backward_sync(model, enabled=is_accumulating):
  42. outputs = model(**batch)
  43. loss = outputs.loss
  44. metrics = getattr(outputs, "metrics", {})
  45. fabric.backward(loss)
  46. if is_accumulating:
  47. continue
  48. # Perform gradient clipping
  49. grad_norm = fabric.clip_gradients(
  50. model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
  51. )
  52. # Update
  53. optimizer.step()
  54. optimizer.zero_grad()
  55. scheduler.step()
  56. fabric.log_dict(
  57. {
  58. "train/loss": loss,
  59. "train/lr": optimizer.param_groups[0]["lr"],
  60. "train/grad_norm": grad_norm,
  61. **{f"train/{k}": v for k, v in metrics.items()},
  62. },
  63. step=global_step,
  64. )
  65. global_step += 1
  66. bar.update(1)
  67. if global_step % cfg.schedule.save_interval == 0:
  68. fabric.save(
  69. Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
  70. {
  71. "model": model,
  72. "optimizer": optimizer,
  73. "scheduler": scheduler,
  74. "global_step": global_step,
  75. },
  76. )
  77. if global_step >= cfg.schedule.max_steps:
  78. break
  79. @hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
  80. def main(cfg: DictConfig):
  81. log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
  82. if is_flash_attn_available() is False:
  83. log.warning("Flash attention is not available, using default attention")
  84. fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
  85. fabric.launch()
  86. log.info(f"Fabric: {fabric}")
  87. model = hydra.utils.instantiate(cfg.model)
  88. log.info(f"Model: {repr(model)}")
  89. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  90. freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
  91. log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
  92. log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
  93. optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
  94. scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
  95. log.info(f"Optimizer: {optimizer}")
  96. log.info(f"Scheduler: {scheduler}")
  97. # Build state
  98. global_step = 0
  99. # Restore training from checkpoint
  100. checkpoint_dir = Path(cfg.paths.checkpoint_dir)
  101. checkpoint_dir.mkdir(parents=True, exist_ok=True)
  102. checkpoint_path = checkpoint_dir / "last.ckpt"
  103. if checkpoint_path.exists():
  104. log.info(f"Restoring checkpoint from {checkpoint_path}")
  105. remainder = fabric.load(
  106. checkpoint_path,
  107. {
  108. "model": model,
  109. "optimizer": optimizer,
  110. "scheduler": scheduler,
  111. },
  112. )
  113. global_step = remainder["global_step"]
  114. log.info(f"Restored global step: {global_step}")
  115. log.info(f"Setup fabric model & dataset")
  116. model, optimizer, scheduler = fabric.setup(model, optimizer, scheduler)
  117. train_dataloader = hydra.utils.instantiate(cfg.dataloader)
  118. log.info(f"Dataloader: {train_dataloader}")
  119. train_dataloader = fabric.setup_dataloaders(train_dataloader)
  120. log.info(f"Begin training")
  121. train(
  122. model=model,
  123. optimizer=optimizer,
  124. scheduler=scheduler,
  125. dataloader=train_dataloader,
  126. global_step=global_step,
  127. fabric=fabric,
  128. cfg=cfg,
  129. )
  130. if __name__ == "__main__":
  131. main()