instantiators.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from typing import List
  2. import hydra
  3. from omegaconf import DictConfig
  4. from pytorch_lightning import Callback
  5. from pytorch_lightning.loggers import Logger
  6. from .logger import RankedLogger
  7. log = RankedLogger(__name__, rank_zero_only=True)
  8. def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
  9. """Instantiates callbacks from config."""
  10. callbacks: List[Callback] = []
  11. if not callbacks_cfg:
  12. log.warning("No callback configs found! Skipping..")
  13. return callbacks
  14. if not isinstance(callbacks_cfg, DictConfig):
  15. raise TypeError("Callbacks config must be a DictConfig!")
  16. for _, cb_conf in callbacks_cfg.items():
  17. if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
  18. log.info(f"Instantiating callback <{cb_conf._target_}>")
  19. callbacks.append(hydra.utils.instantiate(cb_conf))
  20. return callbacks
  21. def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
  22. """Instantiates loggers from config."""
  23. logger: List[Logger] = []
  24. if not logger_cfg:
  25. log.warning("No logger configs found! Skipping...")
  26. return logger
  27. if not isinstance(logger_cfg, DictConfig):
  28. raise TypeError("Logger config must be a DictConfig!")
  29. for _, lg_conf in logger_cfg.items():
  30. if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
  31. log.info(f"Instantiating logger <{lg_conf._target_}>")
  32. logger.append(hydra.utils.instantiate(lg_conf))
  33. return logger