Browse Source

Fix hop length & vocoder artifact

Lengyue 2 years ago
parent
commit
38aa5ea106

+ 3 - 3
fish_speech/configs/vq_diffusion.yaml

@@ -12,7 +12,7 @@ trainer:
   gradient_clip_val: 1.0
   gradient_clip_val: 1.0
   gradient_clip_algorithm: 'norm'
   gradient_clip_algorithm: 'norm'
   precision: 16-mixed
   precision: 16-mixed
-  max_steps: 1_000_000
+  max_steps: 300_000
   val_check_interval: 5000
   val_check_interval: 5000
 
 
 sample_rate: 24000
 sample_rate: 24000
@@ -84,7 +84,7 @@ model:
     _target_: fish_speech.models.vq_diffusion.wavenet.WaveNet
     _target_: fish_speech.models.vq_diffusion.wavenet.WaveNet
     d_encoder: 128
     d_encoder: 128
     mel_channels: 100
     mel_channels: 100
-    residual_channels: 384
+    residual_channels: 512
     residual_layers: 20
     residual_layers: 20
 
 
   vocoder:
   vocoder:
@@ -104,7 +104,7 @@ model:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     sample_rate: 32000
     sample_rate: 32000
     n_fft: 2048
     n_fft: 2048
-    hop_length: 640
+    hop_length: 1280
     win_length: 2048
     win_length: 2048
     n_mels: 128
     n_mels: 128
 
 

+ 2 - 0
fish_speech/models/vq_diffusion/bigvgan/bigvgan.py

@@ -366,7 +366,9 @@ class BigVGAN(nn.Module):
 
 
     @torch.no_grad()
     @torch.no_grad()
     def decode(self, mel):
     def decode(self, mel):
+        mel = F.pad(mel, (0, 10), "reflect")
         y = self.model(mel)
         y = self.model(mel)
+        y = y[:, :, : -self.h.hop_size * 10]
         return y
         return y
 
 
     @torch.no_grad()
     @torch.no_grad()