소스 검색

rm diffusion code

Lengyue 2 년 전
부모
커밋
e8a6123109

+ 0 - 137
fish_speech/configs/hubert_vq_diffusion.yaml

@@ -1,137 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: hubert_vq_diffusion
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: 4
-  strategy: ddp_find_unused_parameters_true
-  gradient_clip_val: 1.0
-  gradient_clip_algorithm: 'norm'
-  precision: 16-mixed
-  max_steps: 1_000_000
-  val_check_interval: 5000
-
-sample_rate: 44100
-hop_length: 512
-num_mels: 128
-n_fft: 2048
-win_length: 2048
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/filelist.split.train
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  slice_frames: 512
-
-val_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/filelist.split.valid
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-data:
-  _target_: fish_speech.datasets.vqgan.VQGANDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 8
-  batch_size: 32
-  val_batch_size: 4
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vq_diffusion.lit_module.VQDiffusion
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-  text_encoder:
-    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
-    in_channels: 128
-    out_channels: 128
-    hidden_channels: 192
-    hidden_channels_ffn: 768
-    n_heads: 2
-    n_layers: 6
-    kernel_size: 1
-    dropout: 0.1
-    use_vae: false
-    gin_channels: 512
-    speaker_cond_layer: 0
-
-  vq_encoder:
-    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
-    in_channels: 128
-    vq_channels: 128
-    codebook_size: 4096
-    downsample: 1
-
-  speaker_encoder:
-    _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
-    in_channels: 128
-    hidden_channels: 192
-    out_channels: 128
-    num_heads: 2
-    num_layers: 4
-    p_dropout: 0.1
-  
-  denoiser:
-    _target_: fish_speech.models.vq_diffusion.convnext_1d.ConvNext1DModel
-    in_channels: 256
-    out_channels: 128
-    intermediate_dim: 512
-    # condition_dim: 128
-    mlp_dim: 2048
-    num_layers: 20
-    dilation_cycle_length: 2
-    time_embedding_type: "positional"
-
-  vocoder:
-    _target_: fish_speech.models.vq_diffusion.adamos.ADaMoSHiFiGANV1
-
-  mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-    f_min: 40
-    f_max: 16000
-
-  feature_mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
-    sample_rate: 32000
-    n_fft: 2048
-    hop_length: 640
-    win_length: 2048
-    n_mels: 128
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 1e-4
-    betas: [0.9, 0.999]
-    eps: 1e-5
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.LambdaLR
-    _partial_: true
-    lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-      _partial_: true
-      num_warmup_steps: 0
-      num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0.05
-
-callbacks:
-  grad_norm_monitor:
-    sub_module: 
-      - vq_encoder
-      - text_encoder
-      - speaker_encoder
-      - denoiser

+ 0 - 157
fish_speech/configs/vq_diffusion.yaml

@@ -1,157 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: vq_naive
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: 4
-  strategy: ddp_find_unused_parameters_true
-  gradient_clip_val: 1.0
-  gradient_clip_algorithm: 'norm'
-  precision: bf16-mixed
-  max_steps: 300_000
-  val_check_interval: 5000
-
-sample_rate: 24000
-hop_length: 256
-num_mels: 100
-n_fft: 1024
-win_length: 1024
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/filelist.split.train
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  slice_frames: 512
-
-val_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/filelist.split.valid
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-data:
-  _target_: fish_speech.datasets.vqgan.VQGANDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 8
-  batch_size: 32
-  val_batch_size: 4
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vq_diffusion.lit_module.VQDiffusion
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  speaker_use_feats: true
-
-  downsample:
-    _target_: fish_speech.models.vq_diffusion.lit_module.ConvDownSample
-    dims: [128, 512, 128]
-    kernel_sizes: [3, 3]
-    strides: [2, 2]
-
-  text_encoder:
-    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
-    in_channels: 128
-    out_channels: 128
-    hidden_channels: 192
-    hidden_channels_ffn: 768
-    n_heads: 2
-    n_layers: 6
-    kernel_size: 1
-    dropout: 0.1
-    use_vae: false
-    gin_channels: 512
-    speaker_cond_layer: 0
-
-  vq_encoder:
-    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
-    in_channels: 128
-    vq_channels: 128
-    codebook_size: 4096
-    downsample: 1
-
-  speaker_encoder:
-    _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
-    in_channels: 128
-    hidden_channels: 192
-    out_channels: 128
-    num_heads: 2
-    num_layers: 4
-    p_dropout: 0.1
-  
-  decoder:
-    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
-    in_channels: 128
-    out_channels: 100
-    hidden_channels: 192
-    hidden_channels_ffn: 768
-    n_heads: 2
-    n_layers: 6
-    kernel_size: 1
-    use_vae: false
-    dropout: 0
-    gin_channels: 128
-    speaker_cond_layer: 0
-
-  postnet:
-    _target_: fish_speech.models.vq_diffusion.convnext_1d.ConvNext1DModel
-    in_channels: 100
-    out_channels: 100
-    intermediate_dim: 256
-    mlp_dim: 1024
-    num_layers: 6
-    dilation_cycle_length: 2
-
-  vocoder:
-    _target_: fish_speech.models.vq_diffusion.bigvgan.BigVGAN
-
-  mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-    f_min: 0
-    f_max: 12000
-
-  feature_mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
-    sample_rate: 32000
-    n_fft: 2048
-    hop_length: 320
-    win_length: 2048
-    n_mels: 128
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 1e-4
-    betas: [0.9, 0.999]
-    eps: 1e-5
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.LambdaLR
-    _partial_: true
-    lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-      _partial_: true
-      num_warmup_steps: 0
-      num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0.05
-
-callbacks:
-  grad_norm_monitor:
-    sub_module: 
-      - vq_encoder
-      - text_encoder
-      - speaker_encoder
-      - decoder
-      - postnet

+ 0 - 3
fish_speech/models/vq_diffusion/adamos/__init__.py

@@ -1,3 +0,0 @@
-from .adamos import ADaMoSHiFiGANV1
-
-__all__ = ["ADaMoSHiFiGANV1"]

+ 0 - 88
fish_speech/models/vq_diffusion/adamos/adamos.py

@@ -1,88 +0,0 @@
-import librosa
-import torch
-from torch import nn
-
-from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
-
-from .encoder import ConvNeXtEncoder
-from .hifigan import HiFiGANGenerator
-
-
-class ADaMoSHiFiGANV1(nn.Module):
-    def __init__(
-        self,
-        checkpoint_path: str = "checkpoints/adamos-generator-1640000.pth",
-    ):
-        super().__init__()
-
-        self.backbone = ConvNeXtEncoder(
-            input_channels=128,
-            depths=[3, 3, 9, 3],
-            dims=[128, 256, 384, 512],
-            drop_path_rate=0,
-            kernel_sizes=(7,),
-        )
-
-        self.head = HiFiGANGenerator(
-            hop_length=512,
-            upsample_rates=(4, 4, 2, 2, 2, 2, 2),
-            upsample_kernel_sizes=(8, 8, 4, 4, 4, 4, 4),
-            resblock_kernel_sizes=(3, 7, 11, 13),
-            resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
-            num_mels=512,
-            upsample_initial_channel=1024,
-            use_template=False,
-            pre_conv_kernel_size=13,
-            post_conv_kernel_size=13,
-        )
-        self.sampling_rate = 44100
-
-        ckpt_state = torch.load(checkpoint_path, map_location="cpu")
-
-        if "state_dict" in ckpt_state:
-            ckpt_state = ckpt_state["state_dict"]
-
-        if any(k.startswith("generator.") for k in ckpt_state):
-            ckpt_state = {
-                k.replace("generator.", ""): v
-                for k, v in ckpt_state.items()
-                if k.startswith("generator.")
-            }
-
-        self.load_state_dict(ckpt_state)
-        self.eval()
-
-        self.mel_transform = LogMelSpectrogram(
-            sample_rate=44100,
-            n_fft=2048,
-            win_length=2048,
-            hop_length=512,
-            f_min=40,
-            f_max=16000,
-            n_mels=128,
-        )
-
-    @torch.no_grad()
-    def decode(self, mel):
-        y = self.backbone(mel)
-        y = self.head(y)
-
-        return y
-
-    @torch.no_grad()
-    def encode(self, x):
-        return self.mel_transform(x)
-
-
-if __name__ == "__main__":
-    import soundfile as sf
-
-    x = "data/StarRail/Chinese/罗刹/archive_luocha_2.wav"
-    model = ADaMoSHiFiGANV1()
-
-    wav, sr = librosa.load(x, sr=44100, mono=True)
-    wav = torch.from_numpy(wav).float()[None]
-    mel = model.encode(wav)
-
-    wav = model.decode(mel)[0].mT
-    sf.write("test.wav", wav.cpu().numpy(), 44100)

