Lengyue před 2 roky
rodič
revize
fb68b02716

+ 3 - 4
fish_speech/configs/vqgan_pretrain.yaml

@@ -7,7 +7,7 @@ project: vq_reflow_shallow_group_fsq_8x1024_wavenet
 # Lightning Trainer
 trainer:
   accelerator: gpu
-  devices: auto
+  devices: 1
   precision: bf16-mixed
   max_steps: 1_000_000
   val_check_interval: 1000
@@ -38,7 +38,7 @@ data:
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
-  batch_size: 32
+  batch_size: 128
   val_batch_size: 4
 
 # Model Configuration
@@ -52,9 +52,8 @@ model:
   freeze_encoder: false
 
   # Reflow configs
-  reflow_use_shallow: true
   reflow_inference_steps: 10
-  reflow_inference_start_t: 0.5
+  reflow_inference_start_t: 0.0
 
   encoder:
     _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet

+ 19 - 19
fish_speech/models/vqgan/lit_module.py

@@ -30,7 +30,6 @@ class VQGAN(L.LightningModule):
         weight_mel: float = 1.0,
         sampling_rate: int = 44100,
         freeze_encoder: bool = False,
-        reflow_use_shallow: bool = False,
         reflow_inference_steps: int = 10,
         reflow_inference_start_t: float = 0.5,
     ):
@@ -61,7 +60,6 @@ class VQGAN(L.LightningModule):
         self.spec_min = -12
         self.spec_max = 3
         self.sampling_rate = sampling_rate
-        self.reflow_use_shallow = reflow_use_shallow
         self.reflow_inference_steps = reflow_inference_steps
         self.reflow_inference_start_t = reflow_inference_start_t
 
@@ -133,18 +131,12 @@ class VQGAN(L.LightningModule):
 
         # Reflow, given x_1_aux, we want to reconstruct x_1
         x_1 = self.norm_spec(gt_mels)
-
-        if self.reflow_use_shallow:
-            # Detach the gradient of the generated mel
-            x_1_aux = self.norm_spec(gen_mel.detach())
-        else:
-            x_1_aux = x_1
-
-        t = torch.rand(gt_mels.shape[0], device=gt_mels.device)
+        t = torch.rand(gt_mels.shape[0], device=gt_mels.device, dtype=torch.float32)
+        t = torch.clamp(t, 1e-6, 1 - 1e-6)  # Avoid 0 and 1
         x_0 = torch.randn_like(x_1)
 
         # X_t = t * X_1 + (1 - t) * X_0
-        x_t = x_0 + t[:, None, None] * (x_1_aux - x_0)
+        x_t = x_0 + t[:, None, None] * (x_1 - x_0)
 
         v_pred = self.reflow(
             x_t,
@@ -153,13 +145,21 @@ class VQGAN(L.LightningModule):
         )
 
         # Log L2 loss with
-        weights = 0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
-        loss_reflow = weights[:, None, None] * F.mse_loss(
-            x_1 - x_0, v_pred, reduction="none"
-        )
-        loss_reflow = (loss_reflow * mel_masks_float_conv).mean(
-            dim=1
-        ).sum() / mel_masks_float_conv.sum()
+        with torch.autocast(device_type=gt_mels.device.type, dtype=torch.float32):
+            weights = (
+                0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
+            )
+            assert (
+                torch.isnan(weights).any() == False
+                and torch.isinf(weights).any() == False
+            ), "Found NaN or Inf in weights."
+
+            loss_reflow = weights[:, None, None] * F.mse_loss(
+                x_1 - x_0, v_pred, reduction="none"
+            )
+            loss_reflow = (loss_reflow * mel_masks_float_conv).mean(
+                dim=1
+            ).sum() / mel_masks_float_conv.sum()
 
         # Total loss
         loss = (
@@ -239,7 +239,7 @@ class VQGAN(L.LightningModule):
         )
 
         # Reflow inference
-        t_start = self.reflow_inference_start_t if self.reflow_use_shallow else 0.0
+        t_start = 0.0
 
         x_1 = self.norm_spec(gen_aux_mels)
         x_0 = torch.randn_like(x_1)