train.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import os
  2. os.environ["USE_LIBUV"] = "0"
  3. import sys
  4. from typing import Optional
  5. import hydra
  6. import lightning as L
  7. import pyrootutils
  8. import torch
  9. from lightning import Callback, LightningDataModule, LightningModule, Trainer
  10. from lightning.pytorch.loggers import Logger
  11. from lightning.pytorch.strategies import DDPStrategy
  12. from omegaconf import DictConfig, OmegaConf
  13. os.environ.pop("SLURM_NTASKS", None)
  14. os.environ.pop("SLURM_JOB_NAME", None)
  15. os.environ.pop("SLURM_NTASKS_PER_NODE", None)
  16. # register eval resolver and root
  17. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  18. # Allow TF32 on Ampere GPUs
  19. torch.set_float32_matmul_precision("high")
  20. torch.backends.cudnn.allow_tf32 = True
  21. # register eval resolver
  22. OmegaConf.register_new_resolver("eval", eval)
  23. import fish_speech.utils as utils
  24. log = utils.RankedLogger(__name__, rank_zero_only=True)
  25. @utils.task_wrapper
  26. def train(cfg: DictConfig) -> tuple[dict, dict]:
  27. """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
  28. training.
  29. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
  30. failure. Useful for multiruns, saving info about the crash, etc.
  31. Args:
  32. cfg (DictConfig): Configuration composed by Hydra.
  33. Returns:
  34. Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
  35. """ # noqa: E501
  36. # set seed for random number generators in pytorch, numpy and python.random
  37. if cfg.get("seed"):
  38. L.seed_everything(cfg.seed, workers=False)
  39. if cfg.get("deterministic"):
  40. torch.use_deterministic_algorithms(True)
  41. log.info(f"Instantiating datamodule <{cfg.data._target_}>")
  42. datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
  43. log.info(f"Instantiating model <{cfg.model._target_}>")
  44. model: LightningModule = hydra.utils.instantiate(cfg.model)
  45. log.info("Instantiating callbacks...")
  46. callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
  47. log.info("Instantiating loggers...")
  48. logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
  49. log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
  50. trainer: Trainer = hydra.utils.instantiate(
  51. cfg.trainer,
  52. callbacks=callbacks,
  53. logger=logger,
  54. )
  55. object_dict = {
  56. "cfg": cfg,
  57. "datamodule": datamodule,
  58. "model": model,
  59. "callbacks": callbacks,
  60. "logger": logger,
  61. "trainer": trainer,
  62. }
  63. if logger:
  64. log.info("Logging hyperparameters!")
  65. utils.log_hyperparameters(object_dict)
  66. if cfg.get("train"):
  67. log.info("Starting training!")
  68. ckpt_path = cfg.get("ckpt_path")
  69. auto_resume = False
  70. resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
  71. if resume_ckpt_path is not None:
  72. ckpt_path = resume_ckpt_path
  73. auto_resume = True
  74. if ckpt_path is not None:
  75. log.info(f"Resuming from checkpoint: {ckpt_path}")
  76. # resume weights only is disabled for auto-resume
  77. if cfg.get("resume_weights_only") and auto_resume is False:
  78. log.info("Resuming weights only!")
  79. ckpt = torch.load(ckpt_path, map_location=model.device)
  80. if "state_dict" in ckpt:
  81. ckpt = ckpt["state_dict"]
  82. err = model.load_state_dict(ckpt, strict=False)
  83. log.info(f"Error loading state dict: {err}")
  84. ckpt_path = None
  85. trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
  86. train_metrics = trainer.callback_metrics
  87. if cfg.get("test"):
  88. log.info("Starting testing!")
  89. ckpt_path = trainer.checkpoint_callback.best_model_path
  90. if ckpt_path == "":
  91. log.warning("Best ckpt not found! Using current weights for testing...")
  92. ckpt_path = cfg.get("ckpt_path")
  93. trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
  94. log.info(f"Best ckpt path: {ckpt_path}")
  95. test_metrics = trainer.callback_metrics
  96. # merge train and test metrics
  97. metric_dict = {**train_metrics, **test_metrics}
  98. return metric_dict, object_dict
  99. @hydra.main(
  100. version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
  101. )
  102. def main(cfg: DictConfig) -> Optional[float]:
  103. # train the model
  104. train(cfg)
  105. if __name__ == "__main__":
  106. main()