Lengyue пре 2 година
родитељ
комит
e2173f82df

+ 9 - 7
fish_speech/configs/hubert_vq.yaml

@@ -58,15 +58,15 @@ model:
     inter_channels: 192
     hidden_channels: 192
     filter_channels: 768
-    n_heads: 4
-    n_layers: 8
-    n_layers_q: 8
-    n_layers_flow: 8
+    n_heads: 2
+    n_layers: 6
+    n_layers_q: 6
+    n_layers_flow: 6
+    n_layers_spk: 4
     n_flows: 4
-    n_layers_spk: 6
     kernel_size: 3
     p_dropout: 0.1
-    speaker_cond_layer: 0
+    speaker_cond_layer: 2
     resblock: "1"
     resblock_kernel_sizes: [3, 7, 11]
     resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
@@ -74,6 +74,8 @@ model:
     upsample_initial_channel: 512
     upsample_kernel_sizes: [16, 16, 8, 2, 2]
     gin_channels: 512 # basically the speaker embedding size
+    kmeans_ckpt: results/hubert-vq-pretrain/kmeans.pt
+    codebook_size: 2048
 
   discriminator:
     _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
@@ -99,7 +101,7 @@ model:
     lr_lambda:
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _partial_: true
-      num_warmup_steps: 1000
+      num_warmup_steps: 0
       num_training_steps: ${trainer.max_steps}
       final_lr_ratio: 0.05
 

+ 6 - 0
fish_speech/datasets/vqgan.py

@@ -52,6 +52,12 @@ class VQGANDataset(Dataset):
                 start * self.hop_length : (start + self.slice_frames) * self.hop_length
             ]
 
