utils.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import warnings
  2. from importlib.util import find_spec
  3. from typing import Callable
  4. from omegaconf import DictConfig
  5. from .logger import RankedLogger
  6. from .rich_utils import enforce_tags, print_config_tree
  7. log = RankedLogger(__name__, rank_zero_only=True)
  8. def extras(cfg: DictConfig) -> None:
  9. """Applies optional utilities before the task is started.
  10. Utilities:
  11. - Ignoring python warnings
  12. - Setting tags from command line
  13. - Rich config printing
  14. """
  15. # return if no `extras` config
  16. if not cfg.get("extras"):
  17. log.warning("Extras config not found! <cfg.extras=null>")
  18. return
  19. # disable python warnings
  20. if cfg.extras.get("ignore_warnings"):
  21. log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
  22. warnings.filterwarnings("ignore")
  23. # prompt user to input tags from command line if none are provided in the config
  24. if cfg.extras.get("enforce_tags"):
  25. log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
  26. enforce_tags(cfg, save_to_file=True)
  27. # pretty print config tree using Rich library
  28. if cfg.extras.get("print_config"):
  29. log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
  30. print_config_tree(cfg, resolve=True, save_to_file=True)
  31. def task_wrapper(task_func: Callable) -> Callable:
  32. """Optional decorator that controls the failure behavior when executing the task function.
  33. This wrapper can be used to:
  34. - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
  35. - save the exception to a `.log` file
  36. - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
  37. - etc. (adjust depending on your needs)
  38. Example:
  39. ```
  40. @utils.task_wrapper
  41. def train(cfg: DictConfig) -> Tuple[dict, dict]:
  42. ...
  43. return metric_dict, object_dict
  44. ```
  45. """ # noqa: E501
  46. def wrap(cfg: DictConfig):
  47. # execute the task
  48. try:
  49. metric_dict, object_dict = task_func(cfg=cfg)
  50. # things to do if exception occurs
  51. except Exception as ex:
  52. # save exception to `.log` file
  53. log.exception("")
  54. # some hyperparameter combinations might be invalid or
  55. # cause out-of-memory errors so when using hparam search
  56. # plugins like Optuna, you might want to disable
  57. # raising the below exception to avoid multirun failure
  58. raise ex
  59. # things to always do after either success or exception
  60. finally:
  61. # display output dir path in terminal
  62. log.info(f"Output dir: {cfg.paths.run_dir}")
  63. # always close wandb run (even if exception occurs so multirun won't fail)
  64. if find_spec("wandb"): # check if wandb is installed
  65. import wandb
  66. if wandb.run:
  67. log.info("Closing wandb!")
  68. wandb.finish()
  69. return metric_dict, object_dict
  70. return wrap
  71. def get_metric_value(metric_dict: dict, metric_name: str) -> float:
  72. """Safely retrieves value of the metric logged in LightningModule."""
  73. if not metric_name:
  74. log.info("Metric name is None! Skipping metric value retrieval...")
  75. return None
  76. if metric_name not in metric_dict:
  77. raise Exception(
  78. f"Metric value not found! <metric_name={metric_name}>\n"
  79. "Make sure metric name logged in LightningModule is correct!\n"
  80. "Make sure `optimized_metric` name in `hparams_search` config is correct!"
  81. )
  82. metric_value = metric_dict[metric_name].item()
  83. log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
  84. return metric_value