فهرست منبع

Add wavenet & vq diffusion config

Lengyue 2 سال پیش
والد
کامیت
e14317713d

+ 1 - 1
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -67,7 +67,7 @@ model:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
     in_channels: 128
     vq_channels: 128
-    codebook_size: 16384
+    codebook_size: 4096
     downsample: 1
 
   speaker_encoder:

+ 134 - 0
fish_speech/configs/vq_diffusion.yaml

@@ -0,0 +1,134 @@
+defaults:
+  - base
+  - _self_
+
+project: vq_diffusion
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: 4
+  strategy: ddp_find_unused_parameters_true
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
+  precision: 16-mixed
+  max_steps: 1_000_000
+  val_check_interval: 5000
+
+sample_rate: 24000
+hop_length: 256
+num_mels: 100
+n_fft: 1024
+win_length: 1024
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.split.train
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: 512
+
+val_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.split.valid
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+data:
+  _target_: fish_speech.datasets.vqgan.VQGANDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 8
+  batch_size: 32
+  val_batch_size: 4
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vq_diffusion.lit_module.VQDiffusion
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  speaker_use_feats: true
+
+  text_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
+    in_channels: 128
+    out_channels: 128
+    hidden_channels: 192
+    hidden_channels_ffn: 768
+    n_heads: 2
+    n_layers: 6
+    kernel_size: 1
+    dropout: 0.1
+    use_vae: false
+    gin_channels: 512
+    speaker_cond_layer: 0
+
+  vq_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
+    in_channels: 128
+    vq_channels: 128
+    codebook_size: 4096
+    downsample: 1
+
+  speaker_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
+    in_channels: 128
+    hidden_channels: 192
+    out_channels: 128
+    num_heads: 2
+    num_layers: 4
+    p_dropout: 0.1
+  
+  denoiser:
+    _target_: fish_speech.models.vq_diffusion.wavenet.WaveNet
+    d_encoder: 128
+    mel_channels: 100
+    residual_channels: 384
+    residual_layers: 20
+
+  vocoder:
+    _target_: fish_speech.models.vq_diffusion.bigvgan.BigVGAN
+
+  mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+    f_min: 0
+    f_max: 12000
+
+  feature_mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: 32000
+    n_fft: 2048
+    hop_length: 640
+    win_length: 2048
+    n_mels: 128
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 1e-4
+    betas: [0.9, 0.999]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 0
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.05
+
+callbacks:
+  grad_norm_monitor:
+    sub_module: 
+      - vq_encoder
+      - text_encoder
+      - speaker_encoder
+      - denoiser

+ 11 - 2
fish_speech/models/vq_diffusion/lit_module.py

@@ -36,6 +36,7 @@ class VQDiffusion(L.LightningModule):
         vocoder: nn.Module,
         hop_length: int = 640,
         sample_rate: int = 32000,
+        speaker_use_feats: bool = False,
     ):
         super().__init__()
 
@@ -58,6 +59,7 @@ class VQDiffusion(L.LightningModule):
         self.vocoder = vocoder
         self.hop_length = hop_length
         self.sampling_rate = sample_rate
+        self.speaker_use_feats = speaker_use_feats
 
         # Freeze vocoder
         for param in self.vocoder.parameters():
@@ -107,7 +109,11 @@ class VQDiffusion(L.LightningModule):
             gt_mels.dtype
         )
 
-        speaker_features = self.speaker_encoder(gt_mels, mel_masks)
+        if self.speaker_use_feats:
+            speaker_features = self.speaker_encoder(features, feature_masks)
+        else:
+            speaker_features = self.speaker_encoder(gt_mels, mel_masks)
+
         # vq_features is 50 hz, need to convert to true mel size
         text_features = self.text_encoder(features, feature_masks)
         text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
@@ -184,7 +190,10 @@ class VQDiffusion(L.LightningModule):
             gt_mels.dtype
         )
 