+ 0 - 238
fish_speech/models/vq_diffusion/adamos/encoder.py

@@ -1,238 +0,0 @@
-from functools import partial
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-
-def drop_path(
-    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
-):
-    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
-
-    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
-    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
-    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
-    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
-    'survival rate' as the argument.
-
-    """  # noqa: E501
-
-    if drop_prob == 0.0 or not training:
-        return x
-    keep_prob = 1 - drop_prob
-    shape = (x.shape[0],) + (1,) * (
-        x.ndim - 1
-    )  # work with diff dim tensors, not just 2D ConvNets
-    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
-    if keep_prob > 0.0 and scale_by_keep:
-        random_tensor.div_(keep_prob)
-    return x * random_tensor
-
-
-class DropPath(nn.Module):
-    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""  # noqa: E501
-
-    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
-        super(DropPath, self).__init__()
-        self.drop_prob = drop_prob
-        self.scale_by_keep = scale_by_keep
-
-    def forward(self, x):
-        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
-
-    def extra_repr(self):
-        return f"drop_prob={round(self.drop_prob,3):0.3f}"
-
-
-class LayerNorm(nn.Module):
-    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
-    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
-    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
-    with shape (batch_size, channels, height, width).
-    """  # noqa: E501
-
-    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
-        super().__init__()
-        self.weight = nn.Parameter(torch.ones(normalized_shape))
-        self.bias = nn.Parameter(torch.zeros(normalized_shape))
-        self.eps = eps
-        self.data_format = data_format
-        if self.data_format not in ["channels_last", "channels_first"]:
-            raise NotImplementedError
-        self.normalized_shape = (normalized_shape,)
-
-    def forward(self, x):
-        if self.data_format == "channels_last":
-            return F.layer_norm(
-                x, self.normalized_shape, self.weight, self.bias, self.eps
-            )
-        elif self.data_format == "channels_first":
-            u = x.mean(1, keepdim=True)
-            s = (x - u).pow(2).mean(1, keepdim=True)
-            x = (x - u) / torch.sqrt(s + self.eps)
-            x = self.weight[:, None] * x + self.bias[:, None]
-            return x
-
-
-class ConvNeXtBlock(nn.Module):
-    r"""ConvNeXt Block. There are two equivalent implementations:
-    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
-    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
-    We use (2) as we find it slightly faster in PyTorch
-
-    Args:
-        dim (int): Number of input channels.
-        drop_path (float): Stochastic depth rate. Default: 0.0
-        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
-        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
-        kernel_size (int): Kernel size for depthwise conv. Default: 7.
-        dilation (int): Dilation for depthwise conv. Default: 1.
-    """  # noqa: E501
-
-    def __init__(
-        self,
-        dim: int,
-        drop_path: float = 0.0,
-        layer_scale_init_value: float = 1e-6,
-        mlp_ratio: float = 4.0,
-        kernel_size: int = 7,
-        dilation: int = 1,
-    ):
-        super().__init__()
-
-        self.dwconv = nn.Conv1d(
-            dim,
-            dim,
-            kernel_size=kernel_size,
-            padding=int(dilation * (kernel_size - 1) / 2),
-            groups=dim,
-        )  # depthwise conv
-        self.norm = LayerNorm(dim, eps=1e-6)
-        self.pwconv1 = nn.Linear(
-            dim, int(mlp_ratio * dim)
-        )  # pointwise/1x1 convs, implemented with linear layers
-        self.act = nn.GELU()
-        self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
-        self.gamma = (
-            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
-            if layer_scale_init_value > 0
-            else None
-        )
-        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
-    def forward(self, x, apply_residual: bool = True):
-        input = x
-
-        x = self.dwconv(x)
-        x = x.permute(0, 2, 1)  # (N, C, L) -> (N, L, C)
-        x = self.norm(x)
-        x = self.pwconv1(x)
-        x = self.act(x)
-        x = self.pwconv2(x)
-
-        if self.gamma is not None:
-            x = self.gamma * x
-
-        x = x.permute(0, 2, 1)  # (N, L, C) -> (N, C, L)
-        x = self.drop_path(x)
-
-        if apply_residual:
-            x = input + x
-
-        return x
-
-
-class ParallelConvNeXtBlock(nn.Module):
-    def __init__(self, kernel_sizes: list[int], *args, **kwargs):
-        super().__init__()
-        self.blocks = nn.ModuleList(
-            [
-                ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
-                for kernel_size in kernel_sizes
-            ]
-        )
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        return torch.stack(
-            [block(x, apply_residual=False) for block in self.blocks] + [x],
-            dim=1,
-        ).sum(dim=1)
-
-
-class ConvNeXtEncoder(nn.Module):
-    def __init__(
-        self,
-        input_channels=3,
-        depths=[3, 3, 9, 3],
-        dims=[96, 192, 384, 768],
-        drop_path_rate=0.0,
-        layer_scale_init_value=1e-6,
-        kernel_sizes: tuple[int] = (7,),
-    ):
-        super().__init__()
-        assert len(depths) == len(dims)
-
-        self.channel_layers = nn.ModuleList()
-        stem = nn.Sequential(
-            nn.Conv1d(
-                input_channels,
-                dims[0],
-                kernel_size=7,
-                padding=3,
-                padding_mode="replicate",
-            ),
-            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
-        )
-        self.channel_layers.append(stem)
-
-        for i in range(len(depths) - 1):
-            mid_layer = nn.Sequential(
-                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
-                nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
-            )
-            self.channel_layers.append(mid_layer)
-
-        block_fn = (
-            partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
-            if len(kernel_sizes) == 1
-            else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
-        )
-
-        self.stages = nn.ModuleList()
-        drop_path_rates = [
-            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
-        ]
-
-        cur = 0
-        for i in range(len(depths)):
-            stage = nn.Sequential(
-                *[
-                    block_fn(
-                        dim=dims[i],
-                        drop_path=drop_path_rates[cur + j],
-                        layer_scale_init_value=layer_scale_init_value,
-                    )
-                    for j in range(depths[i])
-                ]
-            )
-            self.stages.append(stage)
-            cur += depths[i]
-
-        self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
-        self.apply(self._init_weights)
-
-    def _init_weights(self, m):
-        if isinstance(m, (nn.Conv1d, nn.Linear)):
-            nn.init.trunc_normal_(m.weight, std=0.02)
-            nn.init.constant_(m.bias, 0)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-    ) -> torch.Tensor:
-        for channel_layer, stage in zip(self.channel_layers, self.stages):
-            x = channel_layer(x)
-            x = stage(x)
-
-        return self.norm(x)

+ 0 - 237
fish_speech/models/vq_diffusion/adamos/hifigan.py

