Explorar el Código

Support log grad norm of multiple components

Lengyue hace 2 años
padre
commit
356f8fa6c8

+ 16 - 6
fish_speech/callbacks/grad_norm.py

@@ -4,7 +4,7 @@ import lightning.pytorch as pl
 import torch
 from lightning import LightningModule, Trainer
 from lightning.pytorch.callbacks import Callback
-from torch import Tensor
+from torch import Tensor, nn
 from torch.utils._foreach_utils import (
     _group_tensors_by_device_and_dtype,
     _has_foreach_support,
@@ -61,7 +61,7 @@ class GradNormMonitor(Callback):
         self,
         norm_type: float = 2.0,
         logging_interval: str = "step",
-        sub_module: Optional[str] = None,
+        sub_module: Optional[Union[str, list[str]]] = None,
     ) -> None:
         """
         Args:
@@ -85,11 +85,21 @@ class GradNormMonitor(Callback):
 
         lightning_model = model
 
-        path = ""
-        if self.sub_module is not None:
-            model = getattr(model, self.sub_module)
-            path = f"/{self.sub_module}"
+        if self.sub_module is None:
+            return self.log_sub_module_grad_norm(lightning_model, model, "")
 
+        sub_modules = self.sub_module
+        if isinstance(sub_modules, str):
+            sub_modules = [sub_modules]
+
+        for sub_module in sub_modules:
+            self.log_sub_module_grad_norm(
+                lightning_model, getattr(model, sub_module), f"/{sub_module}"
+            )
+
+    def log_sub_module_grad_norm(
+        self, lightning_model: LightningModule, model: nn.Module, path: str
+    ) -> None:
         grad_norm_val = grad_norm(model.parameters(), self.norm_type)
         if grad_norm_val is None:
             return

+ 3 - 6
fish_speech/models/vq_diffusion/lit_module.py

@@ -129,12 +129,9 @@ class VQDiffusion(L.LightningModule):
         model_output = self.denoiser(noisy_images, timesteps, mel_masks, text_features)
 
         # MSE loss without the mask
-        # noise_loss = (
-        #     (model_output * mel_masks - normalized_gt_mels * mel_masks) ** 2
-        # ).sum() / (mel_masks.sum() * gt_mels.shape[1])
-        noise_loss = torch.abs(
-            model_output * mel_masks - normalized_gt_mels * mel_masks
-        ).sum() / (mel_masks.sum() * gt_mels.shape[1])
+        noise_loss = ((model_output * mel_masks - noise * mel_masks) ** 2).sum() / (
+            mel_masks.sum() * gt_mels.shape[1]
+        )
 
         self.log(
             "train/noise_loss",