train.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. import time
  2. from collections import defaultdict
  3. from datetime import timedelta
  4. from pathlib import Path
  5. from typing import Optional
  6. import hydra
  7. import torch
  8. from lightning.fabric import Fabric
  9. from natsort import natsorted
  10. from omegaconf import DictConfig, OmegaConf
  11. from tqdm import tqdm
  12. from transformers import LlamaForCausalLM
  13. from transformers.utils import is_flash_attn_available
  14. from speech_lm.logger import RankedLogger
  15. # Allow TF32 on Ampere GPUs
  16. torch.set_float32_matmul_precision("high")
  17. torch.backends.cudnn.allow_tf32 = True
  18. # register eval resolver
  19. OmegaConf.register_new_resolver("eval", eval)
  20. log = RankedLogger(__name__, rank_zero_only=True)
  21. def valid(
  22. model: LlamaForCausalLM,
  23. valid_dataloader: Optional[torch.utils.data.DataLoader],
  24. global_step: int,
  25. fabric: Fabric,
  26. cfg: DictConfig,
  27. ):
  28. model.eval()
  29. log.info(f"Evaluating at step {global_step}")
  30. accumulate_infos = None
  31. for idx, batch in tqdm(enumerate(valid_dataloader), desc="Evaluating"):
  32. outputs = model(**batch)
  33. loss = outputs.loss
  34. metrics = getattr(outputs, "metrics", {})
  35. log_info = {
  36. "valid/loss": float(loss),
  37. **{f"valid/{k}": float(v) for k, v in metrics.items()},
  38. }
  39. fabric.log_dict(
  40. log_info,
  41. step=global_step + idx,
  42. )
  43. # Update log info
  44. if accumulate_infos is None:
  45. accumulate_infos = log_info
  46. else:
  47. assert set(accumulate_infos.keys()) == set(
  48. log_info.keys()
  49. ), "Log keys changed during evaluation"
  50. for k in accumulate_infos.keys():
  51. accumulate_infos[k] += log_info[k]
  52. if idx == getattr(cfg.schedule, "eval_max_batches", None):
  53. break
  54. # Log average
  55. items = []
  56. for k in accumulate_infos.keys():
  57. items.append(f"{k}: {accumulate_infos[k] / (idx + 1):.4f}")
  58. log.info(f"Average: {' | '.join(items)}")
  59. def train(
  60. model: LlamaForCausalLM,
  61. optimizer: torch.optim.Optimizer,
  62. scheduler: torch.optim.lr_scheduler._LRScheduler,
  63. train_dataloader: torch.utils.data.DataLoader,
  64. valid_dataloader: Optional[torch.utils.data.DataLoader],
  65. global_step: int,
  66. fabric: Fabric,
  67. cfg: DictConfig,
  68. ):
  69. bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
  70. bar.update(global_step)
  71. accumulate_steps = 0
  72. optimizer.zero_grad()
  73. # Start time is ~model forward time + data loading time
  74. start_time = time.time()
  75. trackers = defaultdict(list)
  76. while global_step < cfg.schedule.max_steps:
  77. last_batch_time = time.time()
  78. for batch in train_dataloader:
  79. # Measure time used by data loading
  80. trackers["data_time"].append(time.time() - last_batch_time)
  81. # Measure time used by model forward
  82. model_begin_time = time.time()
  83. model.train()
  84. # Accumulate gradients
  85. is_accumulating = (
  86. accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
  87. )
  88. accumulate_steps += 1
  89. # Train one step
  90. with fabric.no_backward_sync(model, enabled=is_accumulating):
  91. outputs = model(**batch)
  92. loss = outputs.loss
  93. metrics = getattr(outputs, "metrics", {})
  94. # Need to divide loss by accumulation steps
  95. fabric.backward(loss / cfg.schedule.gradient_accumulation_steps)
  96. # Update trackers
  97. trackers["loss"].append(float(loss))
  98. trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
  99. for k, v in metrics.items():
  100. trackers[f"metrics/{k}"].append(float(v))
  101. trackers["model_time"].append(time.time() - model_begin_time)
  102. if is_accumulating:
  103. last_batch_time = time.time()
  104. continue
  105. # Check all trackers has the same length
  106. assert (
  107. len(set(len(v) for k, v in trackers.items() if k != "grad_norm")) == 1
  108. ), "Trackers has ambiguous length"
  109. # Perform gradient clipping
  110. grad_norm = fabric.clip_gradients(
  111. model,
  112. optimizer,
  113. max_norm=cfg.schedule.clip_grad_norm,
  114. norm_type=2.0,
  115. error_if_nonfinite=True,
  116. )
  117. if torch.isnan(grad_norm) or torch.isinf(grad_norm):
  118. log.warning(f"Gradient norm is {grad_norm}, skipping update")
  119. optimizer.zero_grad()
  120. # We can't average gradients across multiple steps
  121. trackers["grad_norm"].append(float(grad_norm))
  122. # Update
  123. optimizer.step()
  124. optimizer.zero_grad()
  125. scheduler.step()
  126. fabric.log_dict(
  127. {
  128. f"train/{k}": sum(v[-accumulate_steps:])
  129. / len(v[-accumulate_steps:])
  130. for k, v in trackers.items()
  131. },
  132. step=global_step,
  133. )
  134. global_step += 1
  135. bar.update(1)
  136. if global_step % cfg.schedule.log_interval == 0:
  137. step_time = (time.time() - start_time) / cfg.schedule.log_interval
  138. eta = step_time * (cfg.schedule.max_steps - global_step)
  139. additional_info = [
  140. f"{k}: {sum(v[-cfg.schedule.log_interval:]) / len(v[-cfg.schedule.log_interval:]):.4f}"
  141. for k, v in trackers.items()
  142. if k != "lr" # lr use .2e format
  143. ]
  144. log.info(
  145. f"[{global_step}/{cfg.schedule.max_steps}] "
  146. + f"step_time: {step_time:.2f}s "
  147. + f"ETA: {timedelta(seconds=round(eta))}s "
  148. f"lr: {optimizer.param_groups[0]['lr']:.2e} "
  149. + " ".join(additional_info)
  150. )
  151. # Reset trackers
  152. trackers = defaultdict(list)
  153. start_time = time.time()
  154. if global_step % cfg.schedule.save_interval == 0:
  155. fabric.save(
  156. Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
  157. {
  158. "model": model,
  159. "optimizer": optimizer,
  160. "scheduler": scheduler.state_dict(),
  161. "global_step": global_step,
  162. },
  163. )
  164. if (
  165. getattr(cfg.schedule, "eval_interval", None) is not None
  166. and global_step % cfg.schedule.eval_interval == 0
  167. and valid_dataloader is not None
  168. ):
  169. valid(model, valid_dataloader, global_step, fabric, cfg)
  170. if global_step >= cfg.schedule.max_steps:
  171. break
  172. last_batch_time = time.time()
  173. @hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
  174. def main(cfg: DictConfig):
  175. log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
  176. if is_flash_attn_available() is False:
  177. log.warning("Flash attention is not available, using default attention")
  178. fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
  179. fabric.launch()
  180. log.info(f"Fabric: {fabric}")
  181. model = hydra.utils.instantiate(cfg.model)
  182. log.info(f"Model: {repr(model)}")
  183. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  184. freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
  185. log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
  186. log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
  187. optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
  188. scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
  189. log.info(f"Optimizer: {optimizer}")
  190. log.info(f"Scheduler: {scheduler}")
  191. log.info(f"Setup fabric model & dataset")
  192. model = fabric.setup_module(model)
  193. optimizer = fabric.setup_optimizers(optimizer)
  194. # Build state
  195. global_step = 0
  196. # Restore training from checkpoint
  197. checkpoint_dir = Path(cfg.paths.checkpoint_dir)
  198. checkpoint_dir.mkdir(parents=True, exist_ok=True)
  199. # Alphabetically sort checkpoints
  200. checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
  201. if len(checkpoints) > 0:
  202. checkpoint_path = checkpoints[-1]
  203. log.info(f"Restoring checkpoint from {checkpoint_path}")
  204. remainder = fabric.load(
  205. checkpoint_path,
  206. {
  207. "model": model,
  208. "optimizer": optimizer,
  209. "scheduler": scheduler,
  210. },
  211. )
  212. global_step = remainder["global_step"]
  213. log.info(f"Restored global step: {global_step}")
  214. train_dataloader = hydra.utils.instantiate(cfg.train_dataloader)
  215. log.info(f"Train Dataloader: {train_dataloader}")
  216. valid_dataloader = None
  217. if getattr(cfg, "valid_dataloader", None) is not None:
  218. valid_dataloader = hydra.utils.instantiate(cfg.valid_dataloader)
  219. log.info(f"Valid Dataloader: {valid_dataloader}")
  220. train_dataloader = fabric.setup_dataloaders(train_dataloader)
  221. if valid_dataloader is not None:
  222. valid_dataloader = fabric.setup_dataloaders(valid_dataloader)
  223. log.info(f"Begin training")
  224. train(
  225. model=model,
  226. optimizer=optimizer,
  227. scheduler=scheduler,
  228. train_dataloader=train_dataloader,
  229. valid_dataloader=valid_dataloader,
  230. global_step=global_step,
  231. fabric=fabric,
  232. cfg=cfg,
  233. )
  234. if __name__ == "__main__":
  235. main()