Procházet zdrojové kódy

Update vqgan finetune config

Lengyue před 2 roky
rodič
revize
39e2a96aaa

+ 129 - 0
fish_speech/configs/vqgan_finetune.yaml

@@ -0,0 +1,129 @@
+defaults:
+  - base
+  - _self_
+
+project: vqgan_finetune
+ckpt_path: checkpoints/vqgan-v1.pth
+resume_weights_only: true
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: 4
+  strategy: ddp_find_unused_parameters_true
+  precision: 32
+  max_steps: 100000
+  val_check_interval: 1000
+
+sample_rate: 22050
+hop_length: 256
+num_mels: 80
+n_fft: 1024
+win_length: 1024
+segment_size: 256
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/demo/vq_train_filelist.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: ${segment_size}
+
+val_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/demo/vq_val_filelist.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+data:
+  _target_: fish_speech.datasets.vqgan.VQGANDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 32
+  val_batch_size: 4
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vqgan.VQGAN
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  segment_size: 8192
+  freeze_hifigan: false
+  freeze_vq: true
+
+  downsample:
+    _target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
+    dims: ["${num_mels}", 512, 256]
+    kernel_sizes: [3, 3]
+    strides: [2, 2]
+
+  mel_encoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WN
+    hidden_channels: 256
+    kernel_size: 3
+    dilation_rate: 2
+    n_layers: 6
+
+  vq_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
+    in_channels: 256
+    vq_channels: 256
+    codebook_size: 160
+    codebook_groups: 4
+    downsample: 1
+    threshold_ema_dead_code: 0  # Disable dead code removal
+
+  decoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WN
+    hidden_channels: 256
+    out_channels: ${num_mels}
+    kernel_size: 3
+    dilation_rate: 2
+    n_layers: 6
+
+  generator:
+    _target_: fish_speech.models.vqgan.modules.decoder.Generator
+    initial_channel: ${num_mels}
+    resblock: "1"
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    upsample_rates: [8, 8, 2, 2]
+    upsample_initial_channel: 512
+    upsample_kernel_sizes: [16, 16, 4, 4]
+
+  discriminator:
+    _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
+
+  mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+    f_min: 0
+    f_max: 8000
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 2e-5
+    betas: [0.8, 0.99]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.ExponentialLR
+    _partial_: true
+    gamma: 0.999999  # Estimated base on LibriTTS dataset
+
+callbacks:
+  grad_norm_monitor:
+    sub_module: 
+      - generator
+      - discriminator
+      - decoder
+
+  model_checkpoint:
+    every_n_train_steps: 1000

+ 29 - 8
fish_speech/models/vqgan/lit_module.py

@@ -98,15 +98,25 @@ class VQGAN(L.LightningModule):
 
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
-        optimizer_generator = self.optimizer_builder(
-            itertools.chain(
-                self.downsample.parameters(),
-                self.vq_encoder.parameters(),
-                self.mel_encoder.parameters(),
-                self.decoder.parameters(),
-                self.generator.parameters(),
+        components = []
+        if self.freeze_vq is False:
+            components.extend(
+                [
+                    self.downsample.parameters(),
+                    self.vq_encoder.parameters(),
+                    self.mel_encoder.parameters(),
+                ]
             )
-        )
+
+        if self.speaker_encoder is not None:
+            components.append(self.speaker_encoder.parameters())
+
+        components.append(self.decoder.parameters())
+
+        if self.freeze_hifigan is False:
+            components.append(self.generator.parameters())
+
+        optimizer_generator = self.optimizer_builder(itertools.chain(*components))
         optimizer_discriminator = self.optimizer_builder(
             self.discriminator.parameters()
         )
@@ -146,6 +156,13 @@ class VQGAN(L.LightningModule):
                 audios, sample_rate=self.sampling_rate
             )
 
+        if self.freeze_vq:
+            # Disable gradient computation for VQ
+            torch.set_grad_enabled(False)
+            self.vq_encoder.eval()
+            self.mel_encoder.eval()
+            self.downsample.eval()
+
         if self.downsample is not None:
             features = self.downsample(features)
 
@@ -175,6 +192,10 @@ class VQGAN(L.LightningModule):
         if loss_vq.ndim > 1:
             loss_vq = loss_vq.mean()
 
+        if self.freeze_vq:
+            # Enable gradient computation
+            torch.set_grad_enabled(True)
+
         # Sample mels
         speaker_features = (
             self.speaker_encoder(gt_mels, mel_masks)

+ 3 - 2
fish_speech/models/vqgan/modules/encoders.py

@@ -276,6 +276,7 @@ class VQEncoder(nn.Module):
         downsample: int = 1,
         codebook_groups: int = 1,
         codebook_layers: int = 1,
+        threshold_ema_dead_code: int = 2,
     ):
         super().__init__()
 
@@ -283,7 +284,7 @@ class VQEncoder(nn.Module):
             self.vq = GroupedResidualVQ(
                 dim=vq_channels,
                 codebook_size=codebook_size,
-                threshold_ema_dead_code=2,
+                threshold_ema_dead_code=threshold_ema_dead_code,
                 kmeans_init=False,
                 groups=codebook_groups,
                 num_quantizers=codebook_layers,
@@ -292,7 +293,7 @@ class VQEncoder(nn.Module):
             self.vq = VectorQuantize(
                 dim=vq_channels,
                 codebook_size=codebook_size,
-                threshold_ema_dead_code=2,
+                threshold_ema_dead_code=threshold_ema_dead_code,
                 kmeans_init=False,
             )
 

+ 1 - 6
tools/vqgan/create_train_split.py

@@ -13,12 +13,7 @@ def main(root):
     files = list_files(root, AUDIO_EXTENSIONS, recursive=True)
     print(f"Found {len(files)} files")
 
-    files = [
-        str(file.relative_to(root))
-        for file in tqdm(files)
-        if file.with_suffix(".npy").exists()
-    ]
-    print(f"Found {len(files)} files with features")
+    files = [str(file.relative_to(root)) for file in tqdm(files)]
 
     Random(42).shuffle(files)