grad_norm.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from typing import Optional, Union
  2. import lightning.pytorch as pl
  3. import torch
  4. from lightning import LightningModule, Trainer
  5. from lightning.pytorch.callbacks import Callback
  6. from torch import Tensor, nn
  7. from torch.utils._foreach_utils import (
  8. _group_tensors_by_device_and_dtype,
  9. _has_foreach_support,
  10. )
  11. @torch.no_grad()
  12. def grad_norm(
  13. parameters: Union[Tensor, list[Tensor]],
  14. norm_type: float = 2.0,
  15. ) -> float:
  16. """
  17. Returns the norm of the gradients of the given parameters.
  18. Args:
  19. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  20. single Tensor that will have gradients normalized
  21. norm_type (float): type of the used p-norm.
  22. Returns:
  23. Total norm of the parameter gradients (viewed as a single vector).
  24. """ # noqa: E501
  25. if isinstance(parameters, Tensor):
  26. parameters = [parameters]
  27. grads = [p.grad for p in parameters if p.grad is not None]
  28. if len(grads) == 0:
  29. return None
  30. first_device = grads[0].device
  31. grouped_grads: dict[
  32. tuple[torch.device, torch.dtype], list[list[Tensor]]
  33. ] = _group_tensors_by_device_and_dtype(
  34. [[g.detach() for g in grads]]
  35. ) # type: ignore[assignment]
  36. norms = []
  37. for (device, _), ([grads], _) in grouped_grads.items():
  38. if _has_foreach_support(grads, device=device):
  39. norms.extend(torch._foreach_norm(grads, norm_type))
  40. else:
  41. norms.extend([torch.norm(g, norm_type) for g in grads])
  42. return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
  43. class GradNormMonitor(Callback):
  44. """
  45. Callback that computes the gradient norm of the model parameters.
  46. """
  47. def __init__(
  48. self,
  49. norm_type: float = 2.0,
  50. logging_interval: str = "step",
  51. sub_module: Optional[Union[str, list[str]]] = None,
  52. ) -> None:
  53. """
  54. Args:
  55. norm_type (float): type of the used p-norm.
  56. logging_interval (str): "step" or "epoch".
  57. """
  58. super().__init__()
  59. self.norm_type = norm_type
  60. self.logging_interval = logging_interval
  61. self.sub_module = sub_module
  62. def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
  63. """
  64. Computes the gradient norm of the model parameters and logs it to the logger.
  65. Args:
  66. trainer (Trainer): The trainer object
  67. model (LightningModule): The current lightningModule
  68. """
  69. lightning_model = model
  70. if self.sub_module is None:
  71. return self.log_sub_module_grad_norm(lightning_model, model, "")
  72. sub_modules = self.sub_module
  73. if isinstance(sub_modules, str):
  74. sub_modules = [sub_modules]
  75. for sub_module in sub_modules:
  76. self.log_sub_module_grad_norm(
  77. lightning_model, getattr(model, sub_module), f"/{sub_module}"
  78. )
  79. def log_sub_module_grad_norm(
  80. self, lightning_model: LightningModule, model: nn.Module, path: str
  81. ) -> None:
  82. grad_norm_val = grad_norm(model.parameters(), self.norm_type)
  83. if grad_norm_val is None:
  84. return
  85. on_step = self.logging_interval == "step"
  86. lightning_model.log(
  87. f"train{path}/grad_norm",
  88. grad_norm_val,
  89. on_step=on_step,
  90. on_epoch=not on_step,
  91. )