@@ -1,237 +0,0 @@
-from functools import partial
-from math import prod
-from typing import Callable
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn import Conv1d
-from torch.nn.utils.parametrizations import weight_norm
-from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
-
-
-def init_weights(m, mean=0.0, std=0.01):
-    classname = m.__class__.__name__
-    if classname.find("Conv") != -1:
-        m.weight.data.normal_(mean, std)
-
-
-def get_padding(kernel_size, dilation=1):
-    return (kernel_size * dilation - dilation) // 2
-
-
-class ResBlock1(torch.nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
-        super().__init__()
-
-        self.convs1 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[2],
-                        padding=get_padding(kernel_size, dilation[2]),
-                    )
-                ),
-            ]
-        )
-        self.convs1.apply(init_weights)
-
-        self.convs2 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-            ]
-        )
-        self.convs2.apply(init_weights)
-
-    def forward(self, x):
-        for c1, c2 in zip(self.convs1, self.convs2):
-            xt = F.silu(x)
-            xt = c1(xt)
-            xt = F.silu(xt)
-            xt = c2(xt)
-            x = xt + x
-        return x
-
-    def remove_weight_norm(self):
-        for conv in self.convs1:
-            remove_weight_norm(conv)
-        for conv in self.convs2:
-            remove_weight_norm(conv)
-
-
-class HiFiGANGenerator(nn.Module):
-    def __init__(
-        self,
-        *,
-        hop_length: int = 512,
-        upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
-        upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
-        resblock_kernel_sizes: tuple[int] = (3, 7, 11),
-        resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
-        num_mels: int = 128,
-        upsample_initial_channel: int = 512,
-        use_template: bool = True,
-        pre_conv_kernel_size: int = 7,
-        post_conv_kernel_size: int = 7,
-        post_activation: Callable = partial(nn.SiLU, inplace=True),
-    ):
-        super().__init__()
-
-        assert (
-            prod(upsample_rates) == hop_length
-        ), f"hop_length must be {prod(upsample_rates)}"
-
-        self.conv_pre = weight_norm(
-            nn.Conv1d(
-                num_mels,
-                upsample_initial_channel,
-                pre_conv_kernel_size,
-                1,
-                padding=get_padding(pre_conv_kernel_size),
-            )
-        )
-
-        self.num_upsamples = len(upsample_rates)
-        self.num_kernels = len(resblock_kernel_sizes)
-
-        self.noise_convs = nn.ModuleList()
-        self.use_template = use_template
-        self.ups = nn.ModuleList()
-
-        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
-            c_cur = upsample_initial_channel // (2 ** (i + 1))
-            self.ups.append(
-                weight_norm(
-                    nn.ConvTranspose1d(
-                        upsample_initial_channel // (2**i),
-                        upsample_initial_channel // (2 ** (i + 1)),
-                        k,
-                        u,
-                        padding=(k - u) // 2,
-                    )
-                )
-            )
-
-            if not use_template:
-                continue
-
-            if i + 1 < len(upsample_rates):
-                stride_f0 = np.prod(upsample_rates[i + 1 :])
-                self.noise_convs.append(
-                    Conv1d(
-                        1,
-                        c_cur,
-                        kernel_size=stride_f0 * 2,
-                        stride=stride_f0,
-                        padding=stride_f0 // 2,
-                    )
-                )
-            else:
-                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
-
-        self.resblocks = nn.ModuleList()
-        for i in range(len(self.ups)):
-            ch = upsample_initial_channel // (2 ** (i + 1))
-            for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
-                self.resblocks.append(ResBlock1(ch, k, d))
-
-        self.activation_post = post_activation()
-        self.conv_post = weight_norm(
-            nn.Conv1d(
-                ch,
-                1,
-                post_conv_kernel_size,
-                1,
-                padding=get_padding(post_conv_kernel_size),
-            )
-        )
-        self.ups.apply(init_weights)
-        self.conv_post.apply(init_weights)
-
-    def forward(self, x, template=None):
-        x = self.conv_pre(x)
-
-        for i in range(self.num_upsamples):
-            x = F.silu(x, inplace=True)
-            x = self.ups[i](x)
-
-            if self.use_template:
-                x = x + self.noise_convs[i](template)
-
-            xs = None
-
-            for j in range(self.num_kernels):
-                if xs is None:
-                    xs = self.resblocks[i * self.num_kernels + j](x)
-                else:
-                    xs += self.resblocks[i * self.num_kernels + j](x)
-
-            x = xs / self.num_kernels
-
-        x = self.activation_post(x)
-        x = self.conv_post(x)
-        x = torch.tanh(x)
-
-        return x
-
-    def remove_weight_norm(self):
-        for up in self.ups:
-            remove_weight_norm(up)
-        for block in self.resblocks:
-            block.remove_weight_norm()
-        remove_weight_norm(self.conv_pre)
-        remove_weight_norm(self.conv_post)

+ 0 - 3
fish_speech/models/vq_diffusion/bigvgan/__init__.py

@@ -1,3 +0,0 @@
-from .bigvgan import BigVGAN
-
-__all__ = ["BigVGAN"]

+ 0 - 126
fish_speech/models/vq_diffusion/bigvgan/activations.py

@@ -1,126 +0,0 @@
-# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
-#   LICENSE is in incl_licenses directory.
-
-import torch
-from torch import nn, pow, sin
-from torch.nn import Parameter
-
-
-class Snake(nn.Module):
-    """
-    Implementation of a sine-based periodic activation function
-    Shape:
-        - Input: (B, C, T)
-        - Output: (B, C, T), same shape as the input
-    Parameters:
-        - alpha - trainable parameter
-    References:
-        - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
-        https://arxiv.org/abs/2006.08195
-    Examples:
-        >>> a1 = snake(256)
-        >>> x = torch.randn(256)
-        >>> x = a1(x)
-    """
-
-    def __init__(
-        self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
-    ):
-        """
-        Initialization.
-        INPUT:
-            - in_features: shape of the input
-            - alpha: trainable parameter
-            alpha is initialized to 1 by default, higher values = higher-frequency.
-            alpha will be trained along with the rest of your model.
-        """
-        super(Snake, self).__init__()
-        self.in_features = in_features
-
-        # initialize alpha
-        self.alpha_logscale = alpha_logscale
-        if self.alpha_logscale:  # log scale alphas initialized to zeros
-            self.alpha = Parameter(torch.zeros(in_features) * alpha)
-        else:  # linear scale alphas initialized to ones
-            self.alpha = Parameter(torch.ones(in_features) * alpha)
-
-        self.alpha.requires_grad = alpha_trainable
-
-        self.no_div_by_zero = 0.000000001
-
-    def forward(self, x):
-        """
-        Forward pass of the function.
-        Applies the function to the input elementwise.
-        Snake ∶= x + 1/a * sin^2 (xa)
-        """
-        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # line up with x to [B, C, T]
-        if self.alpha_logscale:
-            alpha = torch.exp(alpha)
-        x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
-
-        return x
-
-
-class SnakeBeta(nn.Module):
-    """
-    A modified Snake function which uses separate parameters for the magnitude of the periodic components
-    Shape:
-        - Input: (B, C, T)
-        - Output: (B, C, T), same shape as the input
-    Parameters:
-        - alpha - trainable parameter that controls frequency
-        - beta - trainable parameter that controls magnitude
-    References:
-        - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
-        https://arxiv.org/abs/2006.08195
-    Examples:
-        >>> a1 = snakebeta(256)
-        >>> x = torch.randn(256)
-        >>> x = a1(x)
-    """
-
-    def __init__(
-        self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
-    ):
-        """
-        Initialization.
-        INPUT:
-            - in_features: shape of the input
-            - alpha - trainable parameter that controls frequency
-            - beta - trainable parameter that controls magnitude
-            alpha is initialized to 1 by default, higher values = higher-frequency.
-            beta is initialized to 1 by default, higher values = higher-magnitude.
-            alpha will be trained along with the rest of your model.
-        """
-        super(SnakeBeta, self).__init__()
-        self.in_features = in_features
-
-        # initialize alpha
-        self.alpha_logscale = alpha_logscale
-        if self.alpha_logscale:  # log scale alphas initialized to zeros
-            self.alpha = Parameter(torch.zeros(in_features) * alpha)
-            self.beta = Parameter(torch.zeros(in_features) * alpha)
-        else:  # linear scale alphas initialized to ones
-            self.alpha = Parameter(torch.ones(in_features) * alpha)
-            self.beta = Parameter(torch.ones(in_features) * alpha)
-
-        self.alpha.requires_grad = alpha_trainable
-        self.beta.requires_grad = alpha_trainable
-
-        self.no_div_by_zero = 0.000000001
-
-    def forward(self, x):
-        """
-        Forward pass of the function.
-        Applies the function to the input elementwise.
-        SnakeBeta ∶= x + 1/b * sin^2 (xa)
-        """
-        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)  # line up with x to [B, C, T]
-        beta = self.beta.unsqueeze(0).unsqueeze(-1)
-        if self.alpha_logscale:
-            alpha = torch.exp(alpha)
-            beta = torch.exp(beta)
-        x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
-
-        return x

+ 0 - 6
fish_speech/models/vq_diffusion/bigvgan/alias_free_torch/__init__.py

@@ -1,6 +0,0 @@
-# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
-#   LICENSE is in incl_licenses directory.
-
-from .act import *
-from .filter import *
-from .resample import *

+ 0 - 31
fish_speech/models/vq_diffusion/bigvgan/alias_free_torch/act.py

@@ -1,31 +0,0 @@
-# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
-#   LICENSE is in incl_licenses directory.
-
-import torch.nn as nn
-
-from .resample import DownSample1d, UpSample1d
-
-
-class Activation1d(nn.Module):
-    def __init__(
-        self,
-        activation,
-        up_ratio: int = 2,
-        down_ratio: int = 2,
-        up_kernel_size: int = 12,
-        down_kernel_size: int = 12,
-    ):
-        super().__init__()
-        self.up_ratio = up_ratio
-        self.down_ratio = down_ratio
-        self.act = activation
-        self.upsample = UpSample1d(up_ratio, up_kernel_size)
-        self.downsample = DownSample1d(down_ratio, down_kernel_size)
-
-    # x: [B,C,T]
-    def forward(self, x):
-        x = self.upsample(x)
-        x = self.act(x)
-        x = self.downsample(x)
-
-        return x

+ 0 - 100
fish_speech/models/vq_diffusion/bigvgan/alias_free_torch/filter.py

@@ -1,100 +0,0 @@
-# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
-#   LICENSE is in incl_licenses directory.
-
-import math
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-if "sinc" in dir(torch):
-    sinc = torch.sinc
-else:
-    # This code is adopted from adefossez's julius.core.sinc under the MIT License
-    # https://adefossez.github.io/julius/julius/core.html
-    #   LICENSE is in incl_licenses directory.
-    def sinc(x: torch.Tensor):
-        """
-        Implementation of sinc, i.e. sin(pi * x) / (pi * x)
-        __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
-        """
-        return torch.where(
-            x == 0,
-            torch.tensor(1.0, device=x.device, dtype=x.dtype),
-            torch.sin(math.pi * x) / math.pi / x,
-        )
-
-
-# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
-# https://adefossez.github.io/julius/julius/lowpass.html
-#   LICENSE is in incl_licenses directory.
-def kaiser_sinc_filter1d(
-    cutoff, half_width, kernel_size
-):  # return filter [1,1,kernel_size]
-    even = kernel_size % 2 == 0
-    half_size = kernel_size // 2
-
-    # For kaiser window
-    delta_f = 4 * half_width
-    A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
-    if A > 50.0:
-        beta = 0.1102 * (A - 8.7)
-    elif A >= 21.0:
-        beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
-    else:
-        beta = 0.0
-    window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
-
-    # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
-    if even:
-        time = torch.arange(-half_size, half_size) + 0.5
-    else:
-        time = torch.arange(kernel_size) - half_size
-    if cutoff == 0:
-        filter_ = torch.zeros_like(time)
-    else:
-        filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
-        # Normalize filter to have sum = 1, otherwise we will have a small leakage
-        # of the constant component in the input signal.
-        filter_ /= filter_.sum()
-        filter = filter_.view(1, 1, kernel_size)
-
-    return filter
-
-
-class LowPassFilter1d(nn.Module):
-    def __init__(
-        self,
-        cutoff=0.5,
-        half_width=0.6,
-        stride: int = 1,
-        padding: bool = True,
-        padding_mode: str = "replicate",
-        kernel_size: int = 12,
-    ):
-        # kernel_size should be even number for stylegan3 setup,
-        # in this implementation, odd number is also possible.
-        super().__init__()
-        if cutoff < -0.0:
-            raise ValueError("Minimum cutoff must be larger than zero.")
-        if cutoff > 0.5:
-            raise ValueError("A cutoff above 0.5 does not make sense.")
-        self.kernel_size = kernel_size
-        self.even = kernel_size % 2 == 0
-        self.pad_left = kernel_size // 2 - int(self.even)
-        self.pad_right = kernel_size // 2
-        self.stride = stride
-        self.padding = padding
-        self.padding_mode = padding_mode
-        filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
-        self.register_buffer("filter", filter)
-
-    # input [B, C, T]
-    def forward(self, x):
-        _, C, _ = x.shape
-
-        if self.padding:
-            x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
-        out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
-
-        return out

+ 0 - 58
fish_speech/models/vq_diffusion/bigvgan/alias_free_torch/resample.py

@@ -1,58 +0,0 @@
-# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
-#   LICENSE is in incl_licenses directory.
-
-import torch.nn as nn
-from torch.nn import functional as F
-
-from .filter import LowPassFilter1d, kaiser_sinc_filter1d
-
-
-class UpSample1d(nn.Module):
-    def __init__(self, ratio=2, kernel_size=None):
-        super().__init__()
-        self.ratio = ratio
-        self.kernel_size = (
-            int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
-        )
-        self.stride = ratio
-        self.pad = self.kernel_size // ratio - 1
-        self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
-        self.pad_right = (
-            self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
-        )
-        filter = kaiser_sinc_filter1d(
-            cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
-        )
-        self.register_buffer("filter", filter)
-
-    # x: [B, C, T]
-    def forward(self, x):
-        _, C, _ = x.shape
-
-        x = F.pad(x, (self.pad, self.pad), mode="replicate")
-        x = self.ratio * F.conv_transpose1d(
-            x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
-        )
-        x = x[..., self.pad_left : -self.pad_right]
-
-        return x
-
-
-class DownSample1d(nn.Module):
-    def __init__(self, ratio=2, kernel_size=None):
-        super().__init__()
-        self.ratio = ratio
-        self.kernel_size = (
-            int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
-        )
-        self.lowpass = LowPassFilter1d(
-            cutoff=0.5 / ratio,
-            half_width=0.6 / ratio,
-            stride=ratio,
-            kernel_size=self.kernel_size,
-        )
-
-    def forward(self, x):
-        xx = self.lowpass(x)
-
-        return xx

+ 0 - 389
fish_speech/models/vq_diffusion/bigvgan/bigvgan.py

@@ -1,389 +0,0 @@
-# Copyright (c) 2022 NVIDIA CORPORATION.
-#   Licensed under the MIT license.
-
-# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
-#   LICENSE is in incl_licenses directory.
-
-
-import json
-from pathlib import Path
-from typing import Optional
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn import Conv1d, Conv2d, ConvTranspose1d
-from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
-
-from fish_speech.models.vq_diffusion.bigvgan.activations import Snake, SnakeBeta
-from fish_speech.models.vq_diffusion.bigvgan.alias_free_torch import Activation1d
-from fish_speech.models.vq_diffusion.bigvgan.utils import get_padding, init_weights
-from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
-
-LRELU_SLOPE = 0.1
-
-
-class AttrDict(dict):
-    def __init__(self, *args, **kwargs):
-        super(AttrDict, self).__init__(*args, **kwargs)
-        self.__dict__ = self
-
-
-class AMPBlock1(torch.nn.Module):
-    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
-        super(AMPBlock1, self).__init__()
-        self.h = h
-
-        self.convs1 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[2],
-                        padding=get_padding(kernel_size, dilation[2]),
-                    )
-                ),
-            ]
-        )
-        self.convs1.apply(init_weights)
-
-        self.convs2 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-            ]
-        )
-        self.convs2.apply(init_weights)
-
-        self.num_layers = len(self.convs1) + len(
-            self.convs2
-        )  # total number of conv layers
-
-        if (
-            activation == "snake"
-        ):  # periodic nonlinearity with snake function and anti-aliasing
-            self.activations = nn.ModuleList(
-                [
-                    Activation1d(
-                        activation=Snake(channels, alpha_logscale=h.snake_logscale)
-                    )
-                    for _ in range(self.num_layers)
-                ]
-            )
-        elif (
-            activation == "snakebeta"
-        ):  # periodic nonlinearity with snakebeta function and anti-aliasing
-            self.activations = nn.ModuleList(
-                [
-                    Activation1d(
-                        activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale)
-                    )
-                    for _ in range(self.num_layers)
-                ]
-            )
-        else:
-            raise NotImplementedError(
-                "activation incorrectly specified. check the config file and look for 'activation'."
-            )
-
-    def forward(self, x):
-        acts1, acts2 = self.activations[::2], self.activations[1::2]
-        for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
-            xt = a1(x)
-            xt = c1(xt)
-            xt = a2(xt)
-            xt = c2(xt)
-            x = xt + x
-
-        return x
-
-    def remove_weight_norm(self):
-        for l in self.convs1:
-            remove_weight_norm(l)
-        for l in self.convs2:
-            remove_weight_norm(l)
-
-
-class AMPBlock2(torch.nn.Module):
-    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
-        super(AMPBlock2, self).__init__()
-        self.h = h
-
-        self.convs = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-            ]
-        )
-        self.convs.apply(init_weights)
-
-        self.num_layers = len(self.convs)  # total number of conv layers
-
-        if (
-            activation == "snake"
-        ):  # periodic nonlinearity with snake function and anti-aliasing
-            self.activations = nn.ModuleList(
-                [
-                    Activation1d(
-                        activation=Snake(channels, alpha_logscale=h.snake_logscale)
-                    )
-                    for _ in range(self.num_layers)
-                ]
-            )
-        elif (
-            activation == "snakebeta"
-        ):  # periodic nonlinearity with snakebeta function and anti-aliasing
-            self.activations = nn.ModuleList(
-                [
-                    Activation1d(
-                        activation=SnakeBeta(channels, alpha_logscale=h.snake_logscale)
-                    )
-                    for _ in range(self.num_layers)
-                ]
-            )
-        else:
-            raise NotImplementedError(
-                "activation incorrectly specified. check the config file and look for 'activation'."
-            )
-
-    def forward(self, x):
-        for c, a in zip(self.convs, self.activations):
-            xt = a(x)
-            xt = c(xt)
-            x = xt + x
-
-        return x
-
-    def remove_weight_norm(self):
-        for l in self.convs:
-            remove_weight_norm(l)
-
-
-class BigVGANModule(torch.nn.Module):
-    # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
-    def __init__(self, h):
-        super(BigVGANModule, self).__init__()
-        self.h = h
-
-        self.num_kernels = len(h.resblock_kernel_sizes)
-        self.num_upsamples = len(h.upsample_rates)
-
-        # pre conv
-        self.conv_pre = weight_norm(
-            Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
-        )
-
-        # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
-        resblock = AMPBlock1 if h.resblock == "1" else AMPBlock2
-
-        # transposed conv-based upsamplers. does not apply anti-aliasing
-        self.ups = nn.ModuleList()
-        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
-            self.ups.append(
-                nn.ModuleList(
-                    [
-                        weight_norm(
-                            ConvTranspose1d(
-                                h.upsample_initial_channel // (2**i),
-                                h.upsample_initial_channel // (2 ** (i + 1)),
-                                k,
-                                u,
-                                padding=(k - u) // 2,
-                            )
-                        )
-                    ]
-                )
-            )
-
-        # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
-        self.resblocks = nn.ModuleList()
-        for i in range(len(self.ups)):
-            ch = h.upsample_initial_channel // (2 ** (i + 1))
-            for j, (k, d) in enumerate(
-                zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
-            ):
-                self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
-
-        # post conv
-        if (
-            h.activation == "snake"
-        ):  # periodic nonlinearity with snake function and anti-aliasing
-            activation_post = Snake(ch, alpha_logscale=h.snake_logscale)
-            self.activation_post = Activation1d(activation=activation_post)
-        elif (
-            h.activation == "snakebeta"
-        ):  # periodic nonlinearity with snakebeta function and anti-aliasing
-            activation_post = SnakeBeta(ch, alpha_logscale=h.snake_logscale)
-            self.activation_post = Activation1d(activation=activation_post)
-        else:
-            raise NotImplementedError(
-                "activation incorrectly specified. check the config file and look for 'activation'."
-            )
-
-        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
-
-        # weight initialization
-        for i in range(len(self.ups)):
-            self.ups[i].apply(init_weights)
-        self.conv_post.apply(init_weights)
-
-    def forward(self, x):
-        # pre conv
-        x = self.conv_pre(x)
-
-        for i in range(self.num_upsamples):
-            # upsampling
-            for i_up in range(len(self.ups[i])):
-                x = self.ups[i][i_up](x)
-            # AMP blocks
-            xs = None
-            for j in range(self.num_kernels):
-                if xs is None:
-                    xs = self.resblocks[i * self.num_kernels + j](x)
-                else:
-                    xs += self.resblocks[i * self.num_kernels + j](x)
-            x = xs / self.num_kernels
-
-        # post conv
-        x = self.activation_post(x)
-        x = self.conv_post(x)
-        x = torch.tanh(x)
-
-        return x
-
-    def remove_weight_norm(self):
-        print("Removing weight norm...")
-        for l in self.ups:
-            for l_i in l:
-                remove_weight_norm(l_i)
-        for l in self.resblocks:
-            l.remove_weight_norm()
-        remove_weight_norm(self.conv_pre)
-        remove_weight_norm(self.conv_post)
-
-
-class BigVGAN(nn.Module):
-    def __init__(
-        self,
-        checkpoint_path: str = "checkpoints/bigvgan-24k-100band/g_05000000",
-        config_file: Optional[str] = None,
-    ):
-        super().__init__()
-
-        if config_file is None:
-            config_file = Path(checkpoint_path).parent / "config.json"
-
-        with open(config_file) as f:
-            data = f.read()
-
-        json_config = json.loads(data)
-        self.h = AttrDict(json_config)
-        self.model = BigVGANModule(self.h)
-
-        state_dict = torch.load(checkpoint_path, map_location="cpu")["generator"]
-        self.model.load_state_dict(state_dict, strict=True)
-        self.model.eval()
-        self.model.remove_weight_norm()
-
-        self.mel_transform = LogMelSpectrogram(
-            sample_rate=self.h.sampling_rate,
-            n_fft=self.h.n_fft,
-            win_length=self.h.win_size,
-            hop_length=self.h.hop_size,
-            f_min=self.h.fmin,
-            f_max=self.h.fmax,
-            n_mels=self.h.num_mels,
-        )
-
-    @torch.no_grad()
-    def decode(self, mel):
-        y = self.model(mel)
-        return y
-
-    @torch.no_grad()
-    def encode(self, x):
-        return self.mel_transform(x)
-
-
-if __name__ == "__main__":
-    import librosa
-    import soundfile as sf
-
-    x = "data/StarRail/Chinese/罗刹/archive_luocha_2.wav"
-    model = BigVGAN()
-
-    wav, sr = librosa.load(x, sr=24000, mono=True)
-    wav = torch.from_numpy(wav).float()[None]
-    mel = model.encode(wav)
-
-    wav = model.decode(mel)[0].mT
-    sf.write("test.wav", wav.cpu().numpy(), 24000)

+ 0 - 80
fish_speech/models/vq_diffusion/bigvgan/utils.py

@@ -1,80 +0,0 @@
-# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
-#   LICENSE is in incl_licenses directory.
-
-import glob
-import os
-
-import matplotlib
-import torch
-from torch.nn.utils import weight_norm
-
-matplotlib.use("Agg")
-import matplotlib.pylab as plt
-from scipy.io.wavfile import write
-
-
-def plot_spectrogram(spectrogram):
-    fig, ax = plt.subplots(figsize=(10, 2))
-    im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
-    plt.colorbar(im, ax=ax)
-
-    fig.canvas.draw()
-    plt.close()
-
-    return fig
-
-
-def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
-    fig, ax = plt.subplots(figsize=(10, 2))
-    im = ax.imshow(
-        spectrogram,
-        aspect="auto",
-        origin="lower",
-        interpolation="none",
-        vmin=1e-6,
-        vmax=clip_max,
-    )
-    plt.colorbar(im, ax=ax)
-
-    fig.canvas.draw()
-    plt.close()
-
-    return fig
-
-
-def init_weights(m, mean=0.0, std=0.01):
-    classname = m.__class__.__name__
-    if classname.find("Conv") != -1:
-        m.weight.data.normal_(mean, std)
-
-
-def apply_weight_norm(m):
-    classname = m.__class__.__name__
-    if classname.find("Conv") != -1:
-        weight_norm(m)
-
-
-def get_padding(kernel_size, dilation=1):
-    return int((kernel_size * dilation - dilation) / 2)
-
-
-def load_checkpoint(filepath, device):
-    assert os.path.isfile(filepath)
-    print("Loading '{}'".format(filepath))
-    checkpoint_dict = torch.load(filepath, map_location=device)
-    print("Complete.")
-    return checkpoint_dict
-
-
-def save_checkpoint(filepath, obj):
-    print("Saving checkpoint to {}".format(filepath))
-    torch.save(obj, filepath)
-    print("Complete.")
-
-
-def scan_checkpoint(cp_dir, prefix):
-    pattern = os.path.join(cp_dir, prefix + "????????")
-    cp_list = glob.glob(pattern)
-    if len(cp_list) == 0:
-        return None
-    return sorted(cp_list)[-1]

+ 0 - 240
fish_speech/models/vq_diffusion/convnext_1d.py

@@ -1,240 +0,0 @@
-from dataclasses import dataclass
-from typing import Optional, Tuple, Union
-
-import torch
-import torch.nn as nn
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.models.embeddings import (
-    GaussianFourierProjection,
-    TimestepEmbedding,
-    Timesteps,
-)
-from diffusers.models.modeling_utils import ModelMixin
-from diffusers.utils import BaseOutput
-
-
-class ConvNeXtBlock(nn.Module):
-    """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
-
-    Args:
-        dim (int): Number of input channels.
-        mlp_dim (int): Dimensionality of the intermediate layer.
-        layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
-            Defaults to None.
-        adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
-            None means non-conditional LayerNorm. Defaults to None.
-    """
-
-    def __init__(
-        self,
-        dim: int,
-        intermediate_dim: int,
-        dilation: int = 1,
-        layer_scale_init_value: Optional[float] = 1e-6,
-    ):
-        super().__init__()
-        self.dwconv = nn.Conv1d(
-            dim,
-            dim,
-            kernel_size=7,
-            groups=dim,
-            dilation=dilation,
-            padding=int(dilation * (7 - 1) / 2),
-        )  # depthwise conv
-        self.norm = nn.LayerNorm(dim, eps=1e-6)
-        self.pwconv1 = nn.Linear(
-            dim, intermediate_dim
-        )  # pointwise/1x1 convs, implemented with linear layers
-        self.act = nn.GELU()
-        self.pwconv2 = nn.Linear(intermediate_dim, dim)
-        self.gamma = (
-            nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
-            if layer_scale_init_value is not None and layer_scale_init_value > 0
-            else None
-        )
-
-        self.condition_projection = nn.Sequential(
-            nn.Conv1d(dim, dim, 1),
-            nn.GELU(),
-            nn.Conv1d(dim, dim, 1),
-        )
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        condition: Optional[torch.Tensor] = None,
-        x_mask: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
-        residual = x
-
-        if condition is not None:
-            x = x + self.condition_projection(condition)
-
-        if x_mask is not None:
-            x = x * x_mask
-
-        x = self.dwconv(x)
-        x = x.transpose(1, 2)  # (B, C, T) -> (B, T, C)
-        x = self.norm(x)
-        x = self.pwconv1(x)
-        x = self.act(x)
-        x = self.pwconv2(x)
-        if self.gamma is not None:
-            x = self.gamma * x
-        x = x.transpose(1, 2)  # (B, T, C) -> (B, C, T)
-
-        x = residual + x
-
-        return x
-
-
-class ConvNext1DModel(ModelMixin, ConfigMixin):
-    r"""
-    A ConvNext model that takes a noisy sample and a timestep and returns a sample shaped output.
-
-    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
-    for all models (such as downloading or saving).
-
-    Parameters:
-        in_channels (`int`, *optional*, defaults to 128):
-            Number of channels in the input sample.
-        out_channels (`int`, *optional*, defaults to 128):
-            Number of channels in the output.
-        intermediate_dim (`int`, *optional*, defaults to 512):
-            Dimensionality of the intermediate blocks.
-        mlp_dim (`int`, *optional*, defaults to 2048):
-            Dimensionality of the MLP.
-        num_layers (`int`, *optional*, defaults to 20):
-            Number of layers in the model.
-        dilation_cycle_length (`int`, *optional*, defaults to 4):
-            Length of the dilation cycle.
-        time_embedding_type (`str`, *optional*, defaults to `positional`):
-            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
-        time_embedding_dim (`int`, *optional*, defaults to `None`):
-            An optional override for the dimension of the projected time embedding.
-        time_embedding_act_fn (`str`, *optional*, defaults to `None`):
-            Optional activation function to use only once on the time embeddings before they are passed to the rest of
-            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
-    """
-
-    _supports_gradient_checkpointing = True
-
-    @register_to_config
-    def __init__(
-        self,
-        in_channels: int = 128,
-        out_channels: int = 128,
-        intermediate_dim: int = 512,
-        mlp_dim: int = 2048,
-        num_layers: int = 20,
-        dilation_cycle_length: int = 4,
-        time_embedding_type: str = "positional",
-    ):
-        super().__init__()
-
-        if intermediate_dim % 2 != 0:
-            raise ValueError("intermediate_dim must be divisible by 2.")
-
-        # time
-        if time_embedding_type == "fourier":
-            self.time_proj = GaussianFourierProjection(
-                intermediate_dim // 2,
-                set_W_to_weight=False,
-                log=False,
-                flip_sin_to_cos=False,
-            )
-            timestep_input_dim = intermediate_dim
-        elif time_embedding_type == "positional":
-            self.time_proj = Timesteps(in_channels, False, 0)
-            timestep_input_dim = in_channels
-        else:
-            raise ValueError(
-                f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
-            )
-
-        self.time_mlp = TimestepEmbedding(
-            timestep_input_dim,
-            intermediate_dim,
-            act_fn="silu",
-            cond_proj_dim=None,  # No conditional projection for now
-        )
-
-        # Project to intermediate dim
-        self.in_proj = nn.Conv1d(in_channels, intermediate_dim, 1)
-        self.out_proj = nn.Conv1d(intermediate_dim, out_channels, 1)
-
-        # Blocks
-        self.blocks = nn.ModuleList(
-            [
-                ConvNeXtBlock(
-                    dim=intermediate_dim,
-                    intermediate_dim=mlp_dim,
-                    dilation=2 ** (i % dilation_cycle_length),
-                )
-                for i in range(num_layers)
-            ]
-        )
-
-        # Initialize weights
-        self.apply(self._init_weights)
-
-        self.gradient_checkpointing = False
-
-    def _set_gradient_checkpointing(self, module, value: bool = False):
-        self.gradient_checkpointing = value
-
-    def _init_weights(self, m):
-        if isinstance(m, (nn.Conv2d, nn.Linear, nn.Conv1d)):
-            nn.init.trunc_normal_(m.weight, mean=0, std=0.02)
-            if m.bias is not None:
-                nn.init.zeros_(m.bias)
-
-    def forward(
-        self,
-        sample: torch.FloatTensor,
-        timestep: Union[torch.Tensor, float, int],
-        sample_mask: Optional[torch.Tensor] = None,
-        condition: Optional[torch.Tensor] = None,
-    ) -> torch.FloatTensor:
-        r"""
-        The [`ConvNext1DModel`] forward method.
-
-        Args:
-            sample (`torch.FloatTensor`):
-                The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
-            timestep (`torch.FloatTensor` or `float` or `int`):
-                The number of timesteps to denoise an input.
-            sample_mask (`torch.BoolTensor`, *optional*):
-                A mask of the same shape as `sample` that indicates which elements are invalid.
-                True means the element is invalid and should be masked out.
-            return_dict (`bool`, *optional*, defaults to `True`):
-                Whether or not to return a [`~models.unet_1d.ConvNext1DOutput`] instead of a plain tuple.
-
-        Returns:
-            [`~models.unet_1d.ConvNext1DOutput`] or `tuple`:
-                If `return_dict` is True, an [`~models.unet_1d.ConvNext1DOutput`] is returned, otherwise a `tuple` is
-                returned where the first element is the sample tensor.
-        """
-
-        # 1. time
-        t_emb = self.time_proj(timestep)
-        t_emb = self.time_mlp(t_emb)[..., None]
-
-        # 2. pre-process
-        if condition is not None:
-            sample = torch.cat([sample, condition], dim=1)
-
-        x = self.in_proj(sample)
-
-        if sample_mask.ndim == 2:
-            sample_mask = sample_mask[:, None, :]
-
-        # 3. blocks
-        for block in self.blocks:
-            if self.training and self.is_gradient_checkpointing:
-                x = torch.utils.checkpoint.checkpoint(block, x, t_emb, sample_mask)
-            else:
-                x = block(x, t_emb, sample_mask)
-
-        # 4. post-process
-        return self.out_proj(x)

+ 0 - 373
fish_speech/models/vq_diffusion/lit_module.py

