train.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from pathlib import Path
  2. import hydra
  3. import pyrootutils
  4. import torch
  5. from lightning.fabric import Fabric
  6. from omegaconf import DictConfig, OmegaConf
  7. from tqdm import tqdm
  8. from transformers import LlamaForCausalLM
  9. from transformers.utils import is_flash_attn_available
  10. from natsort import natsorted
  11. # Allow TF32 on Ampere GPUs
  12. torch.set_float32_matmul_precision("high")
  13. torch.backends.cudnn.allow_tf32 = True
  14. # register eval resolver and root
  15. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  16. OmegaConf.register_new_resolver("eval", eval)
  17. # flake8: noqa: E402
  18. from speech_lm.logger import RankedLogger
  19. log = RankedLogger(__name__, rank_zero_only=True)
  20. def train(
  21. model: LlamaForCausalLM,
  22. optimizer: torch.optim.Optimizer,
  23. scheduler: torch.optim.lr_scheduler._LRScheduler,
  24. dataloader: torch.utils.data.DataLoader,
  25. global_step: int,
  26. fabric: Fabric,
  27. cfg: DictConfig,
  28. ):
  29. model.train()
  30. bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
  31. bar.update(global_step)
  32. accumulate_steps = 0
  33. optimizer.zero_grad()
  34. while global_step < cfg.schedule.max_steps:
  35. for batch in dataloader:
  36. # Accumulate gradients
  37. is_accumulating = (
  38. accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
  39. )
  40. accumulate_steps += 1
  41. # Train one step
  42. with fabric.no_backward_sync(model, enabled=is_accumulating):
  43. outputs = model(**batch)
  44. loss = outputs.loss
  45. metrics = getattr(outputs, "metrics", {})
  46. fabric.backward(loss)
  47. if is_accumulating:
  48. continue
  49. # Perform gradient clipping
  50. grad_norm = fabric.clip_gradients(
  51. model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
  52. )
  53. # Update
  54. optimizer.step()
  55. optimizer.zero_grad()
  56. scheduler.step()
  57. fabric.log_dict(
  58. {
  59. "train/loss": loss,
  60. "train/lr": optimizer.param_groups[0]["lr"],
  61. "train/grad_norm": grad_norm,
  62. **{f"train/{k}": v for k, v in metrics.items()},
  63. },
  64. step=global_step,
  65. )
  66. global_step += 1
  67. bar.update(1)
  68. if global_step % cfg.schedule.save_interval == 0:
  69. fabric.save(
  70. Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
  71. {
  72. "model": model,
  73. "optimizer": optimizer,
  74. "scheduler": scheduler,
  75. "global_step": global_step,
  76. },
  77. )
  78. if global_step >= cfg.schedule.max_steps:
  79. break
  80. @hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
  81. def main(cfg: DictConfig):
  82. log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
  83. if is_flash_attn_available() is False:
  84. log.warning("Flash attention is not available, using default attention")
  85. fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
  86. fabric.launch()
  87. log.info(f"Fabric: {fabric}")
  88. model = hydra.utils.instantiate(cfg.model)
  89. log.info(f"Model: {repr(model)}")
  90. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  91. freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
  92. log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
  93. log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
  94. optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
  95. scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
  96. log.info(f"Optimizer: {optimizer}")
  97. log.info(f"Scheduler: {scheduler}")
  98. # Build state
  99. global_step = 0
  100. # Restore training from checkpoint
  101. checkpoint_dir = Path(cfg.paths.checkpoint_dir)
  102. checkpoint_dir.mkdir(parents=True, exist_ok=True)
  103. # Alphabetically sort checkpoints
  104. checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
  105. if len(checkpoints) > 0:
  106. checkpoint_path = checkpoints[-1]
  107. log.info(f"Restoring checkpoint from {checkpoint_path}")
  108. remainder = fabric.load(
  109. checkpoint_path,
  110. {
  111. "model": model,
  112. "optimizer": optimizer,
  113. "scheduler": scheduler,
  114. },
  115. )
  116. global_step = remainder["global_step"]
  117. log.info(f"Restored global step: {global_step}")
  118. log.info(f"Setup fabric model & dataset")
  119. model, optimizer, scheduler = fabric.setup(model, optimizer, scheduler)
  120. train_dataloader = hydra.utils.instantiate(cfg.dataloader)
  121. log.info(f"Dataloader: {train_dataloader}")
  122. train_dataloader = fabric.setup_dataloaders(train_dataloader)
  123. log.info(f"Begin training")
  124. train(
  125. model=model,
  126. optimizer=optimizer,
  127. scheduler=scheduler,
  128. dataloader=train_dataloader,
  129. global_step=global_step,
  130. fabric=fabric,
  131. cfg=cfg,
  132. )
  133. if __name__ == "__main__":
  134. main()