grad_norm.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from typing import Union
  2. import lightning.pytorch as pl
  3. import torch
  4. from lightning import LightningModule, Trainer
  5. from lightning.pytorch.callbacks import Callback
  6. from torch import Tensor
  7. from torch.utils._foreach_utils import (
  8. _group_tensors_by_device_and_dtype,
  9. _has_foreach_support,
  10. )
  11. @torch.no_grad()
  12. def grad_norm(
  13. parameters: Union[Tensor, list[Tensor]],
  14. norm_type: float = 2.0,
  15. ) -> float:
  16. """
  17. Returns the norm of the gradients of the given parameters.
  18. Args:
  19. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  20. single Tensor that will have gradients normalized
  21. norm_type (float): type of the used p-norm.
  22. Returns:
  23. Total norm of the parameter gradients (viewed as a single vector).
  24. """ # noqa: E501
  25. if isinstance(parameters, Tensor):
  26. parameters = [parameters]
  27. grads = [p.grad for p in parameters if p.grad is not None]
  28. first_device = grads[0].device
  29. grouped_grads: dict[
  30. tuple[torch.device, torch.dtype], list[list[Tensor]]
  31. ] = _group_tensors_by_device_and_dtype(
  32. [[g.detach() for g in grads]]
  33. ) # type: ignore[assignment]
  34. norms = []
  35. for (device, _), ([grads], _) in grouped_grads.items():
  36. if _has_foreach_support(grads, device=device):
  37. norms.extend(torch._foreach_norm(grads, norm_type))
  38. else:
  39. norms.extend([torch.norm(g, norm_type) for g in grads])
  40. return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
  41. class GradNormMonitor(Callback):
  42. """
  43. Callback that computes the gradient norm of the model parameters.
  44. """
  45. def __init__(self, norm_type: float = 2.0, logging_interval: str = "step") -> None:
  46. """
  47. Args:
  48. norm_type (float): type of the used p-norm.
  49. logging_interval (str): "step" or "epoch".
  50. """
  51. super().__init__()
  52. self.norm_type = norm_type
  53. self.logging_interval = logging_interval
  54. def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
  55. """
  56. Computes the gradient norm of the model parameters and logs it to the logger.
  57. Args:
  58. trainer (Trainer): The trainer object
  59. model (LightningModule): The current lightningModule
  60. """
  61. grad_norm_val = grad_norm(model.parameters(), self.norm_type)
  62. model_name = model.__class__.__name__.lower()
  63. on_step = self.logging_interval == "step"
  64. model.log(
  65. f"train/{model_name}/grad_norm",
  66. grad_norm_val,
  67. on_step=on_step,
  68. on_epoch=not on_step,
  69. )