Lengyue 2 лет назад
Родитель
Сommit
639e2d047e

+ 125 - 0
fish_speech/configs/vqgan_pretrain_v2.yaml

@@ -0,0 +1,125 @@
+defaults:
+  - base
+  - _self_
+
+project: vqgan_pretrain_v2
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: 4
+  strategy: ddp_find_unused_parameters_true
+  precision: 32
+  max_steps: 1_000_000
+  val_check_interval: 5000
+
+sample_rate: 44100
+hop_length: 512
+num_mels: 128
+n_fft: 2048
+win_length: 2048
+segment_size: 256
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/Genshin/vq_train_filelist.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: ${segment_size}
+
+val_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/Genshin/vq_val_filelist.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+data:
+  _target_: fish_speech.datasets.vqgan.VQGANDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 32
+  val_batch_size: 4
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vqgan.VQGAN
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  segment_size: 8192
+  mode: pretrain-stage1
+
+  downsample:
+    _target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
+    dims: ["${num_mels}", 512, 256]
+    kernel_sizes: [3, 3]
+    strides: [2, 2]
+
+  mel_encoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WN
+    hidden_channels: 256
+    kernel_size: 3
+    dilation_rate: 2
+    n_layers: 12
+
+  vq_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
+    in_channels: 256
+    vq_channels: 256
+    codebook_size: 1024
+    codebook_layers: 4
+    downsample: 1
+
+  decoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WN
+    hidden_channels: 256
+    out_channels: ${num_mels}
+    kernel_size: 3
+    dilation_rate: 2
+    n_layers: 6
+
+  generator:
+    _target_: fish_speech.models.vqgan.modules.decoder.Generator
+    initial_channel: ${num_mels}
+    resblock: "1"
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    upsample_rates: [8, 8, 2, 2, 2]
+    upsample_initial_channel: 512
+    upsample_kernel_sizes: [16, 16, 4, 4, 4]
+
+  discriminator:
+    _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
+    periods: [2, 3, 5, 7, 11, 17, 23, 37]
+
+  mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+    f_min: 0
+    f_max: 16000
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 2e-4
+    betas: [0.8, 0.99]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.ExponentialLR
+    _partial_: true
+    gamma: 0.999999  # Estimated base on LibriTTS dataset
+
+callbacks:
+  grad_norm_monitor:
+    sub_module: 
+      - generator
+      - discriminator
+      - mel_encoder
+      - vq_encoder
+      - decoder

+ 79 - 58
fish_speech/models/vqgan/lit_module.py

@@ -1,5 +1,5 @@
 import itertools
-from typing import Any, Callable
+from typing import Any, Callable, Literal
 
 import lightning as L
 import torch
@@ -47,12 +47,15 @@ class VQGAN(L.LightningModule):
         segment_size: int = 20480,
         hop_length: int = 640,
         sample_rate: int = 32000,
-        freeze_hifigan: bool = False,
-        freeze_vq: bool = False,
+        mode: Literal["pretrain-stage1", "pretrain-stage2", "finetune"] = "finetune",
         speaker_encoder: SpeakerEncoder = None,
     ):
         super().__init__()
 
+        # pretrain-stage1: vq use gt mel as target, hifigan use gt mel as input
+        # pretrain-stage2: end-to-end training, use gt mel as hifi gan target
+        # finetune: end-to-end training, use gt mel as hifi gan target but freeze vq
+
         # Model parameters
         self.optimizer_builder = optimizer
         self.lr_scheduler_builder = lr_scheduler
@@ -71,22 +74,13 @@ class VQGAN(L.LightningModule):
         self.segment_size = segment_size
         self.hop_length = hop_length
         self.sampling_rate = sample_rate
-        self.freeze_hifigan = freeze_hifigan
-        self.freeze_vq = freeze_vq
+        self.mode = mode
 
         # Disable automatic optimization
         self.automatic_optimization = False
 
-        # Stage 1: Train the VQ only
-        if self.freeze_hifigan:
-            for p in self.discriminator.parameters():
-                p.requires_grad = False
-
-            for p in self.generator.parameters():
-                p.requires_grad = False
-
-        # Stage 2: Train the HifiGAN + Decoder + Generator
-        if freeze_vq:
+        # Finetune: Train the VQ only
+        if self.mode == "finetune":
             for p in self.vq_encoder.parameters():
                 p.requires_grad = False
 
@@ -99,7 +93,7 @@ class VQGAN(L.LightningModule):
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
         components = []
