train.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import os
  2. from typing import Optional
  3. import hydra
  4. import lightning as L
  5. import pyrootutils
  6. import torch
  7. from lightning import Callback, LightningDataModule, LightningModule, Trainer
  8. from lightning.pytorch.loggers import Logger
  9. from omegaconf import DictConfig, OmegaConf
  10. os.environ.pop("SLURM_NTASKS", None)
  11. os.environ.pop("SLURM_JOB_NAME", None)
  12. os.environ.pop("SLURM_NTASKS_PER_NODE", None)
  13. # register eval resolver and root
  14. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  15. # Allow TF32 on Ampere GPUs
  16. torch.set_float32_matmul_precision("high")
  17. torch.backends.cudnn.allow_tf32 = True
  18. # register eval resolver
  19. OmegaConf.register_new_resolver("eval", eval)
  20. import fish_speech.utils as utils
  21. log = utils.RankedLogger(__name__, rank_zero_only=True)
  22. @utils.task_wrapper
  23. def train(cfg: DictConfig) -> tuple[dict, dict]:
  24. """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
  25. training.
  26. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
  27. failure. Useful for multiruns, saving info about the crash, etc.
  28. Args:
  29. cfg (DictConfig): Configuration composed by Hydra.
  30. Returns:
  31. Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
  32. """ # noqa: E501
  33. # set seed for random number generators in pytorch, numpy and python.random
  34. if cfg.get("seed"):
  35. L.seed_everything(cfg.seed, workers=False)
  36. if cfg.get("deterministic"):
  37. torch.use_deterministic_algorithms(True)
  38. log.info(f"Instantiating datamodule <{cfg.data._target_}>")
  39. datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
  40. log.info(f"Instantiating model <{cfg.model._target_}>")
  41. model: LightningModule = hydra.utils.instantiate(cfg.model)
  42. log.info("Instantiating callbacks...")
  43. callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
  44. log.info("Instantiating loggers...")
  45. logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
  46. log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
  47. trainer: Trainer = hydra.utils.instantiate(
  48. cfg.trainer, callbacks=callbacks, logger=logger
  49. )
  50. object_dict = {
  51. "cfg": cfg,
  52. "datamodule": datamodule,
  53. "model": model,
  54. "callbacks": callbacks,
  55. "logger": logger,
  56. "trainer": trainer,
  57. }
  58. if logger:
  59. log.info("Logging hyperparameters!")
  60. utils.log_hyperparameters(object_dict)
  61. if cfg.get("train"):
  62. log.info("Starting training!")
  63. ckpt_path = cfg.get("ckpt_path")
  64. auto_resume = False
  65. resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
  66. if resume_ckpt_path is not None:
  67. ckpt_path = resume_ckpt_path
  68. auto_resume = True
  69. if ckpt_path is not None:
  70. log.info(f"Resuming from checkpoint: {ckpt_path}")
  71. # resume weights only is disabled for auto-resume
  72. if cfg.get("resume_weights_only") and auto_resume is False:
  73. log.info("Resuming weights only!")
  74. ckpt = torch.load(ckpt_path, map_location=model.device)
  75. if "state_dict" in ckpt:
  76. ckpt = ckpt["state_dict"]
  77. err = model.load_state_dict(ckpt, strict=False)
  78. log.info(f"Error loading state dict: {err}")
  79. ckpt_path = None
  80. trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
  81. train_metrics = trainer.callback_metrics
  82. if cfg.get("test"):
  83. log.info("Starting testing!")
  84. ckpt_path = trainer.checkpoint_callback.best_model_path
  85. if ckpt_path == "":
  86. log.warning("Best ckpt not found! Using current weights for testing...")
  87. ckpt_path = cfg.get("ckpt_path")
  88. trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
  89. log.info(f"Best ckpt path: {ckpt_path}")
  90. test_metrics = trainer.callback_metrics
  91. # merge train and test metrics
  92. metric_dict = {**train_metrics, **test_metrics}
  93. return metric_dict, object_dict
  94. @hydra.main(
  95. version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
  96. )
  97. def main(cfg: DictConfig) -> Optional[float]:
  98. # train the model
  99. train(cfg)
  100. if __name__ == "__main__":
  101. main()