+        if features.shape[0] % 2 != 0:
+            features = features[:-1]
+
+        if len(audio) > len(features) * self.hop_length:
+            audio = audio[: features.shape[0] * self.hop_length]
+
         if len(audio) < len(features) * self.hop_length:
             audio = np.pad(
                 audio,

+ 21 - 14
fish_speech/models/vqgan/lit_module.py

@@ -8,7 +8,7 @@ import wandb
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from torch import nn
-from vector_quantize_pytorch import ResidualLFQ
+from vector_quantize_pytorch import VectorQuantize
 
 from fish_speech.models.vqgan.losses import (
     discriminator_loss,
@@ -93,9 +93,7 @@ class VQGAN(L.LightningModule):
 
         with torch.no_grad():
             gt_mels = self.mel_transform(audios)
-            assert (
-                gt_mels.shape[2] == features.shape[1]
-            ), f"Shapes do not match: {gt_mels.shape}, {features.shape}"
+            gt_mels = gt_mels[:, :, : features.shape[1]]
 
         (
             y_hat,
@@ -105,6 +103,7 @@ class VQGAN(L.LightningModule):
             (z_q, z_p),
             (m_p, logs_p),
             (m_q, logs_q),
+            vq_loss,
         ) = self.generator(features, feature_lengths, gt_mels)
 
         y_hat_mel = self.mel_transform(y_hat.squeeze(1))
@@ -148,7 +147,15 @@ class VQGAN(L.LightningModule):
                 z_mask=x_mask,
             )
 
-            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + loss_kl * 1
+            # Cyclical kl loss
+            # then 500 steps linear: 0.1
+            # then 500 steps 0.1
+            # then go back to 0
+
+            beta = self.global_step % 1000
+            beta = min(beta, 500) / 500 * 0.1 + 1e-6
+
+            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + loss_kl * beta + vq_loss
 
         self.log(
             "train/generator/loss",
@@ -195,15 +202,15 @@ class VQGAN(L.LightningModule):
             logger=True,
             sync_dist=True,
         )
-        # self.log(
-        #     "train/generator/loss_vq",
-        #     prior.loss,
-        #     on_step=True,
-        #     on_epoch=False,
-        #     prog_bar=False,
-        #     logger=True,
-        #     sync_dist=True,
-        # )
+        self.log(
+            "train/generator/loss_vq",
+            vq_loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
 
         optim_g.zero_grad()
         self.manual_backward(loss_gen_all)

+ 79 - 27
fish_speech/models/vqgan/modules/encoders.py

@@ -1,8 +1,8 @@
-import math
-from dataclasses import dataclass
+from typing import Optional
 
 import torch
 import torch.nn as nn
+from vector_quantize_pytorch import VectorQuantize
 
 from fish_speech.models.vqgan.modules.modules import WN
 from fish_speech.models.vqgan.modules.transformer import RelativePositionTransformer
@@ -158,7 +158,6 @@ class SpeakerEncoder(nn.Module):
     ) -> None:
         super().__init__()
 
-        self.query = nn.Parameter(torch.randn(1, 1, hidden_channels))
         self.in_proj = nn.Sequential(
             nn.Conv1d(in_channels, hidden_channels, 1),
             nn.SiLU(),
@@ -168,17 +167,16 @@ class SpeakerEncoder(nn.Module):
             nn.SiLU(),
             nn.Dropout(p_dropout),
         )
-
-        self.blocks = nn.ModuleList(
-            [
-                nn.MultiheadAttention(
-                    embed_dim=hidden_channels,
-                    num_heads=num_heads,
-                    dropout=p_dropout,
-                    batch_first=True,
-                )
-                for _ in range(num_layers)
-            ]
+        self.encoder = RelativePositionTransformer(
+            in_channels=hidden_channels,
+            out_channels=hidden_channels,
+            hidden_channels=hidden_channels,
+            hidden_channels_ffn=hidden_channels,
+            n_heads=num_heads,
+            n_layers=num_layers,
+            kernel_size=5,
+            dropout=p_dropout,
+            window_size=4,
         )
         self.out_proj = nn.Linear(hidden_channels, out_channels)
 
@@ -189,22 +187,76 @@ class SpeakerEncoder(nn.Module):
             - x_lengths: :math:`[B, 1]`
         """
 
-        x_mask = ~(sequence_mask(mel_lengths, mels.size(2)).bool())
+        x_mask = torch.unsqueeze(sequence_mask(mel_lengths, mels.size(2)), 1).to(
+            mels.dtype
+        )
+        x = self.in_proj(mels) * x_mask
+        x = self.encoder(x, x_mask)
+
+        # Avg Pooling
+        x = x * x_mask
+        x = torch.sum(x, dim=2) / torch.sum(x_mask, dim=2)
+        x = self.out_proj(x)[..., None]
+
+        return x
+
+
+class VQEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels: int = 1024,
+        vq_channels: int = 1024,
+        codebook_size: int = 2048,
+        downsample: int = 2,
+        kmeans_ckpt: Optional[str] = None,
+    ):
+        super().__init__()
+
+        self.vq = VectorQuantize(
+            dim=vq_channels,
+            codebook_size=codebook_size,
+            threshold_ema_dead_code=2,
+            kmeans_init=False,
+            channel_last=False,
+        )
 
-        x = self.in_proj(mels).transpose(1, 2)
-        x = torch.cat([self.query.expand(x.shape[0], -1, -1), x], dim=1)
+        self.conv_in = nn.Conv1d(
+            in_channels, vq_channels, kernel_size=downsample, stride=downsample
+        )
+        self.conv_out = nn.Sequential(
+            nn.Upsample(scale_factor=downsample, mode="nearest"),
+            nn.Conv1d(vq_channels, in_channels, kernel_size=1, stride=1),
+        )
+
+        if kmeans_ckpt is not None:
+            self.init_weights(kmeans_ckpt)
 
-        x_mask = torch.cat(
-            [
-                torch.zeros(x.shape[0], 1, dtype=torch.bool, device=x.device),
-                x_mask,
-            ],
-            dim=1,
+    def init_weights(self, kmeans_ckpt):
+        torch.nn.init.normal_(
+            self.conv_in.weight,
+            mean=1 / (self.conv_in.weight.shape[0] * self.conv_in.weight.shape[-1]),
+            std=1e-2,
         )
+        self.conv_in.bias.data.zero_()
 
-        for block in self.blocks:
-            x = block(x, x, x, key_padding_mask=x_mask)[0]
+        kmeans_ckpt = "results/hubert-vq-pretrain/kmeans.pt"
+        kmeans_ckpt = torch.load(kmeans_ckpt, map_location="cpu")
 
-        x = self.out_proj(x[:, :1, :]).mT
+        centroids = kmeans_ckpt["centroids"]
+        bins = kmeans_ckpt["bins"]
+        state_dict = {
+            "_codebook.initted": torch.Tensor([True]),
+            "_codebook.cluster_size": bins,
+            "_codebook.embed": centroids,
+            "_codebook.embed_avg": centroids.clone(),
+        }
 
-        return x
+        self.vq.load_state_dict(state_dict, strict=True)
+
+    def forward(self, x):
+        # x: [B, T, C]
+        x = self.conv_in(x.mT)
+        q, _, loss = self.vq(x)
+        x = self.conv_out(q).mT
+
+        return x, loss

+ 104 - 1
fish_speech/models/vqgan/modules/flow.py

@@ -1,11 +1,114 @@
 import torch
 from torch import nn
 
-from fish_speech.models.vqgan.modules.modules import Flip
+from fish_speech.models.vqgan.modules.modules import WN, Flip
 from fish_speech.models.vqgan.modules.normalization import LayerNorm
 from fish_speech.models.vqgan.modules.transformer import FFN, MultiHeadAttention
 
 
+class ResidualCouplingBlock(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        n_flows=4,
+        gin_channels=0,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.n_flows = n_flows
+        self.gin_channels = gin_channels
+
+        self.flows = nn.ModuleList()
+
+        for i in range(n_flows):
+            self.flows.append(
+                ResidualCouplingLayer(
+                    channels,
+                    hidden_channels,
+                    kernel_size,
+                    dilation_rate,
+                    n_layers,
+                    gin_channels=gin_channels,
+                    mean_only=True,
+                )
+            )
+            self.flows.append(Flip())
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        if not reverse:
+            for flow in self.flows:
+                x, _ = flow(x, x_mask, g=g, reverse=reverse)
+        else:
+            for flow in reversed(self.flows):
+                x = flow(x, x_mask, g=g, reverse=reverse)
+        return x
+
+
+class ResidualCouplingLayer(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        p_dropout=0,
+        gin_channels=0,
+        mean_only=False,
+    ):
+        assert channels % 2 == 0, "channels should be divisible by 2"
+        super().__init__()
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.half_channels = channels // 2
+        self.mean_only = mean_only
+
+        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+        self.enc = WN(
+            hidden_channels,
+            kernel_size,
+            dilation_rate,
+            n_layers,
+            p_dropout=p_dropout,
+            gin_channels=gin_channels,
+        )
+        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+        self.post.weight.data.zero_()
+        self.post.bias.data.zero_()
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+        h = self.pre(x0) * x_mask
+        h = self.enc(h, x_mask, g=g)
+        stats = self.post(h) * x_mask
+        if not self.mean_only:
+            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+        else:
+            m = stats
+            logs = torch.zeros_like(m)
+
+        if not reverse:
+            x1 = m + x1 * torch.exp(logs) * x_mask
+            x = torch.cat([x0, x1], 1)
+            logdet = torch.sum(logs, [1, 2])
+            return x, logdet
+        else:
+            x1 = (x1 - m) * torch.exp(-logs) * x_mask
+            x = torch.cat([x0, x1], 1)
+            return x
+
+
 class TransformerCouplingBlock(nn.Module):
     def __init__(
         self,

+ 16 - 6
fish_speech/models/vqgan/modules/models.py

@@ -6,8 +6,9 @@ from fish_speech.models.vqgan.modules.encoders import (
     PosteriorEncoder,
     SpeakerEncoder,
     TextEncoder,
+    VQEncoder,
 )
-from fish_speech.models.vqgan.modules.flow import TransformerCouplingBlock
+from fish_speech.models.vqgan.modules.flow import ResidualCouplingBlock
 from fish_speech.models.vqgan.utils import rand_slice_segments
 
 
@@ -41,11 +42,19 @@ class SynthesizerTrn(nn.Module):
         upsample_initial_channel,
         upsample_kernel_sizes,
         gin_channels,
+        codebook_size,
+        kmeans_ckpt=None,
     ):
         super().__init__()
 
         self.segment_size = segment_size
 
+        self.vq = VQEncoder(
+            in_channels=in_channels,
+            vq_channels=in_channels,
+            codebook_size=codebook_size,
+            kmeans_ckpt=kmeans_ckpt,
+        )
         self.enc_p = TextEncoder(
             in_channels,
             inter_channels,
@@ -66,14 +75,12 @@ class SynthesizerTrn(nn.Module):
             num_layers=n_layers_spk,
             p_dropout=p_dropout,
         )
-        self.flow = TransformerCouplingBlock(
+        self.flow = ResidualCouplingBlock(
             channels=inter_channels,
             hidden_channels=hidden_channels,
-            filter_channels=filter_channels,
-            n_heads=n_heads,
-            n_layers=n_layers_flow,
             kernel_size=5,
-            p_dropout=p_dropout,
+            dilation_rate=1,
+            n_layers=n_layers_flow,
             n_flows=n_flows,
             gin_channels=gin_channels,
         )
@@ -99,6 +106,7 @@ class SynthesizerTrn(nn.Module):
 
     def forward(self, x, x_lengths, y):
         g = self.enc_spk(y, x_lengths)
+        x, vq_loss = self.vq(x)
 
         _, m_p, logs_p, _, x_mask = self.enc_p(x, x_lengths, g=g)
         z_q, m_q, logs_q, y_mask = self.enc_q(y, x_lengths, g=g)
@@ -115,10 +123,12 @@ class SynthesizerTrn(nn.Module):
             (z_q, z_p),
             (m_p, logs_p),
             (m_q, logs_q),
+            vq_loss,
         )
 
     def infer(self, x, x_lengths, y, max_len=None, noise_scale=0.35):
         g = self.enc_spk(y, x_lengths)
+        x, vq_loss = self.vq(x)
         z_p, m_p, logs_p, h_text, x_mask = self.enc_p(
             x, x_lengths, g=g, noise_scale=noise_scale
         )