train.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. import time
  2. from collections import defaultdict
  3. from datetime import timedelta
  4. from pathlib import Path
  5. from typing import Optional
  6. import hydra
  7. import torch
  8. from lightning.fabric import Fabric
  9. from natsort import natsorted
  10. from omegaconf import DictConfig, OmegaConf
  11. from tqdm import tqdm
  12. from transformers import LlamaForCausalLM
  13. from transformers.utils import is_flash_attn_available
  14. from fish_speech.logger import RankedLogger
  15. # Allow TF32 on Ampere GPUs
  16. torch.set_float32_matmul_precision("high")
  17. torch.backends.cudnn.allow_tf32 = True
  18. # register eval resolver
  19. OmegaConf.register_new_resolver("eval", eval)
  20. log = RankedLogger(__name__, rank_zero_only=True)
  21. def valid(
  22. model: LlamaForCausalLM,
  23. valid_dataloader: Optional[torch.utils.data.DataLoader],
  24. global_step: int,
  25. fabric: Fabric,
  26. cfg: DictConfig,
  27. ):
  28. model.eval()
  29. log.info(f"Evaluating at step {global_step}")
  30. accumulate_infos = None
  31. for idx, batch in enumerate(tqdm(valid_dataloader, desc="Evaluating")):
  32. outputs = model(**batch)
  33. loss = outputs.loss
  34. metrics = getattr(outputs, "metrics", {})
  35. log_info = {
  36. "valid/loss": float(loss),
  37. **{f"valid/{k}": float(v) for k, v in metrics.items()},
  38. }
  39. fabric.log_dict(
  40. log_info,
  41. step=global_step + idx,
  42. )
  43. # Update log info
  44. if accumulate_infos is None:
  45. accumulate_infos = log_info
  46. else:
  47. assert set(accumulate_infos.keys()) == set(
  48. log_info.keys()
  49. ), "Log keys changed during evaluation"
  50. for k in accumulate_infos.keys():
  51. accumulate_infos[k] += log_info[k]
  52. if idx == getattr(cfg.schedule, "eval_max_batches", None):
  53. break
  54. # Log average
  55. items = []
  56. for k in accumulate_infos.keys():
  57. items.append(f"{k}: {accumulate_infos[k] / (idx + 1):.4f}")
  58. log.info(f"Average: {' | '.join(items)}")
  59. def train(
  60. model: LlamaForCausalLM,
  61. optimizer: torch.optim.Optimizer,
  62. scheduler: torch.optim.lr_scheduler._LRScheduler,
  63. train_dataloader: torch.utils.data.DataLoader,
  64. valid_dataloader: Optional[torch.utils.data.DataLoader],
  65. global_step: int,
  66. fabric: Fabric,
  67. cfg: DictConfig,
  68. ):
  69. accumulate_steps = 0
  70. optimizer.zero_grad()
  71. # Start time is ~model forward time + data loading time
  72. start_time = time.time()
  73. trackers = defaultdict(list)
  74. while global_step < cfg.schedule.max_steps:
  75. last_batch_time = time.time()
  76. for batch in train_dataloader:
  77. # Measure time used by data loading
  78. trackers["data_time"].append(time.time() - last_batch_time)
  79. # Measure time used by model forward
  80. model_begin_time = time.time()
  81. model.train()
  82. # Accumulate gradients
  83. gradient_accumulation_steps = cfg.schedule.gradient_accumulation_steps
  84. is_accumulating = accumulate_steps % gradient_accumulation_steps != 0
  85. accumulate_steps += 1
  86. # Train one step
  87. with fabric.no_backward_sync(model, enabled=is_accumulating):
  88. outputs = model(**batch)
  89. loss = outputs.loss
  90. metrics = getattr(outputs, "metrics", {})
  91. # Need to divide loss by accumulation steps
  92. fabric.backward(loss / gradient_accumulation_steps)
  93. # Update trackers
  94. trackers["loss"].append(float(loss))
  95. trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
  96. for k, v in metrics.items():
  97. trackers[f"metrics/{k}"].append(float(v))
  98. trackers["model_time"].append(time.time() - model_begin_time)
  99. if is_accumulating:
  100. last_batch_time = time.time()
  101. continue
  102. # Check all trackers has the same length
  103. assert (
  104. len(set(len(v) for k, v in trackers.items() if k != "grad_norm")) == 1
  105. ), "Trackers has ambiguous length"
  106. # Perform gradient clipping
  107. grad_norm = fabric.clip_gradients(
  108. model,
  109. optimizer,
  110. max_norm=cfg.schedule.clip_grad_norm,
  111. norm_type=2.0,
  112. error_if_nonfinite=True,
  113. )
  114. if torch.isnan(grad_norm) or torch.isinf(grad_norm):
  115. log.warning(f"Gradient norm is {grad_norm}, skipping update")
  116. optimizer.zero_grad()
  117. # We can't average gradients across multiple steps
  118. trackers["grad_norm"].append(float(grad_norm))
  119. # Update
  120. optimizer.step()
  121. optimizer.zero_grad()
  122. scheduler.step()
  123. fabric.log_dict(
  124. {
  125. f"train/{k}": sum(v[-gradient_accumulation_steps:])
  126. / len(v[-gradient_accumulation_steps:])
  127. for k, v in trackers.items()
  128. },
  129. step=global_step,
  130. )
  131. # accumulate_steps = 0
  132. global_step += 1
  133. if global_step % cfg.schedule.log_interval == 0:
  134. step_time = (time.time() - start_time) / cfg.schedule.log_interval
  135. eta = step_time * (cfg.schedule.max_steps - global_step)
  136. additional_info = [
  137. f"{k}: {sum(v[-cfg.schedule.log_interval:]) / len(v[-cfg.schedule.log_interval:]):.4f}"
  138. for k, v in trackers.items()
  139. if k != "lr" # lr use .2e format
  140. ]
  141. log.info(
  142. f"[{global_step}/{cfg.schedule.max_steps}] "
  143. + f"step_time: {step_time:.2f}s "
  144. + f"ETA: {timedelta(seconds=round(eta))}s "
  145. f"lr: {optimizer.param_groups[0]['lr']:.2e} "
  146. + " ".join(additional_info)
  147. )
  148. # Reset trackers
  149. trackers = defaultdict(list)
  150. start_time = time.time()
  151. if global_step % cfg.schedule.save_interval == 0:
  152. fabric.save(
  153. Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
  154. {
  155. "model": model,
  156. "optimizer": optimizer,
  157. "scheduler": scheduler.state_dict(),
  158. "global_step": global_step,
  159. },
  160. )
  161. if (
  162. getattr(cfg.schedule, "eval_interval", None) is not None
  163. and global_step % cfg.schedule.eval_interval == 0
  164. and valid_dataloader is not None
  165. ):
  166. valid(model, valid_dataloader, global_step, fabric, cfg)
  167. if global_step >= cfg.schedule.max_steps:
  168. break
  169. last_batch_time = time.time()
  170. @hydra.main(
  171. version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
  172. )
  173. def main(cfg: DictConfig):
  174. log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
  175. if is_flash_attn_available() is False:
  176. log.warning("Flash attention is not available, using default attention")
  177. fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
  178. fabric.launch()
  179. log.info(f"Fabric: {fabric}")
  180. model = hydra.utils.instantiate(cfg.model)
  181. log.info(f"Model: {repr(model)}")
  182. trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  183. freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
  184. log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
  185. log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
  186. optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
  187. scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
  188. log.info(f"Optimizer: {optimizer}")
  189. log.info(f"Scheduler: {scheduler}")
  190. log.info(f"Setup fabric model & dataset")
  191. model = fabric.setup_module(model)
  192. optimizer = fabric.setup_optimizers(optimizer)
  193. # Build state
  194. global_step = 0
  195. # Restore training from checkpoint
  196. checkpoint_dir = Path(cfg.paths.checkpoint_dir)
  197. checkpoint_dir.mkdir(parents=True, exist_ok=True)
  198. # Alphabetically sort checkpoints
  199. checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
  200. if len(checkpoints) > 0:
  201. checkpoint_path = checkpoints[-1]
  202. log.info(f"Restoring checkpoint from {checkpoint_path}")
  203. remainder = fabric.load(
  204. checkpoint_path,
  205. {
  206. "model": model,
  207. "optimizer": optimizer,
  208. "scheduler": scheduler,
  209. },
  210. )
  211. global_step = remainder["global_step"]
  212. log.info(f"Restored global step: {global_step}")
  213. train_dataloader = hydra.utils.instantiate(cfg.train_dataloader)
  214. log.info(f"Train Dataloader: {train_dataloader}")
  215. valid_dataloader = None
  216. if getattr(cfg, "valid_dataloader", None) is not None:
  217. valid_dataloader = hydra.utils.instantiate(cfg.valid_dataloader)
  218. log.info(f"Valid Dataloader: {valid_dataloader}")
  219. train_dataloader = fabric.setup_dataloaders(train_dataloader)
  220. if valid_dataloader is not None:
  221. valid_dataloader = fabric.setup_dataloaders(valid_dataloader)
  222. log.info(f"Begin training")
  223. train(
  224. model=model,
  225. optimizer=optimizer,
  226. scheduler=scheduler,
  227. train_dataloader=train_dataloader,
  228. valid_dataloader=valid_dataloader,
  229. global_step=global_step,
  230. fabric=fabric,
  231. cfg=cfg,
  232. )
  233. if __name__ == "__main__":
  234. main()