@@ -1,373 +0,0 @@
-import itertools
-from typing import Any, Callable, Optional
-
-import lightning as L
-import numpy as np
-import torch
-import torch.nn.functional as F
-import wandb
-from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler
-from diffusers.utils.torch_utils import randn_tensor
-from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
-from matplotlib import pyplot as plt
-from torch import nn
-from tqdm import tqdm
-
-from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
-from fish_speech.models.vqgan.modules.encoders import (
-    SpeakerEncoder,
-    TextEncoder,
-    VQEncoder,
-)
-from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
-
-
-class ConvDownSample(nn.Module):
-    def __init__(
-        self,
-        dims: list,
-        kernel_sizes: list,
-        strides: list,
-    ):
-        super().__init__()
-
-        self.dims = dims
-        self.kernel_sizes = kernel_sizes
-        self.strides = strides
-        self.total_strides = np.prod(self.strides)
-
-        self.convs = nn.ModuleList(
-            [
-                nn.ModuleList(
-                    [
-                        nn.Conv1d(
-                            in_channels=self.dims[i],
-                            out_channels=self.dims[i + 1],
-                            kernel_size=self.kernel_sizes[i],
-                            stride=self.strides[i],
-                            padding=(self.kernel_sizes[i] - 1) // 2,
-                        ),
-                        nn.LayerNorm(self.dims[i + 1], elementwise_affine=True),
-                        nn.GELU(),
-                    ]
-                )
-                for i in range(len(self.dims) - 1)
-            ]
-        )
-
-        self.apply(self.init_weights)
-
-    def init_weights(self, m):
-        if isinstance(m, nn.Conv1d):
-            nn.init.normal_(m.weight, std=0.02)
-        elif isinstance(m, nn.LayerNorm):
-            nn.init.ones_(m.weight)
-            nn.init.zeros_(m.bias)
-
-    def forward(self, x):
-        for conv, norm, act in self.convs:
-            x = conv(x)
-            x = norm(x.mT).mT
-            x = act(x)
-
-        return x
-
-
-class VQDiffusion(L.LightningModule):
-    def __init__(
-        self,
-        optimizer: Callable,
-        lr_scheduler: Callable,
-        mel_transform: nn.Module,
-        feature_mel_transform: nn.Module,
-        vq_encoder: VQEncoder,
-        speaker_encoder: SpeakerEncoder,
-        text_encoder: TextEncoder,
-        decoder: ConvNext1DModel,
-        postnet: ConvNext1DModel,
-        vocoder: nn.Module,
-        hop_length: int = 640,
-        sample_rate: int = 32000,
-        speaker_use_feats: bool = False,
-        downsample: Optional[nn.Module] = None,
-    ):
-        super().__init__()
-
-        # Model parameters
-        self.optimizer_builder = optimizer
-        self.lr_scheduler_builder = lr_scheduler
-
-        # Generator and discriminators
-        self.mel_transform = mel_transform
-        self.feature_mel_transform = feature_mel_transform
-        self.noise_scheduler = DDIMScheduler(
-            num_train_timesteps=1000,
-            clip_sample=False,
-            beta_end=0.01,
-        )
-
-        # Modules
-        self.vq_encoder = vq_encoder
-        self.speaker_encoder = speaker_encoder
-        self.text_encoder = text_encoder
-        self.decoder = decoder
-        self.postnet = postnet
-        self.downsample = downsample
-
-        self.vocoder = vocoder
-        self.hop_length = hop_length
-        self.sampling_rate = sample_rate
-        self.speaker_use_feats = speaker_use_feats
-
-        # Freeze vocoder
-        for param in self.vocoder.parameters():
-            param.requires_grad = False
-
-    def configure_optimizers(self):
-        optimizer = self.optimizer_builder(self.parameters())
-        lr_scheduler = self.lr_scheduler_builder(optimizer)
-
-        return {
-            "optimizer": optimizer,
-            "lr_scheduler": {
-                "scheduler": lr_scheduler,
-                "interval": "step",
-            },
-        }
-
-    def normalize_mels(self, x):
-        # x is in range -10.1 to 3.1, normalize to -1 to 1
-        x_min, x_max = -10.1, 3.1
-        return (x - x_min) / (x_max - x_min) * 2 - 1
-
-    def denormalize_mels(self, x):
-        x_min, x_max = -10.1, 3.1
-        return (x + 1) / 2 * (x_max - x_min) + x_min
-
-    def training_step(self, batch, batch_idx):
-        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-        features, feature_lengths = batch["features"], batch["feature_lengths"]
-
-        audios = audios.float()
-        # features = features.float().mT
-        audios = audios[:, None, :]
-
-        with torch.no_grad():
-            gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
-            features = self.feature_mel_transform(
-                audios, sample_rate=self.sampling_rate
-            )
-
-        if self.downsample is not None:
-            features = self.downsample(features)
-
-        mel_lengths = audio_lengths // self.hop_length
-        feature_lengths = (
-            audio_lengths
-            / self.sampling_rate
-            * self.feature_mel_transform.sample_rate
-            / self.feature_mel_transform.hop_length
-            / (self.downsample.total_strides if self.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
-        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
-            gt_mels.dtype
-        )
-
-        if self.speaker_use_feats:
-            speaker_features = self.speaker_encoder(features, feature_masks)
-        else:
-            speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-
-        # vq_features is 50 hz, need to convert to true mel size
-        text_features = self.text_encoder(features, feature_masks)
-        text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
-        text_features = F.interpolate(
-            text_features, size=gt_mels.shape[2], mode="nearest"
-        )
-
-        # Sample noise that we'll add to the images
-        normalized_gt_mels = gt_mels / 2.303
-
-        # Predict
-        mels = self.decoder(text_features, mel_masks, g=speaker_features)
-        t = torch.tensor([0] * mels.shape[0], device=mels.device, dtype=torch.long)
-        postnet_mels = self.postnet(mels, t, mel_masks)
-
-        # MSE loss without the mask
-        mel_loss = F.l1_loss(
-            mels * mel_masks,
-            normalized_gt_mels * mel_masks,
-        )
-
-        postnet_loss = F.l1_loss(
-            postnet_mels * mel_masks,
-            normalized_gt_mels * mel_masks,
-        )
-
-        self.log(
-            "train/mel_loss",
-            mel_loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
-
-        self.log(
-            "train/postnet_loss",
-            postnet_loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
-
-        self.log(
-            "train/vq_loss",
-            vq_loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
-
-        return vq_loss + mel_loss + postnet_loss
-
-    def validation_step(self, batch: Any, batch_idx: int):
-        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-        features, feature_lengths = batch["features"], batch["feature_lengths"]
-
-        audios = audios.float()
-        # features = features.float().mT
-        audios = audios[:, None, :]
-        gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
-        features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
-
-        if self.downsample is not None:
-            features = self.downsample(features)
-
-        mel_lengths = audio_lengths // self.hop_length
-        feature_lengths = (
-            audio_lengths
-            / self.sampling_rate
-            * self.feature_mel_transform.sample_rate
-            / self.feature_mel_transform.hop_length
-            / (self.downsample.total_strides if self.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
-        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
-            gt_mels.dtype
-        )
-
-        if self.speaker_use_feats:
-            speaker_features = self.speaker_encoder(features, feature_masks)
-        else:
-            speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-
-        # vq_features is 50 hz, need to convert to true mel size
-        text_features = self.text_encoder(features, feature_masks)
-        text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
-        text_features = F.interpolate(
-            text_features, size=gt_mels.shape[2], mode="nearest"
-        )
-
-        # Begin sampling
-        mels = self.decoder(text_features, mel_masks, g=speaker_features)
-        t = torch.tensor([0] * mels.shape[0], device=mels.device, dtype=torch.long)
-        postnet_mels = self.postnet(mels, t, mel_masks)
-
-        sampled_mels = postnet_mels * 2.303
-        sampled_mels = sampled_mels * mel_masks
-
-        with torch.autocast(device_type=sampled_mels.device.type, enabled=False):
-            # Run vocoder on fp32
-            fake_audios = self.vocoder.decode(sampled_mels.float())
-
-        mel_loss = F.l1_loss(gt_mels * mel_masks, sampled_mels * mel_masks)
-        self.log(
-            "val/mel_loss",
-            mel_loss,
-            on_step=False,
-            on_epoch=True,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
-
-        for idx, (
-            mel,
-            gen_mel,
-            audio,
-            gen_audio,
-            audio_len,
-        ) in enumerate(
-            zip(
-                gt_mels,
-                sampled_mels,
-                audios,
-                fake_audios,
-                audio_lengths,
-            )
-        ):
-            mel_len = audio_len // self.hop_length
-
-            image_mels = plot_mel(
-                [
-                    gen_mel[:, :mel_len],
-                    mel[:, :mel_len],
-                ],
-                [
-                    "Generated Spectrogram",
-                    "Ground-Truth Spectrogram",
-                ],
-            )
-
-            if isinstance(self.logger, WandbLogger):
-                self.logger.experiment.log(
-                    {
-                        "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
-                        "wavs": [
-                            wandb.Audio(
-                                audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="gt",
-                            ),
-                            wandb.Audio(
-                                gen_audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="prediction",
-                            ),
-                        ],
-                    },
-                )
-
-            if isinstance(self.logger, TensorBoardLogger):
-                self.logger.experiment.add_figure(
-                    f"sample-{idx}/mels",
-                    image_mels,
-                    global_step=self.global_step,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/gt",
-                    audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/prediction",
-                    gen_audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-
-            plt.close(image_mels)

+ 0 - 227
fish_speech/models/vq_diffusion/wavenet.py

@@ -1,227 +0,0 @@
-import math
-from typing import Optional, Union
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-
-class Mish(nn.Module):
-    def forward(self, x):
-        return x * torch.tanh(F.softplus(x))
-
-
-class DiffusionEmbedding(nn.Module):
-    """Diffusion Step Embedding"""
-
-    def __init__(self, d_denoiser):
-        super(DiffusionEmbedding, self).__init__()
-        self.dim = d_denoiser
-
-    def forward(self, x):
-        device = x.device
-        half_dim = self.dim // 2
-        emb = math.log(10000) / (half_dim - 1)
-        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
-        emb = x[:, None] * emb[None, :]
-        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
-        return emb
-
-
-class LinearNorm(nn.Module):
-    """LinearNorm Projection"""
-
-    def __init__(self, in_features, out_features, bias=False):
-        super(LinearNorm, self).__init__()
-        self.linear = nn.Linear(in_features, out_features, bias)
-
-        nn.init.xavier_uniform_(self.linear.weight)
-        if bias:
-            nn.init.constant_(self.linear.bias, 0.0)
-
-    def forward(self, x):
-        x = self.linear(x)
-        return x
-
-
-class ConvNorm(nn.Module):
-    """1D Convolution"""
-
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        kernel_size=1,
-        stride=1,
-        padding=None,
-        dilation=1,
-        bias=True,
-        w_init_gain="linear",
-    ):
-        super(ConvNorm, self).__init__()
-
-        if padding is None:
-            assert kernel_size % 2 == 1
-            padding = int(dilation * (kernel_size - 1) / 2)
-
-        self.conv = nn.Conv1d(
-            in_channels,
-            out_channels,
-            kernel_size=kernel_size,
-            stride=stride,
-            padding=padding,
-            dilation=dilation,
-            bias=bias,
-        )
-        nn.init.kaiming_normal_(self.conv.weight)
-
-    def forward(self, signal):
-        conv_signal = self.conv(signal)
-
-        return conv_signal
-
-
-class ResidualBlock(nn.Module):
-    """Residual Block"""
-
-    def __init__(self, d_encoder, residual_channels, use_linear_bias=False, dilation=1):
-        super(ResidualBlock, self).__init__()
-        self.conv_layer = ConvNorm(
-            residual_channels,
-            2 * residual_channels,
-            kernel_size=3,
-            stride=1,
-            padding=dilation,
-            dilation=dilation,
-        )
-        self.diffusion_projection = LinearNorm(
-            residual_channels, residual_channels, use_linear_bias
-        )
-        self.condition_projection = ConvNorm(
-            d_encoder, 2 * residual_channels, kernel_size=1
-        )
-        self.output_projection = ConvNorm(
-            residual_channels, 2 * residual_channels, kernel_size=1
-        )
-
-    def forward(self, x, conditioner, diffusion_step):
-        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
-        conditioner = self.condition_projection(conditioner)
-
-        y = x + diffusion_step
-
-        y = self.conv_layer(y) + conditioner
-
-        gate, filter = torch.chunk(y, 2, dim=1)
-        y = torch.sigmoid(gate) * torch.tanh(filter)
-
-        y = self.output_projection(y)
-        residual, skip = torch.chunk(y, 2, dim=1)
-
-        return (x + residual) / math.sqrt(2.0), skip
-
-
-class SpectrogramUpsampler(nn.Module):
-    def __init__(self, hop_size):
-        super().__init__()
-
-        if hop_size == 256:
-            self.conv1 = nn.ConvTranspose2d(
-                1, 1, [3, 32], stride=[1, 16], padding=[1, 8]
-            )
-        elif hop_size == 512:
-            self.conv1 = nn.ConvTranspose2d(
-                1, 1, [3, 64], stride=[1, 32], padding=[1, 16]
-            )
-        else:
-            raise ValueError(f"Unsupported hop_size: {hop_size}")
-
-        self.conv2 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
-
-    def forward(self, x):
-        x = torch.unsqueeze(x, 1)
-        x = self.conv1(x)
-        x = F.leaky_relu(x, 0.4)
-        x = self.conv2(x)
-        x = F.leaky_relu(x, 0.4)
-        x = torch.squeeze(x, 1)
-
-        return x
-
-
-class WaveNet(nn.Module):
-    """
-    WaveNet
-    https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio
-    """
-
-    def __init__(
-        self,
-        mel_channels=128,
-        d_encoder=256,
-        residual_channels=512,
-        residual_layers=20,
-        use_linear_bias=False,
-        dilation_cycle=None,
-    ):
-        super(WaveNet, self).__init__()
-
-        self.input_projection = ConvNorm(mel_channels, residual_channels, kernel_size=1)
-        self.diffusion_embedding = DiffusionEmbedding(residual_channels)
-        self.mlp = nn.Sequential(
-            LinearNorm(residual_channels, residual_channels * 4, use_linear_bias),
-            Mish(),
-            LinearNorm(residual_channels * 4, residual_channels, use_linear_bias),
-        )
-        self.residual_layers = nn.ModuleList(
-            [
-                ResidualBlock(
-                    d_encoder,
-                    residual_channels,
-                    use_linear_bias=use_linear_bias,
-                    dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
-                )
-                for i in range(residual_layers)
-            ]
-        )
-        self.skip_projection = ConvNorm(
-            residual_channels, residual_channels, kernel_size=1
-        )
-        self.output_projection = ConvNorm(
-            residual_channels, mel_channels, kernel_size=1
-        )
-        nn.init.zeros_(self.output_projection.conv.weight)
-
-    def forward(
-        self,
-        sample: torch.FloatTensor,
-        timestep: Union[torch.Tensor, float, int],
-        sample_mask: Optional[torch.Tensor] = None,
-        condition: Optional[torch.Tensor] = None,
-    ):
-        x = self.input_projection(sample)  # x [B, residual_channel, T]
-        x = F.relu(x)
-
-        diffusion_step = self.diffusion_embedding(timestep)
-        diffusion_step = self.mlp(diffusion_step)
-
-        if sample_mask is not None:
-            if sample_mask.ndim == 2:
-                sample_mask = sample_mask[:, None, :]
-
-            x = x * sample_mask
-
-        skip = []
-        for layer in self.residual_layers:
-            x, skip_connection = layer(x, condition, diffusion_step)
-            skip.append(skip_connection)
-
-        x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
-        x = self.skip_projection(x)
-        x = F.relu(x)
-        x = self.output_projection(x)  # [B, 128, T]
-
-        if sample_mask is not None:
-            x = x * sample_mask
-
-        return x

+ 0 - 104
tools/vqgan/calculate_kmeans_init.py

@@ -1,104 +0,0 @@
-from pathlib import Path
-
-import click
-import numpy as np
-import torch
-from einops import rearrange, repeat
-from torch.utils.data import DataLoader, Dataset
-from tqdm import tqdm
-from vector_quantize_pytorch.vector_quantize_pytorch import (
-    batched_bincount,
-    batched_sample_vectors,
-    cdist,
-)
-
-
-class KMeansDataset(Dataset):
-    def __init__(self, filelist):
-        root = Path(filelist).parent
-
-        with open(filelist) as f:
-            self.files = f.read().splitlines()
-
-        self.files = [root / file.strip() for file in self.files]
-
-    def __len__(self):
-        return len(self.files)
-
-    def __getitem__(self, idx):
-        file = self.files[idx]
-        try:
-            feature = np.load(file.with_suffix(".npy"))
-        except Exception as e:
-            return None
-        return torch.from_numpy(feature).float()
-
-    @staticmethod
-    def collate_fn(features):
-        features = [feature for feature in features if feature is not None]
-        features = torch.concat(features, dim=0)
-        return features
-
-
-@click.command()
-@click.option(
-    "--filelist",
-    type=click.Path(exists=True, path_type=Path),
-    default="data/vq_train_filelist.txt",
-)
-@click.option("--output", type=click.Path(path_type=Path), default="kmeans.pt")
-@click.option("--num-clusters", type=int, default=2048)
-@click.option("--epochs", type=int, default=10)
-def main(filelist: Path, output: Path, num_clusters: int, epochs: int):
-    loader = DataLoader(
-        KMeansDataset(filelist),
-        batch_size=1024,
-        shuffle=True,
-        num_workers=2,
-        collate_fn=KMeansDataset.collate_fn,
-    )
-
-    means = None
-    for epoch in tqdm(range(epochs), desc="Epochs", position=0):
-        total_bins = torch.zeros(1, num_clusters, dtype=torch.int64, device="cuda")
-
-        for samples in tqdm(loader, desc="Batches", position=1):
-            samples = samples.cuda()[None]
-            num_codebooks, dim, dtype = (
-                samples.shape[0],
-                samples.shape[-1],
-                samples.dtype,
-            )
-
-            if means is None:
-                means = batched_sample_vectors(samples, num_clusters)
-
-            dists = -cdist(samples, means)
-
-            buckets = torch.argmax(dists, dim=-1)
-            bins = batched_bincount(buckets, minlength=num_clusters)
-
-            zero_mask = bins == 0
-            bins_min_clamped = bins.masked_fill(zero_mask, 1)
-
-            new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype=dtype)
-
-            new_means.scatter_add_(1, repeat(buckets, "h n -> h n d", d=dim), samples)
-            new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1")
-
-            means = torch.where(rearrange(zero_mask, "... -> ... 1"), means, new_means)
-
-            total_bins += bins
-
-        torch.save(
-            {
-                "centroids": means,
-                "bins": bins,
-            },
-            output,
-        )
-        print(f"Finished epoch {epoch}, total bins: {total_bins}")
-
-
-if __name__ == "__main__":
-    main()