-        if self.freeze_vq is False:
+        if self.mode != "finetune":
             components.extend(
                 [
                     self.downsample.parameters(),
@@ -114,9 +108,7 @@ class VQGAN(L.LightningModule):
         if self.decoder is not None:
             components.append(self.decoder.parameters())
 
-        if self.freeze_hifigan is False:
-            components.append(self.generator.parameters())
-
+        components.append(self.generator.parameters())
         optimizer_generator = self.optimizer_builder(itertools.chain(*components))
         optimizer_discriminator = self.optimizer_builder(
             self.discriminator.parameters()
@@ -157,7 +149,7 @@ class VQGAN(L.LightningModule):
                 audios, sample_rate=self.sampling_rate
             )
 
-        if self.freeze_vq:
+        if self.mode == "finetune":
             # Disable gradient computation for VQ
             torch.set_grad_enabled(False)
             self.vq_encoder.eval()
@@ -183,9 +175,7 @@ class VQGAN(L.LightningModule):
 
         # vq_features is 50 hz, need to convert to true mel size
         text_features = self.mel_encoder(features, feature_masks)
-        text_features, _, loss_vq = self.vq_encoder(
-            text_features, feature_masks, freeze_codebook=self.freeze_vq
-        )
+        text_features, _, loss_vq = self.vq_encoder(text_features, feature_masks)
         text_features = F.interpolate(
             text_features, size=gt_mels.shape[2], mode="nearest"
         )
@@ -193,7 +183,7 @@ class VQGAN(L.LightningModule):
         if loss_vq.ndim > 1:
             loss_vq = loss_vq.mean()
 
-        if self.freeze_vq:
+        if self.mode == "finetune":
             # Enable gradient computation
             torch.set_grad_enabled(True)
 
@@ -208,55 +198,69 @@ class VQGAN(L.LightningModule):
         else:
             decoded_mels = text_features
 
-        fake_audios = self.generator(decoded_mels)
-
-        y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
-
-        y, ids_slice = rand_slice_segments(audios, audio_lengths, self.segment_size)
-        y_hat = slice_segments(fake_audios, ids_slice, self.segment_size)
+        input_mels = gt_mels if self.mode == "pretrain-stage1" else decoded_mels
+        if self.segment_size is not None:
+            audios, ids_slice = rand_slice_segments(
+                audios, audio_lengths, self.segment_size
+            )
+            input_mels = slice_segments(
+                input_mels,
+                ids_slice // self.hop_length,
+                self.segment_size // self.hop_length,
+            )
+            gen_mel_masks = slice_segments(
+                mel_masks,
+                ids_slice // self.hop_length,
+                self.segment_size // self.hop_length,
+            )
 
-        assert y.shape == y_hat.shape, f"{y.shape} != {y_hat.shape}"
+        fake_audios = self.generator(input_mels)
+        fake_audio_mels = self.mel_transform(fake_audios.squeeze(1))
+        assert (
+            audios.shape == fake_audios.shape
+        ), f"{audios.shape} != {fake_audios.shape}"
 
-        # Since we don't want to update the discriminator, we skip the backward pass
-        if self.freeze_hifigan is False:
-            # Discriminator
-            y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
+        # Discriminator
+        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios.detach())
 
-            with torch.autocast(device_type=audios.device.type, enabled=False):
-                loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
+        with torch.autocast(device_type=audios.device.type, enabled=False):
+            loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
 
-            self.log(
-                "train/discriminator/loss",
-                loss_disc_all,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=True,
-                logger=True,
-                sync_dist=True,
-            )
+        self.log(
+            "train/discriminator/loss",
+            loss_disc_all,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
 
-            optim_d.zero_grad()
-            self.manual_backward(loss_disc_all)
-            self.clip_gradients(
-                optim_d, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
-            )
-            optim_d.step()
+        optim_d.zero_grad()
+        self.manual_backward(loss_disc_all)
+        self.clip_gradients(
+            optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+        )
+        optim_d.step()
 
-        y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
+        y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(audios, fake_audios)
 
         with torch.autocast(device_type=audios.device.type, enabled=False):
             loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
-            loss_mel = F.l1_loss(gt_mels * mel_masks, y_hat_mels * mel_masks)
+            loss_mel = F.l1_loss(
+                input_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
+            )
             loss_adv, _ = generator_loss(y_d_hat_g)
             loss_fm = feature_loss(fmap_r, fmap_g)
 
-            if self.freeze_hifigan is True:
-                loss_gen_all = loss_decoded_mel + loss_vq
+            if self.mode == "pretrain-stage1":
+                loss_vq_all = loss_decoded_mel + loss_vq
+                loss_gen_all = loss_mel * 45 + loss_fm + loss_adv
             else:
                 loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
 
         self.log(
-            "train/generator/loss",
+            "train/generator/loss_gen_all",
             loss_gen_all,
             on_step=True,
             on_epoch=False,
@@ -264,6 +268,18 @@ class VQGAN(L.LightningModule):
             logger=True,
             sync_dist=True,
         )
+
+        if self.mode == "pretrain-stage1":
+            self.log(
+                "train/generator/loss_vq_all",
+                loss_vq_all,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=True,
+                logger=True,
+                sync_dist=True,
+            )
+
         self.log(
             "train/generator/loss_decoded_mel",
             loss_decoded_mel,
@@ -311,9 +327,14 @@ class VQGAN(L.LightningModule):
         )
 
         optim_g.zero_grad()
+
+        # Only backpropagate loss_vq_all in pretrain-stage1
+        if self.mode == "pretrain-stage1":
+            self.manual_backward(loss_vq_all)
+
         self.manual_backward(loss_gen_all)
         self.clip_gradients(
-            optim_g, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
+            optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
         )
         optim_g.step()
 

+ 1 - 2
fish_speech/models/vqgan/modules/discriminator.py

@@ -117,9 +117,8 @@ class DiscriminatorS(nn.Module):
 
 
 class EnsembleDiscriminator(nn.Module):
-    def __init__(self, ckpt_path=None):
+    def __init__(self, ckpt_path=None, periods=(2, 3, 5, 7, 11)):
         super(EnsembleDiscriminator, self).__init__()
-        periods = [2, 3, 5, 7, 11]  # [1, 2, 3, 5, 7, 11]
 
         discs = [DiscriminatorS(use_spectral_norm=True)]
         discs = discs + [DiscriminatorP(i, use_spectral_norm=False) for i in periods]

+ 1 - 1
fish_speech/models/vqgan/modules/encoders.py

@@ -309,7 +309,7 @@ class VQEncoder(nn.Module):
             nn.Conv1d(vq_channels, in_channels, kernel_size=1, stride=1),
         )
 
-    def forward(self, x, x_mask, freeze_codebook=False):
+    def forward(self, x, x_mask):
         # x: [B, C, T], x_mask: [B, 1, T]
         x_len = x.shape[2]