Ver código fonte

Support flow & wavenet

Lengyue 2 anos atrás
pai
commit
a81036d28c

+ 36 - 10
fish_speech/configs/vqgan_pretrain.yaml

@@ -2,18 +2,32 @@ defaults:
   - base
   - _self_
 
-project: vq_reflow_debug
+project: vq_reflow_wavenet_group_fsq
+ckpt_path: results/vq_reflow_bf16/checkpoints/step_000248000.ckpt
+resume_weights_only: true
 
 # Lightning Trainer
 trainer:
   accelerator: gpu
   devices: auto
-  strategy: ddp_find_unused_parameters_true
-  precision: 16-mixed
+  precision: 32
   max_steps: 1_000_000
+  # max_steps: 100
   val_check_interval: 2000
   gradient_clip_algorithm: norm
   gradient_clip_val: 1.0
+  # limit_val_batches: 0.0
+
+  strategy: ddp #_find_unused_parameters_true
+  # strategy:
+  #   _target_: lightning.pytorch.strategies.DeepSpeedStrategy
+  #   stage: 1
+  #   overlap_comm: true
+
+  # profiler:
+  #   _target_: lightning.pytorch.profilers.PyTorchProfiler
+  #   export_to_chrome: true
+  #   filename: prof.txt
 
 sample_rate: 44100
 hop_length: 512
@@ -61,7 +75,8 @@ model:
   quantizer:
     _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
     input_dim: 512
-    n_codebooks: 8
+    n_codebooks: 1
+    n_groups: 8
     levels: [8, 5, 5, 5]
   
   aux_decoder:
@@ -71,13 +86,20 @@ model:
     depths: [6]
     dims: [384]
 
+  # reflow:
+  #   _target_: fish_speech.models.vqgan.modules.dit.DiT
+  #   hidden_size: 768
+  #   num_heads: 12
+  #   diffusion_num_layers: 12
+  #   channels: ${num_mels}
+  #   condition_dim: 512
+
   reflow:
-    _target_: fish_speech.models.vqgan.modules.dit.DiT
-    hidden_size: 768
-    num_heads: 12
-    diffusion_num_layers: 12
-    channels: ${num_mels}
-    condition_dim: 512
+    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
+    mel_channels: ${num_mels}
+    d_encoder: 512
+    residual_channels: 512
+    residual_layers: 20
 
   vocoder:
     _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
@@ -97,6 +119,7 @@ model:
     lr: 1e-4
     betas: [0.8, 0.99]
     eps: 1e-5
+    weight_decay: 0.01
 
   lr_scheduler:
     _target_: torch.optim.lr_scheduler.LambdaLR
@@ -115,3 +138,6 @@ callbacks:
       - aux_decoder
       - quantizer
       - reflow
+
+  model_checkpoint:
+    every_n_train_steps: ${trainer.val_check_interval}

+ 17 - 9
fish_speech/models/vqgan/lit_module.py

@@ -69,6 +69,7 @@ class VQGAN(L.LightningModule):
         self.spec_min = -12
         self.spec_max = 3
         self.sampling_rate = sampling_rate
+        self.strict_loading = False
 
     def on_save_checkpoint(self, checkpoint):
         # Do not save vocoder
@@ -96,6 +97,7 @@ class VQGAN(L.LightningModule):
     def denorm_spec(self, x):
         return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
 
+    # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
     def training_step(self, batch, batch_idx):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
 
@@ -124,7 +126,7 @@ class VQGAN(L.LightningModule):
         )
 
         # Reflow
-        x_1 = self.norm_spec(gt_mels.mT)
+        x_1 = self.norm_spec(gt_mels)
         t = torch.rand(gt_mels.shape[0], device=gt_mels.device)
         x_0 = torch.randn_like(x_1)
 
@@ -134,8 +136,9 @@ class VQGAN(L.LightningModule):
         v_pred = self.reflow(
             x_t,
             1000 * t,
-            condition=vq_recon_features.mT,
-            self_mask=mel_masks,
+            vq_recon_features,  # .detach()
+            x_masks=mel_masks_float_conv,
+            cond_masks=mel_masks_float_conv,
         )
 
         # Log L2 loss with
@@ -143,7 +146,7 @@ class VQGAN(L.LightningModule):
         loss_reflow = weights[:, None, None] * F.mse_loss(
             x_1 - x_0, v_pred, reduction="none"
         )
