Przeglądaj źródła

support new vqgan

Lengyue 1 rok temu
rodzic
commit
aad7ba4942

+ 2 - 3
fish_speech/configs/firefly_gan_vq.yaml

@@ -22,13 +22,12 @@ head:
   resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
   resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
   num_mels: 512
   num_mels: 512
   upsample_initial_channel: 512
   upsample_initial_channel: 512
-  use_template: false
   pre_conv_kernel_size: 13
   pre_conv_kernel_size: 13
   post_conv_kernel_size: 13
   post_conv_kernel_size: 13
 quantizer:
 quantizer:
   _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
   _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
   input_dim: 512
   input_dim: 512
-  n_groups: 4
+  n_groups: 8
   n_codebooks: 1
   n_codebooks: 1
   levels: [8, 5, 5, 5]
   levels: [8, 5, 5, 5]
-  downsample_factor: [2]
+  downsample_factor: [2, 2]

+ 0 - 3
fish_speech/models/vqgan/__init__.py

@@ -1,3 +0,0 @@
-from .lit_module import VQGAN
-
-__all__ = ["VQGAN"]

+ 0 - 442
fish_speech/models/vqgan/lit_module.py

@@ -1,442 +0,0 @@
-import itertools
-import math
-from typing import Any, Callable
-
-import lightning as L
-import torch
-import torch.nn.functional as F
-import wandb
-from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
-from matplotlib import pyplot as plt
-from torch import nn
-
-from fish_speech.models.vqgan.modules.discriminator import Discriminator
-from fish_speech.models.vqgan.modules.wavenet import WaveNet
-from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
-
-
-class VQGAN(L.LightningModule):
-    def __init__(
-        self,
-        optimizer: Callable,
-        lr_scheduler: Callable,
-        encoder: WaveNet,
-        quantizer: nn.Module,
-        decoder: WaveNet,
-        discriminator: Discriminator,
-        vocoder: nn.Module,
-        encode_mel_transform: nn.Module,
-        gt_mel_transform: nn.Module,
-        weight_adv: float = 1.0,
-        weight_vq: float = 1.0,
-        weight_mel: float = 1.0,
-        sampling_rate: int = 44100,
-        freeze_encoder: bool = False,
-    ):
-        super().__init__()
-
-        # Model parameters
-        self.optimizer_builder = optimizer
-        self.lr_scheduler_builder = lr_scheduler
-
-        # Modules
-        self.encoder = encoder
-        self.quantizer = quantizer
-        self.decoder = decoder
-        self.vocoder = vocoder
-        self.discriminator = discriminator
-        self.encode_mel_transform = encode_mel_transform
-        self.gt_mel_transform = gt_mel_transform
-
-        # A simple linear layer to project quality to condition channels
-        self.quality_projection = nn.Linear(1, 768)
-
-        # Freeze vocoder
-        for param in self.vocoder.parameters():
-            param.requires_grad = False
-
-        # Loss weights
-        self.weight_adv = weight_adv
-        self.weight_vq = weight_vq
-        self.weight_mel = weight_mel
-
-        # Other parameters
-        self.sampling_rate = sampling_rate
-
-        # Disable strict loading
-        self.strict_loading = False
-
-        # If encoder is frozen
-        if freeze_encoder:
-            for param in self.encoder.parameters():
-                param.requires_grad = False
-
-            for param in self.quantizer.parameters():
-                param.requires_grad = False
-
-        self.automatic_optimization = False
-
-    def on_save_checkpoint(self, checkpoint):
-        # Do not save vocoder
-        state_dict = checkpoint["state_dict"]
-        for name in list(state_dict.keys()):
-            if "vocoder" in name:
-                state_dict.pop(name)
-
-    def configure_optimizers(self):
-        optimizer_generator = self.optimizer_builder(
-            itertools.chain(
-                self.encoder.parameters(),
-                self.quantizer.parameters(),
-                self.decoder.parameters(),
-                self.quality_projection.parameters(),
-            )
-        )
-        optimizer_discriminator = self.optimizer_builder(
-            self.discriminator.parameters()
-        )
-
-        lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
-        lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
-
-        return (
-            {
-                "optimizer": optimizer_generator,
-                "lr_scheduler": {
-                    "scheduler": lr_scheduler_generator,
-                    "interval": "step",
-                    "name": "optimizer/generator",
-                },
-            },
-            {
-                "optimizer": optimizer_discriminator,
-                "lr_scheduler": {
-                    "scheduler": lr_scheduler_discriminator,
-                    "interval": "step",
-                    "name": "optimizer/discriminator",
-                },
-            },
-        )
-
-    def training_step(self, batch, batch_idx):
-        optim_g, optim_d = self.optimizers()
-
-        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-
-        audios = audios.float()
-        audios = audios[:, None, :]
-
-        with torch.no_grad():
-            encoded_mels = self.encode_mel_transform(audios)
-            gt_mels = self.gt_mel_transform(audios)
-            quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
-            quality = quality.unsqueeze(-1)
-
-        mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
-        mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
-        mel_masks_float_conv = mel_masks[:, None, :].float()
-        gt_mels = gt_mels * mel_masks_float_conv
-        encoded_mels = encoded_mels * mel_masks_float_conv
-
-        # Encode
-        encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
-
-        # Quantize
-        vq_result = self.quantizer(encoded_features)
-        loss_vq = getattr("vq_result", "loss", 0.0)
-        vq_recon_features = vq_result.z * mel_masks_float_conv
-        vq_recon_features = (
-            vq_recon_features + self.quality_projection(quality)[:, :, None]
-        )
-
-        # VQ Decode
-        gen_mel = (
-            self.decoder(
-                torch.randn_like(vq_recon_features) * mel_masks_float_conv,
-                condition=vq_recon_features,
-            )
-            * mel_masks_float_conv
-        )
-
-        # Discriminator
-        real_logits = self.discriminator(gt_mels)
-        fake_logits = self.discriminator(gen_mel.detach())
-        d_mask = F.interpolate(
-            mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
-        )
-
-        loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
-        loss_fake = avg_with_mask(fake_logits**2, d_mask)
-
-        loss_d = loss_real + loss_fake
-
-        self.log(
-            "train/discriminator/loss",
-            loss_d,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-        )
-
-        # Discriminator backward
-        optim_d.zero_grad()
-        self.manual_backward(loss_d)
-        self.clip_gradients(
-            optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
-        )
-        optim_d.step()
-
-        # Mel Loss, applying l1, using a weighted sum
-        mel_distance = (
-            gen_mel - gt_mels
-        ).abs()  # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
-        loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
-        loss_mel_mid_freq = avg_with_mask(
-            mel_distance[:, 40:70, :], mel_masks_float_conv
-        )
-        loss_mel_high_freq = avg_with_mask(
-            mel_distance[:, 70:, :], mel_masks_float_conv
-        )
-        loss_mel = (
-            loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
-        )
-
-        # Adversarial Loss
-        fake_logits = self.discriminator(gen_mel)
-        loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
-
-        # Total loss
-        loss = (
-            self.weight_vq * loss_vq
-            + self.weight_mel * loss_mel
-            + self.weight_adv * loss_adv
-        )
-
-        # Log losses
-        self.log(
-            "train/generator/loss",
-            loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-        )
-        self.log(
-            "train/generator/loss_vq",
-            loss_vq,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-        )
-        self.log(
-            "train/generator/loss_mel",
-            loss_mel,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-        )
-        self.log(
-            "train/generator/loss_adv",
-            loss_adv,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-        )
-
-        # Generator backward
-        optim_g.zero_grad()
-        self.manual_backward(loss)
-        self.clip_gradients(
-            optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
-        )
-        optim_g.step()
-
-        scheduler_g, scheduler_d = self.lr_schedulers()
-        scheduler_g.step()
-        scheduler_d.step()
-
-    def validation_step(self, batch: Any, batch_idx: int):
-        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-
-        audios = audios.float()
-        audios = audios[:, None, :]
-
-        encoded_mels = self.encode_mel_transform(audios)
-        gt_mels = self.gt_mel_transform(audios)
-
-        mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
-        mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
-        mel_masks_float_conv = mel_masks[:, None, :].float()
-        gt_mels = gt_mels * mel_masks_float_conv
-        encoded_mels = encoded_mels * mel_masks_float_conv
-
-        # Encode
-        encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
-
-        # Quantize
-        vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
-        vq_recon_features = (
-            vq_recon_features
-            + self.quality_projection(
-                torch.ones(
-                    vq_recon_features.shape[0], 1, device=vq_recon_features.device
-                )
-                * 2
-            )[:, :, None]
-        )
-
-        # VQ Decode
-        gen_aux_mels = (
-            self.decoder(
-                torch.randn_like(vq_recon_features) * mel_masks_float_conv,
-                condition=vq_recon_features,
-            )
-            * mel_masks_float_conv
-        )
-        loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
-
-        self.log(
-            "val/loss_mel",
-            loss_mel,
-            on_step=False,
-            on_epoch=True,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        recon_audios = self.vocoder(gt_mels)
-        gen_aux_audios = self.vocoder(gen_aux_mels)
-
-        # only log the first batch
-        if batch_idx != 0:
-            return
-
-        for idx, (
-            gt_mel,
-            gen_aux_mel,
-            audio,
-            gen_aux_audio,
-            recon_audio,
-            audio_len,
-        ) in enumerate(
-            zip(
-                gt_mels,
-                gen_aux_mels,
-                audios.cpu().float(),
-                gen_aux_audios.cpu().float(),
-                recon_audios.cpu().float(),
-                audio_lengths,
-            )
-        ):
-            if idx > 4:
-                break
-
-            mel_len = audio_len // self.gt_mel_transform.hop_length
-
-            image_mels = plot_mel(
-                [
-                    gt_mel[:, :mel_len],
-                    gen_aux_mel[:, :mel_len],
-                ],
-                [
-                    "Ground-Truth",
-                    "Auxiliary",
-                ],
-            )
-
-            if isinstance(self.logger, WandbLogger):
-                self.logger.experiment.log(
-                    {
-                        "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
-                        "wavs": [
-                            wandb.Audio(
-                                audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="gt",
-                            ),
-                            wandb.Audio(
-                                gen_aux_audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="aux",
-                            ),
-                            wandb.Audio(
-                                recon_audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="recon",
-                            ),
-                        ],
-                    },
-                )
-
-            if isinstance(self.logger, TensorBoardLogger):
-                self.logger.experiment.add_figure(
-                    f"sample-{idx}/mels",
-                    image_mels,
-                    global_step=self.global_step,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/gt",
-                    audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/gen",
-                    gen_aux_audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/recon",
-                    recon_audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-
-            plt.close(image_mels)
-
-    def encode(self, audios, audio_lengths):
-        audios = audios.float()
-
-        mels = self.encode_mel_transform(audios)
-        mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
-        mel_masks = sequence_mask(mel_lengths, mels.shape[2])
-        mel_masks_float_conv = mel_masks[:, None, :].float()
-        mels = mels * mel_masks_float_conv
-
-        # Encode
-        encoded_features = self.encoder(mels) * mel_masks_float_conv
-        feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
-
-        return self.quantizer.encode(encoded_features), feature_lengths
-
-    def decode(self, indices, feature_lengths, return_audios=False):
-        factor = math.prod(self.quantizer.downsample_factor)
-        mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
-        mel_masks_float_conv = mel_masks[:, None, :].float()
-
-        z = self.quantizer.decode(indices) * mel_masks_float_conv
-        z = (
-            z
-            + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
-                :, :, None
-            ]
-        )
-
-        gen_mel = (
-            self.decoder(
-                torch.randn_like(z) * mel_masks_float_conv,
-                condition=z,
-            )
-            * mel_masks_float_conv
-        )
-
-        if return_audios:
-            return self.vocoder(gen_mel)
-
-        return gen_mel

+ 0 - 44
fish_speech/models/vqgan/modules/discriminator.py

@@ -1,44 +0,0 @@
-import torch
-from torch import nn
-from torch.nn.utils.parametrizations import weight_norm
-
-
-class Discriminator(nn.Module):
-    def __init__(self):
-        super().__init__()
-
-        blocks = []
-        convs = [
-            (1, 64, (3, 9), 1, (1, 4)),
-            (64, 128, (3, 9), (1, 2), (1, 4)),
-            (128, 256, (3, 9), (1, 2), (1, 4)),
-            (256, 512, (3, 9), (1, 2), (1, 4)),
-            (512, 1024, (3, 3), 1, (1, 1)),
-            (1024, 1, (3, 3), 1, (1, 1)),
-        ]
-
-        for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
-            convs
-        ):
-            blocks.append(
-                weight_norm(
-                    nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
-                )
-            )
-
-            if idx != len(convs) - 1:
-                blocks.append(nn.SiLU(inplace=True))
-
-        self.blocks = nn.Sequential(*blocks)
-
-    def forward(self, x):
-        return self.blocks(x[:, None])[:, 0]
-
-
-if __name__ == "__main__":
-    model = Discriminator()
-    print(sum(p.numel() for p in model.parameters()) / 1_000_000)
-    x = torch.randn(1, 128, 1024)
-    y = model(x)
-    print(y.shape)
-    print(y)

