| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- import random
- import warnings
- from importlib.util import find_spec
- from typing import Callable
- import numpy as np
- import torch
- from omegaconf import DictConfig
- from .logger import RankedLogger
- from .rich_utils import enforce_tags, print_config_tree
- log = RankedLogger(__name__, rank_zero_only=True)
- def extras(cfg: DictConfig) -> None:
- """Applies optional utilities before the task is started.
- Utilities:
- - Ignoring python warnings
- - Setting tags from command line
- - Rich config printing
- """
- # return if no `extras` config
- if not cfg.get("extras"):
- log.warning("Extras config not found! <cfg.extras=null>")
- return
- # disable python warnings
- if cfg.extras.get("ignore_warnings"):
- log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
- warnings.filterwarnings("ignore")
- # prompt user to input tags from command line if none are provided in the config
- if cfg.extras.get("enforce_tags"):
- log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
- enforce_tags(cfg, save_to_file=True)
- # pretty print config tree using Rich library
- if cfg.extras.get("print_config"):
- log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
- print_config_tree(cfg, resolve=True, save_to_file=True)
- def task_wrapper(task_func: Callable) -> Callable:
- """Optional decorator that controls the failure behavior when executing the task function.
- This wrapper can be used to:
- - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- - save the exception to a `.log` file
- - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- - etc. (adjust depending on your needs)
- Example:
- ```
- @utils.task_wrapper
- def train(cfg: DictConfig) -> Tuple[dict, dict]:
- ...
- return metric_dict, object_dict
- ```
- """ # noqa: E501
- def wrap(cfg: DictConfig):
- # execute the task
- try:
- metric_dict, object_dict = task_func(cfg=cfg)
- # things to do if exception occurs
- except Exception as ex:
- # save exception to `.log` file
- log.exception("")
- # some hyperparameter combinations might be invalid or
- # cause out-of-memory errors so when using hparam search
- # plugins like Optuna, you might want to disable
- # raising the below exception to avoid multirun failure
- raise ex
- # things to always do after either success or exception
- finally:
- # display output dir path in terminal
- log.info(f"Output dir: {cfg.paths.run_dir}")
- # always close wandb run (even if exception occurs so multirun won't fail)
- if find_spec("wandb"): # check if wandb is installed
- import wandb
- if wandb.run:
- log.info("Closing wandb!")
- wandb.finish()
- return metric_dict, object_dict
- return wrap
- def get_metric_value(metric_dict: dict, metric_name: str) -> float:
- """Safely retrieves value of the metric logged in LightningModule."""
- if not metric_name:
- log.info("Metric name is None! Skipping metric value retrieval...")
- return None
- if metric_name not in metric_dict:
- raise Exception(
- f"Metric value not found! <metric_name={metric_name}>\n"
- "Make sure metric name logged in LightningModule is correct!\n"
- "Make sure `optimized_metric` name in `hparams_search` config is correct!"
- )
- metric_value = metric_dict[metric_name].item()
- log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
- return metric_value
- def set_seed(seed: int):
- if seed < 0:
- seed = -seed
- if seed > (1 << 31):
- seed = 1 << 31
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- if torch.backends.cudnn.is_available():
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
|