Lengyue пре 2 година
родитељ
комит
c9e1f95503

+ 1 - 0
.gitignore

@@ -8,3 +8,4 @@ __pycache__
 *.filelist
 filelists
 /fish_speech/text/cmudict_cache.pickle
+/checkpoints

+ 3 - 3
fish_speech/configs/hubert_vq.yaml

@@ -45,7 +45,7 @@ data:
 
 # Model Configuration
 model:
-  _target_: fish_speech.models.vqgan.VQGAN
+  _target_: fish_speech.models.vq_diffusion.VQGAN
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   segment_size: 20480
@@ -110,5 +110,5 @@ callbacks:
     sub_module: generator
 
 # Resume from rcell's checkpoint
-# ckpt_path: results/hubert-vq-pretrain/rcell/ckpt_23000_pl.pth
-# resume_weights_only: true
+ckpt_path: results/hubert-vq-pretrain/rcell/ckpt_23000_pl.pth
+resume_weights_only: true

+ 130 - 0
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -0,0 +1,130 @@
+defaults:
+  - base
+  - _self_
+
+project: hubert_vq_diffusion
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: 4
+  strategy:
+    _target_: lightning.pytorch.strategies.DDPStrategy
+    static_graph: true
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
+  precision: bf16-mixed
+  max_steps: 1_000_000
+  val_check_interval: 1000
+
+sample_rate: 44100
+hop_length: 512
+num_mels: 128
+n_fft: 2048
+win_length: 2048
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/vq_train_filelist.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: 512
+
+val_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/vq_val_filelist.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+data:
+  _target_: fish_speech.datasets.vqgan.VQGANDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 8
+  val_batch_size: 4
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vq_diffusion.lit_module.VQDiffusion
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+  text_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
+    in_channels: 1024
+    out_channels: 128
+    hidden_channels: 192
+    hidden_channels_ffn: 768
+    n_heads: 2
+    n_layers: 4
+    kernel_size: 1
+    dropout: 0.1
+    use_vae: false
+    gin_channels: 512
+    speaker_cond_layer: 0
+
+  vq_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
+    in_channels: 1024
+    vq_channels: 1024
+    codebook_size: 2048
+    downsample: 2
+    kmeans_ckpt: results/hubert-vq-pretrain/kmeans.pt
+
+  speaker_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
+    in_channels: 128
+    hidden_channels: 192
+    out_channels: 512
+    num_heads: 2
+    num_layers: 4
+    p_dropout: 0.1
+  
+  # denoiser:
+  #   _target_: fish_speech.models.vq_diffusion.convnext_1d.ConvNext1DModel
+  #   in_channels: 256
+  #   out_channels: 128
+  #   intermediate_dim: 512
+  #   mlp_dim: 2048
+  #   num_layers: 20
+  #   dilation_cycle_length: 2
+  #   time_embedding_type: "positional"
+
+  denoiser:
+    _target_: fish_speech.models.vq_diffusion.unet1d.Unet1DDenoiser
+    dim: 64
+    dim_mults: [1, 2, 4]
+    groups: 8
+    pe_scale: 1000
+
+  vocoder:
+    _target_: fish_speech.models.vq_diffusion.adamos.ADaMoSHiFiGANV1
+
+  mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+    f_min: 40
+    f_max: 16000
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 1e-4
+    betas: [0.9, 0.999]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 0
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.05

+ 16 - 12
fish_speech/datasets/vqgan.py

@@ -48,23 +48,26 @@ class VQGANDataset(Dataset):
         if self.slice_frames is not None and features.shape[0] > self.slice_frames:
             start = np.random.randint(0, features.shape[0] - self.slice_frames)
             features = features[start : start + self.slice_frames]
