Browse Source

Fix vq alignment

Lengyue 2 năm trước cách đây
mục cha
commit
f578dc3572

+ 7 - 1
fish_speech/configs/vq_diffusion.yaml

@@ -50,6 +50,12 @@ model:
   hop_length: ${hop_length}
   speaker_use_feats: true
 
+  downsample:
+    _target_: fish_speech.models.vq_diffusion.lit_module.ConvDownSample
+    dims: [128, 512, 128]
+    kernel_sizes: [3, 3]
+    strides: [2, 2]
+
   text_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
     in_channels: 128
@@ -104,7 +110,7 @@ model:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     sample_rate: 32000
     n_fft: 2048
-    hop_length: 1280
+    hop_length: 320
     win_length: 2048
     n_mels: 128
 

+ 0 - 2
fish_speech/models/vq_diffusion/bigvgan/bigvgan.py

@@ -366,9 +366,7 @@ class BigVGAN(nn.Module):
 
     @torch.no_grad()
     def decode(self, mel):
-        mel = F.pad(mel, (0, 10), "reflect")
         y = self.model(mel)
-        y = y[:, :, : -self.h.hop_size * 10]
         return y
 
     @torch.no_grad()

+ 78 - 4
fish_speech/models/vq_diffusion/lit_module.py

@@ -1,7 +1,8 @@
 import itertools
-from typing import Any, Callable
+from typing import Any, Callable, Optional
 
 import lightning as L
+import numpy as np
 import torch
 import torch.nn.functional as F
 import wandb
@@ -11,7 +12,6 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from torch import nn
 from tqdm import tqdm
-from transformers import HubertModel
 
 from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
 from fish_speech.models.vqgan.modules.encoders import (
@@ -22,6 +22,57 @@ from fish_speech.models.vqgan.modules.encoders import (
 from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
 
 
+class ConvDownSample(nn.Module):
+    def __init__(
+        self,
+        dims: list,
+        kernel_sizes: list,
+        strides: list,
+    ):
+        super().__init__()
+
+        self.dims = dims
+        self.kernel_sizes = kernel_sizes
+        self.strides = strides
+        self.total_strides = np.prod(self.strides)
+
+        self.convs = nn.ModuleList(
+            [
+                nn.ModuleList(
+                    [
+                        nn.Conv1d(
+                            in_channels=self.dims[i],
+                            out_channels=self.dims[i + 1],
+                            kernel_size=self.kernel_sizes[i],
+                            stride=self.strides[i],
+                            padding=(self.kernel_sizes[i] - 1) // 2,
+                        ),
+                        nn.LayerNorm(self.dims[i + 1], elementwise_affine=True),
+                        nn.GELU(),
+                    ]
+                )
+                for i in range(len(self.dims) - 1)
+            ]
+        )
+
+        self.apply(self.init_weights)
+
+    def init_weights(self, m):
+        if isinstance(m, nn.Conv1d):
+            nn.init.normal_(m.weight, std=0.02)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.ones_(m.weight)
+            nn.init.zeros_(m.bias)
+
+    def forward(self, x):
+        for conv, norm, act in self.convs:
+            x = conv(x)
+            x = norm(x.mT).mT
+            x = act(x)
+
+        return x
+
+
 class VQDiffusion(L.LightningModule):
     def __init__(
         self,
@@ -37,6 +88,7 @@ class VQDiffusion(L.LightningModule):
         hop_length: int = 640,
         sample_rate: int = 32000,
         speaker_use_feats: bool = False,
+        downsample: Optional[nn.Module] = None,
     ):
         super().__init__()
 
@@ -55,6 +107,7 @@ class VQDiffusion(L.LightningModule):
         self.speaker_encoder = speaker_encoder
         self.text_encoder = text_encoder
         self.denoiser = denoiser
+        self.downsample = downsample
 
         self.vocoder = vocoder
         self.hop_length = hop_length
@@ -100,8 +153,18 @@ class VQDiffusion(L.LightningModule):
                 audios, sample_rate=self.sampling_rate
             )
 
+        if self.downsample is not None:
+            features = self.downsample(features)
+
         mel_lengths = audio_lengths // self.hop_length
-        feature_lengths = audio_lengths // self.hop_length // 2
+        feature_lengths = (
+            audio_lengths
+            / self.sampling_rate
+            * self.feature_mel_transform.sample_rate
+            / self.feature_mel_transform.hop_length
+            / (self.downsample.total_strides if self.downsample is not None else 1)
+        ).long()
+
         feature_masks = torch.unsqueeze(
             sequence_mask(feature_lengths, features.shape[2]), 1
         ).to(gt_mels.dtype)
@@ -181,8 +244,18 @@ class VQDiffusion(L.LightningModule):
         gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
         features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
 
+        if self.downsample is not None:
+            features = self.downsample(features)
+
         mel_lengths = audio_lengths // self.hop_length
-        feature_lengths = audio_lengths // self.hop_length // 2
+        feature_lengths = (
+            audio_lengths
+            / self.sampling_rate
+            * self.feature_mel_transform.sample_rate
+            / self.feature_mel_transform.hop_length
+            / (self.downsample.total_strides if self.downsample is not None else 1)
+        ).long()
+
         feature_masks = torch.unsqueeze(
             sequence_mask(feature_lengths, features.shape[2]), 1
         ).to(gt_mels.dtype)
@@ -222,6 +295,7 @@ class VQDiffusion(L.LightningModule):
             ).prev_sample
 
         sampled_mels = self.denormalize_mels(sampled_mels)
+        sampled_mels = sampled_mels * mel_masks
 
         with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
             # Run vocoder on fp32