train.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from typing import Optional
  2. import hydra
  3. import lightning as L
  4. import torch
  5. from lightning import Callback, LightningDataModule, LightningModule, Trainer
  6. from lightning.pytorch.loggers import Logger
  7. from omegaconf import DictConfig, OmegaConf
  8. import fish_speech.utils as utils
  9. # Allow TF32 on Ampere GPUs
  10. torch.set_float32_matmul_precision("high")
  11. torch.backends.cudnn.allow_tf32 = True
  12. # register eval resolver
  13. OmegaConf.register_new_resolver("eval", eval)
  14. log = utils.RankedLogger(__name__, rank_zero_only=True)
  15. @utils.task_wrapper
  16. def train(cfg: DictConfig) -> tuple[dict, dict]:
  17. """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
  18. training.
  19. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
  20. failure. Useful for multiruns, saving info about the crash, etc.
  21. Args:
  22. cfg (DictConfig): Configuration composed by Hydra.
  23. Returns:
  24. Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
  25. """ # noqa: E501
  26. # set seed for random number generators in pytorch, numpy and python.random
  27. if cfg.get("seed"):
  28. L.seed_everything(cfg.seed, workers=False)
  29. if cfg.get("deterministic"):
  30. torch.use_deterministic_algorithms(True)
  31. log.info(f"Instantiating datamodule <{cfg.data._target_}>")
  32. datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
  33. log.info(f"Instantiating model <{cfg.model._target_}>")
  34. model: LightningModule = hydra.utils.instantiate(cfg.model)
  35. log.info("Instantiating callbacks...")
  36. callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
  37. log.info("Instantiating loggers...")
  38. logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
  39. log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
  40. trainer: Trainer = hydra.utils.instantiate(
  41. cfg.trainer, callbacks=callbacks, logger=logger
  42. )
  43. object_dict = {
  44. "cfg": cfg,
  45. "datamodule": datamodule,
  46. "model": model,
  47. "callbacks": callbacks,
  48. "logger": logger,
  49. "trainer": trainer,
  50. }
  51. if logger:
  52. log.info("Logging hyperparameters!")
  53. utils.log_hyperparameters(object_dict)
  54. if cfg.get("train"):
  55. log.info("Starting training!")
  56. ckpt_path = cfg.get("ckpt_path")
  57. auto_resume = False
  58. if ckpt_path is None:
  59. ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
  60. auto_resume = True
  61. if ckpt_path is not None:
  62. log.info(f"Resuming from checkpoint: {ckpt_path}")
  63. # resume weights only is disabled for auto-resume
  64. if cfg.get("resume_weights_only") and auto_resume is False:
  65. log.info("Resuming weights only!")
  66. ckpt = torch.load(ckpt_path, map_location=model.device)
  67. if "state_dict" in ckpt:
  68. ckpt = ckpt["state_dict"]
  69. err = model.load_state_dict(ckpt, strict=False)
  70. log.info(f"Error loading state dict: {err}")
  71. ckpt_path = None
  72. trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
  73. train_metrics = trainer.callback_metrics
  74. if cfg.get("test"):
  75. log.info("Starting testing!")
  76. ckpt_path = trainer.checkpoint_callback.best_model_path
  77. if ckpt_path == "":
  78. log.warning("Best ckpt not found! Using current weights for testing...")
  79. ckpt_path = cfg.get("ckpt_path")
  80. trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
  81. log.info(f"Best ckpt path: {ckpt_path}")
  82. test_metrics = trainer.callback_metrics
  83. # merge train and test metrics
  84. metric_dict = {**train_metrics, **test_metrics}
  85. return metric_dict, object_dict
  86. @hydra.main(
  87. version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
  88. )
  89. def main(cfg: DictConfig) -> Optional[float]:
  90. # train the model
  91. train(cfg)
  92. if __name__ == "__main__":
  93. main()