|
|
@@ -4,7 +4,7 @@ import lightning.pytorch as pl
|
|
|
import torch
|
|
|
from lightning import LightningModule, Trainer
|
|
|
from lightning.pytorch.callbacks import Callback
|
|
|
-from torch import Tensor
|
|
|
+from torch import Tensor, nn
|
|
|
from torch.utils._foreach_utils import (
|
|
|
_group_tensors_by_device_and_dtype,
|
|
|
_has_foreach_support,
|
|
|
@@ -61,7 +61,7 @@ class GradNormMonitor(Callback):
|
|
|
self,
|
|
|
norm_type: float = 2.0,
|
|
|
logging_interval: str = "step",
|
|
|
- sub_module: Optional[str] = None,
|
|
|
+ sub_module: Optional[Union[str, list[str]]] = None,
|
|
|
) -> None:
|
|
|
"""
|
|
|
Args:
|
|
|
@@ -85,11 +85,21 @@ class GradNormMonitor(Callback):
|
|
|
|
|
|
lightning_model = model
|
|
|
|
|
|
- path = ""
|
|
|
- if self.sub_module is not None:
|
|
|
- model = getattr(model, self.sub_module)
|
|
|
- path = f"/{self.sub_module}"
|
|
|
+ if self.sub_module is None:
|
|
|
+ return self.log_sub_module_grad_norm(lightning_model, model, "")
|
|
|
|
|
|
+ sub_modules = self.sub_module
|
|
|
+ if isinstance(sub_modules, str):
|
|
|
+ sub_modules = [sub_modules]
|
|
|
+
|
|
|
+ for sub_module in sub_modules:
|
|
|
+ self.log_sub_module_grad_norm(
|
|
|
+ lightning_model, getattr(model, sub_module), f"/{sub_module}"
|
|
|
+ )
|
|
|
+
|
|
|
+ def log_sub_module_grad_norm(
|
|
|
+ self, lightning_model: LightningModule, model: nn.Module, path: str
|
|
|
+ ) -> None:
|
|
|
grad_norm_val = grad_norm(model.parameters(), self.norm_type)
|
|
|
if grad_norm_val is None:
|
|
|
return
|