Lengyue 2 лет назад
Родитель
Сommit
f578dc3572

+ 7 - 1
fish_speech/configs/vq_diffusion.yaml

@@ -50,6 +50,12 @@ model:
   hop_length: ${hop_length}
   hop_length: ${hop_length}
   speaker_use_feats: true
   speaker_use_feats: true
 
 
+  downsample:
+    _target_: fish_speech.models.vq_diffusion.lit_module.ConvDownSample
+    dims: [128, 512, 128]
+    kernel_sizes: [3, 3]
+    strides: [2, 2]
+
   text_encoder:
   text_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
     _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
     in_channels: 128
     in_channels: 128
@@ -104,7 +110,7 @@ model:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     sample_rate: 32000
     sample_rate: 32000
     n_fft: 2048
     n_fft: 2048
-    hop_length: 1280
+    hop_length: 320
     win_length: 2048
     win_length: 2048
     n_mels: 128
     n_mels: 128
 
 

+ 0 - 2
fish_speech/models/vq_diffusion/bigvgan/bigvgan.py

@@ -366,9 +366,7 @@ class BigVGAN(nn.Module):
 
 
     @torch.no_grad()
     @torch.no_grad()
     def decode(self, mel):
     def decode(self, mel):
-        mel = F.pad(mel, (0, 10), "reflect")
         y = self.model(mel)
         y = self.model(mel)
-        y = y[:, :, : -self.h.hop_size * 10]
         return y
         return y
 
 
     @torch.no_grad()
     @torch.no_grad()

+ 78 - 4
fish_speech/models/vq_diffusion/lit_module.py

@@ -1,7 +1,8 @@
 import itertools
 import itertools
-from typing import Any, Callable
+from typing import Any, Callable, Optional
 
 
 import lightning as L
 import lightning as L
+import numpy as np
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
 import wandb
 import wandb
@@ -11,7 +12,6 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from matplotlib import pyplot as plt
 from torch import nn
 from torch import nn
 from tqdm import tqdm
 from tqdm import tqdm
-from transformers import HubertModel
 
 
 from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
 from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
 from fish_speech.models.vqgan.modules.encoders import (
 from fish_speech.models.vqgan.modules.encoders import (
@@ -22,6 +22,57 @@ from fish_speech.models.vqgan.modules.encoders import (
 from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
 from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
 
 
 
 
+class ConvDownSample(nn.Module):
+    def __init__(
+        self,
+        dims: list,
+        kernel_sizes: list,
+        strides: list,
+    ):
+        super().__init__()
+
+        self.dims = dims
+        self.kernel_sizes = kernel_sizes
+        self.strides = strides
+        self.total_strides = np.prod(self.strides)
+
+        self.convs = nn.ModuleList(
+            [
+                nn.ModuleList(
+                    [
+                        nn.Conv1d(
+                            in_channels=self.dims[i],
+                            out_channels=self.dims[i + 1],
+                            kernel_size=self.kernel_sizes[i],
+                            stride=self.strides[i],
+                            padding=(self.kernel_sizes[i] - 1) // 2,
+                        ),
+                        nn.LayerNorm(self.dims[i + 1], elementwise_affine=True),
+                        nn.GELU(),
+                    ]
+                )
+                for i in range(len(self.dims) - 1)
+            ]
+        )
+
+        self.apply(self.init_weights)
+
+    def init_weights(self, m):
+        if isinstance(m, nn.Conv1d):
+            nn.init.normal_(m.weight, std=0.02)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.ones_(m.weight)
+            nn.init.zeros_(m.bias)
+
+    def forward(self, x):
+        for conv, norm, act in self.convs:
+            x = conv(x)
+            x = norm(x.mT).mT
+            x = act(x)
+
+        return x
+
+
 class VQDiffusion(L.LightningModule):
 class VQDiffusion(L.LightningModule):
     def __init__(
     def __init__(
         self,
         self,
@@ -37,6 +88,7 @@ class VQDiffusion(L.LightningModule):
         hop_length: int = 640,
         hop_length: int = 640,
         sample_rate: int = 32000,
         sample_rate: int = 32000,
         speaker_use_feats: bool = False,
         speaker_use_feats: bool = False,
+        downsample: Optional[nn.Module] = None,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
@@ -55,6 +107,7 @@ class VQDiffusion(L.LightningModule):
         self.speaker_encoder = speaker_encoder
         self.speaker_encoder = speaker_encoder
         self.text_encoder = text_encoder
         self.text_encoder = text_encoder
         self.denoiser = denoiser
         self.denoiser = denoiser
+        self.downsample = downsample
 
 
         self.vocoder = vocoder
         self.vocoder = vocoder
         self.hop_length = hop_length
         self.hop_length = hop_length
@@ -100,8 +153,18 @@ class VQDiffusion(L.LightningModule):
                 audios, sample_rate=self.sampling_rate
                 audios, sample_rate=self.sampling_rate
             )
             )
 
 
+        if self.downsample is not None:
+            features = self.downsample(features)
+
         mel_lengths = audio_lengths // self.hop_length
         mel_lengths = audio_lengths // self.hop_length
-        feature_lengths = audio_lengths // self.hop_length // 2
+        feature_lengths = (
+            audio_lengths
+            / self.sampling_rate
+            * self.feature_mel_transform.sample_rate
+            / self.feature_mel_transform.hop_length
+            / (self.downsample.total_strides if self.downsample is not None else 1)
+        ).long()
+
         feature_masks = torch.unsqueeze(
         feature_masks = torch.unsqueeze(
             sequence_mask(feature_lengths, features.shape[2]), 1
             sequence_mask(feature_lengths, features.shape[2]), 1
         ).to(gt_mels.dtype)
         ).to(gt_mels.dtype)
@@ -181,8 +244,18 @@ class VQDiffusion(L.LightningModule):
         gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
         gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
         features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
         features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
 
 
+        if self.downsample is not None:
+            features = self.downsample(features)
+
         mel_lengths = audio_lengths // self.hop_length
         mel_lengths = audio_lengths // self.hop_length
-        feature_lengths = audio_lengths // self.hop_length // 2
+        feature_lengths = (
+            audio_lengths
+            / self.sampling_rate
+            * self.feature_mel_transform.sample_rate
+            / self.feature_mel_transform.hop_length
+            / (self.downsample.total_strides if self.downsample is not None else 1)
+        ).long()
+
         feature_masks = torch.unsqueeze(
         feature_masks = torch.unsqueeze(
             sequence_mask(feature_lengths, features.shape[2]), 1
             sequence_mask(feature_lengths, features.shape[2]), 1
         ).to(gt_mels.dtype)
         ).to(gt_mels.dtype)
@@ -222,6 +295,7 @@ class VQDiffusion(L.LightningModule):
             ).prev_sample
             ).prev_sample
 
 
         sampled_mels = self.denormalize_mels(sampled_mels)
         sampled_mels = self.denormalize_mels(sampled_mels)
+        sampled_mels = sampled_mels * mel_masks
 
 
         with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
         with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
             # Run vocoder on fp32
             # Run vocoder on fp32