소스 검색

Update vqgan default config & fix vqgan inference

Lengyue 2 년 전
부모
커밋
e2413e25b1

+ 30 - 30
fish_speech/configs/vqgan_finetune.yaml

@@ -49,51 +49,53 @@ model:
   _target_: fish_speech.models.vqgan.VQGAN
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
-  segment_size: 8192
-  mode: finetune
 
-  downsample:
-    _target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
-    dims: ["${num_mels}", 512, 256]
-    kernel_sizes: [3, 3]
-    strides: [2, 2]
-
-  mel_encoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WN
-    hidden_channels: 256
+  encoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
+    hidden_channels: 512
     kernel_size: 3
     dilation_rate: 2
-    n_layers: 6
+    n_layers: 20
+    in_channels: ${num_mels}
 
-  vq_encoder:
+  vq:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
-    in_channels: 256
-    vq_channels: 256
-    codebook_size: 160
+    in_channels: 512
+    vq_channels: 512
+    codebook_size: 256
     codebook_groups: 4
-    downsample: 1
-    threshold_ema_dead_code: 0  # Disable dead code removal
+    codebook_layers: 2
+    downsample: 4
 
   decoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WN
-    hidden_channels: 256
-    out_channels: ${num_mels}
+    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
+    hidden_channels: 512
     kernel_size: 3
     dilation_rate: 2
-    n_layers: 6
+    n_layers: 20
+    out_channels: ${num_mels}
 
   generator:
-    _target_: fish_speech.models.vqgan.modules.decoder.Generator
-    initial_channel: ${num_mels}
-    resblock: "1"
+    _target_: fish_speech.models.vqgan.modules.decoder_v2.HiFiGANGenerator
+    hop_length: ${hop_length}
+    upsample_rates: [8, 8, 2, 2, 2]  # aka. strides
+    upsample_kernel_sizes: [16, 16, 4, 4, 4]
     resblock_kernel_sizes: [3, 7, 11]
     resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    upsample_rates: [8, 8, 2, 2]
+    num_mels: ${num_mels}
     upsample_initial_channel: 512
-    upsample_kernel_sizes: [16, 16, 4, 4]
+    use_template: true
+    pre_conv_kernel_size: 7
+    post_conv_kernel_size: 7
+    ckpt_path: checkpoints/hifi-gan-base-002000000.ckpt
 
   discriminator:
-    _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
+    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
+    hidden_channels: 256
+    kernel_size: 3
+    dilation_rate: 2
+    n_layers: 6
+    in_channels: ${num_mels}
 
   mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
@@ -102,8 +104,6 @@ model:
     hop_length: ${hop_length}
     win_length: ${win_length}
     n_mels: ${num_mels}
-    f_min: 0
-    f_max: 8000
 
   optimizer:
     _target_: torch.optim.AdamW

+ 61 - 49
fish_speech/configs/vqgan_pretrain.yaml

@@ -2,35 +2,49 @@ defaults:
   - base
   - _self_
 
-project: vqgan
+project: vqgan_pretrain_v2_large_30
 
 # Lightning Trainer
 trainer:
   accelerator: gpu
   devices: auto
   strategy: ddp_find_unused_parameters_true
-  precision: 32
-  max_steps: 1_000_000
+  precision: bf16-mixed
+  max_steps: 10_000_000
   val_check_interval: 5000
 
-sample_rate: 22050
-hop_length: 256
-num_mels: 80
-n_fft: 1024
-win_length: 1024
+sample_rate: 44100
+hop_length: 512
+num_mels: 160
+n_fft: 2048
+win_length: 2048
 segment_size: 256
 
 # Dataset Configuration
 train_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/filelist.split.train
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  slice_frames: ${segment_size}
+  _target_: fish_speech.datasets.vqgan.MixDatast
+  datasets:
+    high-quality-441:
+      prob: 0.5
+      dataset:
+        _target_: fish_speech.datasets.vqgan.VQGANDataset
+        filelist: data/vocoder_data_441/vq_train_filelist.txt
+        sample_rate: ${sample_rate}
+        hop_length: ${hop_length}
+        slice_frames: ${segment_size}
+    
+    common-voice:
+      prob: 0.5
+      dataset:
+        _target_: fish_speech.datasets.vqgan.VQGANDataset
+        filelist: data/cv-corpus-16.0-2023-12-06/vq_train_filelist.txt
+        sample_rate: ${sample_rate}
+        hop_length: ${hop_length}
+        slice_frames: ${segment_size}
 
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/filelist.split.valid
+  filelist: data/vocoder_data_441/vq_val_filelist.txt
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
 
