Lengyue 2 лет назад
Родитель
Сommit
27d532f685

+ 14 - 43
fish_speech/configs/vqgan.yaml

@@ -7,7 +7,7 @@ project: vqgan
 # Lightning Trainer
 trainer:
   accelerator: gpu
-  devices: 8
+  devices: 4
   strategy: ddp_find_unused_parameters_true
   precision: 32
   max_steps: 1_000_000
@@ -48,59 +48,38 @@ model:
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   segment_size: 8192
-  freeze_hifigan: false
+  freeze_hifigan: true
 
   downsample:
     _target_: fish_speech.models.vq_diffusion.lit_module.ConvDownSample
-    dims: [128, 512, 128]
+    dims: ["${num_mels}", 512, 256]
     kernel_sizes: [3, 3]
     strides: [2, 2]
 
-  text_encoder:
-    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
-    in_channels: 128
-    out_channels: 256
-    hidden_channels: 192
-    hidden_channels_ffn: 768
-    n_heads: 2
+  mel_encoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WN
+    hidden_channels: 256
+    kernel_size: 3
+    dilation_rate: 2
     n_layers: 6
-    kernel_size: 1
-    dropout: 0.1
-    use_vae: false
 
   vq_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
     in_channels: 256
     vq_channels: 256
-    codebook_size: 4096
+    codebook_size: 160
+    codebook_groups: 4
     downsample: 1
 
-  speaker_encoder:
-    _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
-    in_channels: 128
-    hidden_channels: 192
-    out_channels: 256
-    num_heads: 2
-    num_layers: 4
-    p_dropout: 0.1
-
   decoder:
-    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
-    in_channels: 256
-    out_channels: ${num_mels}
-    hidden_channels: 192
-    hidden_channels_ffn: 768
-    n_heads: 2
+    hidden_channels: 256
+    kernel_size: 3
+    dilation_rate: 2
     n_layers: 6
-    kernel_size: 1
-    use_vae: false
-    dropout: 0
-    gin_channels: 256
-    speaker_cond_layer: 0
 
   generator:
     _target_: fish_speech.models.vqgan.modules.decoder.Generator
-    initial_channel: ${num_mels}
+    initial_channel: 256
     resblock: "1"
     resblock_kernel_sizes: [3, 7, 11]
     resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
@@ -123,14 +102,6 @@ model:
     f_min: 0
     f_max: 8000
 
-  feature_mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
-    sample_rate: 32000
-    n_fft: 2048
-    hop_length: 320
-    win_length: 2048
-    n_mels: 128
-
   optimizer:
     _target_: torch.optim.AdamW
     _partial_: true

+ 12 - 26
fish_speech/models/vqgan/lit_module.py

@@ -40,12 +40,11 @@ class VQGAN(L.LightningModule):
         downsample: ConvDownSampler,
         vq_encoder: VQEncoder,
         speaker_encoder: SpeakerEncoder,
-        text_encoder: TextEncoder,
+        mel_encoder: TextEncoder,
         decoder: TextEncoder,
         generator: Generator,
         discriminator: EnsembleDiscriminator,
         mel_transform: nn.Module,
-        feature_mel_transform: nn.Module,
         segment_size: int = 20480,
         hop_length: int = 640,
         sample_rate: int = 32000,
@@ -61,13 +60,11 @@ class VQGAN(L.LightningModule):
         # Generator and discriminators
         self.downsample = downsample
         self.vq_encoder = vq_encoder
-        self.speaker_encoder = speaker_encoder
-        self.text_encoder = text_encoder
+        self.mel_encoder = mel_encoder
         self.decoder = decoder
         self.generator = generator
         self.discriminator = discriminator
         self.mel_transform = mel_transform
-        self.feature_mel_transform = feature_mel_transform
 
         # Crop length for saving memory
         self.segment_size = segment_size
@@ -91,7 +88,7 @@ class VQGAN(L.LightningModule):
             for p in self.vq_encoder.parameters():
                 p.requires_grad = False
 
-            for p in self.text_encoder.parameters():
+            for p in self.mel_encoder.parameters():
                 p.requires_grad = False
 
             for p in self.downsample.parameters():
@@ -103,8 +100,7 @@ class VQGAN(L.LightningModule):
             itertools.chain(
                 self.downsample.parameters(),
                 self.vq_encoder.parameters(),
-                self.speaker_encoder.parameters(),
-                self.text_encoder.parameters(),
+                self.mel_encoder.parameters(),
                 self.decoder.parameters(),
                 self.generator.parameters(),
             )
