logging_utils.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from lightning.pytorch.utilities import rank_zero_only
  2. from fish_speech.utils import logger as log
  3. @rank_zero_only
  4. def log_hyperparameters(object_dict: dict) -> None:
  5. """Controls which config parts are saved by lightning loggers.
  6. Additionally saves:
  7. - Number of model parameters
  8. """
  9. hparams = {}
  10. cfg = object_dict["cfg"]
  11. model = object_dict["model"]
  12. trainer = object_dict["trainer"]
  13. if not trainer.logger:
  14. log.warning("Logger not found! Skipping hyperparameter logging...")
  15. return
  16. hparams["model"] = cfg["model"]
  17. # save number of model parameters
  18. hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
  19. hparams["model/params/trainable"] = sum(
  20. p.numel() for p in model.parameters() if p.requires_grad
  21. )
  22. hparams["model/params/non_trainable"] = sum(
  23. p.numel() for p in model.parameters() if not p.requires_grad
  24. )
  25. hparams["data"] = cfg["data"]
  26. hparams["trainer"] = cfg["trainer"]
  27. hparams["callbacks"] = cfg.get("callbacks")
  28. hparams["extras"] = cfg.get("extras")
  29. hparams["task_name"] = cfg.get("task_name")
  30. hparams["tags"] = cfg.get("tags")
  31. hparams["ckpt_path"] = cfg.get("ckpt_path")
  32. hparams["seed"] = cfg.get("seed")
  33. # send hparams to all loggers
  34. for logger in trainer.loggers:
  35. logger.log_hyperparams(hparams)