| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- from lightning.pytorch.utilities import rank_zero_only
- from fish_speech.utils import logger as log
- @rank_zero_only
- def log_hyperparameters(object_dict: dict) -> None:
- """Controls which config parts are saved by lightning loggers.
- Additionally saves:
- - Number of model parameters
- """
- hparams = {}
- cfg = object_dict["cfg"]
- model = object_dict["model"]
- trainer = object_dict["trainer"]
- if not trainer.logger:
- log.warning("Logger not found! Skipping hyperparameter logging...")
- return
- hparams["model"] = cfg["model"]
- # save number of model parameters
- hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
- hparams["model/params/trainable"] = sum(
- p.numel() for p in model.parameters() if p.requires_grad
- )
- hparams["model/params/non_trainable"] = sum(
- p.numel() for p in model.parameters() if not p.requires_grad
- )
- hparams["data"] = cfg["data"]
- hparams["trainer"] = cfg["trainer"]
- hparams["callbacks"] = cfg.get("callbacks")
- hparams["extras"] = cfg.get("extras")
- hparams["task_name"] = cfg.get("task_name")
- hparams["tags"] = cfg.get("tags")
- hparams["ckpt_path"] = cfg.get("ckpt_path")
- hparams["seed"] = cfg.get("seed")
- # send hparams to all loggers
- for logger in trainer.loggers:
- logger.log_hyperparams(hparams)
|