+ 167 - 196
fish_speech/models/vqgan/modules/firefly.py

@@ -1,25 +1,26 @@
-# A inference only version of the FireflyGAN model
-
 import math
 import math
 from functools import partial
 from functools import partial
 from math import prod
 from math import prod
 from typing import Callable
 from typing import Callable
 
 
-import numpy as np
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
 from torch import nn
 from torch import nn
-from torch.nn import Conv1d
 from torch.nn.utils.parametrizations import weight_norm
 from torch.nn.utils.parametrizations import weight_norm
 from torch.nn.utils.parametrize import remove_parametrizations
 from torch.nn.utils.parametrize import remove_parametrizations
 from torch.utils.checkpoint import checkpoint
 from torch.utils.checkpoint import checkpoint
 
 
-from fish_speech.models.vqgan.utils import sequence_mask
+
+def sequence_mask(length, max_length=None):
+    if max_length is None:
+        max_length = length.max()
+    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+    return x.unsqueeze(0) < length.unsqueeze(1)
 
 
 
 
 def init_weights(m, mean=0.0, std=0.01):
 def init_weights(m, mean=0.0, std=0.01):
     classname = m.__class__.__name__
     classname = m.__class__.__name__
-    if classname.find("Conv") != -1:
+    if classname.find("Conv1D") != -1:
         m.weight.data.normal_(mean, std)
         m.weight.data.normal_(mean, std)
 
 
 
 