@@ -144,8 +140,7 @@ class VQGAN(L.LightningModule):
         audios = audios[:, None, :]
 
         with torch.no_grad():
-            gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
-            features = self.feature_mel_transform(
+            features = gt_mels = self.mel_transform(
                 audios, sample_rate=self.sampling_rate
             )
 
@@ -155,9 +150,7 @@ class VQGAN(L.LightningModule):
         mel_lengths = audio_lengths // self.hop_length
         feature_lengths = (
             audio_lengths
-            / self.sampling_rate
-            * self.feature_mel_transform.sample_rate
-            / self.feature_mel_transform.hop_length
+            / self.hop_length
             / (self.downsample.total_strides if self.downsample is not None else 1)
         ).long()
 
@@ -168,17 +161,15 @@ class VQGAN(L.LightningModule):
             gt_mels.dtype
         )
 
-        speaker_features = self.speaker_encoder(features, feature_masks)
-
         # vq_features is 50 hz, need to convert to true mel size
-        text_features = self.text_encoder(features, feature_masks)
+        text_features = self.mel_encoder(features, feature_masks)
         text_features, loss_vq = self.vq_encoder(text_features, feature_masks)
         text_features = F.interpolate(
             text_features, size=gt_mels.shape[2], mode="nearest"
         )
 
         # Sample mels
-        decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
+        decoded_mels = self.decoder(text_features, mel_masks)
         fake_audios = self.generator(decoded_mels)
 
         y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
@@ -299,8 +290,7 @@ class VQGAN(L.LightningModule):
         audios = audios.float()
         audios = audios[:, None, :]
 
-        gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
-        features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
+        features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
 
         if self.downsample is not None:
             features = self.downsample(features)
@@ -308,9 +298,7 @@ class VQGAN(L.LightningModule):
         mel_lengths = audio_lengths // self.hop_length
         feature_lengths = (
             audio_lengths
-            / self.sampling_rate
-            * self.feature_mel_transform.sample_rate
-            / self.feature_mel_transform.hop_length
+            / self.hop_length
             / (self.downsample.total_strides if self.downsample is not None else 1)
         ).long()
 
@@ -321,17 +309,15 @@ class VQGAN(L.LightningModule):
             gt_mels.dtype
         )
 
-        speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-
         # vq_features is 50 hz, need to convert to true mel size
-        text_features = self.text_encoder(features, feature_masks)
+        text_features = self.mel_encoder(features, feature_masks)
         text_features, _ = self.vq_encoder(text_features, feature_masks)
         text_features = F.interpolate(
             text_features, size=gt_mels.shape[2], mode="nearest"
         )
 
         # Sample mels
-        decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
+        decoded_mels = self.decoder(text_features, mel_masks)
         fake_audios = self.generator(decoded_mels)
 
         fake_mels = self.mel_transform(fake_audios.squeeze(1))

+ 10 - 47
fish_speech/models/vqgan/modules/encoders.py

@@ -4,7 +4,7 @@ from typing import Optional
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from vector_quantize_pytorch import LFQ, VectorQuantize
+from vector_quantize_pytorch import LFQ, GroupedResidualVQ, VectorQuantize
 
 from fish_speech.models.vqgan.modules.modules import WN
 from fish_speech.models.vqgan.modules.transformer import (
@@ -283,24 +283,19 @@ class VQEncoder(nn.Module):
         vq_channels: int = 1024,
         codebook_size: int = 2048,
         downsample: int = 1,
-        kmeans_ckpt: Optional[str] = None,
-        use_lfq: bool = False,
+        codebook_groups: int = 1,
     ):
         super().__init__()
 
