瀏覽代碼

Implement new vqgan

Lengyue 2 年之前
父節點
當前提交
a89035df75
共有 2 個文件被更改,包括 100 次插入28 次删除
  1. 88 28
      fish_speech/models/vqgan/lit_module.py
  2. 12 0
      fish_speech/models/vqgan/utils.py

+ 88 - 28
fish_speech/models/vqgan/lit_module.py

@@ -1,3 +1,4 @@
+import itertools
 import math
 from typing import Any, Callable
 
@@ -9,9 +10,9 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from torch import nn
 
-from fish_speech.models.vqgan.modules.wavenet import WaveNet
-from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
 from fish_speech.models.vqgan.modules.discriminator import Discriminator
+from fish_speech.models.vqgan.modules.wavenet import WaveNet
+from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
 
 
 class VQGAN(L.LightningModule):
@@ -55,8 +56,6 @@ class VQGAN(L.LightningModule):
         self.weight_mel = weight_mel
 
         # Other parameters
-        self.spec_min = -12
-        self.spec_max = 3
         self.sampling_rate = sampling_rate
 
         # Disable strict loading
@@ -69,7 +68,7 @@ class VQGAN(L.LightningModule):
 
             for param in self.quantizer.parameters():
                 param.requires_grad = False
-            
+
         self.automatic_optimization = False
 
     def on_save_checkpoint(self, checkpoint):
@@ -80,24 +79,42 @@ class VQGAN(L.LightningModule):
                 state_dict.pop(name)
 
     def configure_optimizers(self):
-        # optimizer = self.optimizer_builder(self.parameters())
-        # lr_scheduler = self.lr_scheduler_builder(optimizer)
-
-        # return {
-        #     "optimizer": optimizer,
-        #     "lr_scheduler": {
-        #         "scheduler": lr_scheduler,
-        #         "interval": "step",
-        #     },
-        # }
-
-    def norm_spec(self, x):
-        return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
+        optimizer_generator = self.optimizer_builder(
+            itertools.chain(
+                self.encoder.parameters(),
+                self.quantizer.parameters(),
+                self.decoder.parameters(),
+            )
+        )
+        optimizer_discriminator = self.optimizer_builder(
+            self.discriminator.parameters()
+        )
 
-    def denorm_spec(self, x):
-        return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
+        lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
+        lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
+
+        return (
+            {
+                "optimizer": optimizer_generator,
+                "lr_scheduler": {
+                    "scheduler": lr_scheduler_generator,
+                    "interval": "step",
+                    "name": "optimizer/generator",
+                },
+            },
+            {
+                "optimizer": optimizer_discriminator,
+                "lr_scheduler": {
+                    "scheduler": lr_scheduler_discriminator,
+                    "interval": "step",
+                    "name": "optimizer/discriminator",
+                },
+            },
+        )
 
     def training_step(self, batch, batch_idx):
+        optim_g, optim_d = self.optimizers()
+
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
 
         audios = audios.float()
@@ -128,13 +145,48 @@ class VQGAN(L.LightningModule):
             * mel_masks_float_conv
         )
 
+        # Discriminator
+        real_logits = self.discriminator(gt_mels)
+        fake_logits = self.discriminator(gen_mel.detach())
+        d_mask = F.interpolate(
+            mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
+        )
+
+        loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
+        loss_fake = avg_with_mask(fake_logits**2, d_mask)
+
+        loss_d = loss_real + loss_fake
+
+        self.log(
+            "train/discriminator/loss",
+            loss_d,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+        )
+
+        # Discriminator backward
+        optim_d.zero_grad()
+        self.manual_backward(loss_d)
+        self.clip_gradients(
+            optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+        )
+        optim_d.step()
+
         # Mel Loss
-        loss_mel = (gen_mel - gt_mels).abs().mean(
-            dim=1, keepdim=True
-        ).sum() / mel_masks_float_conv.sum()
+        loss_mel = avg_with_mask((gen_mel - gt_mels).abs(), mel_masks_float_conv)
+
+        # Adversarial Loss
+        fake_logits = self.discriminator(gen_mel)
+        loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
 
         # Total loss
-        loss = self.weight_vq * loss_vq + self.weight_mel * loss_mel
+        loss = (
+            self.weight_vq * loss_vq
+            + self.weight_mel * loss_mel
+            + self.weight_adv * loss_adv
+        )
 
         # Log losses
         self.log(
@@ -162,7 +214,17 @@ class VQGAN(L.LightningModule):
             logger=True,
         )
 
-        return loss
+        # Generator backward
+        optim_g.zero_grad()
+        self.manual_backward(loss)
+        self.clip_gradients(
+            optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+        )
+        optim_g.step()
+
+        scheduler_g, scheduler_d = self.lr_schedulers()
+        scheduler_g.step()
+        scheduler_d.step()
 
     def validation_step(self, batch: Any, batch_idx: int):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
@@ -190,9 +252,7 @@ class VQGAN(L.LightningModule):
             )
             * mel_masks_float_conv
         )
-        loss_mel = (gen_aux_mels - gt_mels).abs().mean(
-            dim=1, keepdim=True
-        ).sum() / mel_masks_float_conv.sum()
+        loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
 
         self.log(
             "val/loss_mel",

+ 12 - 0
fish_speech/models/vqgan/utils.py

@@ -80,3 +80,15 @@ def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
     acts = t_act * s_act
 
     return acts
+
+
+def avg_with_mask(x, mask):
+    assert mask.dtype == torch.float, "Mask should be float"
+
+    if mask.ndim == 2:
+        mask = mask.unsqueeze(1)
+
+    if mask.shape[1] == 1:
+        mask = mask.expand_as(x)
+
+    return (x * mask).sum() / mask.sum()