@@ -27,78 +28,141 @@ def get_padding(kernel_size, dilation=1):
     return (kernel_size * dilation - dilation) // 2
     return (kernel_size * dilation - dilation) // 2
 
 
 
 
+def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
+    """Remove padding from x, handling properly zero padding. Only for 1d!"""
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    assert (padding_left + padding_right) <= x.shape[-1]
+    end = x.shape[-1] - padding_right
+    return x[..., padding_left:end]
+
+
+def get_extra_padding_for_conv1d(
+    x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+) -> int:
+    """See `pad_for_conv1d`."""
+    length = x.shape[-1]
+    n_frames = (length - kernel_size + padding_total) / stride + 1
+    ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+    return ideal_length - length
+
+
+def pad1d(
+    x: torch.Tensor,
+    paddings: tuple[int, int],
+    mode: str = "zeros",
+    value: float = 0.0,
+):
+    """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+    If this is the case, we insert extra 0 padding to the right
+    before the reflection happen.
+    """
+    length = x.shape[-1]
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    if mode == "reflect":
+        max_pad = max(padding_left, padding_right)
+        extra_pad = 0
+        if length <= max_pad:
+            extra_pad = max_pad - length + 1
+            x = F.pad(x, (0, extra_pad))
+        padded = F.pad(x, paddings, mode, value)
+        end = padded.shape[-1] - extra_pad
+        return padded[..., :end]
+    else:
+        return F.pad(x, paddings, mode, value)
+
+
+class FishConvNet(nn.Module):
+    def __init__(
+        self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
+    ):
+        super(FishConvNet, self).__init__()
+        self.conv = nn.Conv1d(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            dilation=dilation,
+            groups=groups,
+        )
+        self.stride = stride
+        self.kernel_size = (kernel_size - 1) * dilation + 1
+        self.dilation = dilation
+
+    def forward(self, x):
+        pad = self.kernel_size - self.stride
+        extra_padding = get_extra_padding_for_conv1d(
+            x, self.kernel_size, self.stride, pad
+        )
+        x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
+        return self.conv(x).contiguous()
+
+    def weight_norm(self, name="weight", dim=0):
+        self.conv = weight_norm(self.conv, name=name, dim=dim)
+        return self
+
+    def remove_weight_norm(self):
+        self.conv = remove_parametrizations(self.conv)
+        return self
+
+
+class FishTransConvNet(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
+        super(FishTransConvNet, self).__init__()
+        self.conv = nn.ConvTranspose1d(
+            in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
+        )
+        self.stride = stride
+        self.kernel_size = kernel_size
+
+    def forward(self, x):
+        x = self.conv(x)
+        pad = self.kernel_size - self.stride
+        padding_right = math.ceil(pad)
+        padding_left = pad - padding_right
+        x = unpad1d(x, (padding_left, padding_right))
+        return x.contiguous()
+
+    def weight_norm(self, name="weight", dim=0):
+        self.conv = weight_norm(self.conv, name=name, dim=dim)
+        return self
+
+    def remove_weight_norm(self):
+        self.conv = remove_parametrizations(self.conv)
+        return self
+
+
 class ResBlock1(torch.nn.Module):
 class ResBlock1(torch.nn.Module):
     def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
     def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
         super().__init__()
         super().__init__()
 
 
         self.convs1 = nn.ModuleList(
         self.convs1 = nn.ModuleList(
             [
             [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[2],
-                        padding=get_padding(kernel_size, dilation[2]),
-                    )
-                ),
+                FishConvNet(
+                    channels, channels, kernel_size, stride=1, dilation=dilation[0]
+                ).weight_norm(),
+                FishConvNet(
+                    channels, channels, kernel_size, stride=1, dilation=dilation[1]
+                ).weight_norm(),
+                FishConvNet(
+                    channels, channels, kernel_size, stride=1, dilation=dilation[2]
+                ).weight_norm(),
             ]
             ]
         )
         )
         self.convs1.apply(init_weights)
         self.convs1.apply(init_weights)
 
 
         self.convs2 = nn.ModuleList(
         self.convs2 = nn.ModuleList(
             [
             [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
+                FishConvNet(
+                    channels, channels, kernel_size, stride=1, dilation=dilation[0]
+                ).weight_norm(),
+                FishConvNet(
+                    channels, channels, kernel_size, stride=1, dilation=dilation[1]
+                ).weight_norm(),
+                FishConvNet(
+                    channels, channels, kernel_size, stride=1, dilation=dilation[2]
+                ).weight_norm(),
             ]
             ]
         )
         )
         self.convs2.apply(init_weights)
         self.convs2.apply(init_weights)
