Przeglądaj źródła

Add downsample VQ

Lengyue 2 lat temu
rodzic
commit
e2173f82df

+ 9 - 7
fish_speech/configs/hubert_vq.yaml

@@ -58,15 +58,15 @@ model:
     inter_channels: 192
     inter_channels: 192
     hidden_channels: 192
     hidden_channels: 192
     filter_channels: 768
     filter_channels: 768
-    n_heads: 4
-    n_layers: 8
-    n_layers_q: 8
-    n_layers_flow: 8
+    n_heads: 2
+    n_layers: 6
+    n_layers_q: 6
+    n_layers_flow: 6
+    n_layers_spk: 4
     n_flows: 4
     n_flows: 4
-    n_layers_spk: 6
     kernel_size: 3
     kernel_size: 3
     p_dropout: 0.1
     p_dropout: 0.1
-    speaker_cond_layer: 0
+    speaker_cond_layer: 2
     resblock: "1"
     resblock: "1"
     resblock_kernel_sizes: [3, 7, 11]
     resblock_kernel_sizes: [3, 7, 11]
     resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
     resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
@@ -74,6 +74,8 @@ model:
     upsample_initial_channel: 512
     upsample_initial_channel: 512
     upsample_kernel_sizes: [16, 16, 8, 2, 2]
     upsample_kernel_sizes: [16, 16, 8, 2, 2]
     gin_channels: 512 # basically the speaker embedding size
     gin_channels: 512 # basically the speaker embedding size
+    kmeans_ckpt: results/hubert-vq-pretrain/kmeans.pt
+    codebook_size: 2048
 
 
   discriminator:
   discriminator:
     _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
     _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
@@ -99,7 +101,7 @@ model:
     lr_lambda:
     lr_lambda:
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _partial_: true
       _partial_: true
-      num_warmup_steps: 1000
+      num_warmup_steps: 0
       num_training_steps: ${trainer.max_steps}
       num_training_steps: ${trainer.max_steps}
       final_lr_ratio: 0.05
       final_lr_ratio: 0.05
 
 

+ 6 - 0
fish_speech/datasets/vqgan.py

@@ -52,6 +52,12 @@ class VQGANDataset(Dataset):
                 start * self.hop_length : (start + self.slice_frames) * self.hop_length
                 start * self.hop_length : (start + self.slice_frames) * self.hop_length
             ]
             ]
 
 
