train.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import time
  2. from datetime import timedelta
  3. from pathlib import Path
  4. import hydra
  5. import torch
  6. from lightning.fabric import Fabric
  7. from natsort import natsorted
  8. from omegaconf import DictConfig, OmegaConf
  9. from tqdm import tqdm
  10. from transformers import LlamaForCausalLM
  11. from transformers.utils import is_flash_attn_available
  12. from speech_lm.logger import RankedLogger
  13. # Allow TF32 on Ampere GPUs
  14. torch.set_float32_matmul_precision("high")
  15. torch.backends.cudnn.allow_tf32 = True
  16. # register eval resolver
  17. OmegaConf.register_new_resolver("eval", eval)
  18. log = RankedLogger(__name__, rank_zero_only=True)
  19. def train(
  20. model: LlamaForCausalLM,
  21. optimizer: torch.optim.Optimizer,
  22. scheduler: torch.optim.lr_scheduler._LRScheduler,
  23. dataloader: torch.utils.data.DataLoader,
  24. global_step: int,
  25. fabric: Fabric,
  26. cfg: DictConfig,
  27. ):
  28. model.train()
  29. bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
  30. bar.update(global_step)
  31. accumulate_steps = 0
  32. optimizer.zero_grad()
  33. start_time = time.time()
  34. while global_step < cfg.schedule.max_steps:
  35. for batch in dataloader:
  36. # Accumulate gradients
  37. is_accumulating = (
  38. accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
  39. )
  40. accumulate_steps += 1
  41. # Train one step
  42. with fabric.no_backward_sync(model, enabled=is_accumulating):
  43. outputs = model(**batch)
  44. loss = outputs.loss
  45. metrics = getattr(outputs, "metrics", {})
  46. fabric.backward(loss)
  47. if is_accumulating:
  48. continue
  49. # Perform gradient clipping
  50. grad_norm = fabric.clip_gradients(
  51. model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
  52. )
  53. # Update
  54. optimizer.step()
  55. optimizer.zero_grad()
  56. scheduler.step()
  57. fabric.log_dict(
  58. {
  59. "train/loss": loss,
  60. "train/lr": optimizer.param_groups[0]["lr"],
  61. "train/grad_norm": grad_norm,
  62. **{f"train/{k}": v for k, v in metrics.items()},
  63. },
  64. step=global_step,
  65. )
  66. global_step += 1
  67. bar.update(1)
  68. if global_step % cfg.schedule.log_interval == 0:
  69. step_time = (time.time() - start_time) / cfg.schedule.log_interval
  70. eta = step_time * (cfg.schedule.max_steps - global_step)
  71. log.info(
  72. f"[{global_step}/{cfg.schedule.max_steps}] loss: {loss:.4f} "
  73. + f"step time: {step_time:.2f}s "
  74. f"lr: {optimizer.param_groups[0]['lr']:.2e} "
  75. + f"grad_norm: {grad_norm:.2f} "
  76. + f"ETA: {timedelta(round(eta))}s"
  77. )
  78. start_time = time.time()
  79. if global_step % cfg.schedule.save_interval == 0:
  80. fabric.save(
  81. Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
  82. {
  83. "model": model,
  84. "optimizer": optimizer,
  85. "scheduler": scheduler.state_dict(),
  86. "global_step": global_step,
  87. },
  88. )
  89. if global_step >= cfg.schedule.max_steps:
  90. break
  91. @hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
  92. def main(cfg: DictConfig):
  93. log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
  94. if is_flash_attn_available() is False:
  95. log.warning("Flash attention is not available, using default attention")
  96. fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
  97. fabric.launch()
  98. log.info(f"Fabric: {fabric}")
  99. model = hydra.utils.instantiate(cfg.model)
  100. log.info(f"Model: {repr(model)}")
  101. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  102. freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
  103. log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
  104. log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
  105. optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
  106. scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
  107. log.info(f"Optimizer: {optimizer}")
  108. log.info(f"Scheduler: {scheduler}")
  109. log.info(f"Setup fabric model & dataset")
  110. model = fabric.setup_module(model)
  111. optimizer = fabric.setup_optimizers(optimizer)
  112. # Build state
  113. global_step = 0
  114. # Restore training from checkpoint
  115. checkpoint_dir = Path(cfg.paths.checkpoint_dir)
  116. checkpoint_dir.mkdir(parents=True, exist_ok=True)
  117. # Alphabetically sort checkpoints
  118. checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
  119. if len(checkpoints) > 0:
  120. checkpoint_path = checkpoints[-1]
  121. log.info(f"Restoring checkpoint from {checkpoint_path}")
  122. remainder = fabric.load(
  123. checkpoint_path,
  124. {
  125. "model": model,
  126. "optimizer": optimizer,
  127. "scheduler": scheduler,
  128. },
  129. )
  130. global_step = remainder["global_step"]
  131. log.info(f"Restored global step: {global_step}")
  132. train_dataloader = hydra.utils.instantiate(cfg.dataloader)
  133. log.info(f"Dataloader: {train_dataloader}")
  134. train_dataloader = fabric.setup_dataloaders(train_dataloader)
  135. log.info(f"Begin training")
  136. train(
  137. model=model,
  138. optimizer=optimizer,
  139. scheduler=scheduler,
  140. dataloader=train_dataloader,
  141. global_step=global_step,
  142. fabric=fabric,
  143. cfg=cfg,
  144. )
  145. if __name__ == "__main__":
  146. main()