Преглед изворни кода

Implement mel direct generation

Lengyue пре 2 година
родитељ
комит
a5699b169d

+ 31 - 52
fish_speech/configs/vqgan_pretrain_v2.yaml

@@ -3,16 +3,14 @@ defaults:
   - _self_
 
 project: vqgan_pretrain_v2
-ckpt_path: checkpoints/hifigan-base-comb-mix-lb-020/step_001200000_weights_only.ckpt
-resume_weights_only: true
 
 # Lightning Trainer
 trainer:
   accelerator: gpu
   devices: auto
   strategy: ddp_find_unused_parameters_true
-  precision: 32
-  max_steps: 1_000_000
+  precision: bf16-mixed
+  max_steps: 10_000_000
   val_check_interval: 5000
 
 sample_rate: 44100
@@ -63,38 +61,31 @@ model:
   _target_: fish_speech.models.vqgan.VQGAN
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
-  segment_size: 32768
-  mode: pretrain
-  freeze_discriminator: true
-
-  downsample:
-    _target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
-    dims: ["${num_mels}", 512, 256]
-    kernel_sizes: [3, 3]
-    strides: [2, 2]
-
-  mel_encoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WN
-    hidden_channels: 256
+
+  encoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
+    hidden_channels: 384
     kernel_size: 3
     dilation_rate: 2
-    n_layers: 12
+    n_layers: 10
+    in_channels: ${num_mels}
 
-  vq_encoder:
+  vq:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
-    in_channels: 256
-    vq_channels: 256
+    in_channels: 384
+    vq_channels: 384
     codebook_size: 256
-    codebook_groups: 4
-    downsample: 1
+    codebook_groups: 2
+    codebook_layers: 2
+    downsample: 4
 
   decoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WN
-    hidden_channels: 256
-    out_channels: ${num_mels}
+    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
+    hidden_channels: 384
     kernel_size: 3
     dilation_rate: 2
-    n_layers: 6
+    n_layers: 10
+    out_channels: ${num_mels}
 
   generator:
     _target_: fish_speech.models.vqgan.modules.decoder_v2.HiFiGANGenerator
@@ -108,27 +99,16 @@ model:
     use_template: true
     pre_conv_kernel_size: 7
     post_conv_kernel_size: 7
+    ckpt_path: checkpoints/hifigan-base-comb-mix-lb-020/step_001200000_weights_only.ckpt
+
+  discriminator:
+    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
+    hidden_channels: 256
+    kernel_size: 3
+    dilation_rate: 2
+    n_layers: 6
+    in_channels: ${num_mels}
 
-  discriminators:
-    _target_: torch.nn.ModuleDict
-    modules:
-      mpd:
-        _target_: fish_speech.models.vqgan.modules.discriminators.mpd.MultiPeriodDiscriminator
-        periods: [2, 3, 5, 7, 11, 17, 23, 37]
-
-      mrd:
-        _target_: fish_speech.models.vqgan.modules.discriminators.mrd.MultiResolutionDiscriminator
-        resolutions:
-          - ["${n_fft}", "${hop_length}", "${win_length}"]
-          - [1024, 120, 600]
-          - [2048, 240, 1200]
-          - [4096, 480, 2400]
-          - [512, 50, 240]
-
-  multi_resolution_stft_loss:
-    _target_: fish_speech.models.vqgan.losses.MultiResolutionSTFTLoss
-    resolutions: ${model.discriminators.modules.mrd.resolutions}
-  
   mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     sample_rate: ${sample_rate}
@@ -140,7 +120,7 @@ model:
   optimizer:
     _target_: torch.optim.AdamW
     _partial_: true
-    lr: 1e-4
+    lr: 2e-4
     betas: [0.8, 0.99]
     eps: 1e-5
 
@@ -152,8 +132,7 @@ model:
 callbacks:
   grad_norm_monitor:
     sub_module: 
-      - generator
-      - discriminators
-      - mel_encoder
-      - vq_encoder
+      - encoder
+      - vq
       - decoder
+      - discriminator

+ 91 - 283
fish_speech/models/vqgan/lit_module.py

@@ -9,6 +9,7 @@ import wandb
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from torch import nn
+from torch.utils.checkpoint import checkpoint as gradient_checkpoint
 
 from fish_speech.models.vqgan.losses import (
     MultiResolutionSTFTLoss,
@@ -16,19 +17,9 @@ from fish_speech.models.vqgan.losses import (
     feature_loss,
     generator_loss,
 )
-from fish_speech.models.vqgan.modules.balancer import Balancer
-from fish_speech.models.vqgan.modules.decoder import Generator
-from fish_speech.models.vqgan.modules.encoders import (
-    ConvDownSampler,
-    TextEncoder,
-    VQEncoder,
-)
-from fish_speech.models.vqgan.utils import (
-    plot_mel,
-    rand_slice_segments,
-    sequence_mask,
-    slice_segments,
-)
+from fish_speech.models.vqgan.modules.convnext import ConvNeXt
+from fish_speech.models.vqgan.modules.encoders import VQEncoder
+from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
 
 
 @dataclass
@@ -41,9 +32,7 @@ class VQEncodeResult:
 
 @dataclass
 class VQDecodeResult:
-    audios: torch.Tensor
     mels: torch.Tensor
-    mel_lengths: torch.Tensor
 
 
 class VQGAN(L.LightningModule):
@@ -51,19 +40,15 @@ class VQGAN(L.LightningModule):
         self,
         optimizer: Callable,
         lr_scheduler: Callable,
-        downsample: ConvDownSampler,
-        vq_encoder: VQEncoder,
-        mel_encoder: TextEncoder,
-        decoder: TextEncoder,
-        generator: Generator,
-        discriminators: nn.ModuleDict,
+        encoder: ConvNeXt,
+        vq: VQEncoder,
+        decoder: ConvNeXt,
+        generator: nn.Module,
+        discriminator: ConvNeXt,
         mel_transform: nn.Module,
-        segment_size: int = 20480,
         hop_length: int = 640,
         sample_rate: int = 32000,
-        mode: Literal["pretrain", "finetune"] = "finetune",
         freeze_discriminator: bool = False,
-        multi_resolution_stft_loss: Optional[MultiResolutionSTFTLoss] = None,
     ):
         super().__init__()
 
@@ -74,68 +59,41 @@ class VQGAN(L.LightningModule):
         self.optimizer_builder = optimizer
         self.lr_scheduler_builder = lr_scheduler
 
-        # Generator and discriminators
-        self.downsample = downsample
-        self.vq_encoder = vq_encoder
-        self.mel_encoder = mel_encoder
+        # Generator and discriminator
+        self.encoder = encoder
+        self.vq = vq
         self.decoder = decoder
         self.generator = generator
-        self.discriminators = discriminators
+        self.discriminator = discriminator
         self.mel_transform = mel_transform
         self.freeze_discriminator = freeze_discriminator
 
         # Crop length for saving memory
-        self.segment_size = segment_size
         self.hop_length = hop_length
         self.sampling_rate = sample_rate
-        self.mode = mode
 
         # Disable automatic optimization
         self.automatic_optimization = False
 
-        # Finetune: Train the VQ only
-        if self.mode == "finetune":
-            for p in self.vq_encoder.parameters():
-                p.requires_grad = False
-
-            for p in self.mel_encoder.parameters():
-                p.requires_grad = False
-
-            for p in self.downsample.parameters():
-                p.requires_grad = False
-
         if self.freeze_discriminator:
-            for p in self.discriminators.parameters():
+            for p in self.discriminator.parameters():
                 p.requires_grad = False
 
-        # Losses
-        self.multi_resolution_stft_loss = multi_resolution_stft_loss
-        loss_dict = {
-            "mel": 1,
-            "adv": 1,
-            "fm": 1,
-        }
-
-        if self.multi_resolution_stft_loss is not None:
-            loss_dict["stft"] = 1
-
-        self.balancer = Balancer(loss_dict)
+        # Freeze generator
+        for p in self.generator.parameters():
+            p.requires_grad = False
 
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
-        components = [
-            self.downsample.parameters(),
-            self.vq_encoder.parameters(),
-            self.mel_encoder.parameters(),
-        ]
-
-        if self.decoder is not None:
-            components.append(self.decoder.parameters())
-
-        components.append(self.generator.parameters())
-        optimizer_generator = self.optimizer_builder(itertools.chain(*components))
+        optimizer_generator = self.optimizer_builder(
+            itertools.chain(
+                self.encoder.parameters(),
+                self.vq.parameters(),
+                self.decoder.parameters(),
+            )
+        )
         optimizer_discriminator = self.optimizer_builder(
-            self.discriminators.parameters()
+            self.discriminator.parameters()
         )
 
         lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
@@ -171,13 +129,6 @@ class VQGAN(L.LightningModule):
         with torch.no_grad():
             gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
 
-        if self.mode == "finetune":
-            # Disable gradient computation for VQ
-            torch.set_grad_enabled(False)
-            self.vq_encoder.eval()
-            self.mel_encoder.eval()
-            self.downsample.eval()
-
         mel_lengths = audio_lengths // self.hop_length
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
             gt_mels.dtype
@@ -189,186 +140,80 @@ class VQGAN(L.LightningModule):
         if loss_vq.ndim > 1:
             loss_vq = loss_vq.mean()
 
-        if self.mode == "finetune":
-            # Enable gradient computation
-            torch.set_grad_enabled(True)
-
-        decoded = self.decode(
-            indices=vq_result.indices if self.mode == "finetune" else None,
-            features=vq_result.features if self.mode == "pretrain" else None,
+        decoded_mels = self.decode(
+            indices=None,
+            features=vq_result.features,
             audio_lengths=audio_lengths,
-            mel_only=True,
-        )
-        decoded_mels = decoded.mels
-        input_mels = gt_mels if self.mode == "pretrain" else decoded_mels
+        ).mels
 
-        if self.segment_size is not None:
-            audios, ids_slice = rand_slice_segments(
-                audios, audio_lengths, self.segment_size
-            )
-            input_mels = slice_segments(
-                input_mels,
-                ids_slice // self.hop_length,
-                self.segment_size // self.hop_length,
-            )
-            sliced_gt_mels = slice_segments(
-                gt_mels,
-                ids_slice // self.hop_length,
-                self.segment_size // self.hop_length,
-            )
-            gen_mel_masks = slice_segments(
-                mel_masks,
-                ids_slice // self.hop_length,
-                self.segment_size // self.hop_length,
-            )
-        else:
-            sliced_gt_mels = gt_mels
-            gen_mel_masks = mel_masks
+        with torch.no_grad():
+            with torch.autocast(device_type=audios.device.type, enabled=False):
+                fake_audios = self.generator(decoded_mels.float())
 
-        fake_audios = self.generator(input_mels)
-        fake_audio_mels = self.mel_transform(fake_audios.squeeze(1))
         assert (
             audios.shape == fake_audios.shape
         ), f"{audios.shape} != {fake_audios.shape}"
 
-        # Multi-Resolution STFT Loss
-        if self.multi_resolution_stft_loss is not None:
-            with torch.autocast(device_type=audios.device.type, enabled=False):
-                sc_loss, mag_loss = self.multi_resolution_stft_loss(
-                    fake_audios.squeeze(1).float(), audios.squeeze(1).float()
-                )
-                loss_stft = sc_loss + mag_loss
-
         # Discriminator
         if self.freeze_discriminator is False:
-            loss_disc_all = []
-
-            for key, disc in self.discriminators.items():
-                scores, _ = disc(audios)
-                score_fakes, _ = disc(fake_audios.detach())
-
-                with torch.autocast(device_type=audios.device.type, enabled=False):
-                    loss_disc, _, _ = discriminator_loss(scores, score_fakes)
-
-                self.log(
-                    f"train/discriminator/{key}",
-                    loss_disc,
-                    on_step=True,
-                    on_epoch=False,
-                    prog_bar=False,
-                    logger=True,
-                    sync_dist=True,
-                )
+            scores = self.discriminator(gt_mels)
+            score_fakes = self.discriminator(decoded_mels.detach())
 
-                loss_disc_all.append(loss_disc)
-
-            loss_disc_all = torch.stack(loss_disc_all).mean()
+            with torch.autocast(device_type=audios.device.type, enabled=False):
+                loss_disc, _, _ = discriminator_loss([scores], [score_fakes])
 
             self.log(
-                "train/discriminator/loss",
-                loss_disc_all,
+                f"train/discriminator/loss",
+                loss_disc,
                 on_step=True,
                 on_epoch=False,
-                prog_bar=True,
+                prog_bar=False,
                 logger=True,
                 sync_dist=True,
             )
 
             optim_d.zero_grad()
-            self.manual_backward(loss_disc_all)
+            self.manual_backward(loss_disc)
             self.clip_gradients(
                 optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
             )
             optim_d.step()
 
         # Adv Loss
