Lengyue 2 年之前
父節點
當前提交
df304ac438

+ 2 - 2
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -13,7 +13,7 @@ trainer:
   gradient_clip_algorithm: 'norm'
   precision: 16-mixed
   max_steps: 1_000_000
-  val_check_interval: 1000
+  val_check_interval: 5000
 
 sample_rate: 44100
 hop_length: 512
@@ -67,7 +67,7 @@ model:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
     in_channels: 128
     vq_channels: 128
-    codebook_size: 4096
+    codebook_size: 16384
     downsample: 1
 
   speaker_encoder:

+ 1 - 1
fish_speech/models/vq_diffusion/lit_module.py

@@ -218,7 +218,7 @@ class VQDiffusion(L.LightningModule):
             # Run vocoder on fp32
             fake_audios = self.vocoder.decode(sampled_mels.float())
 
-        mel_loss = F.l1_loss(gt_mels, sampled_mels)
+        mel_loss = F.l1_loss(gt_mels * mel_masks, sampled_mels * mel_masks)
         self.log(
             "val/mel_loss",
             mel_loss,

+ 7 - 7
fish_speech/models/vqgan/modules/encoders.py

@@ -3,7 +3,7 @@ from typing import Optional
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from vector_quantize_pytorch import VectorQuantize
+from vector_quantize_pytorch import LFQ, VectorQuantize
 
 from fish_speech.models.vqgan.modules.modules import WN
 from fish_speech.models.vqgan.modules.transformer import (
@@ -234,12 +234,12 @@ class VQEncoder(nn.Module):
     ):
         super().__init__()
 
-        self.vq = VectorQuantize(
+        self.vq = LFQ(
             dim=vq_channels,
             codebook_size=codebook_size,
-            threshold_ema_dead_code=2,
-            kmeans_init=False,
-            channel_last=False,
+            # threshold_ema_dead_code=2,
+            # kmeans_init=False,
+            # channel_last=False,
         )
         self.downsample = downsample
         self.conv_in = nn.Conv1d(
@@ -286,8 +286,8 @@ class VQEncoder(nn.Module):
             x_mask = F.pad(x_mask, (0, self.downsample - x_len % self.downsample))
 
         x = self.conv_in(x)
-        q, _, loss = self.vq(x)
-        x = self.conv_out(q) * x_mask
+        q, _, loss = self.vq(x.mT)
+        x = self.conv_out(q.mT) * x_mask
         x = x[:, :, :x_len]
 
         return x, loss