-        loss_reflow = (loss_reflow * mel_masks_float_conv.mT).mean()
+        loss_reflow = (loss_reflow * mel_masks_float_conv).mean()
 
         # Total loss
         loss = (
@@ -218,8 +221,12 @@ class VQGAN(L.LightningModule):
 
         # Reflow inference
         t_start = 0.0
-        infer_step = 20
-        gen_mels = torch.randn(gt_mels.shape, device=gt_mels.device).mT
+        infer_step = 10
+
+        x_1 = self.norm_spec(aux_mels)
+        x_0 = torch.randn_like(x_1)
+        gen_mels = (1 - t_start) * x_0 + t_start * x_1
+
         t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
         dt = (1.0 - t_start) / infer_step
 
@@ -228,14 +235,15 @@ class VQGAN(L.LightningModule):
                 self.reflow(
                     gen_mels,
                     1000 * t,
-                    condition=vq_result.z.mT,
-                    self_mask=mel_masks,
+                    vq_result.z,
+                    x_masks=mel_masks_float_conv,
+                    cond_masks=mel_masks_float_conv,
                 )
                 * dt
             )
             t += dt
 
-        gen_mels = self.denorm_spec(gen_mels).mT
+        gen_mels = self.denorm_spec(gen_mels)
         loss_recon_reflow = F.l1_loss(
             gen_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
         )

+ 4 - 2
fish_speech/models/vqgan/modules/fsq.py

@@ -7,7 +7,7 @@ import torch.nn as nn
 import torch.nn.functional as F
 from einops import rearrange
 from torch.nn.utils import weight_norm
-from vector_quantize_pytorch import ResidualFSQ
+from vector_quantize_pytorch import GroupedResidualFSQ
 
 from .convnext import ConvNeXtBlock
 
@@ -24,6 +24,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
         self,
         input_dim: int = 512,
         n_codebooks: int = 9,
+        n_groups: int = 1,
         levels: tuple[int] = (8, 5, 5, 5),  # Approximate 2**10
         downsample_factor: tuple[int] = (2, 2),
         downsample_dims: tuple[int] | None = None,
@@ -35,10 +36,11 @@ class DownsampleFiniteScalarQuantize(nn.Module):
 
         all_dims = (input_dim,) + tuple(downsample_dims)
 