+        if features.shape[0] % 2 != 0:
+            features = features[:-1]
+
+        if len(audio) > len(features) * self.hop_length:
+            audio = audio[: features.shape[0] * self.hop_length]
+
         if len(audio) < len(features) * self.hop_length:
         if len(audio) < len(features) * self.hop_length:
             audio = np.pad(
             audio = np.pad(
                 audio,
                 audio,

+ 21 - 14
fish_speech/models/vqgan/lit_module.py

@@ -8,7 +8,7 @@ import wandb
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from matplotlib import pyplot as plt
 from torch import nn
 from torch import nn
-from vector_quantize_pytorch import ResidualLFQ
+from vector_quantize_pytorch import VectorQuantize
 
 
 from fish_speech.models.vqgan.losses import (
 from fish_speech.models.vqgan.losses import (
     discriminator_loss,
     discriminator_loss,
@@ -93,9 +93,7 @@ class VQGAN(L.LightningModule):
 
 
         with torch.no_grad():
         with torch.no_grad():
             gt_mels = self.mel_transform(audios)
             gt_mels = self.mel_transform(audios)
-            assert (
-                gt_mels.shape[2] == features.shape[1]
-            ), f"Shapes do not match: {gt_mels.shape}, {features.shape}"
+            gt_mels = gt_mels[:, :, : features.shape[1]]
 
 
         (
         (
             y_hat,
             y_hat,
@@ -105,6 +103,7 @@ class VQGAN(L.LightningModule):
             (z_q, z_p),
             (z_q, z_p),
             (m_p, logs_p),
             (m_p, logs_p),
             (m_q, logs_q),
             (m_q, logs_q),
+            vq_loss,
         ) = self.generator(features, feature_lengths, gt_mels)
         ) = self.generator(features, feature_lengths, gt_mels)
 
 
         y_hat_mel = self.mel_transform(y_hat.squeeze(1))
         y_hat_mel = self.mel_transform(y_hat.squeeze(1))
@@ -148,7 +147,15 @@ class VQGAN(L.LightningModule):
                 z_mask=x_mask,
                 z_mask=x_mask,
             )
             )
 
 
-            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + loss_kl * 1
+            # Cyclical kl loss
+            # then 500 steps linear: 0.1
+            # then 500 steps 0.1
+            # then go back to 0
+
+            beta = self.global_step % 1000
+            beta = min(beta, 500) / 500 * 0.1 + 1e-6
+
+            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + loss_kl * beta + vq_loss
 
 
         self.log(
         self.log(
             "train/generator/loss",
             "train/generator/loss",
@@ -195,15 +202,15 @@ class VQGAN(L.LightningModule):
             logger=True,
             logger=True,
             sync_dist=True,
             sync_dist=True,
         )
         )
-        # self.log(
-        #     "train/generator/loss_vq",
-        #     prior.loss,
-        #     on_step=True,
-        #     on_epoch=False,
-        #     prog_bar=False,
-        #     logger=True,
-        #     sync_dist=True,
-        # )
+        self.log(
+            "train/generator/loss_vq",
+            vq_loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
 
 
         optim_g.zero_grad()
         optim_g.zero_grad()
         self.manual_backward(loss_gen_all)
         self.manual_backward(loss_gen_all)

+ 79 - 27
fish_speech/models/vqgan/modules/encoders.py

@@ -1,8 +1,8 @@
-import math
-from dataclasses import dataclass
+from typing import Optional
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
+from vector_quantize_pytorch import VectorQuantize
 
 
 from fish_speech.models.vqgan.modules.modules import WN
 from fish_speech.models.vqgan.modules.modules import WN
 from fish_speech.models.vqgan.modules.transformer import RelativePositionTransformer
 from fish_speech.models.vqgan.modules.transformer import RelativePositionTransformer
@@ -158,7 +158,6 @@ class SpeakerEncoder(nn.Module):
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
 
 
-        self.query = nn.Parameter(torch.randn(1, 1, hidden_channels))
         self.in_proj = nn.Sequential(
         self.in_proj = nn.Sequential(
             nn.Conv1d(in_channels, hidden_channels, 1),
             nn.Conv1d(in_channels, hidden_channels, 1),
             nn.SiLU(),
             nn.SiLU(),
@@ -168,17 +167,16 @@ class SpeakerEncoder(nn.Module):
             nn.SiLU(),
             nn.SiLU(),
             nn.Dropout(p_dropout),
             nn.Dropout(p_dropout),
         )
         )
-
-        self.blocks = nn.ModuleList(
-            [
-                nn.MultiheadAttention(
-                    embed_dim=hidden_channels,
-                    num_heads=num_heads,
-                    dropout=p_dropout,
-                    batch_first=True,
-                )
-                for _ in range(num_layers)
-            ]
+        self.encoder = RelativePositionTransformer(
+            in_channels=hidden_channels,
+            out_channels=hidden_channels,
+            hidden_channels=hidden_channels,
+            hidden_channels_ffn=hidden_channels,
+            n_heads=num_heads,
+            n_layers=num_layers,
+            kernel_size=5,
+            dropout=p_dropout,
+            window_size=4,
         )
         )
         self.out_proj = nn.Linear(hidden_channels, out_channels)
         self.out_proj = nn.Linear(hidden_channels, out_channels)
 
 
@@ -189,22 +187,76 @@ class SpeakerEncoder(nn.Module):
             - x_lengths: :math:`[B, 1]`
             - x_lengths: :math:`[B, 1]`
         """
         """
 
 
-        x_mask = ~(sequence_mask(mel_lengths, mels.size(2)).bool())
+        x_mask = torch.unsqueeze(sequence_mask(mel_lengths, mels.size(2)), 1).to(
+            mels.dtype
+        )
+        x = self.in_proj(mels) * x_mask
+        x = self.encoder(x, x_mask)
+
+        # Avg Pooling
+        x = x * x_mask
+        x = torch.sum(x, dim=2) / torch.sum(x_mask, dim=2)
+        x = self.out_proj(x)[..., None]
+
+        return x
+
+
+class VQEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels: int = 1024,
+        vq_channels: int = 1024,
+        codebook_size: int = 2048,
+        downsample: int = 2,
+        kmeans_ckpt: Optional[str] = None,
+    ):
+        super().__init__()
+
+        self.vq = VectorQuantize(
+            dim=vq_channels,
+            codebook_size=codebook_size,
+            threshold_ema_dead_code=2,
+            kmeans_init=False,
+            channel_last=False,
+        )
 
 
-        x = self.in_proj(mels).transpose(1, 2)
-        x = torch.cat([self.query.expand(x.shape[0], -1, -1), x], dim=1)
+        self.conv_in = nn.Conv1d(
+            in_channels, vq_channels, kernel_size=downsample, stride=downsample
+        )
+        self.conv_out = nn.Sequential(
+            nn.Upsample(scale_factor=downsample, mode="nearest"),
+            nn.Conv1d(vq_channels, in_channels, kernel_size=1, stride=1),
+        )
+
+        if kmeans_ckpt is not None:
+            self.init_weights(kmeans_ckpt)
 
 
-        x_mask = torch.cat(
-            [
-                torch.zeros(x.shape[0], 1, dtype=torch.bool, device=x.device),
-                x_mask,
-            ],
-            dim=1,
+    def init_weights(self, kmeans_ckpt):
+        torch.nn.init.normal_(
+            self.conv_in.weight,
+            mean=1 / (self.conv_in.weight.shape[0] * self.conv_in.weight.shape[-1]),
+            std=1e-2,
         )
         )
+        self.conv_in.bias.data.zero_()
 
 
-        for block in self.blocks:
-            x = block(x, x, x, key_padding_mask=x_mask)[0]
+        kmeans_ckpt = "results/hubert-vq-pretrain/kmeans.pt"
+        kmeans_ckpt = torch.load(kmeans_ckpt, map_location="cpu")
 
 
-        x = self.out_proj(x[:, :1, :]).mT
+        centroids = kmeans_ckpt["centroids"]
+        bins = kmeans_ckpt["bins"]
+        state_dict = {
+            "_codebook.initted": torch.Tensor([True]),
+            "_codebook.cluster_size": bins,
+            "_codebook.embed": centroids,
+            "_codebook.embed_avg": centroids.clone(),
+        }
 
 
-        return x
+        self.vq.load_state_dict(state_dict, strict=True)
+
+    def forward(self, x):
+        # x: [B, T, C]
+        x = self.conv_in(x.mT)
+        q, _, loss = self.vq(x)
+        x = self.conv_out(q).mT
+
+        return x, loss

+ 104 - 1
fish_speech/models/vqgan/modules/flow.py

@@ -1,11 +1,114 @@
 import torch
 import torch
 from torch import nn
 from torch import nn
 
 
-from fish_speech.models.vqgan.modules.modules import Flip
+from fish_speech.models.vqgan.modules.modules import WN, Flip
 from fish_speech.models.vqgan.modules.normalization import LayerNorm
 from fish_speech.models.vqgan.modules.normalization import LayerNorm
 from fish_speech.models.vqgan.modules.transformer import FFN, MultiHeadAttention
 from fish_speech.models.vqgan.modules.transformer import FFN, MultiHeadAttention
 
 
 
 
+class ResidualCouplingBlock(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        n_flows=4,
+        gin_channels=0,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.n_flows = n_flows
+        self.gin_channels = gin_channels
+
+        self.flows = nn.ModuleList()
+
+        for i in range(n_flows):
+            self.flows.append(
+                ResidualCouplingLayer(
+                    channels,
+                    hidden_channels,
+                    kernel_size,
+                    dilation_rate,
+                    n_layers,
+                    gin_channels=gin_channels,
+                    mean_only=True,
+                )
+            )
+            self.flows.append(Flip())
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        if not reverse:
+            for flow in self.flows:
+                x, _ = flow(x, x_mask, g=g, reverse=reverse)
+        else:
+            for flow in reversed(self.flows):
+                x = flow(x, x_mask, g=g, reverse=reverse)
+        return x
+
+
+class ResidualCouplingLayer(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        p_dropout=0,
+        gin_channels=0,
+        mean_only=False,
+    ):
+        assert channels % 2 == 0, "channels should be divisible by 2"
+        super().__init__()
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.half_channels = channels // 2
+        self.mean_only = mean_only
+
+        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+        self.enc = WN(
+            hidden_channels,
+            kernel_size,
+            dilation_rate,
+            n_layers,
+            p_dropout=p_dropout,
+            gin_channels=gin_channels,
+        )
+        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+        self.post.weight.data.zero_()
+        self.post.bias.data.zero_()
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+        h = self.pre(x0) * x_mask
+        h = self.enc(h, x_mask, g=g)
+        stats = self.post(h) * x_mask
+        if not self.mean_only:
+            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+        else:
+            m = stats
+            logs = torch.zeros_like(m)
+
+        if not reverse:
+            x1 = m + x1 * torch.exp(logs) * x_mask
+            x = torch.cat([x0, x1], 1)
+            logdet = torch.sum(logs, [1, 2])
+            return x, logdet
+        else:
+            x1 = (x1 - m) * torch.exp(-logs) * x_mask
+            x = torch.cat([x0, x1], 1)
+            return x
+
+
 class TransformerCouplingBlock(nn.Module):
 class TransformerCouplingBlock(nn.Module):
     def __init__(
     def __init__(
         self,
         self,

+ 16 - 6
fish_speech/models/vqgan/modules/models.py

@@ -6,8 +6,9 @@ from fish_speech.models.vqgan.modules.encoders import (
     PosteriorEncoder,
     PosteriorEncoder,
     SpeakerEncoder,
     SpeakerEncoder,
     TextEncoder,
     TextEncoder,
+    VQEncoder,
 )
 )
-from fish_speech.models.vqgan.modules.flow import TransformerCouplingBlock
+from fish_speech.models.vqgan.modules.flow import ResidualCouplingBlock
 from fish_speech.models.vqgan.utils import rand_slice_segments
 from fish_speech.models.vqgan.utils import rand_slice_segments
 
 
 
 
@@ -41,11 +42,19 @@ class SynthesizerTrn(nn.Module):
         upsample_initial_channel,
         upsample_initial_channel,
         upsample_kernel_sizes,
         upsample_kernel_sizes,
         gin_channels,
         gin_channels,
+        codebook_size,
+        kmeans_ckpt=None,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
         self.segment_size = segment_size
         self.segment_size = segment_size
 
 
+        self.vq = VQEncoder(
+            in_channels=in_channels,
+            vq_channels=in_channels,
+            codebook_size=codebook_size,
+            kmeans_ckpt=kmeans_ckpt,
+        )
         self.enc_p = TextEncoder(
         self.enc_p = TextEncoder(
             in_channels,
             in_channels,
             inter_channels,
             inter_channels,
@@ -66,14 +75,12 @@ class SynthesizerTrn(nn.Module):
             num_layers=n_layers_spk,
             num_layers=n_layers_spk,
             p_dropout=p_dropout,
             p_dropout=p_dropout,
         )
         )
-        self.flow = TransformerCouplingBlock(
+        self.flow = ResidualCouplingBlock(
             channels=inter_channels,
             channels=inter_channels,
             hidden_channels=hidden_channels,
             hidden_channels=hidden_channels,
-            filter_channels=filter_channels,
-            n_heads=n_heads,
-            n_layers=n_layers_flow,
             kernel_size=5,
             kernel_size=5,
-            p_dropout=p_dropout,
+            dilation_rate=1,
+            n_layers=n_layers_flow,
             n_flows=n_flows,
             n_flows=n_flows,
             gin_channels=gin_channels,
             gin_channels=gin_channels,
         )
         )
@@ -99,6 +106,7 @@ class SynthesizerTrn(nn.Module):
 
 
     def forward(self, x, x_lengths, y):
     def forward(self, x, x_lengths, y):
         g = self.enc_spk(y, x_lengths)
         g = self.enc_spk(y, x_lengths)
+        x, vq_loss = self.vq(x)
 
 
         _, m_p, logs_p, _, x_mask = self.enc_p(x, x_lengths, g=g)
         _, m_p, logs_p, _, x_mask = self.enc_p(x, x_lengths, g=g)
         z_q, m_q, logs_q, y_mask = self.enc_q(y, x_lengths, g=g)
         z_q, m_q, logs_q, y_mask = self.enc_q(y, x_lengths, g=g)
@@ -115,10 +123,12 @@ class SynthesizerTrn(nn.Module):
             (z_q, z_p),
             (z_q, z_p),
             (m_p, logs_p),
             (m_p, logs_p),
             (m_q, logs_q),
             (m_q, logs_q),
+            vq_loss,
         )
         )
 
 
     def infer(self, x, x_lengths, y, max_len=None, noise_scale=0.35):
     def infer(self, x, x_lengths, y, max_len=None, noise_scale=0.35):
         g = self.enc_spk(y, x_lengths)
         g = self.enc_spk(y, x_lengths)
+        x, vq_loss = self.vq(x)
         z_p, m_p, logs_p, h_text, x_mask = self.enc_p(
         z_p, m_p, logs_p, h_text, x_mask = self.enc_p(
             x, x_lengths, g=g, noise_scale=noise_scale
             x, x_lengths, g=g, noise_scale=noise_scale
         )
         )