Lengyue 2 lat temu
rodzic
commit
a1c3890b6d

+ 4 - 3
fish_speech/configs/vqgan.yaml

@@ -72,14 +72,16 @@ model:
     downsample: 1
 
   decoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WN
     hidden_channels: 256
+    out_channels: ${num_mels}
     kernel_size: 3
     dilation_rate: 2
     n_layers: 6
 
   generator:
     _target_: fish_speech.models.vqgan.modules.decoder.Generator
-    initial_channel: 256
+    initial_channel: ${num_mels}
     resblock: "1"
     resblock_kernel_sizes: [3, 7, 11]
     resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
@@ -119,7 +121,6 @@ callbacks:
     sub_module: 
       - generator
       - discriminator
-      - text_encoder
+      - mel_encoder
       - vq_encoder
-      - speaker_encoder
       - decoder

+ 2 - 3
fish_speech/models/vqgan/lit_module.py

@@ -39,7 +39,6 @@ class VQGAN(L.LightningModule):
         lr_scheduler: Callable,
         downsample: ConvDownSampler,
         vq_encoder: VQEncoder,
-        speaker_encoder: SpeakerEncoder,
         mel_encoder: TextEncoder,
         decoder: TextEncoder,
         generator: Generator,
@@ -163,7 +162,7 @@ class VQGAN(L.LightningModule):
 
         # vq_features is 50 hz, need to convert to true mel size
         text_features = self.mel_encoder(features, feature_masks)
-        text_features, loss_vq = self.vq_encoder(text_features, feature_masks)
+        text_features, _, loss_vq = self.vq_encoder(text_features, feature_masks)
         text_features = F.interpolate(
             text_features, size=gt_mels.shape[2], mode="nearest"
         )
@@ -311,7 +310,7 @@ class VQGAN(L.LightningModule):
 
         # vq_features is 50 hz, need to convert to true mel size
         text_features = self.mel_encoder(features, feature_masks)
-        text_features, _ = self.vq_encoder(text_features, feature_masks)
+        text_features, _, _ = self.vq_encoder(text_features, feature_masks)
         text_features = F.interpolate(
             text_features, size=gt_mels.shape[2], mode="nearest"
         )

+ 6 - 3
fish_speech/models/vqgan/modules/encoders.py

@@ -293,7 +293,6 @@ class VQEncoder(nn.Module):
                 codebook_size=codebook_size,
                 threshold_ema_dead_code=2,
                 kmeans_init=False,
-                channel_last=False,
                 groups=codebook_groups,
                 num_quantizers=1,
             )
@@ -303,9 +302,9 @@ class VQEncoder(nn.Module):
                 codebook_size=codebook_size,
                 threshold_ema_dead_code=2,
                 kmeans_init=False,
-                channel_last=False,
             )
 
+        self.codebook_groups = codebook_groups
         self.downsample = downsample
         self.conv_in = nn.Conv1d(
             in_channels, vq_channels, kernel_size=downsample, stride=downsample
@@ -326,7 +325,11 @@ class VQEncoder(nn.Module):
             x_mask = F.pad(x_mask, (0, self.downsample - x_len % self.downsample))
 
         x = self.conv_in(x)
-        q, indices, loss = self.vq(x)
+        q, indices, loss = self.vq(x.mT)
+        q = q.mT
+
+        if self.codebook_groups > 1:
+            loss = loss.mean()
 
         x = self.conv_out(q) * x_mask
         x = x[:, :, :x_len]

+ 12 - 1
fish_speech/models/vqgan/modules/modules.py

@@ -19,6 +19,7 @@ class WN(nn.Module):
         n_layers,
         gin_channels=0,
         p_dropout=0,
+        out_channels=None,
     ):
         super(WN, self).__init__()
         assert kernel_size % 2 == 1
@@ -56,6 +57,10 @@ class WN(nn.Module):
             res_skip_layer = weight_norm(res_skip_layer, name="weight")
             self.res_skip_layers.append(res_skip_layer)
 
+        self.out_channels = out_channels
+        if out_channels is not None:
+            self.out_layer = nn.Conv1d(hidden_channels, out_channels, 1)
+
     def forward(self, x, x_mask, g=None, **kwargs):
         output = torch.zeros_like(x)
         n_channels_tensor = torch.IntTensor([self.hidden_channels])
@@ -81,7 +86,13 @@ class WN(nn.Module):
                 output = output + res_skip_acts[:, self.hidden_channels :, :]
             else:
                 output = output + res_skip_acts
-        return output * x_mask
+
+        x = output * x_mask
+
+        if self.out_channels is not None:
+            x = self.out_layer(x)
+
+        return x
 
     def remove_weight_norm(self):
         if self.gin_channels != 0: