ソースを参照

Better convnext condition

Lengyue 2 年 前
コミット
9fe898fd2f

+ 17 - 25
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -7,7 +7,7 @@ project: hubert_vq_diffusion
 # Lightning Trainer
 trainer:
   accelerator: gpu
-  devices: 1
+  devices: 4
   strategy:
     _target_: lightning.pytorch.strategies.DDPStrategy
     static_graph: true
@@ -41,7 +41,7 @@ data:
   _target_: fish_speech.datasets.vqgan.VQGANDataModule
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
-  num_workers: 0 #16
+  num_workers: 8
   batch_size: 32
   val_batch_size: 4
 
@@ -82,32 +82,16 @@ model:
     num_layers: 4
     p_dropout: 0.1
   
-  # denoiser:
-  #   _target_: fish_speech.models.vq_diffusion.convnext_1d.ConvNext1DModel
-  #   in_channels: 256
-  #   out_channels: 128
-  #   intermediate_dim: 512
-  #   mlp_dim: 2048
-  #   num_layers: 20
-  #   dilation_cycle_length: 2
-  #   time_embedding_type: "positional"
-
   denoiser:
-    _target_: fish_speech.models.vq_diffusion.wavenet.WaveNet
+    _target_: fish_speech.models.vq_diffusion.convnext_1d.ConvNext1DModel
     in_channels: 128
     out_channels: 128
-    d_encoder: 128
-    residual_channels: 512
-    residual_layers: 20
-    use_linear_bias: false
-    dilation_cycle: 2
-
-  # denoiser:
-  #   _target_: fish_speech.models.vq_diffusion.unet1d.Unet1DDenoiser
-  #   dim: 64
-  #   dim_mults: [1, 2, 4]
-  #   groups: 8
-  #   pe_scale: 1000
+    intermediate_dim: 512
+    condition_dim: 128
+    mlp_dim: 2048
+    num_layers: 20
+    dilation_cycle_length: 2
+    time_embedding_type: "positional"
 
   vocoder:
     _target_: fish_speech.models.vq_diffusion.adamos.ADaMoSHiFiGANV1
@@ -138,3 +122,11 @@ model:
       num_warmup_steps: 0
       num_training_steps: ${trainer.max_steps}
       final_lr_ratio: 0.05
+
+callbacks:
+  grad_norm_monitor:
+    sub_module: 
+      - vq_encoder
+      - text_encoder
+      - speaker_encoder
+      - denoiser

+ 3 - 5
fish_speech/models/vq_diffusion/convnext_1d.py

@@ -129,6 +129,7 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
         num_layers: int = 20,
         dilation_cycle_length: int = 4,
         time_embedding_type: str = "positional",
+        condition_dim: Optional[int] = None,
     ):
         super().__init__()
 
@@ -156,7 +157,7 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
             timestep_input_dim,
             intermediate_dim,
             act_fn="silu",
-            cond_proj_dim=None,  # No conditional projection for now
+            cond_proj_dim=condition_dim,
         )
 
         # Project to intermediate dim
@@ -218,12 +219,9 @@ class ConvNext1DModel(ModelMixin, ConfigMixin):
 
         # 1. time
         t_emb = self.time_proj(timestep)
-        t_emb = self.time_mlp(t_emb)[..., None]
+        t_emb = self.time_mlp(sample=t_emb[:, None], condition=condition.mT).mT
 
         # 2. pre-process
-        if condition is not None:
-            sample = torch.cat([sample, condition], dim=1)
-
         x = self.in_proj(sample)
 
         if sample_mask.ndim == 2:

+ 2 - 2
fish_speech/models/vq_diffusion/lit_module.py

@@ -11,7 +11,7 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from torch import nn
 
-from fish_speech.models.vq_diffusion.unet1d import Unet1DDenoiser
+from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
 from fish_speech.models.vqgan.modules.encoders import (
     SpeakerEncoder,
     TextEncoder,
@@ -29,7 +29,7 @@ class VQDiffusion(L.LightningModule):
         vq_encoder: VQEncoder,
         speaker_encoder: SpeakerEncoder,
         text_encoder: TextEncoder,
-        denoiser: Unet1DDenoiser,
+        denoiser: ConvNext1DModel,
         vocoder: nn.Module,
         hop_length: int = 640,
         sample_rate: int = 32000,

+ 0 - 198
fish_speech/models/vq_diffusion/unet1d.py

@@ -1,198 +0,0 @@
-# Refer to https://github.com/huawei-noah/Speech-Backbones/blob/main/Grad-TTS/model/diffusion.py
-
-import math
-
-import torch
-from einops import rearrange
-from torch import nn
-
-
-class Block(nn.Module):
-    def __init__(self, dim, dim_out, groups=8):
-        super().__init__()
-        self.block = nn.Sequential(
-            nn.Conv2d(dim, dim_out, 3, padding=1),
-            nn.GroupNorm(groups, dim_out),
-            nn.Mish(),
-        )
-
-    def forward(self, x, mask):
-        output = self.block(x * mask)
-        return output * mask
-
-
-class ResnetBlock(nn.Module):
-    def __init__(self, dim, dim_out, time_emb_dim, groups=8):
-        super().__init__()
-        self.mlp = nn.Sequential(nn.Mish(), nn.Linear(time_emb_dim, dim_out))
-
-        self.block1 = Block(dim, dim_out, groups=groups)
-        self.block2 = Block(dim_out, dim_out, groups=groups)
-        if dim != dim_out:
-            self.res_conv = nn.Conv2d(dim, dim_out, 1)
-        else:
-            self.res_conv = nn.Identity()
-
-    def forward(self, x, mask, time_emb):
-        h = self.block1(x, mask)
-        h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
-        h = self.block2(h, mask)
-        output = h + self.res_conv(x * mask)
-        return output
-
-
-class LinearAttention(nn.Module):
-    def __init__(self, dim, heads=4, dim_head=32, init_values=1e-5):
-        super().__init__()
-
-        self.heads = heads
-        hidden_dim = dim_head * heads
-
-        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
-        self.to_out = nn.Conv2d(hidden_dim, dim, 1)
-        self.gamma = nn.Parameter(torch.ones(dim) * init_values)
-
-    def forward(self, x):
-        b, c, h, w = x.shape
-        qkv = self.to_qkv(x)
-        q, k, v = rearrange(
-            qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
-        )
-        k = k.softmax(dim=-1)
-        context = torch.einsum("bhdn,bhen->bhde", k, v)
-        out = torch.einsum("bhde,bhdn->bhen", context, q)
-        out = rearrange(
-            out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
-        )
-        return self.to_out(out) * self.gamma.view(1, -1, 1, 1) + x
-
-
-class SinusoidalPosEmb(nn.Module):
-    def __init__(self, dim):
-        super().__init__()
-        self.dim = dim
-
-    def forward(self, x, scale=1000):
-        device = x.device
-        half_dim = self.dim // 2
-        emb = math.log(10000) / (half_dim - 1)
-        emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
-        emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
-        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
-        return emb
-
-
-class Unet1DDenoiser(nn.Module):
-    def __init__(
-        self,
-        dim,
-        dim_mults=(1, 2, 4),
-        groups=8,
-        pe_scale=1000,
-    ):
-        super().__init__()
-        self.dim = dim
-        self.dim_mults = dim_mults
-        self.groups = groups
-        self.pe_scale = pe_scale
-
-        self.time_pos_emb = SinusoidalPosEmb(dim)
-        self.mlp = nn.Sequential(
-            nn.Linear(dim, dim * 4), nn.Mish(), nn.Linear(dim * 4, dim)
-        )
-        self.downsample_rate = 2 ** (len(dim_mults) - 1)
-
-        dims = [2, *map(lambda m: dim * m, dim_mults)]
-        in_out = list(zip(dims[:-1], dims[1:]))
-        self.downs = nn.ModuleList([])
-        self.ups = nn.ModuleList([])
-        num_resolutions = len(in_out)
-
-        for ind, (dim_in, dim_out) in enumerate(in_out):
-            is_last = ind >= (num_resolutions - 1)
-            self.downs.append(
-                nn.ModuleList(
-                    [
-                        ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
-                        ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
-                        LinearAttention(dim_out),
-                        nn.Conv2d(dim_out, dim_out, 3, 2, 1)
-                        if not is_last
-                        else nn.Identity(),
-                    ]
-                )
-            )
-
-        mid_dim = dims[-1]
-        self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
-        self.mid_attn = LinearAttention(mid_dim)
-        self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
-
-        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
-            self.ups.append(
-                nn.ModuleList(
-                    [
-                        ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
-                        ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
-                        LinearAttention(dim_in),
-                        nn.ConvTranspose2d(dim_in, dim_in, 4, 2, 1),
-                    ]
-                )
-            )
-        self.final_block = Block(dim, dim)
-        self.final_conv = nn.Conv2d(dim, 1, 1)
-
-    def forward(self, x, t, mask, condition):
-        t = self.time_pos_emb(t, scale=self.pe_scale)
-        t = self.mlp(t)
-
-        x = torch.stack([condition, x], 1)
-        mask = mask.unsqueeze(1)
-
-        original_len = x.shape[3]
-        if x.shape[3] % self.downsample_rate != 0:
-            x = nn.functional.pad(
-                x, (0, self.downsample_rate - x.shape[3] % self.downsample_rate)
-            )
-            mask = nn.functional.pad(
-                mask, (0, self.downsample_rate - mask.shape[3] % self.downsample_rate)
-            )
-
-        hiddens = []
-        masks = [mask]
-        for resnet1, resnet2, attn, downsample in self.downs:
-            mask_down = masks[-1]
-            x = resnet1(x, mask_down, t)
-            x = resnet2(x, mask_down, t)
-            x = attn(x)
-            hiddens.append(x)
-            x = downsample(x * mask_down)
-            masks.append(mask_down[:, :, :, ::2])
-
-        masks = masks[:-1]
-        mask_mid = masks[-1]
-        x = self.mid_block1(x, mask_mid, t)
-        x = self.mid_attn(x)
-        x = self.mid_block2(x, mask_mid, t)
-
-        for resnet1, resnet2, attn, upsample in self.ups:
-            mask_up = masks.pop()
-            x = torch.cat((x, hiddens.pop()), dim=1)
-            x = resnet1(x, mask_up, t)
-            x = resnet2(x, mask_up, t)
-            x = attn(x)
-            x = upsample(x * mask_up)
-
-        x = self.final_block(x, mask)
-        output = self.final_conv(x * mask)
-
-        output = (output * mask).squeeze(1)
-        return output[:, :, :original_len]
-
-
-if __name__ == "__main__":
-    model = Unet1DDenoiser(128)
-    mel = torch.randn(1, 128, 99)
-    mask = torch.ones(1, 1, 99)
-
-    print(model(mel, mask, torch.tensor([10], dtype=torch.long), mel).shape)

+ 0 - 186
fish_speech/models/vq_diffusion/wavenet.py

@@ -1,186 +0,0 @@
-import math
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-
-class Mish(nn.Module):
-    def forward(self, x):
-        return x * torch.tanh(F.softplus(x))
-
-
-class DiffusionEmbedding(nn.Module):
-    """Diffusion Step Embedding"""
-
-    def __init__(self, d_denoiser):
-        super(DiffusionEmbedding, self).__init__()
-        self.dim = d_denoiser
-
-    def forward(self, x):
-        device = x.device
-        half_dim = self.dim // 2
-        emb = math.log(10000) / (half_dim - 1)
-        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
-        emb = x[:, None] * emb[None, :]
-        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
-        return emb
-
-
-class LinearNorm(nn.Module):
-    """LinearNorm Projection"""
-
-    def __init__(self, in_features, out_features, bias=False):
-        super(LinearNorm, self).__init__()
-        self.linear = nn.Linear(in_features, out_features, bias)
-
-        nn.init.xavier_uniform_(self.linear.weight)
-        if bias:
-            nn.init.constant_(self.linear.bias, 0.0)
-
-    def forward(self, x):
-        x = self.linear(x)
-        return x
-
-
-class ConvNorm(nn.Module):
-    """1D Convolution"""
-
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        kernel_size=1,
-        stride=1,
-        padding=None,
-        dilation=1,
-        bias=True,
-        w_init_gain="linear",
-    ):
-        super(ConvNorm, self).__init__()
-
-        if padding is None:
-            assert kernel_size % 2 == 1
-            padding = int(dilation * (kernel_size - 1) / 2)
-
-        self.conv = nn.Conv1d(
-            in_channels,
-            out_channels,
-            kernel_size=kernel_size,
-            stride=stride,
-            padding=padding,
-            dilation=dilation,
-            bias=bias,
-        )
-        nn.init.kaiming_normal_(self.conv.weight)
-
-    def forward(self, signal):
-        conv_signal = self.conv(signal)
-
-        return conv_signal
-
-
-class ResidualBlock(nn.Module):
-    """Residual Block"""
-
-    def __init__(self, d_encoder, residual_channels, use_linear_bias=False, dilation=1):
-        super(ResidualBlock, self).__init__()
-        self.conv_layer = ConvNorm(
-            residual_channels,
-            2 * residual_channels,
-            kernel_size=3,
-            stride=1,
-            padding=dilation,
-            dilation=dilation,
-        )
-        self.diffusion_projection = LinearNorm(
-            residual_channels, residual_channels, use_linear_bias
-        )
-        self.conditioner_projection = ConvNorm(
-            d_encoder, 2 * residual_channels, kernel_size=1
-        )
-        self.output_projection = ConvNorm(
-            residual_channels, 2 * residual_channels, kernel_size=1
-        )
-
-    def forward(self, x, conditioner, diffusion_step):
-        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
-        conditioner = self.conditioner_projection(conditioner)
-
-        y = x + diffusion_step
-
-        y = self.conv_layer(y) + conditioner
-
-        gate, filter = torch.chunk(y, 2, dim=1)
-        y = torch.sigmoid(gate) * torch.tanh(filter)
-
-        y = self.output_projection(y)
-        residual, skip = torch.chunk(y, 2, dim=1)
-
-        return (x + residual) / math.sqrt(2.0), skip
-
-
-class WaveNet(nn.Module):
-    """
-    WaveNet
-    https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio
-    """
-
-    def __init__(
-        self,
-        in_channels=128,
-        out_channels=128,
-        d_encoder=256,
-        residual_channels=512,
-        residual_layers=20,
-        use_linear_bias=False,
-        dilation_cycle=None,
-    ):
-        super(WaveNet, self).__init__()
-
-        self.input_projection = ConvNorm(in_channels, residual_channels, kernel_size=1)
-        self.diffusion_embedding = DiffusionEmbedding(residual_channels)
-        self.mlp = nn.Sequential(
-            LinearNorm(residual_channels, residual_channels * 4, use_linear_bias),
-            Mish(),
-            LinearNorm(residual_channels * 4, residual_channels, use_linear_bias),
-        )
-        self.residual_layers = nn.ModuleList(
-            [
-                ResidualBlock(
-                    d_encoder,
-                    residual_channels,
-                    use_linear_bias=use_linear_bias,
-                    dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
-                )
-                for i in range(residual_layers)
-            ]
-        )
-        self.skip_projection = ConvNorm(
-            residual_channels, residual_channels, kernel_size=1
-        )
-        self.output_projection = ConvNorm(
-            residual_channels, out_channels, kernel_size=1
-        )
-        nn.init.zeros_(self.output_projection.conv.weight)
-
-    def forward(self, x, diffusion_step, x_masks, condition):
-        x = self.input_projection(x)  # x [B, residual_channel, T]
-        x = F.relu(x)
-
-        diffusion_step = self.diffusion_embedding(diffusion_step)
-        diffusion_step = self.mlp(diffusion_step)
-
-        skip = []
-        for layer in self.residual_layers:
-            x, skip_connection = layer(x * x_masks, condition * x_masks, diffusion_step)
-            skip.append(skip_connection)
-
-        x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
-        x = self.skip_projection(x)
-        x = F.relu(x)
-        x = self.output_projection(x)  # [B, 128, T]
-
-        x = x * x_masks
-
-        return x