-        self.residual_fsq = ResidualFSQ(
+        self.residual_fsq = GroupedResidualFSQ(
             dim=all_dims[-1],
             levels=levels,
             num_quantizers=n_codebooks,
+            groups=n_groups,
         )
 
         self.downsample_factor = downsample_factor

+ 236 - 0
fish_speech/models/vqgan/modules/wavenet.py

@@ -0,0 +1,236 @@
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Mish(nn.Module):
+    def forward(self, x):
+        return x * torch.tanh(F.softplus(x))
+
+
+class DiffusionEmbedding(nn.Module):
+    """Diffusion Step Embedding"""
+
+    def __init__(self, d_denoiser):
+        super(DiffusionEmbedding, self).__init__()
+        self.dim = d_denoiser
+
+    def forward(self, x):
+        device = x.device
+        half_dim = self.dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+        emb = x[:, None] * emb[None, :]
+        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+        return emb
+
+
+class LinearNorm(nn.Module):
+    """LinearNorm Projection"""
+
+    def __init__(self, in_features, out_features, bias=False):
+        super(LinearNorm, self).__init__()
+        self.linear = nn.Linear(in_features, out_features, bias)
+
+        nn.init.xavier_uniform_(self.linear.weight)
+        if bias:
+            nn.init.constant_(self.linear.bias, 0.0)
+
+    def forward(self, x):
+        x = self.linear(x)
+        return x
+
+
+class ConvNorm(nn.Module):
+    """1D Convolution"""
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size=1,
+        stride=1,
+        padding=None,
+        dilation=1,
+        bias=True,
+        w_init_gain="linear",
+    ):
+        super(ConvNorm, self).__init__()
+
+        if padding is None:
+            assert kernel_size % 2 == 1
+            padding = int(dilation * (kernel_size - 1) / 2)
+
+        self.conv = nn.Conv1d(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            bias=bias,
+        )
+        nn.init.kaiming_normal_(self.conv.weight)
+
+    def forward(self, signal):
+        conv_signal = self.conv(signal)
+
+        return conv_signal
+
+
+class ResidualBlock(nn.Module):
+    """Residual Block"""
+
+    def __init__(self, d_encoder, residual_channels, use_linear_bias=False, dilation=1):
+        super(ResidualBlock, self).__init__()
+        self.conv_layer = ConvNorm(
+            residual_channels,
+            2 * residual_channels,
+            kernel_size=3,
+            stride=1,
+            padding=dilation,
+            dilation=dilation,
+        )
+        self.diffusion_projection = LinearNorm(
+            residual_channels, residual_channels, use_linear_bias
+        )
+        self.conditioner_projection = ConvNorm(
+            d_encoder, 2 * residual_channels, kernel_size=1
+        )
+        self.output_projection = ConvNorm(
+            residual_channels, 2 * residual_channels, kernel_size=1
+        )
+
+    def forward(self, x, conditioner, diffusion_step):
+        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
+        conditioner = self.conditioner_projection(conditioner)
+
+        y = x + diffusion_step
+
+        y = self.conv_layer(y) + conditioner
+
+        gate, filter = torch.chunk(y, 2, dim=1)
+        y = torch.sigmoid(gate) * torch.tanh(filter)
+
+        y = self.output_projection(y)
+        residual, skip = torch.chunk(y, 2, dim=1)
+
+        return (x + residual) / math.sqrt(2.0), skip
+
+
+class SpectrogramUpsampler(nn.Module):
+    def __init__(self, hop_size):
+        super().__init__()
+
+        if hop_size == 256:
+            self.conv1 = nn.ConvTranspose2d(
+                1, 1, [3, 32], stride=[1, 16], padding=[1, 8]
+            )
+        elif hop_size == 512:
+            self.conv1 = nn.ConvTranspose2d(
+                1, 1, [3, 64], stride=[1, 32], padding=[1, 16]
+            )
+        else:
+            raise ValueError(f"Unsupported hop_size: {hop_size}")
+
+        self.conv2 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
+
+    def forward(self, x):
+        x = torch.unsqueeze(x, 1)
+        x = self.conv1(x)
+        x = F.leaky_relu(x, 0.4)
+        x = self.conv2(x)
+        x = F.leaky_relu(x, 0.4)
+        x = torch.squeeze(x, 1)
+
+        return x
+
+
+class WaveNet(nn.Module):
+    """
+    WaveNet
+    https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio
+    """
+
+    def __init__(
+        self,
+        mel_channels=128,
+        d_encoder=256,
+        residual_channels=512,
+        residual_layers=20,
+        use_linear_bias=False,
+        dilation_cycle=None,
+    ):
+        super(WaveNet, self).__init__()
+
+        self.input_projection = ConvNorm(mel_channels, residual_channels, kernel_size=1)
+        self.diffusion_embedding = DiffusionEmbedding(residual_channels)
+        self.mlp = nn.Sequential(
+            LinearNorm(residual_channels, residual_channels * 4, use_linear_bias),
+            Mish(),
+            LinearNorm(residual_channels * 4, residual_channels, use_linear_bias),
+        )
+        self.residual_layers = nn.ModuleList(
+            [
+                ResidualBlock(
+                    d_encoder,
+                    residual_channels,
+                    use_linear_bias=use_linear_bias,
+                    dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
+                )
+                for i in range(residual_layers)
+            ]
+        )
+        self.skip_projection = ConvNorm(
+            residual_channels, residual_channels, kernel_size=1
+        )
+        self.output_projection = ConvNorm(
+            residual_channels, mel_channels, kernel_size=1
+        )
+        nn.init.zeros_(self.output_projection.conv.weight)
+
+    def forward(self, x, diffusion_step, conditioner, x_masks=None, cond_masks=None):
+        """
+
+        :param x: [B, M, T]
+        :param diffusion_step: [B,]
+        :param conditioner: [B, M, T]
+        :return:
+        """
+
+        # To keep compatibility with DiffSVC, [B, 1, M, T]
+        use_4_dim = False
+        if x.dim() == 4:
+            x = x[:, 0]
+            use_4_dim = True
+
+        assert x.dim() == 3, f"mel must be 3 dim tensor, but got {x.dim()}"
+
+        x = self.input_projection(x)  # x [B, residual_channel, T]
+        x = F.relu(x)
+
+        diffusion_step = self.diffusion_embedding(diffusion_step)
+        diffusion_step = self.mlp(diffusion_step)
+
+        if x_masks is not None:
+            x = x * x_masks
+
+        if cond_masks is not None:
+            conditioner = conditioner * cond_masks
+
+        skip = []
+        for layer in self.residual_layers:
+            x, skip_connection = layer(x, conditioner, diffusion_step)
+            skip.append(skip_connection)
+
+        x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
+        x = self.skip_projection(x)
+        x = F.relu(x)
+        x = self.output_projection(x)  # [B, 128, T]
+
+        if x_masks is not None:
+            x = x * x_masks
+
+        return x[:, None] if use_4_dim else x