@@ -47,52 +61,53 @@ model:
   _target_: fish_speech.models.vqgan.VQGAN
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
-  segment_size: 8192
-  mode: pretrain-stage1
 
-  downsample:
-    _target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
-    dims: ["${num_mels}", 512, 256]
-    kernel_sizes: [3, 3]
-    strides: [2, 2]
-
-  mel_encoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WN
-    hidden_channels: 256
+  encoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
+    hidden_channels: 512
     kernel_size: 3
     dilation_rate: 2
-    n_layers: 6
+    n_layers: 20
+    in_channels: ${num_mels}
 
-  vq_encoder:
+  vq:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
-    in_channels: 256
-    vq_channels: 256
-    codebook_size: 160
+    in_channels: 512
+    vq_channels: 512
+    codebook_size: 256
     codebook_groups: 4
-    downsample: 1
+    codebook_layers: 2
+    downsample: 4
 
   decoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WN
-    hidden_channels: 256
-    out_channels: ${num_mels}
+    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
+    hidden_channels: 512
     kernel_size: 3
     dilation_rate: 2
-    n_layers: 6
+    n_layers: 20
+    out_channels: ${num_mels}
 
   generator:
-    _target_: fish_speech.models.vqgan.modules.decoder.Generator
-    initial_channel: ${num_mels}
-    resblock: "1"
+    _target_: fish_speech.models.vqgan.modules.decoder_v2.HiFiGANGenerator
+    hop_length: ${hop_length}
+    upsample_rates: [8, 8, 2, 2, 2]  # aka. strides
+    upsample_kernel_sizes: [16, 16, 4, 4, 4]
     resblock_kernel_sizes: [3, 7, 11]
     resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    upsample_rates: [8, 8, 2, 2]
+    num_mels: ${num_mels}
     upsample_initial_channel: 512
-    upsample_kernel_sizes: [16, 16, 4, 4]
-    # ckpt_path: "checkpoints/hifigan-v1-universal-22050/g_02500000"
+    use_template: true
+    pre_conv_kernel_size: 7
+    post_conv_kernel_size: 7
+    ckpt_path: checkpoints/hifi-gan-base-002000000.ckpt
 
   discriminator:
-    _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
-    # ckpt_path: checkpoints/hifigan-v1-universal-22050/do_02500000
+    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
+    hidden_channels: 256
+    kernel_size: 3
+    dilation_rate: 2
+    n_layers: 6
+    in_channels: ${num_mels}
 
   mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
@@ -101,8 +116,6 @@ model:
     hop_length: ${hop_length}
     win_length: ${win_length}
     n_mels: ${num_mels}
-    f_min: 0
-    f_max: 8000
 
   optimizer:
     _target_: torch.optim.AdamW
@@ -119,8 +132,7 @@ model:
 callbacks:
   grad_norm_monitor:
     sub_module: 
-      - generator
-      - discriminator
-      - mel_encoder
-      - vq_encoder
+      - encoder
+      - vq
       - decoder
+      - discriminator

+ 0 - 138
fish_speech/configs/vqgan_pretrain_v2.yaml

