| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- import logging
- from typing import Mapping, Optional
- from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
- class RankedLogger(logging.LoggerAdapter):
- """A multi-GPU-friendly python command line logger."""
- def __init__(
- self,
- name: str = __name__,
- rank_zero_only: bool = True,
- extra: Optional[Mapping[str, object]] = None,
- ) -> None:
- """Initializes a multi-GPU-friendly python command line logger that logs on all processes
- with their rank prefixed in the log message.
- :param name: The name of the logger. Default is ``__name__``.
- :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
- :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
- """
- logger = logging.getLogger(name)
- super().__init__(logger=logger, extra=extra)
- self.rank_zero_only = rank_zero_only
- def log(
- self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs
- ) -> None:
- """Delegate a log call to the underlying logger, after prefixing its message with the rank
- of the process it's being logged from. If `'rank'` is provided, then the log will only
- occur on that rank/process.
- :param level: The level to log at. Look at `logging.__init__.py` for more information.
- :param msg: The message to log.
- :param rank: The rank to log at.
- :param args: Additional args to pass to the underlying logging function.
- :param kwargs: Any additional keyword args to pass to the underlying logging function.
- """
- if self.isEnabledFor(level):
- msg, kwargs = self.process(msg, kwargs)
- current_rank = getattr(rank_zero_only, "rank", None)
- if current_rank is None:
- raise RuntimeError(
- "The `rank_zero_only.rank` needs to be set before use"
- )
- msg = rank_prefixed_message(msg, current_rank)
- if self.rank_zero_only:
- if current_rank == 0:
- self.logger.log(level, msg, *args, **kwargs)
- else:
- if rank is None:
- self.logger.log(level, msg, *args, **kwargs)
- elif current_rank == rank:
- self.logger.log(level, msg, *args, **kwargs)
|