train.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from pathlib import Path
  2. import hydra
  3. import torch
  4. from lightning.fabric import Fabric
  5. from natsort import natsorted
  6. from omegaconf import DictConfig, OmegaConf
  7. from tqdm import tqdm
  8. from transformers import LlamaForCausalLM
  9. from transformers.utils import is_flash_attn_available
  10. from speech_lm.logger import RankedLogger
  11. # Allow TF32 on Ampere GPUs
  12. torch.set_float32_matmul_precision("high")
  13. torch.backends.cudnn.allow_tf32 = True
  14. # register eval resolver
  15. OmegaConf.register_new_resolver("eval", eval)
  16. log = RankedLogger(__name__, rank_zero_only=True)
  17. def train(
  18. model: LlamaForCausalLM,
  19. optimizer: torch.optim.Optimizer,
  20. scheduler: torch.optim.lr_scheduler._LRScheduler,
  21. dataloader: torch.utils.data.DataLoader,
  22. global_step: int,
  23. fabric: Fabric,
  24. cfg: DictConfig,
  25. ):
  26. model.train()
  27. bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
  28. bar.update(global_step)
  29. accumulate_steps = 0
  30. optimizer.zero_grad()
  31. while global_step < cfg.schedule.max_steps:
  32. for batch in dataloader:
  33. # Accumulate gradients
  34. is_accumulating = (
  35. accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
  36. )
  37. accumulate_steps += 1
  38. # Train one step
  39. with fabric.no_backward_sync(model, enabled=is_accumulating):
  40. outputs = model(**batch)
  41. loss = outputs.loss
  42. metrics = getattr(outputs, "metrics", {})
  43. fabric.backward(loss)
  44. if is_accumulating:
  45. continue
  46. # Perform gradient clipping
  47. grad_norm = fabric.clip_gradients(
  48. model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
  49. )
  50. # Update
  51. optimizer.step()
  52. optimizer.zero_grad()
  53. scheduler.step()
  54. fabric.log_dict(
  55. {
  56. "train/loss": loss,
  57. "train/lr": optimizer.param_groups[0]["lr"],
  58. "train/grad_norm": grad_norm,
  59. **{f"train/{k}": v for k, v in metrics.items()},
  60. },
  61. step=global_step,
  62. )
  63. global_step += 1
  64. bar.update(1)
  65. if global_step % cfg.schedule.save_interval == 0:
  66. fabric.save(
  67. Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
  68. {
  69. "model": model,
  70. "optimizer": optimizer,
  71. "scheduler": scheduler,
  72. "global_step": global_step,
  73. },
  74. )
  75. if global_step >= cfg.schedule.max_steps:
  76. break
  77. @hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
  78. def main(cfg: DictConfig):
  79. log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
  80. if is_flash_attn_available() is False:
  81. log.warning("Flash attention is not available, using default attention")
  82. fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
  83. fabric.launch()
  84. log.info(f"Fabric: {fabric}")
  85. model = hydra.utils.instantiate(cfg.model)
  86. log.info(f"Model: {repr(model)}")
  87. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  88. freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
  89. log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
  90. log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
  91. optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
  92. scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
  93. log.info(f"Optimizer: {optimizer}")
  94. log.info(f"Scheduler: {scheduler}")
  95. # Build state
  96. global_step = 0
  97. # Restore training from checkpoint
  98. checkpoint_dir = Path(cfg.paths.checkpoint_dir)
  99. checkpoint_dir.mkdir(parents=True, exist_ok=True)
  100. # Alphabetically sort checkpoints
  101. checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
  102. if len(checkpoints) > 0:
  103. checkpoint_path = checkpoints[-1]
  104. log.info(f"Restoring checkpoint from {checkpoint_path}")
  105. remainder = fabric.load(
  106. checkpoint_path,
  107. {
  108. "model": model,
  109. "optimizer": optimizer,
  110. "scheduler": scheduler,
  111. },
  112. )
  113. global_step = remainder["global_step"]
  114. log.info(f"Restored global step: {global_step}")
  115. log.info(f"Setup fabric model & dataset")
  116. model, optimizer, scheduler = fabric.setup(model, optimizer, scheduler)
  117. train_dataloader = hydra.utils.instantiate(cfg.dataloader)
  118. log.info(f"Dataloader: {train_dataloader}")
  119. train_dataloader = fabric.setup_dataloaders(train_dataloader)
  120. log.info(f"Begin training")
  121. train(
  122. model=model,
  123. optimizer=optimizer,
  124. scheduler=scheduler,
  125. dataloader=train_dataloader,
  126. global_step=global_step,
  127. fabric=fabric,
  128. cfg=cfg,
  129. )
  130. if __name__ == "__main__":
  131. main()