-        if use_lfq:
-            assert 2**vq_channels == codebook_size, (
-                "LFQ requires 2 ** vq_channels == codebook_size. "
-                f"Got vq_channels={vq_channels} and codebook_size={codebook_size}"
-            )
-
-            self.ln = nn.LayerNorm(vq_channels, eps=1e-5)
-            self.vq = LFQ(
+        if codebook_groups > 1:
+            self.vq = GroupedResidualVQ(
                 dim=vq_channels,
                 codebook_size=codebook_size,
-                entropy_loss_weight=0.1,
-                commitment_loss_weight=1,
-                diversity_gamma=2.5,
+                threshold_ema_dead_code=2,
+                kmeans_init=False,
+                channel_last=False,
+                groups=codebook_groups,
+                num_quantizers=1,
             )
         else:
             self.vq = VectorQuantize(
@@ -311,7 +306,6 @@ class VQEncoder(nn.Module):
                 channel_last=False,
             )
 
-        self.use_lfq = use_lfq
         self.downsample = downsample
         self.conv_in = nn.Conv1d(
             in_channels, vq_channels, kernel_size=downsample, stride=downsample
@@ -323,31 +317,6 @@ class VQEncoder(nn.Module):
             nn.Conv1d(vq_channels, in_channels, kernel_size=1, stride=1),
         )
 
-        if kmeans_ckpt is not None:
-            self.init_weights(kmeans_ckpt)
-
-    def init_weights(self, kmeans_ckpt):
-        torch.nn.init.normal_(
-            self.conv_in.weight,
-            mean=1 / (self.conv_in.weight.shape[0] * self.conv_in.weight.shape[-1]),
-            std=1e-2,
-        )
-        self.conv_in.bias.data.zero_()
-
-        kmeans_ckpt = "results/hubert-vq-pretrain/kmeans.pt"
-        kmeans_ckpt = torch.load(kmeans_ckpt, map_location="cpu")
-
-        centroids = kmeans_ckpt["centroids"]
-        bins = kmeans_ckpt["bins"]
-        state_dict = {
-            "_codebook.initted": torch.Tensor([True]),
-            "_codebook.cluster_size": bins,
-            "_codebook.embed": centroids,
-            "_codebook.embed_avg": centroids.clone(),
-        }
-
-        self.vq.load_state_dict(state_dict, strict=True)
-
     def forward(self, x, x_mask):
         # x: [B, C, T], x_mask: [B, 1, T]
         x_len = x.shape[2]
@@ -357,13 +326,7 @@ class VQEncoder(nn.Module):
             x_mask = F.pad(x_mask, (0, self.downsample - x_len % self.downsample))
 
         x = self.conv_in(x)
-
-        if self.use_lfq:
-            x = self.ln(x.mT)
-            q, indices, loss = self.vq(x)
-            q = q.mT
-        else:
-            q, indices, loss = self.vq(x)
+        q, indices, loss = self.vq(x)
 
         x = self.conv_out(q) * x_mask
         x = x[:, :, :x_len]

+ 10 - 4
tools/infer_vq.py

@@ -1,4 +1,5 @@
 import librosa
+import numpy as np
 import soundfile as sf
 import torch
 from hydra import compose, initialize
@@ -15,19 +16,21 @@ OmegaConf.register_new_resolver("eval", eval)
 @torch.autocast(device_type="cuda", enabled=True)
 def main():
     with initialize(version_base="1.3", config_path="../fish_speech/configs"):
-        cfg = compose(config_name="vq_naive_50hz")
+        cfg = compose(config_name="vq_naive_40hz")
 
     model: LightningModule = instantiate(cfg.model)
     state_dict = torch.load(
-        "results/vq_naive_25hz/checkpoints/step_000100000.ckpt",
+        "results/vq_naive_40hz/checkpoints/step_000675000.ckpt",
+        # "results/vq_naive_25hz/checkpoints/step_000100000.ckpt",
         map_location=model.device,
     )["state_dict"]
     model.load_state_dict(state_dict, strict=True)
     model.eval()
+    model.cuda()
     logger.info("Restored model from checkpoint")
 
     # Load audio
-    audio = librosa.load("record.wav", sr=model.sampling_rate, mono=True)[0]
+    audio = librosa.load("record1.wav", sr=model.sampling_rate, mono=True)[0]
     audios = torch.from_numpy(audio).to(model.device)[None, None, :]
     logger.info(
         f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
@@ -56,11 +59,14 @@ def main():
     mel1_masks = torch.ones([mel.shape[0], 1, mel.shape[2]], device=model.device)
 
     speaker_features = model.speaker_encoder(mel, mel1_masks)
+
+    speaker_features = model.speaker_encoder(gt_mels, mel_masks)
+    speaker_features = torch.zeros_like(speaker_features)
     decoded_mels = model.vq_decode(text_features, speaker_features, gt_mels, mel_masks)
     fake_audios = model.vocoder(decoded_mels)
 
     # Save audio
-    fake_audio = fake_audios[0, 0].cpu().numpy()
+    fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
     sf.write("fake.wav", fake_audio, model.sampling_rate)