@@ -153,7 +217,6 @@ class HiFiGANGenerator(nn.Module):
         resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
         resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
         num_mels: int = 128,
         num_mels: int = 128,
         upsample_initial_channel: int = 512,
         upsample_initial_channel: int = 512,
-        use_template: bool = True,
         pre_conv_kernel_size: int = 7,
         pre_conv_kernel_size: int = 7,
         post_conv_kernel_size: int = 7,
         post_conv_kernel_size: int = 7,
         post_activation: Callable = partial(nn.SiLU, inplace=True),
         post_activation: Callable = partial(nn.SiLU, inplace=True),
@@ -164,54 +227,29 @@ class HiFiGANGenerator(nn.Module):
             prod(upsample_rates) == hop_length
             prod(upsample_rates) == hop_length
         ), f"hop_length must be {prod(upsample_rates)}"
         ), f"hop_length must be {prod(upsample_rates)}"
 
 
-        self.conv_pre = weight_norm(
-            nn.Conv1d(
-                num_mels,
-                upsample_initial_channel,
-                pre_conv_kernel_size,
-                1,
-                padding=get_padding(pre_conv_kernel_size),
-            )
-        )
+        self.conv_pre = FishConvNet(
+            num_mels,
+            upsample_initial_channel,
+            pre_conv_kernel_size,
+            stride=1,
+        ).weight_norm()
 
 
         self.num_upsamples = len(upsample_rates)
         self.num_upsamples = len(upsample_rates)
         self.num_kernels = len(resblock_kernel_sizes)
         self.num_kernels = len(resblock_kernel_sizes)
 
 
         self.noise_convs = nn.ModuleList()
         self.noise_convs = nn.ModuleList()
-        self.use_template = use_template
         self.ups = nn.ModuleList()
         self.ups = nn.ModuleList()
 
 
         for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
         for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
-            c_cur = upsample_initial_channel // (2 ** (i + 1))
             self.ups.append(
             self.ups.append(
-                weight_norm(
-                    nn.ConvTranspose1d(
-                        upsample_initial_channel // (2**i),
-                        upsample_initial_channel // (2 ** (i + 1)),
-                        k,
-                        u,
-                        padding=(k - u) // 2,
-                    )
-                )
+                FishTransConvNet(
+                    upsample_initial_channel // (2**i),
+                    upsample_initial_channel // (2 ** (i + 1)),
+                    k,
+                    stride=u,
+                ).weight_norm()
             )
             )
 
 
-            if not use_template:
-                continue
-
-            if i + 1 < len(upsample_rates):
-                stride_f0 = np.prod(upsample_rates[i + 1 :])
-                self.noise_convs.append(
-                    Conv1d(
-                        1,
-                        c_cur,
-                        kernel_size=stride_f0 * 2,
-                        stride=stride_f0,
-                        padding=stride_f0 // 2,
-                    )
-                )
-            else:
-                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
-
         self.resblocks = nn.ModuleList()
         self.resblocks = nn.ModuleList()
         for i in range(len(self.ups)):
         for i in range(len(self.ups)):
             ch = upsample_initial_channel // (2 ** (i + 1))
             ch = upsample_initial_channel // (2 ** (i + 1))
@@ -220,29 +258,20 @@ class HiFiGANGenerator(nn.Module):
             )
             )
 
 
         self.activation_post = post_activation()
         self.activation_post = post_activation()
-        self.conv_post = weight_norm(
-            nn.Conv1d(
-                ch,
-                1,
-                post_conv_kernel_size,
-                1,
-                padding=get_padding(post_conv_kernel_size),
-            )
-        )
+        self.conv_post = FishConvNet(
+            ch, 1, post_conv_kernel_size, stride=1
+        ).weight_norm()
         self.ups.apply(init_weights)
         self.ups.apply(init_weights)
         self.conv_post.apply(init_weights)
         self.conv_post.apply(init_weights)
 
 
-    def forward(self, x, template=None):
+    def forward(self, x):
         x = self.conv_pre(x)
         x = self.conv_pre(x)
 
 
         for i in range(self.num_upsamples):
         for i in range(self.num_upsamples):
             x = F.silu(x, inplace=True)
             x = F.silu(x, inplace=True)
             x = self.ups[i](x)
             x = self.ups[i](x)
 
 
-            if self.use_template:
-                x = x + self.noise_convs[i](template)
-
-            if self.training:
+            if self.training and self.checkpointing:
                 x = checkpoint(
                 x = checkpoint(
                     self.resblocks[i],
                     self.resblocks[i],
                     x,
                     x,
@@ -364,11 +393,11 @@ class ConvNeXtBlock(nn.Module):
     ):
     ):
         super().__init__()
         super().__init__()
 
 
-        self.dwconv = nn.Conv1d(
+        self.dwconv = FishConvNet(
             dim,
             dim,
             dim,
             dim,
             kernel_size=kernel_size,
             kernel_size=kernel_size,
-            padding=int(dilation * (kernel_size - 1) / 2),
+            # padding=int(dilation * (kernel_size - 1) / 2),
             groups=dim,
             groups=dim,
         )  # depthwise conv
         )  # depthwise conv
         self.norm = LayerNorm(dim, eps=1e-6)
         self.norm = LayerNorm(dim, eps=1e-6)
@@ -421,12 +450,13 @@ class ConvNeXtEncoder(nn.Module):
 
 
         self.downsample_layers = nn.ModuleList()
         self.downsample_layers = nn.ModuleList()
         stem = nn.Sequential(
         stem = nn.Sequential(
-            nn.Conv1d(
+            FishConvNet(
                 input_channels,
                 input_channels,
                 dims[0],
                 dims[0],
-                kernel_size=kernel_size,
-                padding=kernel_size // 2,
-                padding_mode="zeros",
+                kernel_size=7,
+                # padding=3,
+                # padding_mode="replicate",
+                # padding_mode="zeros",
             ),
             ),
             LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
             LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
         )
         )
@@ -491,6 +521,7 @@ class FireflyArchitecture(nn.Module):
         self.head = head
         self.head = head
         self.quantizer = quantizer
         self.quantizer = quantizer
         self.spec_transform = spec_transform
         self.spec_transform = spec_transform
+        self.downsample_factor = math.prod(self.quantizer.downsample_factor)
 
 
     def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
     def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
         if self.spec_transform is not None:
         if self.spec_transform is not None:
@@ -512,7 +543,7 @@ class FireflyArchitecture(nn.Module):
         if x.ndim == 2:
         if x.ndim == 2:
             x = x[:, None, :]
             x = x[:, None, :]
 
 
-        if self.quantizer is not None:
+        if self.vq is not None:
             return x, vq_result
             return x, vq_result
 
 
         return x
         return x
@@ -528,25 +559,30 @@ class FireflyArchitecture(nn.Module):
 
 
         # Encode
         # Encode
         encoded_features = self.backbone(mels) * mel_masks_float_conv
         encoded_features = self.backbone(mels) * mel_masks_float_conv
-        feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
+        feature_lengths = mel_lengths // self.downsample_factor
 
 
         return self.quantizer.encode(encoded_features), feature_lengths
         return self.quantizer.encode(encoded_features), feature_lengths
 
 
     def decode(self, indices, feature_lengths) -> torch.Tensor:
     def decode(self, indices, feature_lengths) -> torch.Tensor:
-        factor = math.prod(self.quantizer.downsample_factor)
-        mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
+        mel_masks = sequence_mask(
+            feature_lengths * self.downsample_factor,
+            indices.shape[2] * self.downsample_factor,
+        )
         mel_masks_float_conv = mel_masks[:, None, :].float()
         mel_masks_float_conv = mel_masks[:, None, :].float()
+        audio_lengths = (
+            feature_lengths * self.downsample_factor * self.spec_transform.hop_length
+        )
 
 
         audio_masks = sequence_mask(
         audio_masks = sequence_mask(
-            feature_lengths * factor * self.spec_transform.hop_length,
-            indices.shape[2] * factor * self.spec_transform.hop_length,
+            audio_lengths,
+            indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
         )
         )
         audio_masks_float_conv = audio_masks[:, None, :].float()
         audio_masks_float_conv = audio_masks[:, None, :].float()
 
 
         z = self.quantizer.decode(indices) * mel_masks_float_conv
         z = self.quantizer.decode(indices) * mel_masks_float_conv
         x = self.head(z) * audio_masks_float_conv
         x = self.head(z) * audio_masks_float_conv
 
 
-        return x
+        return x, audio_lengths
 
 
     def remove_parametrizations(self):
     def remove_parametrizations(self):
         if hasattr(self.backbone, "remove_parametrizations"):
         if hasattr(self.backbone, "remove_parametrizations"):
@@ -558,68 +594,3 @@ class FireflyArchitecture(nn.Module):
     @property
     @property
     def device(self):
     def device(self):
         return next(self.parameters()).device
         return next(self.parameters()).device
-
-
-class FireflyBase(nn.Module):
-    def __init__(self, ckpt_path: str = None, pretrained: bool = True):
-        super().__init__()
-
-        self.backbone = ConvNeXtEncoder(
-            input_channels=128,
-            depths=[3, 3, 9, 3],
-            dims=[128, 256, 384, 512],
-            drop_path_rate=0.2,
-            kernel_size=7,
-        )
-
-        self.head = HiFiGANGenerator(
-            hop_length=512,
-            upsample_rates=[8, 8, 2, 2, 2],
-            upsample_kernel_sizes=[16, 16, 4, 4, 4],
-            resblock_kernel_sizes=[3, 7, 11],
-            resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
-            num_mels=512,
-            upsample_initial_channel=512,
-            use_template=False,
-            pre_conv_kernel_size=13,
-            post_conv_kernel_size=13,
-        )
-
-        if ckpt_path is not None:
-            state_dict = torch.load(ckpt_path, map_location="cpu")
-        elif pretrained:
-            state_dict = torch.hub.load_state_dict_from_url(
-                "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
-                map_location="cpu",
-                model_dir="checkpoints",
-            )
-
-        if "state_dict" in state_dict:
-            state_dict = state_dict["state_dict"]
-
-        if any("generator." in k for k in state_dict):
-            state_dict = {
-                k.replace("generator.", ""): v
-                for k, v in state_dict.items()
-                if "generator." in k
-            }
-
-        self.load_state_dict(state_dict, strict=True)
-        self.head.remove_parametrizations()
-
-    @torch.no_grad()
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        x = self.backbone(x)
-        x = self.head(x)
-        if x.ndim == 2:
-            x = x[:, None, :]
-        return x
-
-
-if __name__ == "__main__":
-    model = FireflyBase()
-    model.eval()
-    x = torch.randn(1, 128, 128)
-    with torch.no_grad():
-        y = model(x)
-    print(y.shape)

+ 4 - 27
fish_speech/models/vqgan/modules/fsq.py

@@ -6,7 +6,7 @@ import torch.nn.functional as F
 from einops import rearrange
 from einops import rearrange
 from vector_quantize_pytorch import GroupedResidualFSQ
 from vector_quantize_pytorch import GroupedResidualFSQ
 
 
-from .firefly import ConvNeXtBlock
+from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
 
 
 
 
 @dataclass
 @dataclass
@@ -20,7 +20,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
     def __init__(
     def __init__(
         self,
         self,
         input_dim: int = 512,
         input_dim: int = 512,
-        n_codebooks: int = 1,
+        n_codebooks: int = 9,
         n_groups: int = 1,
         n_groups: int = 1,
         levels: tuple[int] = (8, 5, 5, 5),  # Approximate 2**10
         levels: tuple[int] = (8, 5, 5, 5),  # Approximate 2**10
         downsample_factor: tuple[int] = (2, 2),
         downsample_factor: tuple[int] = (2, 2),
@@ -46,7 +46,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
         self.downsample = nn.Sequential(
         self.downsample = nn.Sequential(
             *[
             *[
                 nn.Sequential(
                 nn.Sequential(
-                    nn.Conv1d(
+                    FishConvNet(
                         all_dims[idx],
                         all_dims[idx],
                         all_dims[idx + 1],
                         all_dims[idx + 1],
                         kernel_size=factor,
                         kernel_size=factor,
@@ -61,7 +61,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
         self.upsample = nn.Sequential(
         self.upsample = nn.Sequential(
             *[
             *[
                 nn.Sequential(
                 nn.Sequential(
-                    nn.ConvTranspose1d(
+                    FishTransConvNet(
                         all_dims[idx + 1],
                         all_dims[idx + 1],
                         all_dims[idx],
                         all_dims[idx],
                         kernel_size=factor,
                         kernel_size=factor,
@@ -114,26 +114,3 @@ class DownsampleFiniteScalarQuantize(nn.Module):
         z_q = self.residual_fsq.get_output_from_indices(indices)
         z_q = self.residual_fsq.get_output_from_indices(indices)
         z_q = self.upsample(z_q.mT)
         z_q = self.upsample(z_q.mT)
         return z_q
         return z_q
-
-    # def from_latents(self, latents: torch.Tensor):
-    #     z_q, z_p, codes = super().from_latents(latents)
-    #     z_q = self.upsample(z_q)
-    #     return z_q, z_p, codes
-
-
-if __name__ == "__main__":
-    rvq = DownsampleFiniteScalarQuantize(
-        n_codebooks=1,
-        downsample_factor=(2, 2),
-    )
-    x = torch.randn(16, 512, 80)
-
-    result = rvq(x)
-    print(rvq)
-    print(result.latents.shape, result.codes.shape, result.z.shape)
-
-    # y = rvq.from_codes(result.codes)
-    # print(y[0].shape)
-
-    # y = rvq.from_latents(result.latents)
-    # print(y[0].shape)

+ 0 - 115
fish_speech/models/vqgan/modules/reference.py

@@ -1,115 +0,0 @@
-from typing import Optional
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-from fish_speech.utils import autocast_exclude_mps
-
-from .wavenet import WaveNet
-
-
-class ReferenceEncoder(WaveNet):
-    def __init__(
-        self,
-        input_channels: Optional[int] = None,
-        output_channels: Optional[int] = None,
-        residual_channels: int = 512,
-        residual_layers: int = 20,
-        dilation_cycle: Optional[int] = 4,
-        num_heads: int = 8,
-        latent_len: int = 4,
-    ):
-        super().__init__(
-            input_channels=input_channels,
-            residual_channels=residual_channels,
-            residual_layers=residual_layers,
-            dilation_cycle=dilation_cycle,
-        )
-
-        self.head_dim = residual_channels // num_heads
-        self.num_heads = num_heads
-
-        self.latent_len = latent_len
-        self.latent = nn.Parameter(torch.zeros(1, self.latent_len, residual_channels))
-
-        self.q = nn.Linear(residual_channels, residual_channels, bias=True)
-        self.kv = nn.Linear(residual_channels, residual_channels * 2, bias=True)
-        self.q_norm = nn.LayerNorm(self.head_dim)
-        self.k_norm = nn.LayerNorm(self.head_dim)
-        self.proj = nn.Linear(residual_channels, residual_channels)
-        self.proj_drop = nn.Dropout(0.1)
-
-        self.norm = nn.LayerNorm(residual_channels)
-        self.mlp = nn.Sequential(
-            nn.Linear(residual_channels, residual_channels * 4),
-            nn.SiLU(),
-            nn.Linear(residual_channels * 4, residual_channels),
-        )
-        self.output_projection_attn = nn.Linear(residual_channels, output_channels)
-
-        torch.nn.init.trunc_normal_(self.latent, std=0.02)
-        self.apply(self.init_weights)
-
-    def init_weights(self, m):
-        if isinstance(m, nn.Linear):
-            torch.nn.init.trunc_normal_(m.weight, std=0.02)
-            if m.bias is not None:
-                torch.nn.init.constant_(m.bias, 0)
-
-    def forward(self, x, attn_mask=None):
-        x = super().forward(x).mT
-        B, N, C = x.shape
-
-        # Calculate mask
-        if attn_mask is not None:
-            assert attn_mask.shape == (B, N) and attn_mask.dtype == torch.bool
-
-            attn_mask = attn_mask[:, None, None, :].expand(
-                B, self.num_heads, self.latent_len, N
-            )
-
-        q_latent = self.latent.expand(B, -1, -1)
-        q = (
-            self.q(q_latent)
-            .reshape(B, self.latent_len, self.num_heads, self.head_dim)
-            .transpose(1, 2)
-        )
-
-        kv = (
-            self.kv(x)
-            .reshape(B, N, 2, self.num_heads, self.head_dim)
-            .permute(2, 0, 3, 1, 4)
-        )
-        k, v = kv.unbind(0)
-
-        q, k = self.q_norm(q), self.k_norm(k)
-        x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
-
-        x = x.transpose(1, 2).reshape(B, self.latent_len, C)
-        x = self.proj(x)
-        x = self.proj_drop(x)
-
-        x = x + self.mlp(self.norm(x))
-        x = self.output_projection_attn(x)
-        x = x.mean(1)
-
-        return x
-
-
-if __name__ == "__main__":
-    with autocast_exclude_mps(device_type="cpu", dtype=torch.bfloat16):
-        model = ReferenceEncoder(
-            input_channels=128,
-            output_channels=64,
-            residual_channels=384,
-            residual_layers=20,
-            dilation_cycle=4,
-            num_heads=8,
-        )
-        x = torch.randn(4, 128, 64)
-        mask = torch.ones(4, 64, dtype=torch.bool)
-        y = model(x, mask)
-        print(y.shape)
-        loss = F.mse_loss(y, torch.randn(4, 64))
-        loss.backward()

+ 0 - 225
fish_speech/models/vqgan/modules/wavenet.py

@@ -1,225 +0,0 @@
-import math
-from typing import Optional
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-
-class Mish(nn.Module):
-    def forward(self, x):
-        return x * torch.tanh(F.softplus(x))
-
-
-class DiffusionEmbedding(nn.Module):
-    """Diffusion Step Embedding"""
-
-    def __init__(self, d_denoiser):
-        super(DiffusionEmbedding, self).__init__()
-        self.dim = d_denoiser
-
-    def forward(self, x):
-        device = x.device
-        half_dim = self.dim // 2
-        emb = math.log(10000) / (half_dim - 1)
-        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
-        emb = x[:, None] * emb[None, :]
-        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
-        return emb
-
-
-class LinearNorm(nn.Module):
-    """LinearNorm Projection"""
-
-    def __init__(self, in_features, out_features, bias=False):
-        super(LinearNorm, self).__init__()
-        self.linear = nn.Linear(in_features, out_features, bias)
-
-        nn.init.xavier_uniform_(self.linear.weight)
-        if bias:
-            nn.init.constant_(self.linear.bias, 0.0)
-
-    def forward(self, x):
-        x = self.linear(x)
-        return x
-
-
-class ConvNorm(nn.Module):
-    """1D Convolution"""
-
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        kernel_size=1,
-        stride=1,
-        padding=None,
-        dilation=1,
-        bias=True,
-        w_init_gain="linear",
-    ):
-        super(ConvNorm, self).__init__()
-
-        if padding is None:
-            assert kernel_size % 2 == 1
-            padding = int(dilation * (kernel_size - 1) / 2)
-
-        self.conv = nn.Conv1d(
-            in_channels,
-            out_channels,
-            kernel_size=kernel_size,
-            stride=stride,
-            padding=padding,
-            dilation=dilation,
-            bias=bias,
-        )
-        nn.init.kaiming_normal_(self.conv.weight)
-
-    def forward(self, signal):
-        conv_signal = self.conv(signal)
-
-        return conv_signal
-
-
-class ResidualBlock(nn.Module):
-    """Residual Block"""
-
-    def __init__(
-        self,
-        residual_channels,
-        use_linear_bias=False,
-        dilation=1,
-        condition_channels=None,
-    ):
-        super(ResidualBlock, self).__init__()
-        self.conv_layer = ConvNorm(
-            residual_channels,
-            2 * residual_channels,
-            kernel_size=3,
-            stride=1,
-            padding=dilation,
-            dilation=dilation,
-        )
-
-        if condition_channels is not None:
-            self.diffusion_projection = LinearNorm(
-                residual_channels, residual_channels, use_linear_bias
-            )
-            self.condition_projection = ConvNorm(
-                condition_channels, 2 * residual_channels, kernel_size=1
-            )
-
-        self.output_projection = ConvNorm(
-            residual_channels, 2 * residual_channels, kernel_size=1
-        )
-
-    def forward(self, x, condition=None, diffusion_step=None):
-        y = x
-
-        if diffusion_step is not None:
-            diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
-            y = y + diffusion_step
-
-        y = self.conv_layer(y)
-
-        if condition is not None:
-            condition = self.condition_projection(condition)
-            y = y + condition
-
-        gate, filter = torch.chunk(y, 2, dim=1)
-        y = torch.sigmoid(gate) * torch.tanh(filter)
-
-        y = self.output_projection(y)
-        residual, skip = torch.chunk(y, 2, dim=1)
-
-        return (x + residual) / math.sqrt(2.0), skip
-
-
-class WaveNet(nn.Module):
-    def __init__(
-        self,
-        input_channels: Optional[int] = None,
-        output_channels: Optional[int] = None,
-        residual_channels: int = 512,
-        residual_layers: int = 20,
-        dilation_cycle: Optional[int] = 4,
-        is_diffusion: bool = False,
-        condition_channels: Optional[int] = None,
-    ):
-        super().__init__()
-
-        # Input projection
-        self.input_projection = None
-        if input_channels is not None and input_channels != residual_channels:
-            self.input_projection = ConvNorm(
-                input_channels, residual_channels, kernel_size=1
-            )
-
-        if input_channels is None:
-            input_channels = residual_channels
-
-        self.input_channels = input_channels
-
-        # Residual layers
-        self.residual_layers = nn.ModuleList(
-            [
-                ResidualBlock(
-                    residual_channels=residual_channels,
-                    use_linear_bias=False,
-                    dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
-                    condition_channels=condition_channels,
-                )
-                for i in range(residual_layers)
-            ]
-        )
-
-        # Skip projection
-        self.skip_projection = ConvNorm(
-            residual_channels, residual_channels, kernel_size=1
-        )
-
-        # Output projection
-        self.output_projection = None
-        if output_channels is not None and output_channels != residual_channels:
-            self.output_projection = ConvNorm(
-                residual_channels, output_channels, kernel_size=1
-            )
-
-        if is_diffusion:
-            self.diffusion_embedding = DiffusionEmbedding(residual_channels)
-            self.mlp = nn.Sequential(
-                LinearNorm(residual_channels, residual_channels * 4, False),
-                Mish(),
-                LinearNorm(residual_channels * 4, residual_channels, False),
-            )
-
-        self.apply(self._init_weights)
-
-    def _init_weights(self, m):
-        if isinstance(m, (nn.Conv1d, nn.Linear)):
-            nn.init.trunc_normal_(m.weight, std=0.02)
-            if getattr(m, "bias", None) is not None:
-                nn.init.constant_(m.bias, 0)
-
-    def forward(self, x, t=None, condition=None):
-        if self.input_projection is not None:
-            x = self.input_projection(x)
-            x = F.silu(x)
-
-        if t is not None:
-            t = self.diffusion_embedding(t)
-            t = self.mlp(t)
-
-        skip = []
-        for layer in self.residual_layers:
-            x, skip_connection = layer(x, condition, t)
-            skip.append(skip_connection)
-
-        x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
-        x = self.skip_projection(x)
-
-        if self.output_projection is not None:
-            x = F.silu(x)
-            x = self.output_projection(x)
-
-        return x

+ 3 - 1
tools/vqgan/inference.py

@@ -103,7 +103,9 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
 
 
     # Restore
     # Restore
     feature_lengths = torch.tensor([indices.shape[1]], device=device)
     feature_lengths = torch.tensor([indices.shape[1]], device=device)
-    fake_audios = model.decode(indices=indices[None], feature_lengths=feature_lengths)
+    fake_audios, _ = model.decode(
+        indices=indices[None], feature_lengths=feature_lengths
+    )
     audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
     audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
 
 
     logger.info(
     logger.info(