train.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import time
  2. from collections import defaultdict
  3. from datetime import timedelta
  4. from pathlib import Path
  5. from typing import Optional
  6. import hydra
  7. import torch
  8. from lightning.fabric import Fabric
  9. from natsort import natsorted
  10. from omegaconf import DictConfig, OmegaConf
  11. from tqdm import tqdm
  12. from transformers import LlamaForCausalLM
  13. from transformers.utils import is_flash_attn_available
  14. from speech_lm.logger import RankedLogger
  15. # Allow TF32 on Ampere GPUs
  16. torch.set_float32_matmul_precision("high")
  17. torch.backends.cudnn.allow_tf32 = True
  18. # register eval resolver
  19. OmegaConf.register_new_resolver("eval", eval)
  20. log = RankedLogger(__name__, rank_zero_only=True)
  21. def valid(
  22. model: LlamaForCausalLM,
  23. valid_dataloader: Optional[torch.utils.data.DataLoader],
  24. global_step: int,
  25. fabric: Fabric,
  26. cfg: DictConfig,
  27. ):
  28. model.eval()
  29. log.info(f"Evaluating at step {global_step}")
  30. accumulate_infos = None
  31. for idx, batch in enumerate(tqdm(valid_dataloader, desc="Evaluating")):
  32. outputs = model(**batch)
  33. loss = outputs.loss
  34. metrics = getattr(outputs, "metrics", {})
  35. log_info = {
  36. "valid/loss": float(loss),
  37. **{f"valid/{k}": float(v) for k, v in metrics.items()},
  38. }
  39. fabric.log_dict(
  40. log_info,
  41. step=global_step + idx,
  42. )
  43. # Update log info
  44. if accumulate_infos is None:
  45. accumulate_infos = log_info
  46. else:
  47. assert set(accumulate_infos.keys()) == set(
  48. log_info.keys()
  49. ), "Log keys changed during evaluation"
  50. for k in accumulate_infos.keys():
  51. accumulate_infos[k] += log_info[k]
  52. if idx == getattr(cfg.schedule, "eval_max_batches", None):
  53. break
  54. # Log average
  55. items = []
  56. for k in accumulate_infos.keys():
  57. items.append(f"{k}: {accumulate_infos[k] / (idx + 1):.4f}")
  58. log.info(f"Average: {' | '.join(items)}")
  59. def train(
  60. model: LlamaForCausalLM,
  61. optimizer: torch.optim.Optimizer,
  62. scheduler: torch.optim.lr_scheduler._LRScheduler,
  63. train_dataloader: torch.utils.data.DataLoader,
  64. valid_dataloader: Optional[torch.utils.data.DataLoader],
  65. global_step: int,
  66. fabric: Fabric,
  67. cfg: DictConfig,
  68. ):
  69. accumulate_steps = 0
  70. optimizer.zero_grad()
  71. # Start time is ~model forward time + data loading time
  72. start_time = time.time()
  73. trackers = defaultdict(list)
  74. while global_step < cfg.schedule.max_steps:
  75. last_batch_time = time.time()
  76. for batch in train_dataloader:
  77. # Measure time used by data loading
  78. trackers["data_time"].append(time.time() - last_batch_time)
  79. # Measure time used by model forward
  80. model_begin_time = time.time()
  81. model.train()
  82. # Accumulate gradients
  83. accumulate_steps += 1
  84. is_accumulating = accumulate_steps < cfg.schedule.gradient_accumulation_steps
  85. # Train one step
  86. with fabric.no_backward_sync(model, enabled=is_accumulating):
  87. outputs = model(**batch)
  88. loss = outputs.loss
  89. metrics = getattr(outputs, "metrics", {})
  90. # Need to divide loss by accumulation steps
  91. fabric.backward(loss / cfg.schedule.gradient_accumulation_steps)
  92. # Update trackers
  93. trackers["loss"].append(float(loss))
  94. trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
  95. for k, v in metrics.items():
  96. trackers[f"metrics/{k}"].append(float(v))
  97. trackers["model_time"].append(time.time() - model_begin_time)
  98. if is_accumulating:
  99. last_batch_time = time.time()
  100. continue
  101. # Check all trackers has the same length
  102. assert (
  103. len(set(len(v) for k, v in trackers.items() if k != "grad_norm")) == 1
  104. ), "Trackers has ambiguous length"
  105. # Perform gradient clipping
  106. grad_norm = fabric.clip_gradients(
  107. model,
  108. optimizer,
  109. max_norm=cfg.schedule.clip_grad_norm,
  110. norm_type=2.0,
  111. error_if_nonfinite=True,
  112. )
  113. if torch.isnan(grad_norm) or torch.isinf(grad_norm):
  114. log.warning(f"Gradient norm is {grad_norm}, skipping update")
  115. optimizer.zero_grad()
  116. # We can't average gradients across multiple steps
  117. trackers["grad_norm"].append(float(grad_norm))
  118. # Update
  119. optimizer.step()
  120. optimizer.zero_grad()
  121. scheduler.step()
  122. fabric.log_dict(
  123. {
  124. f"train/{k}": sum(v[-accumulate_steps:])
  125. / len(v[-accumulate_steps:])
  126. for k, v in trackers.items()
  127. },
  128. step=global_step,
  129. )
  130. accumulate_steps = 0
  131. global_step += 1
  132. if global_step % cfg.schedule.log_interval == 0:
  133. step_time = (time.time() - start_time) / cfg.schedule.log_interval
  134. eta = step_time * (cfg.schedule.max_steps - global_step)
  135. additional_info = [
  136. f"{k}: {sum(v[-cfg.schedule.log_interval:]) / len(v[-cfg.schedule.log_interval:]):.4f}"
  137. for k, v in trackers.items()
  138. if k != "lr" # lr use .2e format
  139. ]
  140. log.info(
  141. f"[{global_step}/{cfg.schedule.max_steps}] "
  142. + f"step_time: {step_time:.2f}s "
  143. + f"ETA: {timedelta(seconds=round(eta))}s "
  144. f"lr: {optimizer.param_groups[0]['lr']:.2e} "
  145. + " ".join(additional_info)
  146. )
  147. # Reset trackers
  148. trackers = defaultdict(list)
  149. start_time = time.time()
  150. if global_step % cfg.schedule.save_interval == 0:
  151. fabric.save(
  152. Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
  153. {
  154. "model": model,
  155. "optimizer": optimizer,
  156. "scheduler": scheduler.state_dict(),
  157. "global_step": global_step,
  158. },
  159. )
  160. if (
  161. getattr(cfg.schedule, "eval_interval", None) is not None
  162. and global_step % cfg.schedule.eval_interval == 0
  163. and valid_dataloader is not None
  164. ):
  165. valid(model, valid_dataloader, global_step, fabric, cfg)
  166. if global_step >= cfg.schedule.max_steps:
  167. break
  168. last_batch_time = time.time()
  169. @hydra.main(version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml")
  170. def main(cfg: DictConfig):
  171. log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
  172. if is_flash_attn_available() is False:
  173. log.warning("Flash attention is not available, using default attention")
  174. fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
  175. fabric.launch()
  176. log.info(f"Fabric: {fabric}")
  177. model = hydra.utils.instantiate(cfg.model)
  178. log.info(f"Model: {repr(model)}")
  179. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  180. freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
  181. log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
  182. log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
  183. optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
  184. scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
  185. log.info(f"Optimizer: {optimizer}")
  186. log.info(f"Scheduler: {scheduler}")
  187. log.info(f"Setup fabric model & dataset")
  188. model = fabric.setup_module(model)
  189. optimizer = fabric.setup_optimizers(optimizer)
  190. # Build state
  191. global_step = 0
  192. # Restore training from checkpoint
  193. checkpoint_dir = Path(cfg.paths.checkpoint_dir)
  194. checkpoint_dir.mkdir(parents=True, exist_ok=True)
  195. # Alphabetically sort checkpoints
  196. checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
  197. if len(checkpoints) > 0:
  198. checkpoint_path = checkpoints[-1]
  199. log.info(f"Restoring checkpoint from {checkpoint_path}")
  200. remainder = fabric.load(
  201. checkpoint_path,
  202. {
  203. "model": model,
  204. "optimizer": optimizer,
  205. "scheduler": scheduler,
  206. },
  207. )
  208. global_step = remainder["global_step"]
  209. log.info(f"Restored global step: {global_step}")
  210. train_dataloader = hydra.utils.instantiate(cfg.train_dataloader)
  211. log.info(f"Train Dataloader: {train_dataloader}")
  212. valid_dataloader = None
  213. if getattr(cfg, "valid_dataloader", None) is not None:
  214. valid_dataloader = hydra.utils.instantiate(cfg.valid_dataloader)
  215. log.info(f"Valid Dataloader: {valid_dataloader}")
  216. train_dataloader = fabric.setup_dataloaders(train_dataloader)
  217. if valid_dataloader is not None:
  218. valid_dataloader = fabric.setup_dataloaders(valid_dataloader)
  219. log.info(f"Begin training")
  220. train(
  221. model=model,
  222. optimizer=optimizer,
  223. scheduler=scheduler,
  224. train_dataloader=train_dataloader,
  225. valid_dataloader=valid_dataloader,
  226. global_step=global_step,
  227. fabric=fabric,
  228. cfg=cfg,
  229. )
  230. if __name__ == "__main__":
  231. main()