train.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. from collections import defaultdict
  2. import time
  3. from datetime import timedelta
  4. from pathlib import Path
  5. from typing import Optional
  6. import hydra
  7. import torch
  8. from lightning.fabric import Fabric
  9. from natsort import natsorted
  10. from omegaconf import DictConfig, OmegaConf
  11. from tqdm import tqdm
  12. from transformers import LlamaForCausalLM
  13. from transformers.utils import is_flash_attn_available
  14. from speech_lm.logger import RankedLogger
  15. # Allow TF32 on Ampere GPUs
  16. torch.set_float32_matmul_precision("high")
  17. torch.backends.cudnn.allow_tf32 = True
  18. # register eval resolver
  19. OmegaConf.register_new_resolver("eval", eval)
  20. log = RankedLogger(__name__, rank_zero_only=True)
  21. def valid(
  22. model: LlamaForCausalLM,
  23. valid_dataloader: Optional[torch.utils.data.DataLoader],
  24. global_step: int,
  25. fabric: Fabric,
  26. cfg: DictConfig,
  27. ):
  28. model.eval()
  29. log.info(f"Evaluating at step {global_step}")
  30. accumulate_infos = None
  31. for idx, batch in tqdm(enumerate(valid_dataloader), desc="Evaluating"):
  32. outputs = model(**batch)
  33. loss = outputs.loss
  34. metrics = getattr(outputs, "metrics", {})
  35. log_info = {
  36. "valid/loss": float(loss),
  37. **{f"valid/{k}": float(v) for k, v in metrics.items()},
  38. }
  39. fabric.log_dict(
  40. log_info,
  41. step=global_step + idx,
  42. )
  43. # Update log info
  44. if accumulate_infos is None:
  45. accumulate_infos = log_info
  46. else:
  47. assert set(accumulate_infos.keys()) == set(
  48. log_info.keys()
  49. ), "Log keys changed during evaluation"
  50. for k in accumulate_infos.keys():
  51. accumulate_infos[k] += log_info[k]
  52. if idx == getattr(cfg.schedule, "eval_max_batches", None):
  53. break
  54. # Log average
  55. items = []
  56. for k in accumulate_infos.keys():
  57. items.append(f"{k}: {accumulate_infos[k] / (idx + 1):.4f}")
  58. log.info(f"Average: {' | '.join(items)}")
  59. def train(
  60. model: LlamaForCausalLM,
  61. optimizer: torch.optim.Optimizer,
  62. scheduler: torch.optim.lr_scheduler._LRScheduler,
  63. train_dataloader: torch.utils.data.DataLoader,
  64. valid_dataloader: Optional[torch.utils.data.DataLoader],
  65. global_step: int,
  66. fabric: Fabric,
  67. cfg: DictConfig,
  68. ):
  69. bar = tqdm(total=cfg.schedule.max_steps, desc="Training")
  70. bar.update(global_step)
  71. accumulate_steps = 0
  72. optimizer.zero_grad()
  73. # Start time is ~model forward time + data loading time
  74. start_time = time.time()
  75. trackers = defaultdict(list)
  76. while global_step < cfg.schedule.max_steps:
  77. last_batch_time = time.time()
  78. for batch in train_dataloader:
  79. # Measure time used by data loading
  80. trackers["data_time"].append(time.time() - last_batch_time)
  81. # Measure time used by model forward
  82. model_begin_time = time.time()
  83. model.train()
  84. # Accumulate gradients
  85. is_accumulating = (
  86. accumulate_steps % cfg.schedule.gradient_accumulation_steps != 0
  87. )
  88. accumulate_steps += 1
  89. # Train one step
  90. with fabric.no_backward_sync(model, enabled=is_accumulating):
  91. outputs = model(**batch)
  92. loss = outputs.loss
  93. metrics = getattr(outputs, "metrics", {})
  94. fabric.backward(loss)
  95. # Update trackers
  96. trackers["loss"].append(float(loss))
  97. trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
  98. trackers["grad_norm"].append(
  99. trackers.get("grad_norm", 0) + float(grad_norm)
  100. )
  101. for k, v in metrics.items():
  102. trackers[f"metrics/{k}"].append(float(v))
  103. trackers["model_time"].append(time.time() - model_begin_time)
  104. if is_accumulating:
  105. continue
  106. # Check all trackers has the same length
  107. assert (
  108. len(set(len(v) for v in trackers.values())) == 1
  109. ), "Trackers has ambiguous length"
  110. # Perform gradient clipping
  111. grad_norm = fabric.clip_gradients(
  112. model, optimizer, max_norm=cfg.schedule.clip_grad_norm, norm_type=2.0
  113. )
  114. # Update
  115. optimizer.step()
  116. optimizer.zero_grad()
  117. scheduler.step()
  118. fabric.log_dict(
  119. {
  120. f"train/{k}": sum(v[-accumulate_steps:])
  121. / len(v[-accumulate_steps:])
  122. for k, v in trackers.items()
  123. },
  124. step=global_step,
  125. )
  126. global_step += 1
  127. bar.update(1)
  128. if global_step % cfg.schedule.log_interval == 0:
  129. step_time = (time.time() - start_time) / cfg.schedule.log_interval
  130. eta = step_time * (cfg.schedule.max_steps - global_step)
  131. additional_info = [
  132. f"{k}: {sum(v[-cfg.schedule.log_interval:]) / len(v[-cfg.schedule.log_interval:]):.4f}"
  133. for k, v in trackers.items()
  134. if k != "lr" # lr use .2e format
  135. ]
  136. log.info(
  137. f"[{global_step}/{cfg.schedule.max_steps}] "
  138. + f"step time: {step_time:.2f}s "
  139. + f"ETA: {timedelta(round(eta))}s "
  140. f"lr: {optimizer.param_groups[0]['lr']:.2e} "
  141. + " ".join(additional_info)
  142. )
  143. start_time = time.time()
  144. if global_step % cfg.schedule.save_interval == 0:
  145. fabric.save(
  146. Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
  147. {
  148. "model": model,
  149. "optimizer": optimizer,
  150. "scheduler": scheduler.state_dict(),
  151. "global_step": global_step,
  152. },
  153. )
  154. if (
  155. global_step % cfg.schedule.eval_interval == 0
  156. and valid_dataloader is not None
  157. ):
  158. valid(model, valid_dataloader, fabric, global_step, cfg)
  159. if global_step >= cfg.schedule.max_steps:
  160. break
  161. @hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
  162. def main(cfg: DictConfig):
  163. log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
  164. if is_flash_attn_available() is False:
  165. log.warning("Flash attention is not available, using default attention")
  166. fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
  167. fabric.launch()
  168. log.info(f"Fabric: {fabric}")
  169. model = hydra.utils.instantiate(cfg.model)
  170. log.info(f"Model: {repr(model)}")
  171. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  172. freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
  173. log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
  174. log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
  175. optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
  176. scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
  177. log.info(f"Optimizer: {optimizer}")
  178. log.info(f"Scheduler: {scheduler}")
  179. log.info(f"Setup fabric model & dataset")
  180. model = fabric.setup_module(model)
  181. optimizer = fabric.setup_optimizers(optimizer)
  182. # Build state
  183. global_step = 0
  184. # Restore training from checkpoint
  185. checkpoint_dir = Path(cfg.paths.checkpoint_dir)
  186. checkpoint_dir.mkdir(parents=True, exist_ok=True)
  187. # Alphabetically sort checkpoints
  188. checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
  189. if len(checkpoints) > 0:
  190. checkpoint_path = checkpoints[-1]
  191. log.info(f"Restoring checkpoint from {checkpoint_path}")
  192. remainder = fabric.load(
  193. checkpoint_path,
  194. {
  195. "model": model,
  196. "optimizer": optimizer,
  197. "scheduler": scheduler,
  198. },
  199. )
  200. global_step = remainder["global_step"]
  201. log.info(f"Restored global step: {global_step}")
  202. train_dataloader = hydra.utils.instantiate(cfg.train_dataloader)
  203. log.info(f"Train Dataloader: {train_dataloader}")
  204. valid_dataloader = None
  205. if getattr(train_dataloader, "valid_dataloader", None) is not None:
  206. valid_dataloader = hydra.utils.instantiate(train_dataloader.valid_dataloader)
  207. log.info(f"Valid Dataloader: {valid_dataloader}")
  208. train_dataloader = fabric.setup_dataloaders(train_dataloader)
  209. if valid_dataloader is not None:
  210. valid_dataloader = fabric.setup_dataloaders(valid_dataloader)
  211. log.info(f"Begin training")
  212. train(
  213. model=model,
  214. optimizer=optimizer,
  215. scheduler=scheduler,
  216. train_dataloader=train_dataloader,
  217. valid_dataloader=valid_dataloader,
  218. global_step=global_step,
  219. fabric=fabric,
  220. cfg=cfg,
  221. )
  222. if __name__ == "__main__":
  223. main()