utils.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import random
  2. import warnings
  3. from importlib.util import find_spec
  4. from typing import Callable
  5. import numpy as np
  6. import torch
  7. from omegaconf import DictConfig
  8. from .logger import RankedLogger
  9. from .rich_utils import enforce_tags, print_config_tree
  10. log = RankedLogger(__name__, rank_zero_only=True)
  11. def extras(cfg: DictConfig) -> None:
  12. """Applies optional utilities before the task is started.
  13. Utilities:
  14. - Ignoring python warnings
  15. - Setting tags from command line
  16. - Rich config printing
  17. """
  18. # return if no `extras` config
  19. if not cfg.get("extras"):
  20. log.warning("Extras config not found! <cfg.extras=null>")
  21. return
  22. # disable python warnings
  23. if cfg.extras.get("ignore_warnings"):
  24. log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
  25. warnings.filterwarnings("ignore")
  26. # prompt user to input tags from command line if none are provided in the config
  27. if cfg.extras.get("enforce_tags"):
  28. log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
  29. enforce_tags(cfg, save_to_file=True)
  30. # pretty print config tree using Rich library
  31. if cfg.extras.get("print_config"):
  32. log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
  33. print_config_tree(cfg, resolve=True, save_to_file=True)
  34. def task_wrapper(task_func: Callable) -> Callable:
  35. """Optional decorator that controls the failure behavior when executing the task function.
  36. This wrapper can be used to:
  37. - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
  38. - save the exception to a `.log` file
  39. - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
  40. - etc. (adjust depending on your needs)
  41. Example:
  42. ```
  43. @utils.task_wrapper
  44. def train(cfg: DictConfig) -> Tuple[dict, dict]:
  45. ...
  46. return metric_dict, object_dict
  47. ```
  48. """ # noqa: E501
  49. def wrap(cfg: DictConfig):
  50. # execute the task
  51. try:
  52. metric_dict, object_dict = task_func(cfg=cfg)
  53. # things to do if exception occurs
  54. except Exception as ex:
  55. # save exception to `.log` file
  56. log.exception("")
  57. # some hyperparameter combinations might be invalid or
  58. # cause out-of-memory errors so when using hparam search
  59. # plugins like Optuna, you might want to disable
  60. # raising the below exception to avoid multirun failure
  61. raise ex
  62. # things to always do after either success or exception
  63. finally:
  64. # display output dir path in terminal
  65. log.info(f"Output dir: {cfg.paths.run_dir}")
  66. # always close wandb run (even if exception occurs so multirun won't fail)
  67. if find_spec("wandb"): # check if wandb is installed
  68. import wandb
  69. if wandb.run:
  70. log.info("Closing wandb!")
  71. wandb.finish()
  72. return metric_dict, object_dict
  73. return wrap
  74. def get_metric_value(metric_dict: dict, metric_name: str) -> float:
  75. """Safely retrieves value of the metric logged in LightningModule."""
  76. if not metric_name:
  77. log.info("Metric name is None! Skipping metric value retrieval...")
  78. return None
  79. if metric_name not in metric_dict:
  80. raise Exception(
  81. f"Metric value not found! <metric_name={metric_name}>\n"
  82. "Make sure metric name logged in LightningModule is correct!\n"
  83. "Make sure `optimized_metric` name in `hparams_search` config is correct!"
  84. )
  85. metric_value = metric_dict[metric_name].item()
  86. log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
  87. return metric_value
  88. def set_seed(seed: int):
  89. if seed < 0:
  90. seed = -seed
  91. if seed > (1 << 31):
  92. seed = 1 << 31
  93. random.seed(seed)
  94. np.random.seed(seed)
  95. torch.manual_seed(seed)
  96. if torch.cuda.is_available():
  97. torch.cuda.manual_seed(seed)
  98. torch.cuda.manual_seed_all(seed)
  99. if torch.backends.cudnn.is_available():
  100. torch.backends.cudnn.deterministic = True
  101. torch.backends.cudnn.benchmark = False