train.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. from pathlib import Path
  2. import time
  3. import hydra
  4. import torch
  5. from lightning.fabric import Fabric
  6. from natsort import natsorted
  7. from omegaconf import DictConfig, OmegaConf
  8. from tqdm import tqdm
  9. from transformers import LlamaForCausalLM
  10. from transformers.utils import is_flash_attn_available
  11. from speech_lm.logger import RankedLogger
  12. # Allow TF32 on Ampere GPUs
  13. torch.set_float32_matmul_precision("high")
  14. torch.backends.cudnn.allow_tf32 = True
  15. # register eval resolver
  16. OmegaConf.register_new_resolver("eval", eval)
  17. log = RankedLogger(__name__, rank_zero_only=True)
  18. def train(
  19. model: LlamaForCausalLM,
  20. optimizer: torch.optim.Optimizer,
  21. scheduler: torch.optim.lr_scheduler._LRScheduler,
  22. dataloader: torch.utils.data.DataLoader,
  23. global_step: int,
  24. fabric: Fabric,
  25. cfg: DictConfig,
  26. ):
  27. model.train()
  28. bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
  29. bar.update(global_step)
  30. accumulate_steps = 0
  31. optimizer.zero_grad()
  32. start_time = time.time()
  33. while global_step < cfg.schedule.max_steps:
  34. for batch in dataloader:
  35. # Accumulate gradients
  36. is_accumulating = (
  37. accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
  38. )
  39. accumulate_steps += 1
  40. # Train one step
  41. with fabric.no_backward_sync(model, enabled=is_accumulating):
  42. outputs = model(**batch)
  43. loss = outputs.loss
  44. metrics = getattr(outputs, "metrics", {})
  45. fabric.backward(loss)
  46. if is_accumulating:
  47. continue
  48. # Perform gradient clipping
  49. grad_norm = fabric.clip_gradients(
  50. model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
  51. )
  52. # Update
  53. optimizer.step()
  54. optimizer.zero_grad()
  55. scheduler.step()
  56. fabric.log_dict(
  57. {
  58. "train/loss": loss,
  59. "train/lr": optimizer.param_groups[0]["lr"],
  60. "train/grad_norm": grad_norm,
  61. **{f"train/{k}": v for k, v in metrics.items()},
  62. },
  63. step=global_step,
  64. )
  65. global_step += 1
  66. bar.update(1)
  67. if global_step % cfg.schedule.log_interval == 0:
  68. step_time = (time.time() - start_time) / cfg.schedule.log_interval
  69. log.info(
  70. f"[{global_step}/{cfg.schedule.max_steps}] loss: {loss:.4f} "
  71. + f"step time: {step_time:.2f}s "
  72. f"lr: {optimizer.param_groups[0]['lr']:.2e} "
  73. + f"grad_norm: {grad_norm:.2f} "
  74. + f"ETA: {step_time * (cfg.schedule.max_steps - global_step):.2f}s"
  75. )
  76. start_time = time.time()
  77. if global_step % cfg.schedule.save_interval == 0:
  78. fabric.save(
  79. Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
  80. {
  81. "model": model,
  82. "optimizer": optimizer,
  83. "scheduler": scheduler.state_dict(),
  84. "global_step": global_step,
  85. },
  86. )
  87. if global_step >= cfg.schedule.max_steps:
  88. break
  89. @hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
  90. def main(cfg: DictConfig):
  91. log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
  92. if is_flash_attn_available() is False:
  93. log.warning("Flash attention is not available, using default attention")
  94. fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
  95. fabric.launch()
  96. log.info(f"Fabric: {fabric}")
  97. model = hydra.utils.instantiate(cfg.model)
  98. log.info(f"Model: {repr(model)}")
  99. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  100. freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
  101. log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
  102. log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
  103. optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
  104. scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
  105. log.info(f"Optimizer: {optimizer}")
  106. log.info(f"Scheduler: {scheduler}")
  107. log.info(f"Setup fabric model & dataset")
  108. model = fabric.setup_module(model)
  109. optimizer = fabric.setup_optimizers(optimizer)
  110. # Build state
  111. global_step = 0
  112. # Restore training from checkpoint
  113. checkpoint_dir = Path(cfg.paths.checkpoint_dir)
  114. checkpoint_dir.mkdir(parents=True, exist_ok=True)
  115. # Alphabetically sort checkpoints
  116. checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
  117. if len(checkpoints) > 0:
  118. checkpoint_path = checkpoints[-1]
  119. log.info(f"Restoring checkpoint from {checkpoint_path}")
  120. remainder = fabric.load(
  121. checkpoint_path,
  122. {
  123. "model": model,
  124. "optimizer": optimizer,
  125. "scheduler": scheduler,
  126. },
  127. )
  128. global_step = remainder["global_step"]
  129. log.info(f"Restored global step: {global_step}")
  130. train_dataloader = hydra.utils.instantiate(cfg.dataloader)
  131. log.info(f"Dataloader: {train_dataloader}")
  132. train_dataloader = fabric.setup_dataloaders(train_dataloader)
  133. log.info(f"Begin training")
  134. train(
  135. model=model,
  136. optimizer=optimizer,
  137. scheduler=scheduler,
  138. dataloader=train_dataloader,
  139. global_step=global_step,
  140. fabric=fabric,
  141. cfg=cfg,
  142. )
  143. if __name__ == "__main__":
  144. main()