@@ -1,138 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: vqgan_pretrain_v2_large_30
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: auto
-  strategy: ddp_find_unused_parameters_true
-  precision: bf16-mixed
-  max_steps: 10_000_000
-  val_check_interval: 5000
-
-sample_rate: 44100
-hop_length: 512
-num_mels: 160
-n_fft: 2048
-win_length: 2048
-segment_size: 256
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.vqgan.MixDatast
-  datasets:
-    high-quality-441:
-      prob: 0.5
-      dataset:
-        _target_: fish_speech.datasets.vqgan.VQGANDataset
-        filelist: data/vocoder_data_441/vq_train_filelist.txt
-        sample_rate: ${sample_rate}
-        hop_length: ${hop_length}
-        slice_frames: ${segment_size}
-    
-    common-voice:
-      prob: 0.5
-      dataset:
-        _target_: fish_speech.datasets.vqgan.VQGANDataset
-        filelist: data/cv-corpus-16.0-2023-12-06/vq_train_filelist.txt
-        sample_rate: ${sample_rate}
-        hop_length: ${hop_length}
-        slice_frames: ${segment_size}
-
-val_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vocoder_data_441/vq_val_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-data:
-  _target_: fish_speech.datasets.vqgan.VQGANDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 32
-  val_batch_size: 4
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vqgan.VQGAN
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-  encoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
-    hidden_channels: 512
-    kernel_size: 3
-    dilation_rate: 2
-    n_layers: 20
-    in_channels: ${num_mels}
-
-  vq:
-    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
-    in_channels: 512
-    vq_channels: 512
-    codebook_size: 256
-    codebook_groups: 4
-    codebook_layers: 2
-    downsample: 4
-
-  decoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
-    hidden_channels: 512
-    kernel_size: 3
-    dilation_rate: 2
-    n_layers: 20
-    out_channels: ${num_mels}
-
-  generator:
-    _target_: fish_speech.models.vqgan.modules.decoder_v2.HiFiGANGenerator
-    hop_length: ${hop_length}
-    upsample_rates: [8, 8, 2, 2, 2]  # aka. strides
-    upsample_kernel_sizes: [16, 16, 4, 4, 4]
-    resblock_kernel_sizes: [3, 7, 11]
-    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    num_mels: ${num_mels}
-    upsample_initial_channel: 512
-    use_template: true
-    pre_conv_kernel_size: 7
-    post_conv_kernel_size: 7
-    ckpt_path: checkpoints/hifi-gan-base-002000000.ckpt
-
-  discriminator:
-    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
-    hidden_channels: 256
-    kernel_size: 3
-    dilation_rate: 2
-    n_layers: 6
-    in_channels: ${num_mels}
-
-  mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 2e-4
-    betas: [0.8, 0.99]
-    eps: 1e-5
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.ExponentialLR
-    _partial_: true
-    gamma: 0.999999  # Estimated base on LibriTTS dataset
-
-callbacks:
-  grad_norm_monitor:
-    sub_module: 
-      - encoder
-      - vq
-      - decoder
-      - discriminator

+ 3 - 0
fish_speech/models/vqgan/lit_module.py

@@ -33,6 +33,7 @@ class VQEncodeResult:
 @dataclass
 class VQDecodeResult:
     mels: torch.Tensor
+    audios: Optional[torch.Tensor] = None
 
 
 class VQGAN(L.LightningModule):
@@ -417,6 +418,7 @@ class VQGAN(L.LightningModule):
         features=None,
         audio_lengths=None,
         feature_lengths=None,
+        return_audios=False,
     ):
         assert (
             indices is not None or features is not None
@@ -441,4 +443,5 @@ class VQGAN(L.LightningModule):
 
         return VQDecodeResult(
             mels=decoded,
+            audios=self.generator(decoded) if return_audios else None,
         )

+ 3 - 1
tools/vqgan/inference.py

@@ -84,7 +84,9 @@ def main(input_path, output_path, config_name, checkpoint_path):
 
     # Restore
     feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
-    decoded = model.decode(indices=indices[None], feature_lengths=feature_lengths)
+    decoded = model.decode(
+        indices=indices[None], feature_lengths=feature_lengths, return_audios=True
+    )
     fake_audios = decoded.audios
     audio_time = fake_audios.shape[-1] / model.sampling_rate