-        loss_adv_all = []
-        loss_fm_all = []
-
-        for key, disc in self.discriminators.items():
-            score_fakes, feat_fake = disc(fake_audios)
-
-            # Adversarial Loss
-            with torch.autocast(device_type=audios.device.type, enabled=False):
-                loss_fake, _ = generator_loss(score_fakes)
-
-            self.log(
-                f"train/generator/adv_{key}",
-                loss_fake,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
-
-            loss_adv_all.append(loss_fake)
-
-            # Feature Matching Loss
-            _, feat_real = disc(audios)
-
-            with torch.autocast(device_type=audios.device.type, enabled=False):
-                loss_fm = feature_loss(feat_real, feat_fake)
-
-            self.log(
-                f"train/generator/adv_fm_{key}",
-                loss_fm,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
-
-            loss_fm_all.append(loss_fm)
-
-        loss_adv_all = torch.stack(loss_adv_all).mean()
-        loss_fm_all = torch.stack(loss_fm_all).mean()
+        score_fakes = self.discriminator(decoded_mels)
 
+        # Adversarial Loss
         with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
-            loss_mel = F.l1_loss(
-                sliced_gt_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
-            )
-
-            loss_dict = {
-                "mel": loss_mel,
-                "adv": loss_adv_all,
-                "fm": loss_fm_all,
-            }
+            loss_adv, _ = generator_loss([score_fakes])
 
-            if self.multi_resolution_stft_loss is not None:
-                loss_dict["stft"] = loss_stft
+        self.log(
+            f"train/generator/adv",
+            loss_adv,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
 
-            generator_out_grad = self.balancer.compute(
-                loss_dict,
-                fake_audios,
-            )
+        # Feature Matching Loss
+        score_gts = self.discriminator(gt_mels)
 
-            if self.mode == "pretrain":
-                loss_vq_all = loss_decoded_mel + loss_vq
+        with torch.autocast(device_type=audios.device.type, enabled=False):
+            loss_fm = feature_loss([score_gts], [score_fakes])
 
-        # Loss vq and loss decoded mel are only used in pretrain stage
-        if self.mode == "pretrain":
-            self.log(
-                "train/generator/loss_vq",
-                loss_vq,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
+        self.log(
+            f"train/generator/adv_fm",
+            loss_fm,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
 
-            self.log(
-                "train/generator/loss_decoded_mel",
-                loss_decoded_mel,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
+        with torch.autocast(device_type=audios.device.type, enabled=False):
+            loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
 
         self.log(
             "train/generator/loss_mel",
@@ -380,29 +225,20 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
 
-        if self.multi_resolution_stft_loss is not None:
-            self.log(
-                "train/generator/loss_stft",
-                loss_stft,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
-
         self.log(
-            "train/generator/loss_fm_all",
-            loss_fm_all,
+            "train/generator/loss_vq",
+            loss_vq,
             on_step=True,
             on_epoch=False,
             prog_bar=False,
             logger=True,
             sync_dist=True,
         )
+
+        loss = loss_mel * 20 + loss_vq + loss_adv + loss_fm
         self.log(
-            "train/generator/loss_adv_all",
-            loss_adv_all,
+            "train/generator/loss",
+            loss,
             on_step=True,
             on_epoch=False,
             prog_bar=False,
@@ -412,11 +248,7 @@ class VQGAN(L.LightningModule):
 
         optim_g.zero_grad()
 
-        # Only backpropagate loss_vq_all in pretrain stage
-        if self.mode == "pretrain":
-            self.manual_backward(loss_vq_all, retain_graph=True)
-
-        self.manual_backward(fake_audios, gradient=generator_out_grad)
+        self.manual_backward(loss)
         self.clip_gradients(
             optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
         )
@@ -440,19 +272,11 @@ class VQGAN(L.LightningModule):
         )
 
         vq_result = self.encode(audios, audio_lengths)
-        decoded = self.decode(
+        decoded_mels = self.decode(
             indices=vq_result.indices,
             audio_lengths=audio_lengths,
-            mel_only=self.mode == "pretrain",
-        )
-
-        decoded_mels = decoded.mels
-
-        # Use gt mel as input for pretrain
-        if self.mode == "pretrain":
-            fake_audios = self.generator(gt_mels)
-        else:
-            fake_audios = decoded.audios
+        ).mels
+        fake_audios = self.generator(decoded_mels)
 
         fake_mels = self.mel_transform(fake_audios.squeeze(1))
 
@@ -557,21 +381,25 @@ class VQGAN(L.LightningModule):
         with torch.no_grad():
             features = self.mel_transform(audios, sample_rate=self.sampling_rate)
 
-        if self.downsample is not None:
-            features = self.downsample(features)
-
         feature_lengths = (
             audio_lengths
             / self.hop_length
-            / (self.downsample.total_strides if self.downsample is not None else 1)
+            # / self.vq.downsample
         ).long()
 
+        # print(features.shape, feature_lengths.shape, torch.max(feature_lengths))
+
         feature_masks = torch.unsqueeze(
             sequence_mask(feature_lengths, features.shape[2]), 1
         ).to(features.dtype)
 
-        text_features = self.mel_encoder(features, feature_masks)
-        vq_features, indices, loss = self.vq_encoder(text_features, feature_masks)
+        features = (
+            gradient_checkpoint(
+                self.encoder, features, feature_masks, use_reentrant=False
+            )
+            * feature_masks
+        )
+        vq_features, indices, loss = self.vq(features, feature_masks)
 
         return VQEncodeResult(
             features=vq_features,
@@ -581,18 +409,13 @@ class VQGAN(L.LightningModule):
         )
 
     def calculate_audio_lengths(self, feature_lengths):
-        return (
-            feature_lengths
-            * self.hop_length
-            * (self.downsample.total_strides if self.downsample is not None else 1)
-        )
+        return feature_lengths * self.hop_length * self.vq.downsample
 
     def decode(
         self,
         indices=None,
         features=None,
         audio_lengths=None,
-        mel_only=False,
         feature_lengths=None,
     ):
         assert (
@@ -611,26 +434,11 @@ class VQGAN(L.LightningModule):
         ).float()
 
         if indices is not None:
-            features = self.vq_encoder.decode(indices)
-
-        features = F.interpolate(features, size=mel_masks.shape[2], mode="nearest")
+            features = self.vq.decode(indices)
 
         # Sample mels
-        if self.decoder is not None:
-            decoded_mels = self.decoder(features, mel_masks)
-        else:
-            decoded_mels = features
-
-        if mel_only:
-            return VQDecodeResult(
-                audios=None,
-                mels=decoded_mels,
-                mel_lengths=mel_lengths,
-            )
+        decoded = gradient_checkpoint(self.decoder, features, use_reentrant=False)
 
-        fake_audios = self.generator(decoded_mels)
         return VQDecodeResult(
-            audios=fake_audios,
-            mels=decoded_mels,
-            mel_lengths=mel_lengths,
+            mels=decoded,
         )

+ 3 - 4
fish_speech/models/vqgan/losses.py

@@ -6,10 +6,9 @@ from torch import nn
 def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]):
     loss = 0
     for dr, dg in zip(fmap_r, fmap_g):
-        for rl, gl in zip(dr, dg):
-            rl = rl.float().detach()
-            gl = gl.float()
-            loss += torch.mean(torch.abs(rl - gl))
+        dr = dr.float().detach()
+        dg = dg.float()
+        loss += torch.mean(torch.abs(dr - dg))
 
     return loss * 2
 

+ 0 - 37
fish_speech/models/vqgan/modules/condition.py

@@ -1,37 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-class MultiCondLayer(nn.Module):
-    def __init__(
-        self,
-        gin_channels: int,
-        out_channels: int,
-        n_cond: int,
-    ):
-        """MultiCondLayer of VITS model.
-
-        Args:
-            gin_channels (int): Number of conditioning tensor channels.
-            out_channels (int): Number of output tensor channels.
-            n_cond (int): Number of conditions.
-        """
-        super().__init__()
-        self.n_cond = n_cond
-
-        self.cond_layers = nn.ModuleList()
-        for _ in range(n_cond):
-            self.cond_layers.append(nn.Linear(gin_channels, out_channels))
-
-    def forward(self, cond: torch.Tensor, x_mask: torch.Tensor):
-        """
-        Shapes:
-            - cond: :math:`[B, C, N]`
-            - x_mask: :math`[B, 1, T]`
-        """
-
-        cond_out = torch.zeros_like(cond)
-        for i in range(self.n_cond):
-            cond_in = self.cond_layers[i](cond.mT).mT
-            cond_out = cond_out + cond_in
-        return cond_out * x_mask

+ 315 - 0
fish_speech/models/vqgan/modules/convnext.py

@@ -0,0 +1,315 @@
+from functools import partial
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from vocos.spectral_ops import ISTFT
+
+
+def drop_path(
+    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+    'survival rate' as the argument.
+
+    """  # noqa: E501
+
+    if drop_prob == 0.0 or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (
+        x.ndim - 1
+    )  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+    if keep_prob > 0.0 and scale_by_keep:
+        random_tensor.div_(keep_prob)
+    return x * random_tensor
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""  # noqa: E501
+
+    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+        self.scale_by_keep = scale_by_keep
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+    def extra_repr(self):
+        return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class LayerNorm(nn.Module):
+    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+    with shape (batch_size, channels, height, width).
+    """  # noqa: E501
+
+    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(normalized_shape))
+        self.bias = nn.Parameter(torch.zeros(normalized_shape))
+        self.eps = eps
+        self.data_format = data_format
+        if self.data_format not in ["channels_last", "channels_first"]:
+            raise NotImplementedError
+        self.normalized_shape = (normalized_shape,)
+
+    def forward(self, x):
+        if self.data_format == "channels_last":
+            return F.layer_norm(
+                x, self.normalized_shape, self.weight, self.bias, self.eps
+            )
+        elif self.data_format == "channels_first":
+            u = x.mean(1, keepdim=True)
+            s = (x - u).pow(2).mean(1, keepdim=True)
+            x = (x - u) / torch.sqrt(s + self.eps)
+            x = self.weight[:, None] * x + self.bias[:, None]
+            return x
+
+
+class ConvNeXtBlock(nn.Module):
+    r"""ConvNeXt Block. There are two equivalent implementations:
+    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+    We use (2) as we find it slightly faster in PyTorch
+
+    Args:
+        dim (int): Number of input channels.
+        drop_path (float): Stochastic depth rate. Default: 0.0
+        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+        kernel_size (int): Kernel size for depthwise conv. Default: 7.
+        dilation (int): Dilation for depthwise conv. Default: 1.
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        dim: int,
+        drop_path: float = 0.0,
+        layer_scale_init_value: float = 1e-6,
+        mlp_ratio: float = 4.0,
+        kernel_size: int = 7,
+        dilation: int = 1,
+    ):
+        super().__init__()
+
+        self.dwconv = nn.Conv1d(
+            dim,
+            dim,
+            kernel_size=kernel_size,
+            padding=int(dilation * (kernel_size - 1) / 2),
+            groups=dim,
+        )  # depthwise conv
+        self.norm = LayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(
+            dim, int(mlp_ratio * dim)
+        )  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
+        self.gamma = (
+            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+            if layer_scale_init_value > 0
+            else None
+        )
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+    def forward(self, x, apply_residual: bool = True):
+        input = x
+
+        x = self.dwconv(x)
+        x = x.permute(0, 2, 1)  # (N, C, L) -> (N, L, C)
+        x = self.norm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+
+        if self.gamma is not None:
+            x = self.gamma * x
+
+        x = x.permute(0, 2, 1)  # (N, L, C) -> (N, C, L)
+        x = self.drop_path(x)
+
+        if apply_residual:
+            x = input + x
+
+        return x
+
+
+class ParallelConvNeXtBlock(nn.Module):
+    def __init__(self, kernel_sizes: list[int], *args, **kwargs):
+        super().__init__()
+        self.blocks = nn.ModuleList(
+            [
+                ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
+                for kernel_size in kernel_sizes
+            ]
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return torch.stack(
+            [block(x, apply_residual=False) for block in self.blocks] + [x],
+            dim=1,
+        ).sum(dim=1)
+
+
+class ConvNeXt(nn.Module):
+    def __init__(
+        self,
+        input_channels: int = 3,
+        depths: list[int] = [3, 3, 9, 3],
+        dims: list[int] = [96, 192, 384, 768],
+        drop_path_rate: float = 0.0,
+        layer_scale_init_value: float = 1e-6,
+        kernel_sizes: tuple[int] = (7,),
+    ):
+        super().__init__()
+        assert len(depths) == len(dims)
+
+        self.channel_layers = nn.ModuleList()
+        stem = nn.Sequential(
+            nn.Conv1d(
+                input_channels,
+                dims[0],
+                kernel_size=7,
+                padding=3,
+                # padding_mode="replicate",
+                padding_mode="zeros",
+            ),
+            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+        )
+        self.channel_layers.append(stem)
+
+        for i in range(len(depths) - 1):
+            mid_layer = nn.Sequential(
+                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+                nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
+            )
+            self.channel_layers.append(mid_layer)
+
+        block_fn = (
+            partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
+            if len(kernel_sizes) == 1
+            else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
+        )
+
+        self.stages = nn.ModuleList()
+        drop_path_rates = [
+            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+        ]
+
+        cur = 0
+        for i in range(len(depths)):
+            stage = nn.Sequential(
+                *[
+                    block_fn(
+                        dim=dims[i],
+                        drop_path=drop_path_rates[cur + j],
+                        layer_scale_init_value=layer_scale_init_value,
+                    )
+                    for j in range(depths[i])
+                ]
+            )
+            self.stages.append(stage)
+            cur += depths[i]
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Conv1d, nn.Linear)):
+            nn.init.trunc_normal_(m.weight, std=0.02)
+            nn.init.constant_(m.bias, 0)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        return_features: bool = False,
+    ) -> torch.Tensor:
+        features = []
+
+        for channel_layer, stage in zip(self.channel_layers, self.stages):
+            x = channel_layer(x)
+            x = stage(x)
+
+            if return_features:
+                features.append(x)
+
+        if return_features:
+            return features
+
+        return x
+
+
+class ISTFTHead(nn.Module):
+    """
+    ISTFT Head module for predicting STFT complex coefficients.
+
+    Args:
+        dim (int): Hidden dimension of the model.
+        n_fft (int): Size of Fourier transform.
+        hop_length (int): The distance between neighboring sliding window frames, which should align with
+                          the resolution of the input features.
+        win_length (int): The size of window frame and STFT filter.
+        padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        dim: int,
+        n_fft: int,
+        hop_length: int,
+        win_length: int,
+        padding: str = "same",
+    ):
+        super().__init__()
+
+        self.n_fft = n_fft
+        self.hop_length = hop_length
+        self.win_length = win_length
+
+        self.istft = ISTFT(
+            n_fft=n_fft,
+            hop_length=hop_length,
+            win_length=win_length,
+            padding=padding,
+        )
+
+        out_dim = n_fft * 2
+        self.out = nn.Conv1d(dim, out_dim, 1)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Forward pass of the ISTFTHead module.
+
+        Args:
+            x (Tensor): Input tensor of shape (B, H, L), where B is the batch size,
+                        L is the sequence length, and H denotes the model dimension.
+
+        Returns:
+            Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+        """  # noqa: E501
+
+        x = self.out(x)
+
+        mag, p = x.chunk(2, dim=1)
+        mag = torch.exp(mag)
+        mag = torch.clip(
+            mag, max=1e2
+        )  # safeguard to prevent excessively large magnitudes
+
+        # wrapping happens here. These two lines produce real and imaginary value
+        x = torch.cos(p)
+        y = torch.sin(p)
+
+        S = mag * (x + 1j * y)
+
+        x = self.istft(S)
+        return x.unsqueeze(1)

+ 0 - 270
fish_speech/models/vqgan/modules/decoder.py

@@ -1,270 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn.utils.parametrizations import weight_norm
-from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
-
-from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
-from fish_speech.models.vqgan.utils import get_padding, init_weights
-
-
-class Generator(nn.Module):
-    def __init__(
-        self,
-        initial_channel,
-        resblock,
-        resblock_kernel_sizes,
-        resblock_dilation_sizes,
-        upsample_rates,
-        upsample_initial_channel,
-        upsample_kernel_sizes,
-        gin_channels=0,
-        ckpt_path=None,
-    ):
-        super(Generator, self).__init__()
-        self.num_kernels = len(resblock_kernel_sizes)
-        self.num_upsamples = len(upsample_rates)
-        self.conv_pre = weight_norm(
-            nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
-        )
-        resblock = ResBlock1 if resblock == "1" else ResBlock2
-
-        self.ups = nn.ModuleList()
-        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
-            self.ups.append(
-                weight_norm(
-                    nn.ConvTranspose1d(
-                        upsample_initial_channel // (2**i),
-                        upsample_initial_channel // (2 ** (i + 1)),
-                        k,
-                        u,
-                        padding=(k - u) // 2,
-                    )
-                )
-            )
-
-        self.resblocks = nn.ModuleList()
-        for i in range(len(self.ups)):
-            ch = upsample_initial_channel // (2 ** (i + 1))
-            for j, (k, d) in enumerate(
-                zip(resblock_kernel_sizes, resblock_dilation_sizes)
-            ):
-                self.resblocks.append(resblock(ch, k, d))
-
-        self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
-        self.ups.apply(init_weights)
-
-        if gin_channels != 0:
-            self.cond = nn.Linear(gin_channels, upsample_initial_channel)
-
-        if ckpt_path is not None:
-            self.load_state_dict(torch.load(ckpt_path)["generator"], strict=True)
-
-    def forward(self, x, g=None):
-        x = self.conv_pre(x)
-        if g is not None:
-            x = x + self.cond(g.mT).mT
-
-        for i in range(self.num_upsamples):
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            x = self.ups[i](x)
-            xs = None
-            for j in range(self.num_kernels):
-                if xs is None:
-                    xs = self.resblocks[i * self.num_kernels + j](x)
-                else:
-                    xs += self.resblocks[i * self.num_kernels + j](x)
-            x = xs / self.num_kernels
-        x = F.leaky_relu(x)
-        x = self.conv_post(x)
-        x = torch.tanh(x)
-
-        return x
-
-    def remove_weight_norm(self):
-        print("Removing weight norm...")
-        for l in self.ups:
-            remove_weight_norm(l)
-        for l in self.resblocks:
-            l.remove_weight_norm()
-
-
-class ResBlock1(nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
-        super(ResBlock1, self).__init__()
-        self.convs1 = nn.ModuleList(
-            [
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[2],
-                        padding=get_padding(kernel_size, dilation[2]),
-                    )
-                ),
-            ]
-        )
-        self.convs1.apply(init_weights)
-
-        self.convs2 = nn.ModuleList(
-            [
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-            ]
-        )
-        self.convs2.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c1, c2 in zip(self.convs1, self.convs2):
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c1(xt)
-            xt = F.leaky_relu(xt, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c2(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_weight_norm(self):
-        for l in self.convs1:
-            remove_weight_norm(l)
-        for l in self.convs2:
-            remove_weight_norm(l)
-
-
-class ResBlock2(nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
-        super(ResBlock2, self).__init__()
-        self.convs = nn.ModuleList(
-            [
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-            ]
-        )
-        self.convs.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c in self.convs:
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_weight_norm(self):
-        for l in self.convs:
-            remove_weight_norm(l)
-
-
-if __name__ == "__main__":
-    import librosa
-    import soundfile as sf
-
-    from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
-
-    gen = Generator(
-        80,
-        "1",
-        [3, 7, 11],
-        [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
-        [8, 8, 2, 2],
-        512,
-        [16, 16, 4, 4],
-        ckpt_path="checkpoints/hifigan-v1-universal-22050/g_02500000",
-    )
-
-    spec = LogMelSpectrogram(
-        sample_rate=22050,
-        n_fft=1024,
-        win_length=1024,
-        hop_length=256,
-        n_mels=80,
-        f_min=0.0,
-        f_max=8000.0,
-    )
-
-    audio = librosa.load("data/StarRail/Chinese/符玄/archive_fuxuan_9.wav", sr=22050)[0]
-    audio = torch.from_numpy(audio).unsqueeze(0)
-
-    spec = spec(audio)
-    print(spec.shape)
-
-    audio = gen(spec)
-    print(audio.shape)
-
-    sf.write("test.wav", audio.detach().squeeze().numpy(), 22050)

+ 45 - 0
fish_speech/models/vqgan/modules/decoder_v2.py

@@ -150,6 +150,7 @@ class HiFiGANGenerator(nn.Module):
         post_conv_kernel_size: int = 7,
         post_activation: Callable = partial(nn.SiLU, inplace=True),
         checkpointing: bool = False,
+        ckpt_path: str = None,
     ):
         super().__init__()
 
@@ -229,6 +230,17 @@ class HiFiGANGenerator(nn.Module):
         # Gradient checkpointing
         self.checkpointing = checkpointing
 
+        if ckpt_path is not None:
+            states = torch.load(ckpt_path, map_location="cpu")
+            if "state_dict" in states:
+                states = states["state_dict"]
+            states = {
+                k.replace("generator.", ""): v
+                for k, v in states.items()
+                if k.startswith("generator")
+            }
+            self.load_state_dict(states, strict=True)
+
     def forward(self, x, template=None):
         if self.use_template and template is None:
             length = x.shape[-1] * self.hop_length
@@ -268,3 +280,36 @@ class HiFiGANGenerator(nn.Module):
             block.remove_parametrizations()
         remove_parametrizations(self.conv_pre)
         remove_parametrizations(self.conv_post)
+
+
+if __name__ == "__main__":
+    import torchaudio
+
+    from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
+
+    spec = LogMelSpectrogram(n_mels=160)
+    audio, sr = torchaudio.load("test.wav")
+    audio = audio[None, :]
+    spec = spec(audio, sample_rate=sr)
+
+    model = HiFiGANGenerator(
+        hop_length=512,
+        upsample_rates=(8, 8, 2, 2, 2),
+        upsample_kernel_sizes=(16, 16, 4, 4, 4),
+        resblock_kernel_sizes=(3, 7, 11),
+        resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+        num_mels=160,
+        upsample_initial_channel=512,
+        use_template=True,
+        pre_conv_kernel_size=7,
+        post_conv_kernel_size=7,
+        post_activation=partial(nn.SiLU, inplace=True),
+        ckpt_path="checkpoints/hifigan-base-comb-mix-lb-020/step_001200000_weights_only.ckpt",
+    )
+
+    print(model)
+
+    out = model(spec)
+    print(out.shape)
+
+    torchaudio.save("out.wav", out[0], 44100)

+ 0 - 80
fish_speech/models/vqgan/modules/discriminators/mpd.py

@@ -1,80 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn.utils.parametrizations import weight_norm
-
-
-class DiscriminatorP(nn.Module):
-    def __init__(
-        self,
-        *,
-        period: int,
-        kernel_size: int = 5,
-        stride: int = 3,
-        channels: tuple[int] = (1, 64, 128, 256, 512, 1024),
-    ) -> None:
-        super(DiscriminatorP, self).__init__()
-
-        self.period = period
-        self.convs = nn.ModuleList(
-            [
-                weight_norm(
-                    nn.Conv2d(
-                        in_channels,
-                        out_channels,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(kernel_size // 2, 0),
-                    )
-                )
-                for in_channels, out_channels in zip(channels[:-1], channels[1:])
-            ]
-        )
-
-        self.conv_post = weight_norm(
-            nn.Conv2d(channels[-1], 1, (3, 1), 1, padding=(1, 0))
-        )
-
-    def forward(self, x):
-        fmap = []
-
-        # 1d to 2d
-        b, c, t = x.shape
-        if t % self.period != 0:  # pad first
-            n_pad = self.period - (t % self.period)
-            x = F.pad(x, (0, n_pad), "constant")
-            t = t + n_pad
-        x = x.view(b, c, t // self.period, self.period)
-
-        for conv in self.convs:
-            x = conv(x)
-            x = F.silu(x, inplace=True)
-            fmap.append(x)
-
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class MultiPeriodDiscriminator(nn.Module):
-    def __init__(self, periods: tuple[int] = (2, 3, 5, 7, 11)) -> None:
-        super().__init__()
-
-        self.discriminators = nn.ModuleList(
-            [DiscriminatorP(period=period) for period in periods]
-        )
-
-    def forward(
-        self, x: torch.Tensor
-    ) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]:
-        scores, feature_map = [], []
-
-        for disc in self.discriminators:
-            res, fmap = disc(x)
-
-            scores.append(res)
-            feature_map.append(fmap)
-
-        return scores, feature_map

+ 0 - 100
fish_speech/models/vqgan/modules/discriminators/mrd.py

@@ -1,100 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn.utils.parametrizations import weight_norm
-
-
-class DiscriminatorR(torch.nn.Module):
-    def __init__(
-        self,
-        *,
-        n_fft: int = 1024,
-        hop_length: int = 120,
-        win_length: int = 600,
-    ):
-        super(DiscriminatorR, self).__init__()
-
-        self.n_fft = n_fft
-        self.hop_length = hop_length
-        self.win_length = win_length
-
-        self.convs = nn.ModuleList(
-            [
-                weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
-                weight_norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
-                weight_norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
-                weight_norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
-                weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
-            ]
-        )
-
-        self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
-
-    def forward(self, x):
-        fmap = []
-
-        x = self.spectrogram(x)
-        x = x.unsqueeze(1)
-
-        for conv in self.convs:
-            x = conv(x)
-            x = F.silu(x, inplace=True)
-            fmap.append(x)
-
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-    def spectrogram(self, x):
-        x = F.pad(
-            x,
-            (
-                (self.n_fft - self.hop_length) // 2,
-                (self.n_fft - self.hop_length + 1) // 2,
-            ),
-            mode="reflect",
-        )
-        x = x.squeeze(1)
-        x = torch.stft(
-            x,
-            n_fft=self.n_fft,
-            hop_length=self.hop_length,
-            win_length=self.win_length,
-            center=False,
-            return_complex=True,
-        )
-        x = torch.view_as_real(x)  # [B, F, TT, 2]
-        mag = torch.norm(x, p=2, dim=-1)  # [B, F, TT]
-
-        return mag
-
-
-class MultiResolutionDiscriminator(torch.nn.Module):
-    def __init__(self, resolutions: list[tuple[int]]):
-        super().__init__()
-
-        self.discriminators = nn.ModuleList(
-            [
-                DiscriminatorR(
-                    n_fft=n_fft,
-                    hop_length=hop_length,
-                    win_length=win_length,
-                )
-                for n_fft, hop_length, win_length in resolutions
-            ]
-        )
-
-    def forward(
-        self, x: torch.Tensor
-    ) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]:
-        scores, feature_map = [], []
-
-        for disc in self.discriminators:
-            res, fmap = disc(x)
-
-            scores.append(res)
-            feature_map.append(fmap)
-
-        return scores, feature_map

+ 0 - 188
fish_speech/models/vqgan/modules/discriminators/mssbcqtd.py

@@ -1,188 +0,0 @@
-# Copyright (c) 2023 Amphion.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# Monkey patching to fix a bug in nnAudio
-import numpy as np
-import torch
-import torchaudio.transforms as T
-from einops import rearrange
-from nnAudio import features
-from torch import nn
-
-from .msstftd import NormConv2d, get_2d_padding
-
-np.float = float
-
-LRELU_SLOPE = 0.1
-
-
-class DiscriminatorCQT(nn.Module):
-    def __init__(
-        self,
-        hop_length,
-        n_octaves,
-        bins_per_octave,
-        filters=32,
-        max_filters=1024,
-        filters_scale=1,
-        dilations=[1, 2, 4],
-        in_channels=1,
-        out_channels=1,
-        sample_rate=16000,
-    ):
-        super().__init__()
-
-        self.filters = filters
-        self.max_filters = max_filters
-        self.filters_scale = filters_scale
-        self.kernel_size = (3, 9)
-        self.dilations = dilations
-        self.stride = (1, 2)
-
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.fs = sample_rate
-        self.hop_length = hop_length
-        self.n_octaves = n_octaves
-        self.bins_per_octave = bins_per_octave
-
-        self.cqt_transform = features.cqt.CQT2010v2(
-            sr=self.fs * 2,
-            hop_length=self.hop_length,
-            n_bins=self.bins_per_octave * self.n_octaves,
-            bins_per_octave=self.bins_per_octave,
-            output_format="Complex",
-            pad_mode="constant",
-        )
-
-        self.conv_pres = nn.ModuleList()
-        for i in range(self.n_octaves):
-            self.conv_pres.append(
-                NormConv2d(
-                    self.in_channels * 2,
-                    self.in_channels * 2,
-                    kernel_size=self.kernel_size,
-                    padding=get_2d_padding(self.kernel_size),
-                )
-            )
-
-        self.convs = nn.ModuleList()
-
-        self.convs.append(
-            NormConv2d(
-                self.in_channels * 2,
-                self.filters,
-                kernel_size=self.kernel_size,
-                padding=get_2d_padding(self.kernel_size),
-            )
-        )
-
-        in_chs = min(self.filters_scale * self.filters, self.max_filters)
-        for i, dilation in enumerate(self.dilations):
-            out_chs = min(
-                (self.filters_scale ** (i + 1)) * self.filters, self.max_filters
-            )
-            self.convs.append(
-                NormConv2d(
-                    in_chs,
-                    out_chs,
-                    kernel_size=self.kernel_size,
-                    stride=self.stride,
-                    dilation=(dilation, 1),
-                    padding=get_2d_padding(self.kernel_size, (dilation, 1)),
-                    norm="weight_norm",
-                )
-            )
-            in_chs = out_chs
-        out_chs = min(
-            (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
-            self.max_filters,
-        )
-        self.convs.append(
-            NormConv2d(
-                in_chs,
-                out_chs,
-                kernel_size=(self.kernel_size[0], self.kernel_size[0]),
-                padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
-                norm="weight_norm",
-            )
-        )
-
-        self.conv_post = NormConv2d(
-            out_chs,
-            self.out_channels,
-            kernel_size=(self.kernel_size[0], self.kernel_size[0]),
-            padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
-            norm="weight_norm",
-        )
-
-        self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE)
-        self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2)
-
-    def forward(self, x):
-        fmap = []
-
-        x = self.resample(x)
-
-        z = self.cqt_transform(x)
-
-        z_amplitude = z[:, :, :, 0].unsqueeze(1)
-        z_phase = z[:, :, :, 1].unsqueeze(1)
-
-        z = torch.cat([z_amplitude, z_phase], dim=1)
-        z = rearrange(z, "b c w t -> b c t w")
-
-        latent_z = []
-        for i in range(self.n_octaves):
-            latent_z.append(
-                self.conv_pres[i](
-                    z[
-                        :,
-                        :,
-                        :,
-                        i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
-                    ]
-                )
-            )
-        latent_z = torch.cat(latent_z, dim=-1)
-
-        for i, layer in enumerate(self.convs):
-            latent_z = layer(latent_z)
-
-            latent_z = self.activation(latent_z)
-            fmap.append(latent_z)
-
-        latent_z = self.conv_post(latent_z)
-
-        return latent_z, fmap
-
-
-class MultiScaleSubbandCQTDiscriminator(nn.Module):
-    def __init__(self, hop_lengths, n_octaves, bins_per_octaves, **kwargs):
-        super().__init__()
-
-        self.discriminators = nn.ModuleList(
-            [
-                DiscriminatorCQT(
-                    hop_length=hop_length,
-                    n_octaves=n_octaves,
-                    bins_per_octave=bins_per_octave,
-                    **kwargs,
-                )
-                for hop_length, n_octaves, bins_per_octave in zip(
-                    hop_lengths, n_octaves, bins_per_octaves
-                )
-            ]
-        )
-
-    def forward(self, x: torch.Tensor):
-        logits = []
-        fmaps = []
-        for disc in self.discriminators:
-            logit, fmap = disc(x)
-            logits.append(logit)
-            fmaps.append(fmap)
-
-        return logits, fmaps

+ 0 - 303
fish_speech/models/vqgan/modules/discriminators/msstftd.py

@@ -1,303 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""MS-STFT discriminator, provided here for reference."""
-
-import typing as tp
-
-import einops
-import torch
-import torchaudio
-from einops import rearrange
-from torch import nn
-from torch.nn.utils import spectral_norm, weight_norm
-
-FeatureMapType = tp.List[torch.Tensor]
-LogitsType = torch.Tensor
-DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
-
-
-class ConvLayerNorm(nn.LayerNorm):
-    """
-    Convolution-friendly LayerNorm that moves channels to last dimensions
-    before running the normalization and moves them back to original position right after.
-    """  # noqa: E501
-
-    def __init__(
-        self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
-    ):
-        super().__init__(normalized_shape, **kwargs)
-
-    def forward(self, x):
-        x = einops.rearrange(x, "b ... t -> b t ...")
-        x = super().forward(x)
-        x = einops.rearrange(x, "b t ... -> b ... t")
-        return
-
-
-CONV_NORMALIZATIONS = frozenset(
-    [
-        "none",
-        "weight_norm",
-        "spectral_norm",
-        "time_layer_norm",
-        "layer_norm",
-        "time_group_norm",
-    ]
-)
-
-
-def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
-    assert norm in CONV_NORMALIZATIONS
-    if norm == "weight_norm":
-        return weight_norm(module)
-    elif norm == "spectral_norm":
-        return spectral_norm(module)
-    else:
-        # We already check was in CONV_NORMALIZATION, so any other choice
-        # doesn't need reparametrization.
-        return module
-
-
-def get_norm_module(
-    module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
-) -> nn.Module:
-    """Return the proper normalization module. If causal is True, this will ensure the returned
-    module is causal, or return an error if the normalization doesn't support causal evaluation.
-    """  # noqa: E501
-    assert norm in CONV_NORMALIZATIONS
-    if norm == "layer_norm":
-        assert isinstance(module, nn.modules.conv._ConvNd)
-        return ConvLayerNorm(module.out_channels, **norm_kwargs)
-    elif norm == "time_group_norm":
-        if causal:
-            raise ValueError("GroupNorm doesn't support causal evaluation.")
-        assert isinstance(module, nn.modules.conv._ConvNd)
-        return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
-    else:
-        return nn.Identity()
-
-
-class NormConv2d(nn.Module):
-    """Wrapper around Conv2d and normalization applied to this conv
-    to provide a uniform interface across normalization approaches.
-    """
-
-    def __init__(
-        self,
-        *args,
-        norm: str = "none",
-        norm_kwargs: tp.Dict[str, tp.Any] = {},
-        **kwargs,
-    ):
-        super().__init__()
-        self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
-        self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
-        self.norm_type = norm
-
-    def forward(self, x):
-        x = self.conv(x)
-        x = self.norm(x)
-        return x
-
-
-def get_2d_padding(
-    kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)
-):
-    return (
-        ((kernel_size[0] - 1) * dilation[0]) // 2,
-        ((kernel_size[1] - 1) * dilation[1]) // 2,
-    )
-
-
-class DiscriminatorSTFT(nn.Module):
-    """STFT sub-discriminator.
-    Args:
-        filters (int): Number of filters in convolutions
-        in_channels (int): Number of input channels. Default: 1
-        out_channels (int): Number of output channels. Default: 1
-        n_fft (int): Size of FFT for each scale. Default: 1024
-        hop_length (int): Length of hop between STFT windows for each scale. Default: 256
-        kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
-        stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
-        dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
-        win_length (int): Window size for each scale. Default: 1024
-        normalized (bool): Whether to normalize by magnitude after stft. Default: True
-        norm (str): Normalization method. Default: `'weight_norm'`
-        activation (str): Activation function. Default: `'LeakyReLU'`
-        activation_params (dict): Parameters to provide to the activation function.
-        growth (int): Growth factor for the filters. Default: 1
-    """  # noqa: E501
-
-    def __init__(
-        self,
-        filters: int,
-        in_channels: int = 1,
-        out_channels: int = 1,
-        n_fft: int = 1024,
-        hop_length: int = 256,
-        win_length: int = 1024,
-        max_filters: int = 1024,
-        filters_scale: int = 1,
-        kernel_size: tp.Tuple[int, int] = (3, 9),
-        dilations: tp.List = [1, 2, 4],
-        stride: tp.Tuple[int, int] = (1, 2),
-        normalized: bool = True,
-        norm: str = "weight_norm",
-        activation: str = "LeakyReLU",
-        activation_params: dict = {"negative_slope": 0.2},
-    ):
-        super().__init__()
-        assert len(kernel_size) == 2
-        assert len(stride) == 2
-        self.filters = filters
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.n_fft = n_fft
-        self.hop_length = hop_length
-        self.win_length = win_length
-        self.normalized = normalized
-        self.activation = getattr(torch.nn, activation)(**activation_params)
-        self.spec_transform = torchaudio.transforms.Spectrogram(
-            n_fft=self.n_fft,
-            hop_length=self.hop_length,
-            win_length=self.win_length,
-            window_fn=torch.hann_window,
-            normalized=self.normalized,
-            center=False,
-            pad_mode=None,
-            power=None,
-        )
-        spec_channels = 2 * self.in_channels
-        self.convs = nn.ModuleList()
-        self.convs.append(
-            NormConv2d(
-                spec_channels,
-                self.filters,
-                kernel_size=kernel_size,
-                padding=get_2d_padding(kernel_size),
-            )
-        )
-        in_chs = min(filters_scale * self.filters, max_filters)
-        for i, dilation in enumerate(dilations):
-            out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
-            self.convs.append(
-                NormConv2d(
-                    in_chs,
-                    out_chs,
-                    kernel_size=kernel_size,
-                    stride=stride,
-                    dilation=(dilation, 1),
-                    padding=get_2d_padding(kernel_size, (dilation, 1)),
-                    norm=norm,
-                )
-            )
-            in_chs = out_chs
-        out_chs = min(
-            (filters_scale ** (len(dilations) + 1)) * self.filters, max_filters
-        )
-        self.convs.append(
-            NormConv2d(
-                in_chs,
-                out_chs,
-                kernel_size=(kernel_size[0], kernel_size[0]),
-                padding=get_2d_padding((kernel_size[0], kernel_size[0])),
-                norm=norm,
-            )
-        )
-        self.conv_post = NormConv2d(
-            out_chs,
-            self.out_channels,
-            kernel_size=(kernel_size[0], kernel_size[0]),
-            padding=get_2d_padding((kernel_size[0], kernel_size[0])),
-            norm=norm,
-        )
-
-    def forward(self, x: torch.Tensor):
-        fmap = []
-        z = self.spec_transform(x)  # [B, 2, Freq, Frames, 2]
-        z = torch.cat([z.real, z.imag], dim=1)
-        z = rearrange(z, "b c w t -> b c t w")
-        for i, layer in enumerate(self.convs):
-            z = layer(z)
-            z = self.activation(z)
-            fmap.append(z)
-        z = self.conv_post(z)
-        return z, fmap
-
-
-class MultiScaleSTFTDiscriminator(nn.Module):
-    """Multi-Scale STFT (MS-STFT) discriminator.
-    Args:
-        filters (int): Number of filters in convolutions
-        in_channels (int): Number of input channels. Default: 1
-        out_channels (int): Number of output channels. Default: 1
-        n_ffts (Sequence[int]): Size of FFT for each scale
-        hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
-        win_lengths (Sequence[int]): Window size for each scale
-        **kwargs: additional args for STFTDiscriminator
-    """
-
-    def __init__(
-        self,
-        filters: int,
-        in_channels: int = 1,
-        out_channels: int = 1,
-        n_ffts: tp.List[int] = [1024, 2048, 512],
-        hop_lengths: tp.List[int] = [256, 512, 128],
-        win_lengths: tp.List[int] = [1024, 2048, 512],
-        **kwargs,
-    ):
-        super().__init__()
-        assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
-        self.discriminators = nn.ModuleList(
-            [
-                DiscriminatorSTFT(
-                    filters,
-                    in_channels=in_channels,
-                    out_channels=out_channels,
-                    n_fft=n_ffts[i],
-                    win_length=win_lengths[i],
-                    hop_length=hop_lengths[i],
-                    **kwargs,
-                )
-                for i in range(len(n_ffts))
-            ]
-        )
-        self.num_discriminators = len(self.discriminators)
-
-    def forward(self, x: torch.Tensor) -> DiscriminatorOutput:
-        logits = []
-        fmaps = []
-        for disc in self.discriminators:
-            logit, fmap = disc(x)
-            logits.append(logit)
-            fmaps.append(fmap)
-
-        return logits, fmaps
-
-
-def test():
-    disc = MultiScaleSTFTDiscriminator(filters=32)
-    y = torch.randn(1, 1, 24000)
-    y_hat = torch.randn(1, 1, 24000)
-
-    y_disc_r, fmap_r = disc(y)
-    y_disc_gen, fmap_gen = disc(y_hat)
-    assert (
-        len(y_disc_r)
-        == len(y_disc_gen)
-        == len(fmap_r)
-        == len(fmap_gen)
-        == disc.num_discriminators
-    )
-    assert all([len(fm) == 5 for fm in fmap_r + fmap_gen])
-    assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm])
-    assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen])
-
-
-if __name__ == "__main__":
-    test()

+ 2 - 261
fish_speech/models/vqgan/modules/encoders.py

@@ -8,265 +8,6 @@ import torch.nn.functional as F
 from einops import rearrange
 from vector_quantize_pytorch import LFQ, GroupedResidualVQ, VectorQuantize
 
-from fish_speech.models.vqgan.modules.modules import WN
-from fish_speech.models.vqgan.modules.transformer import (
-    MultiHeadAttention,
-    RelativePositionTransformer,
-)
-from fish_speech.models.vqgan.utils import sequence_mask
-
-
-# * Ready and Tested
-class ConvDownSampler(nn.Module):
-    def __init__(
-        self,
-        dims: list,
-        kernel_sizes: list,
-        strides: list,
-    ):
-        super().__init__()
-
-        self.dims = dims
-        self.kernel_sizes = kernel_sizes
-        self.strides = strides
-        self.total_strides = np.prod(self.strides)
-
-        self.convs = nn.ModuleList(
-            [
-                nn.ModuleList(
-                    [
-                        nn.Conv1d(
-                            in_channels=self.dims[i],
-                            out_channels=self.dims[i + 1],
-                            kernel_size=self.kernel_sizes[i],
-                            stride=self.strides[i],
-                            padding=(self.kernel_sizes[i] - 1) // 2,
-                        ),
-                        nn.LayerNorm(self.dims[i + 1], elementwise_affine=True),
-                        nn.GELU(),
-                    ]
-                )
-                for i in range(len(self.dims) - 1)
-            ]
-        )
-
-        self.apply(self.init_weights)
-
-    def init_weights(self, m):
-        if isinstance(m, nn.Conv1d):
-            nn.init.normal_(m.weight, std=0.02)
-        elif isinstance(m, nn.LayerNorm):
-            nn.init.ones_(m.weight)
-            nn.init.zeros_(m.bias)
-
-    def forward(self, x):
-        for conv, norm, act in self.convs:
-            x = conv(x)
-            x = norm(x.mT).mT
-            x = act(x)
-
-        return x
-
-
-# * Ready and Tested
-class TextEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int,
-        out_channels: int,
-        hidden_channels: int,
-        hidden_channels_ffn: int,
-        n_heads: int,
-        n_layers: int,
-        kernel_size: int,
-        dropout: float,
-        gin_channels=0,
-        speaker_cond_layer=0,
-        use_vae=True,
-        use_embedding=False,
-    ):
-        """Text Encoder for VITS model.
-
-        Args:
-            in_channels (int): Number of characters for the embedding layer.
-            out_channels (int): Number of channels for the output.
-            hidden_channels (int): Number of channels for the hidden layers.
-            hidden_channels_ffn (int): Number of channels for the convolutional layers.
-            n_heads (int): Number of attention heads for the Transformer layers.
-            n_layers (int): Number of Transformer layers.
-            kernel_size (int): Kernel size for the FFN layers in Transformer network.
-            dropout (float): Dropout rate for the Transformer layers.
-            gin_channels (int, optional): Number of channels for speaker embedding. Defaults to 0.
-        """
-        super().__init__()
-        self.out_channels = out_channels
-        self.hidden_channels = hidden_channels
-        self.use_embedding = use_embedding
-
-        if use_embedding:
-            self.proj_in = nn.Embedding(in_channels, hidden_channels)
-        else:
-            self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
-
-        self.encoder = RelativePositionTransformer(
-            in_channels=hidden_channels,
-            out_channels=hidden_channels,
-            hidden_channels=hidden_channels,
-            hidden_channels_ffn=hidden_channels_ffn,
-            n_heads=n_heads,
-            n_layers=n_layers,
-            kernel_size=kernel_size,
-            dropout=dropout,
-            window_size=4,
-            gin_channels=gin_channels,
-            speaker_cond_layer=speaker_cond_layer,
-        )
-        self.proj_out = nn.Conv1d(
-            hidden_channels, out_channels * 2 if use_vae else out_channels, 1
-        )
-        self.use_vae = use_vae
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        x_mask: torch.Tensor,
-        g: torch.Tensor = None,
-        noise_scale: float = 1,
-    ):
-        """
-        Shapes:
-            - x: :math:`[B, T]`
-            - x_length: :math:`[B]`
-        """
-
-        if self.use_embedding:
-            x = self.proj_in(x.long()).mT * x_mask
-        else:
-            x = self.proj_in(x) * x_mask
-
-        x = self.encoder(x, x_mask, g=g)
-        x = self.proj_out(x) * x_mask
-
-        if self.use_vae is False:
-            return x
-
-        m, logs = torch.split(x, self.out_channels, dim=1)
-        z = m + torch.randn_like(m) * torch.exp(logs) * x_mask * noise_scale
-        return z, m, logs, x, x_mask
-
-
-# * Ready and Tested
-class PosteriorEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int,
-        out_channels: int,
-        hidden_channels: int,
-        kernel_size: int,
-        dilation_rate: int,
-        n_layers: int,
-        gin_channels=0,
-    ):
-        """Posterior Encoder of VITS model.
-
-        ::
-            x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
-
-        Args:
-            in_channels (int): Number of input tensor channels.
-            out_channels (int): Number of output tensor channels.
-            hidden_channels (int): Number of hidden channels.
-            kernel_size (int): Kernel size of the WaveNet convolution layers.
-            dilation_rate (int): Dilation rate of the WaveNet layers.
-            num_layers (int): Number of the WaveNet layers.
-            cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
-        """
-        super().__init__()
-        self.out_channels = out_channels
-
-        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
-        self.enc = WN(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            gin_channels=gin_channels,
-        )
-        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        x_mask: torch.Tensor,
-        g: torch.Tensor,
-        noise_scale: float = 1,
-    ):
-        """
-        Shapes:
-            - x: :math:`[B, C, T]`
-            - x_lengths: :math:`[B, 1]`
-            - g: :math:`[B, C, 1]`
-        """
-        x = self.pre(x) * x_mask
-        x = self.enc(x, x_mask, g=g)
-        stats = self.proj(x) * x_mask
-        m, logs = torch.split(stats, self.out_channels, dim=1)
-        z = m + torch.randn_like(m) * torch.exp(logs) * x_mask * noise_scale
-        return z, m, logs, x_mask
-
-
-# TODO: Ready for testing
-class SpeakerEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int = 128,
-        hidden_channels: int = 192,
-        out_channels: int = 512,
-        num_layers: int = 4,
-    ) -> None:
-        super().__init__()
-
-        self.in_proj = nn.Sequential(
-            nn.Conv1d(in_channels, hidden_channels, 1),
-            nn.Mish(),
-            nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
-            nn.Mish(),
-            nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
-            nn.Mish(),
-        )
-        self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1)
-        self.apply(self._init_weights)
-
-        self.encoder = WN(
-            hidden_channels,
-            kernel_size=3,
-            dilation_rate=1,
-            n_layers=num_layers,
-        )
-
-    def _init_weights(self, m):
-        if isinstance(m, (nn.Conv1d, nn.Linear)):
-            nn.init.normal_(m.weight, mean=0, std=0.02)
-            nn.init.zeros_(m.bias)
-
-    def forward(self, mels, mel_masks: torch.Tensor):
-        """
-        Shapes:
-            - x: :math:`[B, C, T]`
-            - x_lengths: :math:`[B, 1]`
-        """
-
-        x = self.in_proj(mels) * mel_masks
-        x = self.encoder(x, mel_masks)
-
-        # Avg Pooling
-        x = x * mel_masks
-        x = self.out_proj(x)
-        x = torch.sum(x, dim=-1) / torch.sum(mel_masks, dim=-1)
-        x = x[..., None]
-
-        return x
-
 
 class VQEncoder(nn.Module):
     def __init__(
@@ -286,7 +27,7 @@ class VQEncoder(nn.Module):
                 dim=vq_channels,
                 codebook_size=codebook_size,
                 threshold_ema_dead_code=threshold_ema_dead_code,
-                kmeans_init=False,
+                kmeans_init=True,
                 groups=codebook_groups,
                 num_quantizers=codebook_layers,
             )