-        speaker_features = self.speaker_encoder(gt_mels, mel_masks)
+        if self.speaker_use_feats:
+            speaker_features = self.speaker_encoder(features, feature_masks)
+        else:
+            speaker_features = self.speaker_encoder(gt_mels, mel_masks)
 
         # vq_features is 50 hz, need to convert to true mel size
         text_features = self.text_encoder(features, feature_masks)

+ 227 - 0
fish_speech/models/vq_diffusion/wavenet.py

@@ -0,0 +1,227 @@
+import math
+from typing import Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Mish(nn.Module):
+    def forward(self, x):
+        return x * torch.tanh(F.softplus(x))
+
+
+class DiffusionEmbedding(nn.Module):
+    """Diffusion Step Embedding"""
+
+    def __init__(self, d_denoiser):
+        super(DiffusionEmbedding, self).__init__()
+        self.dim = d_denoiser
+
+    def forward(self, x):
+        device = x.device
+        half_dim = self.dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+        emb = x[:, None] * emb[None, :]
+        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+        return emb
+
+
+class LinearNorm(nn.Module):
+    """LinearNorm Projection"""
+
+    def __init__(self, in_features, out_features, bias=False):
+        super(LinearNorm, self).__init__()
+        self.linear = nn.Linear(in_features, out_features, bias)
+
+        nn.init.xavier_uniform_(self.linear.weight)
+        if bias:
+            nn.init.constant_(self.linear.bias, 0.0)
+
+    def forward(self, x):
+        x = self.linear(x)
+        return x
+
+
+class ConvNorm(nn.Module):
+    """1D Convolution"""
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size=1,
+        stride=1,
+        padding=None,
+        dilation=1,
+        bias=True,
+        w_init_gain="linear",
+    ):
+        super(ConvNorm, self).__init__()
+
+        if padding is None:
+            assert kernel_size % 2 == 1
+            padding = int(dilation * (kernel_size - 1) / 2)
+
+        self.conv = nn.Conv1d(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            bias=bias,
+        )
+        nn.init.kaiming_normal_(self.conv.weight)
+
+    def forward(self, signal):
+        conv_signal = self.conv(signal)
+
+        return conv_signal
+
+
+class ResidualBlock(nn.Module):
+    """Residual Block"""
+
+    def __init__(self, d_encoder, residual_channels, use_linear_bias=False, dilation=1):
+        super(ResidualBlock, self).__init__()
+        self.conv_layer = ConvNorm(
+            residual_channels,
+            2 * residual_channels,
+            kernel_size=3,
+            stride=1,
+            padding=dilation,
+            dilation=dilation,
+        )
+        self.diffusion_projection = LinearNorm(
+            residual_channels, residual_channels, use_linear_bias
+        )
+        self.condition_projection = ConvNorm(
+            d_encoder, 2 * residual_channels, kernel_size=1
+        )
+        self.output_projection = ConvNorm(
+            residual_channels, 2 * residual_channels, kernel_size=1
+        )
+
+    def forward(self, x, conditioner, diffusion_step):
+        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
+        conditioner = self.condition_projection(conditioner)
+
+        y = x + diffusion_step
+
+        y = self.conv_layer(y) + conditioner
+
+        gate, filter = torch.chunk(y, 2, dim=1)
+        y = torch.sigmoid(gate) * torch.tanh(filter)
+
+        y = self.output_projection(y)
+        residual, skip = torch.chunk(y, 2, dim=1)
+
+        return (x + residual) / math.sqrt(2.0), skip
+
+
+class SpectrogramUpsampler(nn.Module):
+    def __init__(self, hop_size):
+        super().__init__()
+
+        if hop_size == 256:
+            self.conv1 = nn.ConvTranspose2d(
+                1, 1, [3, 32], stride=[1, 16], padding=[1, 8]
+            )
+        elif hop_size == 512:
+            self.conv1 = nn.ConvTranspose2d(
+                1, 1, [3, 64], stride=[1, 32], padding=[1, 16]
+            )
+        else:
+            raise ValueError(f"Unsupported hop_size: {hop_size}")
+
+        self.conv2 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
+
+    def forward(self, x):
+        x = torch.unsqueeze(x, 1)
+        x = self.conv1(x)
+        x = F.leaky_relu(x, 0.4)
+        x = self.conv2(x)
+        x = F.leaky_relu(x, 0.4)
+        x = torch.squeeze(x, 1)
+
+        return x
+
+
+class WaveNet(nn.Module):
+    """
+    WaveNet
+    https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio
+    """
+
+    def __init__(
+        self,
+        mel_channels=128,
+        d_encoder=256,
+        residual_channels=512,
+        residual_layers=20,
+        use_linear_bias=False,
+        dilation_cycle=None,
+    ):
+        super(WaveNet, self).__init__()
+
+        self.input_projection = ConvNorm(mel_channels, residual_channels, kernel_size=1)
+        self.diffusion_embedding = DiffusionEmbedding(residual_channels)
+        self.mlp = nn.Sequential(
+            LinearNorm(residual_channels, residual_channels * 4, use_linear_bias),
+            Mish(),
+            LinearNorm(residual_channels * 4, residual_channels, use_linear_bias),
+        )
+        self.residual_layers = nn.ModuleList(
+            [
+                ResidualBlock(
+                    d_encoder,
+                    residual_channels,
+                    use_linear_bias=use_linear_bias,
+                    dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
+                )
+                for i in range(residual_layers)
+            ]
+        )
+        self.skip_projection = ConvNorm(
+            residual_channels, residual_channels, kernel_size=1
+        )
+        self.output_projection = ConvNorm(
+            residual_channels, mel_channels, kernel_size=1
+        )
+        nn.init.zeros_(self.output_projection.conv.weight)
+
+    def forward(
+        self,
+        sample: torch.FloatTensor,
+        timestep: Union[torch.Tensor, float, int],
+        sample_mask: Optional[torch.Tensor] = None,
+        condition: Optional[torch.Tensor] = None,
+    ):
+        x = self.input_projection(sample)  # x [B, residual_channel, T]
+        x = F.relu(x)
+
+        diffusion_step = self.diffusion_embedding(timestep)
+        diffusion_step = self.mlp(diffusion_step)
+
+        if sample_mask is not None:
+            if sample_mask.ndim == 2:
+                sample_mask = sample_mask[:, None, :]
+
+            x = x * sample_mask
+
+        skip = []
+        for layer in self.residual_layers:
+            x, skip_connection = layer(x, condition, diffusion_step)
+            skip.append(skip_connection)
+
+        x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
+        x = self.skip_projection(x)
+        x = F.relu(x)
+        x = self.output_projection(x)  # [B, 128, T]
+
+        if sample_mask is not None:
+            x = x * sample_mask
+
+        return x

+ 6 - 6
fish_speech/models/vqgan/modules/encoders.py

@@ -234,12 +234,12 @@ class VQEncoder(nn.Module):
     ):
         super().__init__()
 
-        self.vq = LFQ(
+        self.vq = VectorQuantize(
             dim=vq_channels,
             codebook_size=codebook_size,
-            # threshold_ema_dead_code=2,
-            # kmeans_init=False,
-            # channel_last=False,
+            threshold_ema_dead_code=2,
+            kmeans_init=False,
+            channel_last=False,
         )
         self.downsample = downsample
         self.conv_in = nn.Conv1d(
@@ -286,8 +286,8 @@ class VQEncoder(nn.Module):
             x_mask = F.pad(x_mask, (0, self.downsample - x_len % self.downsample))
 
         x = self.conv_in(x)
-        q, _, loss = self.vq(x.mT)
-        x = self.conv_out(q.mT) * x_mask
+        q, _, loss = self.vq(x)
+        x = self.conv_out(q) * x_mask
         x = x[:, :, :x_len]
 
         return x, loss