+            feature_hop_length = features.shape[0] * (32000 // 50)
             audio = audio[
-                start * self.hop_length : (start + self.slice_frames) * self.hop_length
+                start
+                * feature_hop_length : (start + self.slice_frames)
+                * feature_hop_length
             ]
 
-        if features.shape[0] % 2 != 0:
-            features = features[:-1]
+        # if features.shape[0] % 2 != 0:
+        #     features = features[:-1]
 
-        if len(audio) > len(features) * self.hop_length:
-            audio = audio[: features.shape[0] * self.hop_length]
+        # if len(audio) > len(features) * self.hop_length:
+        #     audio = audio[: features.shape[0] * self.hop_length]
 
-        if len(audio) < len(features) * self.hop_length:
-            audio = np.pad(
-                audio,
-                (0, len(features) * self.hop_length - len(audio)),
-                mode="constant",
-                constant_values=0,
-            )
+        # if len(audio) < len(features) * self.hop_length:
+        #     audio = np.pad(
+        #         audio,
+        #         (0, len(features) * self.hop_length - len(audio)),
+        #         mode="constant",
+        #         constant_values=0,
+        #     )
 
         return {
             "audio": torch.from_numpy(audio),
@@ -90,6 +93,7 @@ class VQGANCollator:
         audio_maxlen = audio_lengths.max()
         feature_maxlen = feature_lengths.max()
 
+        # Rounds up to nearest multiple of 2 (audio_lengths)
         audios, features = [], []
         for x in batch:
             audios.append(

+ 3 - 0
fish_speech/models/vq_diffusion/adamos/__init__.py

@@ -0,0 +1,3 @@
+from .adamos import ADaMoSHiFiGANV1
+
+__all__ = ["ADaMoSHiFiGANV1"]

+ 88 - 0
fish_speech/models/vq_diffusion/adamos/adamos.py

@@ -0,0 +1,88 @@
+import librosa
+import torch
+from torch import nn
+
+from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
+
+from .encoder import ConvNeXtEncoder
+from .hifigan import HiFiGANGenerator
+
+
+class ADaMoSHiFiGANV1(nn.Module):
+    def __init__(
+        self,
+        checkpoint_path: str = "checkpoints/adamos-generator-1640000.pth",
+    ):
+        super().__init__()
+
+        self.backbone = ConvNeXtEncoder(
+            input_channels=128,
+            depths=[3, 3, 9, 3],
+            dims=[128, 256, 384, 512],
+            drop_path_rate=0,
+            kernel_sizes=(7,),
+        )
+
+        self.head = HiFiGANGenerator(
+            hop_length=512,
+            upsample_rates=(4, 4, 2, 2, 2, 2, 2),
+            upsample_kernel_sizes=(8, 8, 4, 4, 4, 4, 4),
+            resblock_kernel_sizes=(3, 7, 11, 13),
+            resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
+            num_mels=512,
+            upsample_initial_channel=1024,
+            use_template=False,
+            pre_conv_kernel_size=13,
+            post_conv_kernel_size=13,
+        )
+        self.sampling_rate = 44100
+
+        ckpt_state = torch.load(checkpoint_path, map_location="cpu")
+
+        if "state_dict" in ckpt_state:
+            ckpt_state = ckpt_state["state_dict"]
+
+        if any(k.startswith("generator.") for k in ckpt_state):
+            ckpt_state = {
+                k.replace("generator.", ""): v
+                for k, v in ckpt_state.items()
+                if k.startswith("generator.")
+            }
+
+        self.load_state_dict(ckpt_state)
+        self.eval()
+
+        self.mel_transform = LogMelSpectrogram(
+            sample_rate=44100,
+            n_fft=2048,
+            win_length=2048,
+            hop_length=512,
+            f_min=40,
+            f_max=16000,
+            n_mels=128,
+        )
+
+    @torch.no_grad()
+    def decode(self, mel):
+        y = self.backbone(mel)
+        y = self.head(y)
+
+        return y
+
+    @torch.no_grad()
+    def encode(self, x):
+        return self.mel_transform(x)
+
+
+if __name__ == "__main__":
+    import soundfile as sf
+
+    x = "data/StarRail/Chinese/罗刹/archive_luocha_2.wav"
+    model = ADaMoSHiFiGANV1()
+
+    wav, sr = librosa.load(x, sr=44100, mono=True)
+    wav = torch.from_numpy(wav).float()[None]
+    mel = model.encode(wav)
+
+    wav = model.decode(mel)[0].mT
+    sf.write("test.wav", wav.cpu().numpy(), 44100)

+ 238 - 0
fish_speech/models/vq_diffusion/adamos/encoder.py

@@ -0,0 +1,238 @@
+from functools import partial
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+def drop_path(
+    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+    'survival rate' as the argument.
+
+    """  # noqa: E501
+
+    if drop_prob == 0.0 or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (
+        x.ndim - 1
+    )  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+    if keep_prob > 0.0 and scale_by_keep:
+        random_tensor.div_(keep_prob)
+    return x * random_tensor
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""  # noqa: E501
+
+    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+        self.scale_by_keep = scale_by_keep
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+    def extra_repr(self):
+        return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class LayerNorm(nn.Module):
+    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+    with shape (batch_size, channels, height, width).
+    """  # noqa: E501
+
+    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(normalized_shape))
+        self.bias = nn.Parameter(torch.zeros(normalized_shape))
+        self.eps = eps
+        self.data_format = data_format
+        if self.data_format not in ["channels_last", "channels_first"]:
+            raise NotImplementedError
+        self.normalized_shape = (normalized_shape,)
+
+    def forward(self, x):
+        if self.data_format == "channels_last":
+            return F.layer_norm(
+                x, self.normalized_shape, self.weight, self.bias, self.eps
+            )
+        elif self.data_format == "channels_first":
+            u = x.mean(1, keepdim=True)
+            s = (x - u).pow(2).mean(1, keepdim=True)
+            x = (x - u) / torch.sqrt(s + self.eps)
+            x = self.weight[:, None] * x + self.bias[:, None]
+            return x
+
+
+class ConvNeXtBlock(nn.Module):
+    r"""ConvNeXt Block. There are two equivalent implementations:
+    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+    We use (2) as we find it slightly faster in PyTorch
+
+    Args:
+        dim (int): Number of input channels.
+        drop_path (float): Stochastic depth rate. Default: 0.0
+        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+        kernel_size (int): Kernel size for depthwise conv. Default: 7.
+        dilation (int): Dilation for depthwise conv. Default: 1.
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        dim: int,
+        drop_path: float = 0.0,
+        layer_scale_init_value: float = 1e-6,
+        mlp_ratio: float = 4.0,
+        kernel_size: int = 7,
+        dilation: int = 1,
+    ):
+        super().__init__()
+
+        self.dwconv = nn.Conv1d(
+            dim,
+            dim,
+            kernel_size=kernel_size,
+            padding=int(dilation * (kernel_size - 1) / 2),
+            groups=dim,
+        )  # depthwise conv
+        self.norm = LayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(
+            dim, int(mlp_ratio * dim)
+        )  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
+        self.gamma = (
+            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+            if layer_scale_init_value > 0
+            else None
+        )
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+    def forward(self, x, apply_residual: bool = True):
+        input = x
+
+        x = self.dwconv(x)
+        x = x.permute(0, 2, 1)  # (N, C, L) -> (N, L, C)
+        x = self.norm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+
+        if self.gamma is not None:
+            x = self.gamma * x
+
+        x = x.permute(0, 2, 1)  # (N, L, C) -> (N, C, L)
+        x = self.drop_path(x)
+
+        if apply_residual:
+            x = input + x
+
+        return x
+
+
+class ParallelConvNeXtBlock(nn.Module):
+    def __init__(self, kernel_sizes: list[int], *args, **kwargs):
+        super().__init__()
+        self.blocks = nn.ModuleList(
+            [
+                ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
+                for kernel_size in kernel_sizes
+            ]
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return torch.stack(
+            [block(x, apply_residual=False) for block in self.blocks] + [x],
+            dim=1,
+        ).sum(dim=1)
+
+
+class ConvNeXtEncoder(nn.Module):
+    def __init__(
+        self,
+        input_channels=3,
+        depths=[3, 3, 9, 3],
+        dims=[96, 192, 384, 768],
+        drop_path_rate=0.0,
+        layer_scale_init_value=1e-6,
+        kernel_sizes: tuple[int] = (7,),
+    ):
+        super().__init__()
+        assert len(depths) == len(dims)
+
+        self.channel_layers = nn.ModuleList()
+        stem = nn.Sequential(
+            nn.Conv1d(
+                input_channels,
+                dims[0],
+                kernel_size=7,
+                padding=3,
+                padding_mode="replicate",
+            ),
+            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+        )
+        self.channel_layers.append(stem)
+
+        for i in range(len(depths) - 1):
+            mid_layer = nn.Sequential(
+                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+                nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
+            )
+            self.channel_layers.append(mid_layer)
+
+        block_fn = (
+            partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
+            if len(kernel_sizes) == 1
+            else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
+        )
+
+        self.stages = nn.ModuleList()
+        drop_path_rates = [
+            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+        ]
+
+        cur = 0
+        for i in range(len(depths)):
+            stage = nn.Sequential(
+                *[
+                    block_fn(
+                        dim=dims[i],
+                        drop_path=drop_path_rates[cur + j],
+                        layer_scale_init_value=layer_scale_init_value,
+                    )
+                    for j in range(depths[i])
+                ]
+            )
+            self.stages.append(stage)
+            cur += depths[i]
+
+        self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Conv1d, nn.Linear)):
+            nn.init.trunc_normal_(m.weight, std=0.02)
+            nn.init.constant_(m.bias, 0)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+    ) -> torch.Tensor:
+        for channel_layer, stage in zip(self.channel_layers, self.stages):
+            x = channel_layer(x)
+            x = stage(x)
+
+        return self.norm(x)

+ 237 - 0
fish_speech/models/vq_diffusion/adamos/hifigan.py

@@ -0,0 +1,237 @@
+from functools import partial
+from math import prod
+from typing import Callable
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Conv1d
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return (kernel_size * dilation - dilation) // 2
+
+
+class ResBlock1(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super().__init__()
+
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.silu(x)
+            xt = c1(xt)
+            xt = F.silu(xt)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for conv in self.convs1:
+            remove_weight_norm(conv)
+        for conv in self.convs2:
+            remove_weight_norm(conv)
+
+
+class HiFiGANGenerator(nn.Module):
+    def __init__(
+        self,
+        *,
+        hop_length: int = 512,
+        upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
+        upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
+        resblock_kernel_sizes: tuple[int] = (3, 7, 11),
+        resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+        num_mels: int = 128,
+        upsample_initial_channel: int = 512,
+        use_template: bool = True,
+        pre_conv_kernel_size: int = 7,
+        post_conv_kernel_size: int = 7,
+        post_activation: Callable = partial(nn.SiLU, inplace=True),
+    ):
+        super().__init__()
+
+        assert (
+            prod(upsample_rates) == hop_length
+        ), f"hop_length must be {prod(upsample_rates)}"
+
+        self.conv_pre = weight_norm(
+            nn.Conv1d(
+                num_mels,
+                upsample_initial_channel,
+                pre_conv_kernel_size,
+                1,
+                padding=get_padding(pre_conv_kernel_size),
+            )
+        )
+
+        self.num_upsamples = len(upsample_rates)
+        self.num_kernels = len(resblock_kernel_sizes)
+
+        self.noise_convs = nn.ModuleList()
+        self.use_template = use_template
+        self.ups = nn.ModuleList()
+
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            c_cur = upsample_initial_channel // (2 ** (i + 1))
+            self.ups.append(
+                weight_norm(
+                    nn.ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+            if not use_template:
+                continue
+
+            if i + 1 < len(upsample_rates):
+                stride_f0 = np.prod(upsample_rates[i + 1 :])
+                self.noise_convs.append(
+                    Conv1d(
+                        1,
+                        c_cur,
+                        kernel_size=stride_f0 * 2,
+                        stride=stride_f0,
+                        padding=stride_f0 // 2,
+                    )
+                )
+            else:
+                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
+                self.resblocks.append(ResBlock1(ch, k, d))
+
+        self.activation_post = post_activation()
+        self.conv_post = weight_norm(
+            nn.Conv1d(
+                ch,
+                1,
+                post_conv_kernel_size,
+                1,
+                padding=get_padding(post_conv_kernel_size),
+            )
+        )
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+    def forward(self, x, template=None):
+        x = self.conv_pre(x)
+
+        for i in range(self.num_upsamples):
+            x = F.silu(x, inplace=True)
+            x = self.ups[i](x)
+
+            if self.use_template:
+                x = x + self.noise_convs[i](template)
+
+            xs = None
+
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+
+            x = xs / self.num_kernels
+
+        x = self.activation_post(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        for up in self.ups:
+            remove_weight_norm(up)
+        for block in self.resblocks:
+            block.remove_weight_norm()
+        remove_weight_norm(self.conv_pre)
+        remove_weight_norm(self.conv_post)

+ 244 - 0
fish_speech/models/vq_diffusion/convnext_1d.py

@@ -0,0 +1,244 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.embeddings import (
+    GaussianFourierProjection,
+    TimestepEmbedding,
+    Timesteps,
+)
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import BaseOutput
+
+
+class ConvNeXtBlock(nn.Module):
+    """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
+
+    Args:
+        dim (int): Number of input channels.
+        mlp_dim (int): Dimensionality of the intermediate layer.
+        layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
+            Defaults to None.
+        adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
+            None means non-conditional LayerNorm. Defaults to None.
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        intermediate_dim: int,
+        dilation: int = 1,
+        layer_scale_init_value: Optional[float] = 1e-6,
+    ):
+        super().__init__()
+        self.dwconv = nn.Conv1d(
+            dim,
+            dim,
+            kernel_size=7,
+            groups=dim,
+            dilation=dilation,
+            padding=int(dilation * (7 - 1) / 2),
+        )  # depthwise conv
+        self.norm = nn.LayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(
+            dim, intermediate_dim
+        )  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.pwconv2 = nn.Linear(intermediate_dim, dim)
+        self.gamma = (
+            nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
+            if layer_scale_init_value is not None and layer_scale_init_value > 0
+            else None
+        )
+
+        self.condition_projection = nn.Sequential(
+            nn.Conv1d(dim, dim, 1),
+            nn.GELU(),
+            nn.Conv1d(dim, dim, 1),
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        condition: Optional[torch.Tensor] = None,
+        x_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        residual = x
+
+        if condition is not None:
+            x = x + self.condition_projection(condition)
+
+        if x_mask is not None:
+            x = x * x_mask
+
+        x = self.dwconv(x)
+        x = x.transpose(1, 2)  # (B, C, T) -> (B, T, C)
+        x = self.norm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+        if self.gamma is not None:
+            x = self.gamma * x
+        x = x.transpose(1, 2)  # (B, T, C) -> (B, C, T)
+
+        x = residual + x
+
+        return x
+
+
+@dataclass
+class ConvNext1DOutput(BaseOutput):
+    """
+    The output of [`UNet1DModel`].
+
+    Args:
+        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
+            The hidden states output from the last layer of the model.
+    """
+
+    sample: torch.FloatTensor
+
+
+class ConvNext1DModel(ModelMixin, ConfigMixin):
+    r"""
+    A ConvNext model that takes a noisy sample and a timestep and returns a sample shaped output.
+
+    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+    for all models (such as downloading or saving).
+
+    Parameters:
+        in_channels (`int`, *optional*, defaults to 128):
+            Number of channels in the input sample.
+        out_channels (`int`, *optional*, defaults to 128):
+            Number of channels in the output.
+        intermediate_dim (`int`, *optional*, defaults to 512):
+            Dimensionality of the intermediate blocks.
+        mlp_dim (`int`, *optional*, defaults to 2048):
+            Dimensionality of the MLP.
+        num_layers (`int`, *optional*, defaults to 20):
+            Number of layers in the model.
+        dilation_cycle_length (`int`, *optional*, defaults to 4):
+            Length of the dilation cycle.
+        time_embedding_type (`str`, *optional*, defaults to `positional`):
+            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+        time_embedding_dim (`int`, *optional*, defaults to `None`):
+            An optional override for the dimension of the projected time embedding.
+        time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+            Optional activation function to use only once on the time embeddings before they are passed to the rest of
+            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+    """
+
+    _supports_gradient_checkpointing = True
+
+    @register_to_config
+    def __init__(
+        self,
+        in_channels: int = 128,
+        out_channels: int = 128,
+        intermediate_dim: int = 512,
+        mlp_dim: int = 2048,
+        num_layers: int = 20,
+        dilation_cycle_length: int = 4,
+        time_embedding_type: str = "positional",
+    ):
+        super().__init__()
+
+        if intermediate_dim % 2 != 0:
+            raise ValueError("intermediate_dim must be divisible by 2.")
+
+        # time
+        if time_embedding_type == "fourier":
+            self.time_proj = GaussianFourierProjection(
+                intermediate_dim // 2,
+                set_W_to_weight=False,
+                log=False,
+                flip_sin_to_cos=False,
+            )
+            timestep_input_dim = intermediate_dim
+        elif time_embedding_type == "positional":
+            self.time_proj = Timesteps(in_channels, False, 0)
+            timestep_input_dim = in_channels
+        else:
+            raise ValueError(
+                f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+            )
+
+        self.time_mlp = TimestepEmbedding(
+            timestep_input_dim,
+            intermediate_dim,
+            act_fn="silu",
+            cond_proj_dim=None,  # No conditional projection for now
+        )
+
+        # Project to intermediate dim
+        self.in_proj = nn.Conv1d(in_channels, intermediate_dim, 1)
+        self.out_proj = nn.Conv1d(intermediate_dim, out_channels, 1)
+
+        # Blocks
+        self.blocks = nn.ModuleList(
+            [
+                ConvNeXtBlock(
+                    dim=intermediate_dim,
+                    intermediate_dim=mlp_dim,
+                    dilation=2 ** (i % dilation_cycle_length),
+                )
+                for i in range(num_layers)
+            ]
+        )
+
+        self.gradient_checkpointing = False
+
+    def _set_gradient_checkpointing(self, module, value: bool = False):
+        self.gradient_checkpointing = value
+
+    def forward(
+        self,
+        sample: torch.FloatTensor,
+        timestep: Union[torch.Tensor, float, int],
+        sample_mask: Optional[torch.Tensor] = None,
+        condition: Optional[torch.Tensor] = None,
+    ) -> Union[ConvNext1DOutput, Tuple]:
+        r"""
+        The [`ConvNext1DModel`] forward method.
+
+        Args:
+            sample (`torch.FloatTensor`):
+                The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
+            timestep (`torch.FloatTensor` or `float` or `int`):
+                The number of timesteps to denoise an input.
+            sample_mask (`torch.BoolTensor`, *optional*):
+                A mask of the same shape as `sample` that indicates which elements are invalid.
+                True means the element is invalid and should be masked out.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~models.unet_1d.ConvNext1DOutput`] instead of a plain tuple.
+
+        Returns:
+            [`~models.unet_1d.ConvNext1DOutput`] or `tuple`:
+                If `return_dict` is True, an [`~models.unet_1d.ConvNext1DOutput`] is returned, otherwise a `tuple` is
+                returned where the first element is the sample tensor.
+        """
+
+        # 1. time
+        t_emb = self.time_proj(timestep)
+        t_emb = self.time_mlp(t_emb)[..., None]
+
+        # 2. pre-process
+        if condition is not None:
+            sample = torch.cat([sample, condition], dim=1)
+
+        x = self.in_proj(sample)
+
+        if sample_mask.ndim == 2:
+            sample_mask = sample_mask[:, None, :]
+
+        # 3. blocks
+        for block in self.blocks:
+            if self.training and self.is_gradient_checkpointing:
+                x = torch.utils.checkpoint.checkpoint(block, x, t_emb, sample_mask)
+            else:
+                x = block(x, t_emb, sample_mask)
+
+        # 4. post-process
+        return self.out_proj(x)

+ 270 - 0
fish_speech/models/vq_diffusion/lit_module.py

@@ -0,0 +1,270 @@
+import itertools
+from typing import Any, Callable
+
+import lightning as L
+import torch
+import torch.nn.functional as F
+import wandb
+from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler
+from diffusers.utils.torch_utils import randn_tensor
+from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
+from matplotlib import pyplot as plt
+from torch import nn
+
+from fish_speech.models.vq_diffusion.unet1d import Unet1DDenoiser
+from fish_speech.models.vqgan.modules.encoders import (
+    SpeakerEncoder,
+    TextEncoder,
+    VQEncoder,
+)
+from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
+
+
+class VQDiffusion(L.LightningModule):
+    def __init__(
+        self,
+        optimizer: Callable,
+        lr_scheduler: Callable,
+        mel_transform: nn.Module,
+        vq_encoder: VQEncoder,
+        speaker_encoder: SpeakerEncoder,
+        text_encoder: TextEncoder,
+        denoiser: Unet1DDenoiser,
+        vocoder: nn.Module,
+        hop_length: int = 640,
+        sample_rate: int = 32000,
+    ):
+        super().__init__()
+
+        # Model parameters
+        self.optimizer_builder = optimizer
+        self.lr_scheduler_builder = lr_scheduler
+
+        # Generator and discriminators
+        self.mel_transform = mel_transform
+        self.noise_scheduler_train = DDIMScheduler(num_train_timesteps=1000)
+        self.noise_scheduler_infer = UniPCMultistepScheduler(num_train_timesteps=1000)
+        self.noise_scheduler_infer.set_timesteps(20)
+
+        # Modules
+        self.vq_encoder = vq_encoder
+        self.speaker_encoder = speaker_encoder
+        self.text_encoder = text_encoder
+        self.denoiser = denoiser
+
+        self.vocoder = vocoder
+        self.hop_length = hop_length
+        self.sampling_rate = sample_rate
+
+        # Freeze vocoder
+        for param in self.vocoder.parameters():
+            param.requires_grad = False
+
+    def configure_optimizers(self):
+        optimizer = self.optimizer_builder(self.parameters())
+        lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+        return {
+            "optimizer": optimizer,
+            "lr_scheduler": {
+                "scheduler": lr_scheduler,
+                "interval": "step",
+            },
+        }
+
+    def normalize_mels(self, x):
+        return (x + 11.5129251) / (1 + 11.5129251) * 2 - 1
+
+    def denormalize_mels(self, x):
+        return (x + 1) / 2 * (1.0 + 11.5129251) - 11.5129251
+
+    def training_step(self, batch, batch_idx):
+        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+        features, feature_lengths = batch["features"], batch["feature_lengths"]
+
+        audios = audios.float()
+        features = features.float().mT
+        audios = audios[:, None, :]
+
+        with torch.no_grad():
+            gt_mels = self.mel_transform(audios)
+
+        mel_lengths = audio_lengths // self.hop_length
+
+        feature_masks = torch.unsqueeze(
+            sequence_mask(feature_lengths, features.shape[2]), 1
+        ).to(gt_mels.dtype)
+        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
+            gt_mels.dtype
+        )
+
+        speaker_features = self.speaker_encoder(gt_mels, mel_masks)
+        vq_features, _ = self.vq_encoder(features, feature_masks)
+
+        # vq_features is 50 hz, need to convert to true mel size
+        vq_features = F.interpolate(vq_features, size=gt_mels.shape[2], mode="nearest")
+        text_features = self.text_encoder(vq_features, mel_masks, g=speaker_features)
+
+        # Sample noise that we'll add to the images
+        normalized_gt_mels = self.normalize_mels(gt_mels)
+        noise = torch.randn_like(normalized_gt_mels)
+
+        # Sample a random timestep for each image
+        timesteps = torch.randint(
+            0,
+            self.noise_scheduler_train.config.num_train_timesteps,
+            (normalized_gt_mels.shape[0],),
+            device=normalized_gt_mels.device,
+        ).long()
+
+        # Add noise to the clean images according to the noise magnitude at each timestep
+        # (this is the forward diffusion process)
+        noisy_images = self.noise_scheduler_train.add_noise(
+            normalized_gt_mels, noise, timesteps
+        )
+
+        # Predict
+        model_output = self.denoiser(noisy_images, timesteps, mel_masks, text_features)
+
+        # MSE loss without the mask
+        loss = (
+            (model_output * mel_masks - normalized_gt_mels * mel_masks) ** 2
+        ).sum() / (mel_masks.sum() * gt_mels.shape[1])
+
+        self.log(
+            "train/loss",
+            loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
+
+        return loss
+
+    def validation_step(self, batch: Any, batch_idx: int):
+        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+        features, feature_lengths = batch["features"], batch["feature_lengths"]
+
+        audios = audios.float()
+        features = features.float().mT
+        audios = audios[:, None, :]
+        gt_mels = self.mel_transform(audios)
+        mel_lengths = audio_lengths // self.hop_length
+
+        feature_masks = torch.unsqueeze(
+            sequence_mask(feature_lengths, features.shape[2]), 1
+        ).to(gt_mels.dtype)
+        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
+            gt_mels.dtype
+        )
+
+        speaker_features = self.speaker_encoder(gt_mels, mel_masks)
+        vq_features, _ = self.vq_encoder(features, feature_masks)
+
+        # vq_features is 50 hz, need to convert to true mel size
+        vq_features = F.interpolate(vq_features, size=gt_mels.shape[2], mode="nearest")
+        text_features = self.text_encoder(vq_features, mel_masks, g=speaker_features)
+
+        # Begin sampling
+        sampled_mels = torch.randn_like(gt_mels)
+        self.noise_scheduler_infer.set_timesteps(20)
+
+        for t in self.noise_scheduler_infer.timesteps:
+            timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
+
+            # 1. predict noise model_output
+            model_output = self.denoiser(
+                sampled_mels, timesteps, mel_masks, text_features
+            )
+
+            # 2. compute previous image: x_t -> x_t-1
+            sampled_mels = self.noise_scheduler_infer.step(
+                model_output, t, sampled_mels
+            ).prev_sample
+
+        sampled_mels = self.denormalize_mels(sampled_mels)
+
+        with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
+            # Run vocoder on fp32
+            fake_audios = self.vocoder.decode(sampled_mels.float())
+
+        mel_loss = F.l1_loss(gt_mels, sampled_mels)
+        self.log(
+            "val/mel_loss",
+            mel_loss,
+            on_step=False,
+            on_epoch=True,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
+
+        for idx, (
+            mel,
+            gen_mel,
+            audio,
+            gen_audio,
+            audio_len,
+        ) in enumerate(
+            zip(
+                gt_mels,
+                sampled_mels,
+                audios,
+                fake_audios,
+                audio_lengths,
+            )
+        ):
+            mel_len = audio_len // self.hop_length
+
+            image_mels = plot_mel(
+                [
+                    gen_mel[:, :mel_len],
+                    mel[:, :mel_len],
+                ],
+                [
+                    "Generated Spectrogram",
+                    "Ground-Truth Spectrogram",
+                ],
+            )
+
+            if isinstance(self.logger, WandbLogger):
+                self.logger.experiment.log(
+                    {
+                        "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
+                        "wavs": [
+                            wandb.Audio(
+                                audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="gt",
+                            ),
+                            wandb.Audio(
+                                gen_audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="prediction",
+                            ),
+                        ],
+                    },
+                )
+
+            if isinstance(self.logger, TensorBoardLogger):
+                self.logger.experiment.add_figure(
+                    f"sample-{idx}/mels",
+                    image_mels,
+                    global_step=self.global_step,
+                )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/gt",
+                    audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/prediction",
+                    gen_audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
+
+            plt.close(image_mels)

+ 198 - 0
fish_speech/models/vq_diffusion/unet1d.py

@@ -0,0 +1,198 @@
+# Refer to https://github.com/huawei-noah/Speech-Backbones/blob/main/Grad-TTS/model/diffusion.py
+
+import math
+
+import torch
+from einops import rearrange
+from torch import nn
+
+
+class Block(nn.Module):
+    def __init__(self, dim, dim_out, groups=8):
+        super().__init__()
+        self.block = nn.Sequential(
+            nn.Conv2d(dim, dim_out, 3, padding=1),
+            nn.GroupNorm(groups, dim_out),
+            nn.Mish(),
+        )
+
+    def forward(self, x, mask):
+        output = self.block(x * mask)
+        return output * mask
+
+
+class ResnetBlock(nn.Module):
+    def __init__(self, dim, dim_out, time_emb_dim, groups=8):
+        super().__init__()
+        self.mlp = nn.Sequential(nn.Mish(), nn.Linear(time_emb_dim, dim_out))
+
+        self.block1 = Block(dim, dim_out, groups=groups)
+        self.block2 = Block(dim_out, dim_out, groups=groups)
+        if dim != dim_out:
+            self.res_conv = nn.Conv2d(dim, dim_out, 1)
+        else:
+            self.res_conv = nn.Identity()
+
+    def forward(self, x, mask, time_emb):
+        h = self.block1(x, mask)
+        h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
+        h = self.block2(h, mask)
+        output = h + self.res_conv(x * mask)
+        return output
+
+
+class LinearAttention(nn.Module):
+    def __init__(self, dim, heads=4, dim_head=32, init_values=1e-5):
+        super().__init__()
+
+        self.heads = heads
+        hidden_dim = dim_head * heads
+
+        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+        self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+        self.gamma = nn.Parameter(torch.ones(dim) * init_values)
+
+    def forward(self, x):
+        b, c, h, w = x.shape
+        qkv = self.to_qkv(x)
+        q, k, v = rearrange(
+            qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
+        )
+        k = k.softmax(dim=-1)
+        context = torch.einsum("bhdn,bhen->bhde", k, v)
+        out = torch.einsum("bhde,bhdn->bhen", context, q)
+        out = rearrange(
+            out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
+        )
+        return self.to_out(out) * self.gamma.view(1, -1, 1, 1) + x
+
+
+class SinusoidalPosEmb(nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.dim = dim
+
+    def forward(self, x, scale=1000):
+        device = x.device
+        half_dim = self.dim // 2
+        emb = math.log(10000) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
+        emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
+        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+        return emb
+
+
+class Unet1DDenoiser(nn.Module):
+    def __init__(
+        self,
+        dim,
+        dim_mults=(1, 2, 4),
+        groups=8,
+        pe_scale=1000,
+    ):
+        super().__init__()
+        self.dim = dim
+        self.dim_mults = dim_mults
+        self.groups = groups
+        self.pe_scale = pe_scale
+
+        self.time_pos_emb = SinusoidalPosEmb(dim)
+        self.mlp = nn.Sequential(
+            nn.Linear(dim, dim * 4), nn.Mish(), nn.Linear(dim * 4, dim)
+        )
+        self.downsample_rate = 2 ** (len(dim_mults) - 1)
+
+        dims = [2, *map(lambda m: dim * m, dim_mults)]
+        in_out = list(zip(dims[:-1], dims[1:]))
+        self.downs = nn.ModuleList([])
+        self.ups = nn.ModuleList([])
+        num_resolutions = len(in_out)
+
+        for ind, (dim_in, dim_out) in enumerate(in_out):
+            is_last = ind >= (num_resolutions - 1)
+            self.downs.append(
+                nn.ModuleList(
+                    [
+                        ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
+                        ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
+                        LinearAttention(dim_out),
+                        nn.Conv2d(dim_out, dim_out, 3, 2, 1)
+                        if not is_last
+                        else nn.Identity(),
+                    ]
+                )
+            )
+
+        mid_dim = dims[-1]
+        self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
+        self.mid_attn = LinearAttention(mid_dim)
+        self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
+
+        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
+            self.ups.append(
+                nn.ModuleList(
+                    [
+                        ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
+                        ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
+                        LinearAttention(dim_in),
+                        nn.ConvTranspose2d(dim_in, dim_in, 4, 2, 1),
+                    ]
+                )
+            )
+        self.final_block = Block(dim, dim)
+        self.final_conv = nn.Conv2d(dim, 1, 1)
+
+    def forward(self, x, t, mask, condition):
+        t = self.time_pos_emb(t, scale=self.pe_scale)
+        t = self.mlp(t)
+
+        x = torch.stack([condition, x], 1)
+        mask = mask.unsqueeze(1)
+
+        original_len = x.shape[3]
+        if x.shape[3] % self.downsample_rate != 0:
+            x = nn.functional.pad(
+                x, (0, self.downsample_rate - x.shape[3] % self.downsample_rate)
+            )
+            mask = nn.functional.pad(
+                mask, (0, self.downsample_rate - mask.shape[3] % self.downsample_rate)
+            )
+
+        hiddens = []
+        masks = [mask]
+        for resnet1, resnet2, attn, downsample in self.downs:
+            mask_down = masks[-1]
+            x = resnet1(x, mask_down, t)
+            x = resnet2(x, mask_down, t)
+            x = attn(x)
+            hiddens.append(x)
+            x = downsample(x * mask_down)
+            masks.append(mask_down[:, :, :, ::2])
+
+        masks = masks[:-1]
+        mask_mid = masks[-1]
+        x = self.mid_block1(x, mask_mid, t)
+        x = self.mid_attn(x)
+        x = self.mid_block2(x, mask_mid, t)
+
+        for resnet1, resnet2, attn, upsample in self.ups:
+            mask_up = masks.pop()
+            x = torch.cat((x, hiddens.pop()), dim=1)
+            x = resnet1(x, mask_up, t)
+            x = resnet2(x, mask_up, t)
+            x = attn(x)
+            x = upsample(x * mask_up)
+
+        x = self.final_block(x, mask)
+        output = self.final_conv(x * mask)
+
+        output = (output * mask).squeeze(1)
+        return output[:, :, :original_len]
+
+
+if __name__ == "__main__":
+    model = Unet1DDenoiser(128)
+    mel = torch.randn(1, 128, 99)
+    mask = torch.ones(1, 1, 99)
+
+    print(model(mel, mask, torch.tensor([10], dtype=torch.long), mel).shape)

+ 1 - 3
fish_speech/models/vqgan/lit_module.py

@@ -233,9 +233,7 @@ class VQGAN(L.LightningModule):
         audios = audios[:, None, :]
 
         gt_mels = self.mel_transform(audios)
-        assert (
-            gt_mels.shape[2] == features.shape[1]
-        ), f"Shapes do not match: {gt_mels.shape}, {features.shape}"
+        gt_mels = gt_mels[:, :, : features.shape[1]]
 
         fake_audios = self.generator.infer(features, feature_lengths, gt_mels)
         posterior_audios = self.generator.reconstruct(gt_mels, feature_lengths)

+ 37 - 30
fish_speech/models/vqgan/modules/encoders.py

@@ -2,6 +2,7 @@ from typing import Optional
 
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 from vector_quantize_pytorch import VectorQuantize
 
 from fish_speech.models.vqgan.modules.modules import WN
@@ -13,7 +14,7 @@ from fish_speech.models.vqgan.utils import sequence_mask
 class TextEncoder(nn.Module):
     def __init__(
         self,
-        n_vocab: int,
+        in_channels: int,
         out_channels: int,
         hidden_channels: int,
         hidden_channels_ffn: int,
@@ -23,11 +24,12 @@ class TextEncoder(nn.Module):
         dropout: float,
         gin_channels=0,
         speaker_cond_layer=0,
+        use_vae=True,
     ):
         """Text Encoder for VITS model.
 
         Args:
-            n_vocab (int): Number of characters for the embedding layer.
+            in_channels (int): Number of characters for the embedding layer.
             out_channels (int): Number of channels for the output.
             hidden_channels (int): Number of channels for the hidden layers.
             hidden_channels_ffn (int): Number of channels for the convolutional layers.
@@ -41,9 +43,7 @@ class TextEncoder(nn.Module):
         self.out_channels = out_channels
         self.hidden_channels = hidden_channels
 
-        # self.emb = nn.Linear(n_vocab, hidden_channels)
-        self.emb = nn.Linear(n_vocab, hidden_channels, 1)
-        # nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
+        self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
 
         self.encoder = RelativePositionTransformer(
             in_channels=hidden_channels,
@@ -58,12 +58,15 @@ class TextEncoder(nn.Module):
             gin_channels=gin_channels,
             speaker_cond_layer=speaker_cond_layer,
         )
-        self.proj = nn.Linear(hidden_channels, out_channels * 2)
+        self.proj_out = nn.Conv1d(
+            hidden_channels, out_channels * 2 if use_vae else out_channels, 1
+        )
+        self.use_vae = use_vae
 
     def forward(
         self,
         x: torch.Tensor,
-        x_lengths: torch.Tensor,
+        x_mask: torch.Tensor,
         g: torch.Tensor = None,
         noise_scale: float = 1,
     ):
@@ -72,14 +75,14 @@ class TextEncoder(nn.Module):
             - x: :math:`[B, T]`
             - x_length: :math:`[B]`
         """
-        # x = self.emb(x).mT * math.sqrt(self.hidden_channels)  # [b, h, t]
-        x = self.emb(x).mT  # * math.sqrt(self.hidden_channels)  # [b, h, t]
-        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
-
+        x = self.proj_in(x) * x_mask
         x = self.encoder(x, x_mask, g=g)
-        stats = self.proj(x.mT).mT * x_mask
+        x = self.proj_out(x) * x_mask
 
-        m, logs = torch.split(stats, self.out_channels, dim=1)
+        if self.use_vae is False:
+            return x
+
+        m, logs = torch.split(x, self.out_channels, dim=1)
         z = m + torch.randn_like(m) * torch.exp(logs) * x_mask * noise_scale
         return z, m, logs, x, x_mask
 
@@ -113,7 +116,7 @@ class PosteriorEncoder(nn.Module):
         super().__init__()
         self.out_channels = out_channels
 
-        self.pre = nn.Linear(in_channels, hidden_channels)
+        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
         self.enc = WN(
             hidden_channels,
             kernel_size,
@@ -121,7 +124,7 @@ class PosteriorEncoder(nn.Module):
             n_layers,
             gin_channels=gin_channels,
         )
-        self.proj = nn.Linear(hidden_channels, out_channels * 2)
+        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
 
     def forward(
         self,
@@ -137,9 +140,9 @@ class PosteriorEncoder(nn.Module):
             - g: :math:`[B, C, 1]`
         """
         x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
-        x = self.pre(x.mT).mT * x_mask
+        x = self.pre(x) * x_mask
         x = self.enc(x, x_mask, g=g)
-        stats = self.proj(x.mT).mT * x_mask
+        stats = self.proj(x) * x_mask
         m, logs = torch.split(stats, self.out_channels, dim=1)
         z = m + torch.randn_like(m) * torch.exp(logs) * x_mask * noise_scale
         return z, m, logs, x_mask
@@ -180,22 +183,19 @@ class SpeakerEncoder(nn.Module):
         )
         self.out_proj = nn.Linear(hidden_channels, out_channels)
 
-    def forward(self, mels, mel_lengths: torch.Tensor):
+    def forward(self, mels, mel_masks: torch.Tensor):
         """
         Shapes:
             - x: :math:`[B, C, T]`
             - x_lengths: :math:`[B, 1]`
         """
 
-        x_mask = torch.unsqueeze(sequence_mask(mel_lengths, mels.size(2)), 1).to(
-            mels.dtype
-        )
-        x = self.in_proj(mels) * x_mask
-        x = self.encoder(x, x_mask)
+        x = self.in_proj(mels) * mel_masks
+        x = self.encoder(x, mel_masks)
 
         # Avg Pooling
-        x = x * x_mask
-        x = torch.sum(x, dim=2) / torch.sum(x_mask, dim=2)
+        x = x * mel_masks
+        x = torch.sum(x, dim=2) / torch.sum(mel_masks, dim=2)
         x = self.out_proj(x)[..., None]
 
         return x
@@ -219,7 +219,7 @@ class VQEncoder(nn.Module):
             kmeans_init=False,
             channel_last=False,
         )
-
+        self.downsample = downsample
         self.conv_in = nn.Conv1d(
             in_channels, vq_channels, kernel_size=downsample, stride=downsample
         )
@@ -253,10 +253,17 @@ class VQEncoder(nn.Module):
 
         self.vq.load_state_dict(state_dict, strict=True)
 
-    def forward(self, x):
-        # x: [B, T, C]
-        x = self.conv_in(x.mT)
+    def forward(self, x, x_mask):
+        # x: [B, C, T], x_mask: [B, 1, T]
+        x_len = x.shape[2]
+
+        if x_len % self.downsample != 0:
+            x = F.pad(x, (0, self.downsample - x_len % self.downsample))
+            x_mask = F.pad(x_mask, (0, self.downsample - x_len % self.downsample))
+
+        x = self.conv_in(x)
         q, _, loss = self.vq(x)
-        x = self.conv_out(q).mT
+        x = self.conv_out(q) * x_mask
+        x = x[:, :, :x_len]
 
         return x, loss

+ 5 - 5
fish_speech/models/vqgan/modules/models.py

@@ -104,12 +104,12 @@ class SynthesizerTrn(nn.Module):
             gin_channels=gin_channels,
         )
 
-    def forward(self, x, x_lengths, y):
-        g = self.enc_spk(y, x_lengths)
+    def forward(self, x, x_lengths, specs):
+        g = self.enc_spk(specs, x_lengths)
         x, vq_loss = self.vq(x)
 
         _, m_p, logs_p, _, x_mask = self.enc_p(x, x_lengths, g=g)
-        z_q, m_q, logs_q, y_mask = self.enc_q(y, x_lengths, g=g)
+        z_q, m_q, logs_q, y_mask = self.enc_q(specs, x_lengths, g=g)
         z_p = self.flow(z_q, y_mask, g=g, reverse=False)
 
         z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
@@ -126,8 +126,8 @@ class SynthesizerTrn(nn.Module):
             vq_loss,
         )
 
-    def infer(self, x, x_lengths, y, max_len=None, noise_scale=0.35):
-        g = self.enc_spk(y, x_lengths)
+    def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
+        g = self.enc_spk(specs, x_lengths)
         x, vq_loss = self.vq(x)
         z_p, m_p, logs_p, h_text, x_mask = self.enc_p(
             x, x_lengths, g=g, noise_scale=noise_scale

+ 4 - 4
fish_speech/models/vqgan/modules/modules.py

@@ -32,7 +32,7 @@ class WN(nn.Module):
         self.drop = nn.Dropout(p_dropout)
 
         if gin_channels != 0:
-            cond_layer = nn.Linear(gin_channels, 2 * hidden_channels * n_layers)
+            cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
             self.cond_layer = weight_norm(cond_layer, name="weight")
 
         for i in range(n_layers):
@@ -52,7 +52,7 @@ class WN(nn.Module):
             res_skip_channels = (
                 2 * hidden_channels if i < n_layers - 1 else hidden_channels
             )
-            res_skip_layer = nn.Linear(hidden_channels, res_skip_channels)
+            res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1)
             res_skip_layer = weight_norm(res_skip_layer, name="weight")
             self.res_skip_layers.append(res_skip_layer)
 
@@ -61,7 +61,7 @@ class WN(nn.Module):
         n_channels_tensor = torch.IntTensor([self.hidden_channels])
 
         if g is not None:
-            g = self.cond_layer(g.mT).mT
+            g = self.cond_layer(g)
 
         for i in range(self.n_layers):
             x_in = self.in_layers[i](x)
@@ -74,7 +74,7 @@ class WN(nn.Module):
             acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
             acts = self.drop(acts)
 
-            res_skip_acts = self.res_skip_layers[i](acts.mT).mT
+            res_skip_acts = self.res_skip_layers[i](acts)
             if i < self.n_layers - 1:
                 res_acts = res_skip_acts[:, : self.hidden_channels, :]
                 x = (x + res_acts) * x_mask

+ 1 - 0
pyproject.toml

@@ -26,6 +26,7 @@ dependencies = [
     "vector-quantize-pytorch>=1.10.0",
     "rich>=13.5.3",
     "gradio>=4.0.0",
+    "diffusers@git+https://github.com/huggingface/diffusers",
     "cn2an",
     "pypinyin",
     "jieba",

+ 13 - 43
tools/vqgan/migrate_from_vits.py

@@ -27,22 +27,23 @@ def main(cfg: DictConfig):
 
     # Decoder
     generator_state = {
-        k[4:]: v for k, v in generator_weights.items() if k.startswith("dec.")
+        k[4:]: v
+        for k, v in generator_weights.items()
+        if k.startswith("dec.") and not k.startswith("dec.cond.")
     }
     logger.info(f"Found {len(generator_state)} HiFiGAN weights, restoring...")
-    model.generator.load_state_dict(generator_state, strict=True)
-    logger.info("Generator weights restored.")
+    r = model.generator.dec.load_state_dict(generator_state, strict=False)
+    logger.info(f"Generator weights restored. {r}")
 
     # Posterior Encoder
-    encoder_state = {
-        k[6:]: v
-        for k, v in generator_weights.items()
-        if k.startswith("enc_q.")
-        if not k.startswith("enc_q.proj.")
-    }
-    logger.info(f"Found {len(encoder_state)} posterior encoder weights, restoring...")
-    x = model.posterior_encoder.load_state_dict(encoder_state, strict=False)
-    logger.info(f"Posterior encoder weights restored. {x}")
+    # encoder_state = {
+    #     k[6:]: v
+    #     for k, v in generator_weights.items()
+    #     if k.startswith("enc_q.") and not k.startswith("enc_q.proj.")
+    # }
+    # logger.info(f"Found {len(encoder_state)} posterior encoder weights, restoring...")
+    # x = model.generator.enc_q.load_state_dict(encoder_state, strict=False)
+    # logger.info(f"Posterior encoder weights restored. {x}")
 
     # Flow
     # flow_state = {
@@ -61,37 +62,6 @@ def main(cfg: DictConfig):
     model.discriminator.load_state_dict(discriminator_weights, strict=True)
     logger.info("Discriminator weights restored.")
 
-    # Restore kmeans
-    logger.info("Reset vq projection layer to mimic avg pooling")
-    torch.nn.init.normal_(
-        model.semantic_encoder.in_proj.weight,
-        mean=1
-        / (
-            model.semantic_encoder.in_proj.weight.shape[0]
-            * model.semantic_encoder.in_proj.weight.shape[-1]
-        ),
-        std=1e-2,
-    )
-    model.semantic_encoder.in_proj.bias.data.zero_()
-
-    kmeans_ckpt = "results/hubert-vq-pretrain/kmeans.pt"
-    kmeans_ckpt = torch.load(kmeans_ckpt, map_location="cpu")
-
-    centroids = kmeans_ckpt["centroids"][0]
-    bins = kmeans_ckpt["bins"][0]
-    logger.info(
-        f"Restoring kmeans centroids with shape {centroids.shape} and bins {bins.shape}"
-    )
-
-    state_dict = {
-        "_codebook.inited": torch.Tensor([True]),
-        "_codebook.cluster_size": bins,
-        "_codebook.embed": centroids,
-        "_codebook.embed_avg": centroids.clone(),
-    }
-
-    model.semantic_encoder.vq.load_state_dict(state_dict, strict=True)
-
     torch.save(model.state_dict(), cfg.ckpt_path)
     logger.info("Done")