@@ -295,7 +36,7 @@ class VQEncoder(nn.Module):
                 dim=vq_channels,
                 codebook_size=codebook_size,
                 threshold_ema_dead_code=threshold_ema_dead_code,
-                kmeans_init=False,
+                kmeans_init=True,
             )
 
         self.codebook_groups = codebook_groups

+ 0 - 297
fish_speech/models/vqgan/modules/flow.py

@@ -1,297 +0,0 @@
-import torch
-from torch import nn
-
-from fish_speech.models.vqgan.modules.modules import WN, Flip
-from fish_speech.models.vqgan.modules.normalization import LayerNorm
-from fish_speech.models.vqgan.modules.transformer import FFN, MultiHeadAttention
-
-
-class ResidualCouplingBlock(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        n_flows=4,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.n_flows = n_flows
-        self.gin_channels = gin_channels
-
-        self.flows = nn.ModuleList()
-
-        for i in range(n_flows):
-            self.flows.append(
-                ResidualCouplingLayer(
-                    channels,
-                    hidden_channels,
-                    kernel_size,
-                    dilation_rate,
-                    n_layers,
-                    gin_channels=gin_channels,
-                    mean_only=True,
-                )
-            )
-            self.flows.append(Flip())
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        if not reverse:
-            for flow in self.flows:
-                x, _ = flow(x, x_mask, g=g, reverse=reverse)
-        else:
-            for flow in reversed(self.flows):
-                x = flow(x, x_mask, g=g, reverse=reverse)
-        return x
-
-
-class ResidualCouplingLayer(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        p_dropout=0,
-        gin_channels=0,
-        mean_only=False,
-    ):
-        assert channels % 2 == 0, "channels should be divisible by 2"
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.half_channels = channels // 2
-        self.mean_only = mean_only
-
-        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
-        self.enc = WN(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            p_dropout=p_dropout,
-            gin_channels=gin_channels,
-        )
-        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
-        self.post.weight.data.zero_()
-        self.post.bias.data.zero_()
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
-        h = self.pre(x0) * x_mask
-        h = self.enc(h, x_mask, g=g)
-        stats = self.post(h) * x_mask
-        if not self.mean_only:
-            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
-        else:
-            m = stats
-            logs = torch.zeros_like(m)
-
-        if not reverse:
-            x1 = m + x1 * torch.exp(logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            logdet = torch.sum(logs, [1, 2])
-            return x, logdet
-        else:
-            x1 = (x1 - m) * torch.exp(-logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            return x
-
-
-class TransformerCouplingBlock(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        kernel_size,
-        p_dropout,
-        n_flows=4,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.n_layers = n_layers
-        self.n_flows = n_flows
-        self.gin_channels = gin_channels
-
-        self.flows = nn.ModuleList()
-
-        for i in range(n_flows):
-            self.flows.append(
-                TransformerCouplingLayer(
-                    channels,
-                    hidden_channels,
-                    kernel_size,
-                    n_layers,
-                    n_heads,
-                    p_dropout,
-                    filter_channels,
-                    mean_only=True,
-                    gin_channels=self.gin_channels,
-                )
-            )
-            self.flows.append(Flip())
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        if not reverse:
-            for flow in self.flows:
-                x, _ = flow(x, x_mask, g=g, reverse=reverse)
-        else:
-            for flow in reversed(self.flows):
-                x = flow(x, x_mask, g=g, reverse=reverse)
-        return x
-
-
-class TransformerCouplingLayer(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        n_layers,
-        n_heads,
-        p_dropout=0,
-        filter_channels=0,
-        mean_only=False,
-        gin_channels=0,
-    ):
-        super().__init__()
-
-        assert channels % 2 == 0, "channels should be divisible by 2"
-
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.n_layers = n_layers
-        self.half_channels = channels // 2
-        self.mean_only = mean_only
-
-        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
-        self.enc = Encoder(
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers,
-            kernel_size,
-            p_dropout,
-            gin_channels=gin_channels,
-        )
-        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
-        self.post.weight.data.zero_()
-        self.post.bias.data.zero_()
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
-        h = self.pre(x0) * x_mask
-        h = self.enc(h, x_mask, g=g)
-        stats = self.post(h) * x_mask
-        if not self.mean_only:
-            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
-        else:
-            m = stats
-            logs = torch.zeros_like(m)
-
-        if not reverse:
-            x1 = m + x1 * torch.exp(logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            logdet = torch.sum(logs, [1, 2])
-            return x, logdet
-        else:
-            x1 = (x1 - m) * torch.exp(-logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            return x
-
-
-class Encoder(nn.Module):
-    def __init__(
-        self,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        kernel_size=1,
-        p_dropout=0.0,
-        window_size=4,
-        gin_channels=512,
-        cond_layer_idx=2,
-    ):
-        super().__init__()
-        self.hidden_channels = hidden_channels
-        self.filter_channels = filter_channels
-        self.n_heads = n_heads
-        self.n_layers = n_layers
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.window_size = window_size
-
-        self.spk_emb_linear = nn.Linear(gin_channels, self.hidden_channels)
-        self.cond_layer_idx = cond_layer_idx
-
-        assert (
-            self.cond_layer_idx < self.n_layers
-        ), "cond_layer_idx should be less than n_layers"
-
-        self.drop = nn.Dropout(p_dropout)
-        self.attn_layers = nn.ModuleList()
-        self.norm_layers_1 = nn.ModuleList()
-        self.ffn_layers = nn.ModuleList()
-        self.norm_layers_2 = nn.ModuleList()
-        for i in range(self.n_layers):
-            self.attn_layers.append(
-                MultiHeadAttention(
-                    hidden_channels,
-                    hidden_channels,
-                    n_heads,
-                    p_dropout=p_dropout,
-                    window_size=window_size,
-                )
-            )
-            self.norm_layers_1.append(LayerNorm(hidden_channels))
-            self.ffn_layers.append(
-                FFN(
-                    hidden_channels,
-                    hidden_channels,
-                    filter_channels,
-                    kernel_size,
-                    p_dropout=p_dropout,
-                )
-            )
-            self.norm_layers_2.append(LayerNorm(hidden_channels))
-
-    def forward(self, x, x_mask, g=None):
-        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
-        x = x * x_mask
-
-        for i in range(self.n_layers):
-            if i == self.cond_layer_idx and g is not None:
-                g = self.spk_emb_linear(g.transpose(1, 2))
-                g = g.transpose(1, 2)
-                x = x + g
-                x = x * x_mask
-            y = self.attn_layers[i](x, x, attn_mask)
-            y = self.drop(y)
-            x = self.norm_layers_1[i](x + y)
-
-            y = self.ffn_layers[i](x, x_mask)
-            y = self.drop(y)
-            x = self.norm_layers_2[i](x + y)
-
-        x = x * x_mask
-
-        return x

+ 0 - 174
fish_speech/models/vqgan/modules/models.py

@@ -1,174 +0,0 @@
-import torch
-from torch import nn
-
-from fish_speech.models.vqgan.modules.decoder import Generator
-from fish_speech.models.vqgan.modules.encoders import (
-    PosteriorEncoder,
-    SpeakerEncoder,
-    TextEncoder,
-    VQEncoder,
-)
-from fish_speech.models.vqgan.modules.flow import ResidualCouplingBlock
-from fish_speech.models.vqgan.utils import rand_slice_segments, sequence_mask
-
-
-class SynthesizerTrn(nn.Module):
-    """
-    Synthesizer for Training
-    """
-
-    def __init__(
-        self,
-        *,
-        in_channels,
-        spec_channels,
-        segment_size,
-        inter_channels,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        n_flows,
-        n_layers_q,
-        n_layers_spk,
-        n_layers_flow,
-        kernel_size,
-        p_dropout,
-        speaker_cond_layer,
-        resblock,
-        resblock_kernel_sizes,
-        resblock_dilation_sizes,
-        upsample_rates,
-        upsample_initial_channel,
-        upsample_kernel_sizes,
-        gin_channels,
-        codebook_size,
-        kmeans_ckpt=None,
-    ):
-        super().__init__()
-
-        self.segment_size = segment_size
-
-        # self.vq = VQEncoder(
-        #     in_channels=in_channels,
-        #     vq_channels=in_channels,
-        #     codebook_size=codebook_size,
-        #     kmeans_ckpt=kmeans_ckpt,
-        # )
-        self.enc_p = TextEncoder(
-            in_channels,
-            inter_channels,
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers,
-            kernel_size,
-            p_dropout,
-            gin_channels=gin_channels,
-            speaker_cond_layer=speaker_cond_layer,
-        )
-        self.enc_spk = SpeakerEncoder(
-            in_channels=spec_channels,
-            hidden_channels=inter_channels,
-            out_channels=gin_channels,
-            num_heads=n_heads,
-            num_layers=n_layers_spk,
-            p_dropout=p_dropout,
-        )
-        self.flow = ResidualCouplingBlock(
-            channels=inter_channels,
-            hidden_channels=hidden_channels,
-            kernel_size=5,
-            dilation_rate=1,
-            n_layers=n_layers_flow,
-            n_flows=n_flows,
-            gin_channels=gin_channels,
-        )
-        self.enc_q = PosteriorEncoder(
-            spec_channels,
-            inter_channels,
-            hidden_channels,
-            5,
-            1,
-            n_layers_q,
-            gin_channels=gin_channels,
-        )
-        self.dec = Generator(
-            inter_channels,
-            resblock,
-            resblock_kernel_sizes,
-            resblock_dilation_sizes,
-            upsample_rates,
-            upsample_initial_channel,
-            upsample_kernel_sizes,
-            gin_channels=gin_channels,
-        )
-
-    def forward(self, x, x_lengths, specs):
-        # x = x.mT
-
-        min_length = min(x.shape[1], specs.shape[2])
-        if min_length % 2 != 0:
-            min_length -= 1
-
-        x = x[:, :min_length]
-        specs = specs[:, :, :min_length]
-        x_lengths = torch.clamp(x_lengths, max=min_length)
-
-        spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
-            specs.dtype
-        )
-        x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
-
-        g = self.enc_spk(specs, spec_masks)
-
-        # with torch.no_grad():
-        #     x, _ = self.vq(x, x_masks)
-        #     vq_loss = 0
-
-        _, m_p, logs_p, _, _ = self.enc_p(x, x_masks, g=g)
-        z_q, m_q, logs_q, _ = self.enc_q(specs, spec_masks, g=g)
-        z_p = self.flow(z_q, spec_masks, g=g, reverse=False)
-
-        z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
-        o = self.dec(z_slice, g=g)
-
-        return (
-            o,
-            ids_slice,
-            x_masks,
-            spec_masks,
-            (z_q, z_p),
-            (m_p, logs_p),
-            (m_q, logs_q),
-            # vq_loss,
-        )
-
-    def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
-        # x = x.mT
-        spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
-            specs.dtype
-        )
-        # print(x_lengths, x.shape)
-        x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
-        g = self.enc_spk(specs, spec_masks)
-        # x, vq_loss = self.vq(x, x_masks)
-        z_p, m_p, logs_p, h_text, _ = self.enc_p(
-            x, x_masks, g=g, noise_scale=noise_scale
-        )
-        z_p = self.flow(z_p, x_masks, g=g, reverse=True)
-
-        o = self.dec((z_p * x_masks)[:, :, :max_len], g=g)
-        return o
-
-    def reconstruct(self, specs, spec_lengths, max_len=None, noise_scale=0.35):
-        spec_masks = torch.unsqueeze(sequence_mask(spec_lengths, specs.shape[2]), 1).to(
-            specs.dtype
-        )
-        g = self.enc_spk(specs, spec_masks)
-        z_q, m_q, logs_q, _ = self.enc_q(
-            specs, spec_masks, g=g, noise_scale=noise_scale
-        )
-        o = self.dec((z_q * spec_masks)[:, :, :max_len], g=g)
-
-        return o

+ 17 - 33
fish_speech/models/vqgan/modules/modules.py

@@ -10,31 +10,30 @@ LRELU_SLOPE = 0.1
 
 # ! PosteriorEncoder
 # ! ResidualCouplingLayer
-class WN(nn.Module):
+class WaveNet(nn.Module):
     def __init__(
         self,
         hidden_channels,
         kernel_size,
         dilation_rate,
         n_layers,
-        gin_channels=0,
         p_dropout=0,
         out_channels=None,
+        in_channels=None,
     ):
-        super(WN, self).__init__()
+        super(WaveNet, self).__init__()
         assert kernel_size % 2 == 1
         self.hidden_channels = hidden_channels
         self.kernel_size = (kernel_size,)
         self.n_layers = n_layers
-        self.gin_channels = gin_channels
 
         self.in_layers = nn.ModuleList()
         self.res_skip_layers = nn.ModuleList()
         self.drop = nn.Dropout(p_dropout)
 
-        if gin_channels != 0:
-            cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
-            self.cond_layer = weight_norm(cond_layer, name="weight")
+        self.in_channels = in_channels
+        if in_channels is not None:
+            self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
 
         for i in range(n_layers):
             dilation = dilation_rate**i
@@ -61,33 +60,31 @@ class WN(nn.Module):
         if out_channels is not None:
             self.out_layer = nn.Conv1d(hidden_channels, out_channels, 1)
 
-    def forward(self, x, x_mask, g=None, **kwargs):
-        output = torch.zeros_like(x)
+    def forward(self, x, x_mask=None):
         n_channels_tensor = torch.IntTensor([self.hidden_channels])
 
-        if g is not None:
-            g = self.cond_layer(g)
+        if self.in_channels is not None:
+            x = self.proj_in(x)
+
+        output = torch.zeros_like(x)
 
         for i in range(self.n_layers):
             x_in = self.in_layers[i](x)
-            if g is not None:
-                cond_offset = i * 2 * self.hidden_channels
-                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
-            else:
-                g_l = torch.zeros_like(x_in)
-
-            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+            acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
             acts = self.drop(acts)
 
             res_skip_acts = self.res_skip_layers[i](acts)
             if i < self.n_layers - 1:
                 res_acts = res_skip_acts[:, : self.hidden_channels, :]
-                x = (x + res_acts) * x_mask
+                x = x + res_acts
+                if x_mask is not None:
+                    x = x * x_mask
                 output = output + res_skip_acts[:, self.hidden_channels :, :]
             else:
                 output = output + res_skip_acts
 
-        x = output * x_mask
+        if x_mask is not None:
+            x = output * x_mask
 
         if self.out_channels is not None:
             x = self.out_layer(x)
@@ -101,16 +98,3 @@ class WN(nn.Module):
             remove_parametrizations(l)
         for l in self.res_skip_layers:
             remove_parametrizations(l)
-
-
-# ! StochasticDurationPredictor
-# ! ResidualCouplingBlock
-# TODO convert to class method
-class Flip(nn.Module):
-    def forward(self, x, *args, reverse=False, **kwargs):
-        x = torch.flip(x, [1])
-        if not reverse:
-            logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
-            return x, logdet
-        else:
-            return x

+ 0 - 34
fish_speech/models/vqgan/modules/normalization.py

@@ -1,34 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class LayerNorm(nn.Module):
-    def __init__(self, channels, eps=1e-5):
-        super().__init__()
-        self.channels = channels
-        self.eps = eps
-
-        self.gamma = nn.Parameter(torch.ones(channels))
-        self.beta = nn.Parameter(torch.zeros(channels))
-
-    def forward(self, x: torch.Tensor):
-        x = F.layer_norm(x.mT, (self.channels,), self.gamma, self.beta, self.eps)
-        return x.mT
-
-
-class CondLayerNorm(nn.Module):
-    def __init__(self, channels, eps=1e-5, cond_channels=0):
-        super().__init__()
-        self.channels = channels
-        self.eps = eps
-
-        self.linear_gamma = nn.Linear(cond_channels, channels)
-        self.linear_beta = nn.Linear(cond_channels, channels)
-
-    def forward(self, x: torch.Tensor, cond: torch.Tensor):
-        gamma = self.linear_gamma(cond)
-        beta = self.linear_beta(cond)
-
-        x = F.layer_norm(x.mT, (self.channels,), gamma, beta, self.eps)
-        return x.mT

+ 0 - 328
fish_speech/models/vqgan/modules/transformer.py

@@ -1,328 +0,0 @@
-import math
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-from fish_speech.models.vqgan.modules.normalization import LayerNorm
-from fish_speech.models.vqgan.utils import convert_pad_shape
-
-
-# TODO add conditioning on language
-# TODO check whether we need to stop gradient for speaker embedding
-class RelativePositionTransformer(nn.Module):
-    def __init__(
-        self,
-        in_channels: int,
-        hidden_channels: int,
-        out_channels: int,
-        hidden_channels_ffn: int,
-        n_heads: int,
-        n_layers: int,
-        kernel_size=1,
-        dropout=0.0,
-        window_size=4,
-        gin_channels=0,
-        speaker_cond_layer=0,
-    ):
-        super().__init__()
-        assert (
-            out_channels == hidden_channels
-        ), "out_channels must be equal to hidden_channels"
-
-        self.n_layers = n_layers
-        self.speaker_cond_layer = speaker_cond_layer
-
-        self.drop = nn.Dropout(dropout)
-        self.attn_layers = nn.ModuleList()
-        self.norm_layers_1 = nn.ModuleList()
-        self.ffn_layers = nn.ModuleList()
-        self.norm_layers_2 = nn.ModuleList()
-        for i in range(self.n_layers):
-            self.attn_layers.append(
-                MultiHeadAttention(
-                    hidden_channels if i != 0 else in_channels,
-                    hidden_channels,
-                    n_heads,
-                    p_dropout=dropout,
-                    window_size=window_size,
-                )
-            )
-            self.norm_layers_1.append(LayerNorm(hidden_channels))
-            self.ffn_layers.append(
-                FFN(
-                    hidden_channels,
-                    hidden_channels,
-                    hidden_channels_ffn,
-                    kernel_size,
-                    p_dropout=dropout,
-                )
-            )
-            self.norm_layers_2.append(LayerNorm(hidden_channels))
-
-        if gin_channels != 0:
-            self.cond = nn.Conv1d(gin_channels, hidden_channels, 1)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        x_mask: torch.Tensor,
-        g: torch.Tensor = None,
-    ):
-        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
-        x = x * x_mask
-        for i in range(self.n_layers):
-            # TODO consider using other conditioning
-            # TODO https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/modules/attentions.py#L12
-            if i == self.speaker_cond_layer and g is not None:
-                # ! g = torch.detach(g)
-                x = x + self.cond(g)
-                x = x * x_mask
-
-            y = self.attn_layers[i](x, x, attn_mask)
-            y = self.drop(y)
-            x = self.norm_layers_1[i](x + y)
-
-            y = self.ffn_layers[i](x, x_mask)
-            y = self.drop(y)
-            x = self.norm_layers_2[i](x + y)
-        x = x * x_mask
-        return x
-
-
-class MultiHeadAttention(nn.Module):
-    def __init__(
-        self,
-        channels,
-        out_channels,
-        n_heads,
-        p_dropout=0.0,
-        window_size=None,
-        heads_share=True,
-        block_length=None,
-        proximal_bias=False,
-        proximal_init=False,
-    ):
-        super().__init__()
-        assert channels % n_heads == 0
-
-        self.channels = channels
-        self.out_channels = out_channels
-        self.n_heads = n_heads
-        self.p_dropout = p_dropout
-        self.window_size = window_size
-        self.heads_share = heads_share
-        self.block_length = block_length
-        self.proximal_bias = proximal_bias
-        self.proximal_init = proximal_init
-        self.attn = None
-
-        self.k_channels = channels // n_heads
-        self.conv_q = nn.Linear(channels, channels)
-        self.conv_k = nn.Linear(channels, channels)
-        self.conv_v = nn.Linear(channels, channels)
-        self.conv_o = nn.Linear(channels, out_channels)
-        self.drop = nn.Dropout(p_dropout)
-
-        if window_size is not None:
-            n_heads_rel = 1 if heads_share else n_heads
-            rel_stddev = self.k_channels**-0.5
-            self.emb_rel_k = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-            self.emb_rel_v = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-
-        nn.init.xavier_uniform_(self.conv_q.weight)
-        nn.init.xavier_uniform_(self.conv_k.weight)
-        nn.init.xavier_uniform_(self.conv_v.weight)
-        if proximal_init:
-            with torch.no_grad():
-                self.conv_k.weight.copy_(self.conv_q.weight)
-                self.conv_k.bias.copy_(self.conv_q.bias)
-
-    def forward(self, x, c, attn_mask=None):
-        q = self.conv_q(x.mT).mT
-        k = self.conv_k(c.mT).mT
-        v = self.conv_v(c.mT).mT
-
-        x, self.attn = self.attention(q, k, v, mask=attn_mask)
-
-        x = self.conv_o(x.mT).mT
-        return x
-
-    def attention(self, query, key, value, mask=None):
-        # reshape [b, d, t] -> [b, n_h, t, d_k]
-        b, d, t_s, t_t = (*key.size(), query.size(2))
-        query = query.view(b, self.n_heads, self.k_channels, t_t).mT
-        key = key.view(b, self.n_heads, self.k_channels, t_s).mT
-        value = value.view(b, self.n_heads, self.k_channels, t_s).mT
-
-        scores = torch.matmul(query / math.sqrt(self.k_channels), key.mT)
-        if self.window_size is not None:
-            assert (
-                t_s == t_t
-            ), "Relative attention is only available for self-attention."
-            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
-            rel_logits = self._matmul_with_relative_keys(
-                query / math.sqrt(self.k_channels), key_relative_embeddings
-            )
-            scores_local = self._relative_position_to_absolute_position(rel_logits)
-            scores = scores + scores_local
-        if self.proximal_bias:
-            assert t_s == t_t, "Proximal bias is only available for self-attention."
-            scores = scores + self._attention_bias_proximal(t_s).to(
-                device=scores.device, dtype=scores.dtype
-            )
-        if mask is not None:
-            scores = scores.masked_fill(mask == 0, -1e4)
-            if self.block_length is not None:
-                assert (
-                    t_s == t_t
-                ), "Local attention is only available for self-attention."
-                block_mask = (
-                    torch.ones_like(scores)
-                    .triu(-self.block_length)
-                    .tril(self.block_length)
-                )
-                scores = scores.masked_fill(block_mask == 0, -1e4)
-        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
-        p_attn = self.drop(p_attn)
-        output = torch.matmul(p_attn, value)
-        if self.window_size is not None:
-            relative_weights = self._absolute_position_to_relative_position(p_attn)
-            value_relative_embeddings = self._get_relative_embeddings(
-                self.emb_rel_v, t_s
-            )
-            output = output + self._matmul_with_relative_values(
-                relative_weights, value_relative_embeddings
-            )
-        output = output.mT.contiguous().view(
-            b, d, t_t
-        )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
-        return output, p_attn
-
-    def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor):
-        """
-        x: [b, h, l, m]
-        y: [h or 1, m, d]
-        ret: [b, h, l, d]
-        """
-        return torch.matmul(x, y.unsqueeze(0))
-
-    def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor):
-        """
-        x: [b, h, l, d]
-        y: [h or 1, m, d]
-        ret: [b, h, l, m]
-        """
-        return torch.matmul(x, y.unsqueeze(0).mT)
-
-    def _get_relative_embeddings(self, relative_embeddings, length):
-        max_relative_position = 2 * self.window_size + 1
-        # Pad first before slice to avoid using cond ops.
-        pad_length = max(length - (self.window_size + 1), 0)
-        slice_start_position = max((self.window_size + 1) - length, 0)
-        slice_end_position = slice_start_position + 2 * length - 1
-        if pad_length > 0:
-            padded_relative_embeddings = F.pad(
-                relative_embeddings,
-                convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
-            )
-        else:
-            padded_relative_embeddings = relative_embeddings
-        used_relative_embeddings = padded_relative_embeddings[
-            :, slice_start_position:slice_end_position
-        ]
-        return used_relative_embeddings
-
-    def _relative_position_to_absolute_position(self, x):
-        """
-        x: [b, h, l, 2*l-1]
-        ret: [b, h, l, l]
-        """
-        batch, heads, length, _ = x.size()
-        # Concat columns of pad to shift from relative to absolute indexing.
-        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
-
-        # Concat extra elements so to add up to shape (len+1, 2*len-1).
-        x_flat = x.view([batch, heads, length * 2 * length])
-        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
-
-        # Reshape and slice out the padded elements.
-        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
-            :, :, :length, length - 1 :
-        ]
-        return x_final
-
-    def _absolute_position_to_relative_position(self, x):
-        """
-        x: [b, h, l, l]
-        ret: [b, h, l, 2*l-1]
-        """
-        batch, heads, length, _ = x.size()
-        # pad along column
-        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
-        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
-        # add 0's in the beginning that will skew the elements after reshape
-        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
-        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
-        return x_final
-
-    def _attention_bias_proximal(self, length):
-        """Bias for self-attention to encourage attention to close positions.
-        Args:
-          length: an integer scalar.
-        Returns:
-          a Tensor with shape [1, 1, length, length]
-        """
-        r = torch.arange(length, dtype=torch.float32)
-        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
-        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
-
-
-class FFN(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        filter_channels,
-        kernel_size,
-        p_dropout=0.0,
-        causal=False,
-    ):
-        super().__init__()
-        self.kernel_size = kernel_size
-        self.padding = self._causal_padding if causal else self._same_padding
-
-        self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
-        self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
-        self.drop = nn.Dropout(p_dropout)
-
-    def forward(self, x, x_mask):
-        x = self.conv_1(self.padding(x * x_mask))
-        x = torch.relu(x)
-        x = self.drop(x)
-        x = self.conv_2(self.padding(x * x_mask))
-        return x * x_mask
-
-    def _causal_padding(self, x):
-        if self.kernel_size == 1:
-            return x
-        pad_l = self.kernel_size - 1
-        pad_r = 0
-        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
-        x = F.pad(x, convert_pad_shape(padding))
-        return x
-
-    def _same_padding(self, x):
-        if self.kernel_size == 1:
-            return x
-        pad_l = (self.kernel_size - 1) // 2
-        pad_r = self.kernel_size // 2
-        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
-        x = F.pad(x, convert_pad_shape(padding))
-        return x

+ 2 - 3
fish_speech/models/vqgan/utils.py

@@ -40,7 +40,7 @@ def plot_mel(data, titles=None):
         mel = data[i]
 
         if isinstance(mel, torch.Tensor):
-            mel = mel.detach().cpu().numpy()
+            mel = mel.float().detach().cpu().numpy()
 
         axes[i][0].imshow(mel, origin="lower")
         axes[i][0].set_aspect(2.5, adjustable="box")
@@ -73,9 +73,8 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
 
 
 @torch.jit.script
-def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
     n_channels_int = n_channels[0]
-    in_act = input_a + input_b
     t_act = torch.tanh(in_act[:, :n_channels_int, :])
     s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
     acts = t_act * s_act