فهرست منبع

Merge VQGAN v2 to dev (#56)

* squash vqgan v2 changes

* Merge pretrain stage 1 and 2

* Optimize vqgan inference (remove redundant code)

* Implement data mixing

* Optimize vqgan v2 config

* Add support to freeze discriminator

* Add stft loss & larger segement size
Leng Yue 2 سال پیش
والد
کامیت
1609e9bad4

+ 1 - 0
.pre-commit-config.yaml

@@ -18,6 +18,7 @@ repos:
     hooks:
       - id: codespell
         files: ^.*\.(py|md|rst|yml)$
+        args: [-L=fro]
 
   - repo: https://github.com/pre-commit/pre-commit-hooks
     rev: v4.5.0

+ 59 - 25
fish_speech/configs/vqgan_pretrain_v2.yaml

@@ -3,6 +3,8 @@ defaults:
   - _self_
 
 project: vqgan_pretrain_v2
+ckpt_path: checkpoints/hifigan-base-comb-mix-lb-020/step_001200000_weights_only.ckpt
+resume_weights_only: true
 
 # Lightning Trainer
 trainer:
@@ -15,22 +17,36 @@ trainer:
 
 sample_rate: 44100
 hop_length: 512
-num_mels: 128
+num_mels: 160
 n_fft: 2048
 win_length: 2048
 segment_size: 256
 
 # Dataset Configuration
 train_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/Genshin/vq_train_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  slice_frames: ${segment_size}
+  _target_: fish_speech.datasets.vqgan.MixDatast
+  datasets:
+    high-quality-441:
+      prob: 0.5
+      dataset:
+        _target_: fish_speech.datasets.vqgan.VQGANDataset
+        filelist: data/vocoder_data_441/vq_train_filelist.txt
+        sample_rate: ${sample_rate}
+        hop_length: ${hop_length}
+        slice_frames: ${segment_size}
+    
+    common-voice:
+      prob: 0.5
+      dataset:
+        _target_: fish_speech.datasets.vqgan.VQGANDataset
+        filelist: data/cv-corpus-16.0-2023-12-06/vq_train_filelist.txt
+        sample_rate: ${sample_rate}
+        hop_length: ${hop_length}
+        slice_frames: ${segment_size}
 
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/Genshin/vq_val_filelist.txt
+  filelist: data/vocoder_data_441/vq_val_filelist.txt
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
 
@@ -47,8 +63,9 @@ model:
   _target_: fish_speech.models.vqgan.VQGAN
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
-  segment_size: 8192
-  mode: pretrain-stage1
+  segment_size: 32768
+  mode: pretrain
+  freeze_discriminator: true
 
   downsample:
     _target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
@@ -67,8 +84,8 @@ model:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
     in_channels: 256
     vq_channels: 256
-    codebook_size: 1024
-    codebook_layers: 4
+    codebook_size: 256
+    codebook_groups: 4
     downsample: 1
 
   decoder:
@@ -80,19 +97,38 @@ model:
     n_layers: 6
 
   generator:
-    _target_: fish_speech.models.vqgan.modules.decoder.Generator
-    initial_channel: ${num_mels}
-    resblock: "1"
+    _target_: fish_speech.models.vqgan.modules.decoder_v2.HiFiGANGenerator
+    hop_length: ${hop_length}
+    upsample_rates: [8, 8, 2, 2, 2]  # aka. strides
+    upsample_kernel_sizes: [16, 16, 4, 4, 4]
     resblock_kernel_sizes: [3, 7, 11]
     resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    upsample_rates: [8, 8, 2, 2, 2]
+    num_mels: ${num_mels}
     upsample_initial_channel: 512
-    upsample_kernel_sizes: [16, 16, 4, 4, 4]
-
-  discriminator:
-    _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
-    periods: [2, 3, 5, 7, 11, 17, 23, 37]
-
+    use_template: true
+    pre_conv_kernel_size: 7
+    post_conv_kernel_size: 7
+
+  discriminators:
+    _target_: torch.nn.ModuleDict
+    modules:
+      mpd:
+        _target_: fish_speech.models.vqgan.modules.discriminators.mpd.MultiPeriodDiscriminator
+        periods: [2, 3, 5, 7, 11, 17, 23, 37]
+
+      mrd:
+        _target_: fish_speech.models.vqgan.modules.discriminators.mrd.MultiResolutionDiscriminator
+        resolutions:
+          - ["${n_fft}", "${hop_length}", "${win_length}"]
+          - [1024, 120, 600]
+          - [2048, 240, 1200]
+          - [4096, 480, 2400]
+          - [512, 50, 240]
+
+  multi_resolution_stft_loss:
+    _target_: fish_speech.models.vqgan.losses.MultiResolutionSTFTLoss
+    resolutions: ${model.discriminators.modules.mrd.resolutions}
+  
   mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     sample_rate: ${sample_rate}
@@ -100,13 +136,11 @@ model:
     hop_length: ${hop_length}
     win_length: ${win_length}
     n_mels: ${num_mels}
-    f_min: 0
-    f_max: 16000
 
   optimizer:
     _target_: torch.optim.AdamW
     _partial_: true
-    lr: 2e-4
+    lr: 1e-4
     betas: [0.8, 0.99]
     eps: 1e-5
 
@@ -119,7 +153,7 @@ callbacks:
   grad_norm_monitor:
     sub_module: 
       - generator
-      - discriminator
+      - discriminators
       - mel_encoder
       - vq_encoder
       - decoder

+ 29 - 2
fish_speech/datasets/vqgan.py

@@ -6,7 +6,7 @@ import librosa
 import numpy as np
 import torch
 from lightning import LightningDataModule
-from torch.utils.data import DataLoader, Dataset
+from torch.utils.data import DataLoader, Dataset, IterableDataset
 
 from fish_speech.utils import RankedLogger
 
@@ -72,6 +72,33 @@ class VQGANDataset(Dataset):
             return None
 
 
+class MixDatast(IterableDataset):
+    def __init__(self, datasets: dict[str, dict], seed: int = 42) -> None:
+        values = list(datasets.values())
+        probs = [v["prob"] for v in values]
+        self.datasets = [v["dataset"] for v in values]
+
+        total_probs = sum(probs)
+        self.probs = [p / total_probs for p in probs]
+        self.seed = seed
+
+    def __iter__(self):
+        rng = np.random.default_rng(self.seed)
+        dataset_iterators = [iter(dataset) for dataset in self.datasets]
+
+        while True:
+            # Random choice one
+            dataset_idx = rng.choice(len(self.datasets), p=self.probs)
+            dataset_iterator = dataset_iterators[dataset_idx]
+
+            try:
+                yield next(dataset_iterator)
+            except StopIteration:
+                # Exhausted, create a new iterator
+                dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
+                yield next(dataset_iterators[dataset_idx])
+
+
 @dataclass
 class VQGANCollator:
     def __call__(self, batch):
@@ -116,7 +143,7 @@ class VQGANDataModule(LightningDataModule):
             batch_size=self.batch_size,
             collate_fn=VQGANCollator(),
             num_workers=self.num_workers,
-            shuffle=True,
+            shuffle=not isinstance(self.train_dataset, IterableDataset),
         )
 
     def val_dataloader(self):

+ 298 - 151
fish_speech/models/vqgan/lit_module.py

@@ -1,5 +1,6 @@
 import itertools
-from typing import Any, Callable, Literal
+from dataclasses import dataclass
+from typing import Any, Callable, Literal, Optional
 
 import lightning as L
 import torch
@@ -8,19 +9,17 @@ import wandb
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from torch import nn
-from vector_quantize_pytorch import VectorQuantize
 
 from fish_speech.models.vqgan.losses import (
+    MultiResolutionSTFTLoss,
     discriminator_loss,
     feature_loss,
     generator_loss,
-    kl_loss,
 )
+from fish_speech.models.vqgan.modules.balancer import Balancer
 from fish_speech.models.vqgan.modules.decoder import Generator
-from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
 from fish_speech.models.vqgan.modules.encoders import (
     ConvDownSampler,
-    SpeakerEncoder,
     TextEncoder,
     VQEncoder,
 )
@@ -32,6 +31,21 @@ from fish_speech.models.vqgan.utils import (
 )
 
 
+@dataclass
+class VQEncodeResult:
+    features: torch.Tensor
+    indices: torch.Tensor
+    loss: torch.Tensor
+    feature_lengths: torch.Tensor
+
+
+@dataclass
+class VQDecodeResult:
+    audios: torch.Tensor
+    mels: torch.Tensor
+    mel_lengths: torch.Tensor
+
+
 class VQGAN(L.LightningModule):
     def __init__(
         self,
@@ -42,18 +56,18 @@ class VQGAN(L.LightningModule):
         mel_encoder: TextEncoder,
         decoder: TextEncoder,
         generator: Generator,
-        discriminator: EnsembleDiscriminator,
+        discriminators: nn.ModuleDict,
         mel_transform: nn.Module,
         segment_size: int = 20480,
         hop_length: int = 640,
         sample_rate: int = 32000,
-        mode: Literal["pretrain-stage1", "pretrain-stage2", "finetune"] = "finetune",
-        speaker_encoder: SpeakerEncoder = None,
+        mode: Literal["pretrain", "finetune"] = "finetune",
+        freeze_discriminator: bool = False,
+        multi_resolution_stft_loss: Optional[MultiResolutionSTFTLoss] = None,
     ):
         super().__init__()
 
-        # pretrain-stage1: vq use gt mel as target, hifigan use gt mel as input
-        # pretrain-stage2: end-to-end training, use gt mel as hifi gan target
+        # pretrain: vq use gt mel as target, hifigan use gt mel as input
         # finetune: end-to-end training, use gt mel as hifi gan target but freeze vq
 
         # Model parameters
@@ -64,11 +78,11 @@ class VQGAN(L.LightningModule):
         self.downsample = downsample
         self.vq_encoder = vq_encoder
         self.mel_encoder = mel_encoder
-        self.speaker_encoder = speaker_encoder
         self.decoder = decoder
         self.generator = generator
-        self.discriminator = discriminator
+        self.discriminators = discriminators
         self.mel_transform = mel_transform
+        self.freeze_discriminator = freeze_discriminator
 
         # Crop length for saving memory
         self.segment_size = segment_size
@@ -90,20 +104,30 @@ class VQGAN(L.LightningModule):
             for p in self.downsample.parameters():
                 p.requires_grad = False
 
+        if self.freeze_discriminator:
+            for p in self.discriminators.parameters():
+                p.requires_grad = False
+
+        # Losses
+        self.multi_resolution_stft_loss = multi_resolution_stft_loss
+        loss_dict = {
+            "mel": 1,
+            "adv": 1,
+            "fm": 1,
+        }
+
+        if self.multi_resolution_stft_loss is not None:
+            loss_dict["stft"] = 1
+
+        self.balancer = Balancer(loss_dict)
+
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
-        components = []
-        if self.mode != "finetune":
-            components.extend(
-                [
-                    self.downsample.parameters(),
-                    self.vq_encoder.parameters(),
-                    self.mel_encoder.parameters(),
-                ]
-            )
-
-        if self.speaker_encoder is not None:
-            components.append(self.speaker_encoder.parameters())
+        components = [
+            self.downsample.parameters(),
+            self.vq_encoder.parameters(),
+            self.mel_encoder.parameters(),
+        ]
 
         if self.decoder is not None:
             components.append(self.decoder.parameters())
@@ -111,7 +135,7 @@ class VQGAN(L.LightningModule):
         components.append(self.generator.parameters())
         optimizer_generator = self.optimizer_builder(itertools.chain(*components))
         optimizer_discriminator = self.optimizer_builder(
-            self.discriminator.parameters()
+            self.discriminators.parameters()
         )
 
         lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
@@ -145,9 +169,7 @@ class VQGAN(L.LightningModule):
         audios = audios[:, None, :]
 
         with torch.no_grad():
-            features = gt_mels = self.mel_transform(
-                audios, sample_rate=self.sampling_rate
-            )
+            gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
 
         if self.mode == "finetune":
             # Disable gradient computation for VQ
@@ -156,29 +178,13 @@ class VQGAN(L.LightningModule):
             self.mel_encoder.eval()
             self.downsample.eval()
 
-        if self.downsample is not None:
-            features = self.downsample(features)
-
         mel_lengths = audio_lengths // self.hop_length
-        feature_lengths = (
-            audio_lengths
-            / self.hop_length
-            / (self.downsample.total_strides if self.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
             gt_mels.dtype
         )
 
-        # vq_features is 50 hz, need to convert to true mel size
-        text_features = self.mel_encoder(features, feature_masks)
-        text_features, _, loss_vq = self.vq_encoder(text_features, feature_masks)
-        text_features = F.interpolate(
-            text_features, size=gt_mels.shape[2], mode="nearest"
-        )
+        vq_result = self.encode(audios, audio_lengths)
+        loss_vq = vq_result.loss
 
         if loss_vq.ndim > 1:
             loss_vq = loss_vq.mean()
@@ -187,18 +193,15 @@ class VQGAN(L.LightningModule):
             # Enable gradient computation
             torch.set_grad_enabled(True)
 
-        # Sample mels
-        if self.decoder is not None:
-            speaker_features = (
-                self.speaker_encoder(gt_mels, mel_masks)
-                if self.speaker_encoder is not None
-                else None
-            )
-            decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
-        else:
-            decoded_mels = text_features
+        decoded = self.decode(
+            indices=vq_result.indices if self.mode == "finetune" else None,
+            features=vq_result.features if self.mode == "pretrain" else None,
+            audio_lengths=audio_lengths,
+            mel_only=True,
+        )
+        decoded_mels = decoded.mels
+        input_mels = gt_mels if self.mode == "pretrain" else decoded_mels
 
-        input_mels = gt_mels if self.mode == "pretrain-stage1" else decoded_mels
         if self.segment_size is not None:
             audios, ids_slice = rand_slice_segments(
                 audios, audio_lengths, self.segment_size
@@ -228,75 +231,145 @@ class VQGAN(L.LightningModule):
             audios.shape == fake_audios.shape
         ), f"{audios.shape} != {fake_audios.shape}"
 
+        # Multi-Resolution STFT Loss
+        if self.multi_resolution_stft_loss is not None:
+            with torch.autocast(device_type=audios.device.type, enabled=False):
+                sc_loss, mag_loss = self.multi_resolution_stft_loss(
+                    fake_audios.squeeze(1).float(), audios.squeeze(1).float()
+                )
+                loss_stft = sc_loss + mag_loss
+
         # Discriminator
-        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios.detach())
+        if self.freeze_discriminator is False:
+            loss_disc_all = []
+
+            for key, disc in self.discriminators.items():
+                scores, _ = disc(audios)
+                score_fakes, _ = disc(fake_audios.detach())
+
+                with torch.autocast(device_type=audios.device.type, enabled=False):
+                    loss_disc, _, _ = discriminator_loss(scores, score_fakes)
+
+                self.log(
+                    f"train/discriminator/{key}",
+                    loss_disc,
+                    on_step=True,
+                    on_epoch=False,
+                    prog_bar=False,
+                    logger=True,
+                    sync_dist=True,
+                )
 
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
+                loss_disc_all.append(loss_disc)
 
-        self.log(
-            "train/discriminator/loss",
-            loss_disc_all,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
+            loss_disc_all = torch.stack(loss_disc_all).mean()
 
-        optim_d.zero_grad()
-        self.manual_backward(loss_disc_all)
-        self.clip_gradients(
-            optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
-        )
-        optim_d.step()
+            self.log(
+                "train/discriminator/loss",
+                loss_disc_all,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=True,
+                logger=True,
+                sync_dist=True,
+            )
+
+            optim_d.zero_grad()
+            self.manual_backward(loss_disc_all)
+            self.clip_gradients(
+                optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+            )
+            optim_d.step()
+
+        # Adv Loss
+        loss_adv_all = []
+        loss_fm_all = []
+
+        for key, disc in self.discriminators.items():
+            score_fakes, feat_fake = disc(fake_audios)
+
+            # Adversarial Loss
+            with torch.autocast(device_type=audios.device.type, enabled=False):
+                loss_fake, _ = generator_loss(score_fakes)
+
+            self.log(
+                f"train/generator/adv_{key}",
+                loss_fake,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
+
+            loss_adv_all.append(loss_fake)
+
+            # Feature Matching Loss
+            _, feat_real = disc(audios)
+
+            with torch.autocast(device_type=audios.device.type, enabled=False):
+                loss_fm = feature_loss(feat_real, feat_fake)
 
-        y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(audios, fake_audios)
+            self.log(
+                f"train/generator/adv_fm_{key}",
+                loss_fm,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
+
+            loss_fm_all.append(loss_fm)
+
+        loss_adv_all = torch.stack(loss_adv_all).mean()
+        loss_fm_all = torch.stack(loss_fm_all).mean()
 
         with torch.autocast(device_type=audios.device.type, enabled=False):
             loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
             loss_mel = F.l1_loss(
                 sliced_gt_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
             )
-            loss_adv, _ = generator_loss(y_d_hat_g)
-            loss_fm = feature_loss(fmap_r, fmap_g)
 
-            if self.mode == "pretrain-stage1":
+            loss_dict = {
+                "mel": loss_mel,
+                "adv": loss_adv_all,
+                "fm": loss_fm_all,
+            }
+
+            if self.multi_resolution_stft_loss is not None:
+                loss_dict["stft"] = loss_stft
+
+            generator_out_grad = self.balancer.compute(
+                loss_dict,
+                fake_audios,
+            )
+
+            if self.mode == "pretrain":
                 loss_vq_all = loss_decoded_mel + loss_vq
-                loss_gen_all = loss_mel * 45 + loss_fm + loss_adv
-            else:
-                loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
 
-        self.log(
-            "train/generator/loss_gen_all",
-            loss_gen_all,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
+        # Loss vq and loss decoded mel are only used in pretrain stage
+        if self.mode == "pretrain":
+            self.log(
+                "train/generator/loss_vq",
+                loss_vq,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
 
-        if self.mode == "pretrain-stage1":
             self.log(
-                "train/generator/loss_vq_all",
-                loss_vq_all,
+                "train/generator/loss_decoded_mel",
+                loss_decoded_mel,
                 on_step=True,
                 on_epoch=False,
-                prog_bar=True,
+                prog_bar=False,
                 logger=True,
                 sync_dist=True,
             )
 
-        self.log(
-            "train/generator/loss_decoded_mel",
-            loss_decoded_mel,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
         self.log(
             "train/generator/loss_mel",
             loss_mel,
@@ -306,18 +379,21 @@ class VQGAN(L.LightningModule):
             logger=True,
             sync_dist=True,
         )
+
+        if self.multi_resolution_stft_loss is not None:
+            self.log(
+                "train/generator/loss_stft",
+                loss_stft,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
+
         self.log(
-            "train/generator/loss_fm",
-            loss_fm,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-        self.log(
-            "train/generator/loss_adv",
-            loss_adv,
+            "train/generator/loss_fm_all",
+            loss_fm_all,
             on_step=True,
             on_epoch=False,
             prog_bar=False,
@@ -325,8 +401,8 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
         self.log(
-            "train/generator/loss_vq",
-            loss_vq,
+            "train/generator/loss_adv_all",
+            loss_adv_all,
             on_step=True,
             on_epoch=False,
             prog_bar=False,
@@ -336,11 +412,11 @@ class VQGAN(L.LightningModule):
 
         optim_g.zero_grad()
 
-        # Only backpropagate loss_vq_all in pretrain-stage1
-        if self.mode == "pretrain-stage1":
-            self.manual_backward(loss_vq_all)
+        # Only backpropagate loss_vq_all in pretrain stage
+        if self.mode == "pretrain":
+            self.manual_backward(loss_vq_all, retain_graph=True)
 
-        self.manual_backward(loss_gen_all)
+        self.manual_backward(fake_audios, gradient=generator_out_grad)
         self.clip_gradients(
             optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
         )
@@ -357,44 +433,26 @@ class VQGAN(L.LightningModule):
         audios = audios.float()
         audios = audios[:, None, :]
 
-        features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
-
-        if self.downsample is not None:
-            features = self.downsample(features)
-
+        gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
         mel_lengths = audio_lengths // self.hop_length
-        feature_lengths = (
-            audio_lengths
-            / self.hop_length
-            / (self.downsample.total_strides if self.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
             gt_mels.dtype
         )
 
-        # vq_features is 50 hz, need to convert to true mel size
-        text_features = self.mel_encoder(features, feature_masks)
-        text_features, _, _ = self.vq_encoder(text_features, feature_masks)
-        text_features = F.interpolate(
-            text_features, size=gt_mels.shape[2], mode="nearest"
+        vq_result = self.encode(audios, audio_lengths)
+        decoded = self.decode(
+            indices=vq_result.indices,
+            audio_lengths=audio_lengths,
+            mel_only=self.mode == "pretrain",
         )
 
-        # Sample mels
-        if self.decoder is not None:
-            speaker_features = (
-                self.speaker_encoder(gt_mels, mel_masks)
-                if self.speaker_encoder is not None
-                else None
-            )
-            decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
-        else:
-            decoded_mels = text_features
+        decoded_mels = decoded.mels
 
-        fake_audios = self.generator(decoded_mels)
+        # Use gt mel as input for pretrain
+        if self.mode == "pretrain":
+            fake_audios = self.generator(gt_mels)
+        else:
+            fake_audios = decoded.audios
 
         fake_mels = self.mel_transform(fake_audios.squeeze(1))
 
@@ -487,3 +545,92 @@ class VQGAN(L.LightningModule):
                 )
 
             plt.close(image_mels)
+
+    def encode(self, audios, audio_lengths=None):
+        if audio_lengths is None:
+            audio_lengths = torch.tensor(
+                [audios.shape[-1]] * audios.shape[0],
+                device=audios.device,
+                dtype=torch.long,
+            )
+
+        with torch.no_grad():
+            features = self.mel_transform(audios, sample_rate=self.sampling_rate)
+
+        if self.downsample is not None:
+            features = self.downsample(features)
+
+        feature_lengths = (
+            audio_lengths
+            / self.hop_length
+            / (self.downsample.total_strides if self.downsample is not None else 1)
+        ).long()
+
+        feature_masks = torch.unsqueeze(
+            sequence_mask(feature_lengths, features.shape[2]), 1
+        ).to(features.dtype)
+
+        text_features = self.mel_encoder(features, feature_masks)
+        vq_features, indices, loss = self.vq_encoder(text_features, feature_masks)
+
+        return VQEncodeResult(
+            features=vq_features,
+            indices=indices,
+            loss=loss,
+            feature_lengths=feature_lengths,
+        )
+
+    def calculate_audio_lengths(self, feature_lengths):
+        return (
+            feature_lengths
+            * self.hop_length
+            * (self.downsample.total_strides if self.downsample is not None else 1)
+        )
+
+    def decode(
+        self,
+        indices=None,
+        features=None,
+        audio_lengths=None,
+        mel_only=False,
+        feature_lengths=None,
+    ):
+        assert (
+            indices is not None or features is not None
+        ), "indices or features must be provided"
+        assert (
+            feature_lengths is not None or audio_lengths is not None
+        ), "feature_lengths or audio_lengths must be provided"
+
+        if audio_lengths is None:
+            audio_lengths = self.calculate_audio_lengths(feature_lengths)
+
+        mel_lengths = audio_lengths // self.hop_length
+        mel_masks = torch.unsqueeze(
+            sequence_mask(mel_lengths, torch.max(mel_lengths)), 1
+        ).float()
+
+        if indices is not None:
+            features = self.vq_encoder.decode(indices)
+
+        features = F.interpolate(features, size=mel_masks.shape[2], mode="nearest")
+
+        # Sample mels
+        if self.decoder is not None:
+            decoded_mels = self.decoder(features, mel_masks)
+        else:
+            decoded_mels = features
+
+        if mel_only:
+            return VQDecodeResult(
+                audios=None,
+                mels=decoded_mels,
+                mel_lengths=mel_lengths,
+            )
+
+        fake_audios = self.generator(decoded_mels)
+        return VQDecodeResult(
+            audios=fake_audios,
+            mels=decoded_mels,
+            mel_lengths=mel_lengths,
+        )

+ 135 - 5
fish_speech/models/vqgan/losses.py

@@ -1,9 +1,9 @@
-from typing import List
-
 import torch
+import torch.nn.functional as F
+from torch import nn
 
 
-def feature_loss(fmap_r: List[torch.Tensor], fmap_g: List[torch.Tensor]):
+def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]):
     loss = 0
     for dr, dg in zip(fmap_r, fmap_g):
         for rl, gl in zip(dr, dg):
@@ -15,7 +15,7 @@ def feature_loss(fmap_r: List[torch.Tensor], fmap_g: List[torch.Tensor]):
 
 
 def discriminator_loss(
-    disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
+    disc_real_outputs: list[torch.Tensor], disc_generated_outputs: list[torch.Tensor]
 ):
     loss = 0
     r_losses = []
@@ -32,7 +32,7 @@ def discriminator_loss(
     return loss, r_losses, g_losses
 
 
-def generator_loss(disc_outputs: List[torch.Tensor]):
+def generator_loss(disc_outputs: list[torch.Tensor]):
     loss = 0
     gen_losses = []
     for dg in disc_outputs:
@@ -66,3 +66,133 @@ def kl_loss(
     kl = torch.sum(kl * z_mask)
     l = kl / torch.sum(z_mask)
     return l
+
+
+def stft(x, fft_size, hop_size, win_length, window):
+    """Perform STFT and convert to magnitude spectrogram.
+    Args:
+        x (Tensor): Input signal tensor (B, T).
+        fft_size (int): FFT size.
+        hop_size (int): Hop size.
+        win_length (int): Window length.
+        window (str): Window function type.
+    Returns:
+        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
+    """
+    spec = torch.stft(
+        x,
+        fft_size,
+        hop_size,
+        win_length,
+        window,
+        return_complex=True,
+        pad_mode="reflect",
+    )
+    spec = torch.view_as_real(spec)
+
+    # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
+    return torch.sqrt(torch.clamp(spec.pow(2).sum(-1), min=1e-6)).transpose(2, 1)
+
+
+class SpectralConvergengeLoss(nn.Module):
+    """Spectral convergence loss module."""
+
+    def __init__(self):
+        """Initialize spectral convergence loss module."""
+        super(SpectralConvergengeLoss, self).__init__()
+
+    def forward(self, x_mag, y_mag):
+        """Calculate forward propagation.
+        Args:
+            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+        Returns:
+            Tensor: Spectral convergence loss value.
+        """  # noqa: E501
+
+        return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
+
+
+class LogSTFTMagnitudeLoss(nn.Module):
+    """Log STFT magnitude loss module."""
+
+    def __init__(self):
+        """Initialize los STFT magnitude loss module."""
+        super(LogSTFTMagnitudeLoss, self).__init__()
+
+    def forward(self, x_mag, y_mag):
+        """Calculate forward propagation.
+        Args:
+            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+        Returns:
+            Tensor: Log STFT magnitude loss value.
+        """  # noqa: E501
+
+        return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
+
+
+class STFTLoss(nn.Module):
+    """STFT loss module."""
+
+    def __init__(
+        self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window
+    ):
+        """Initialize STFT loss module."""
+        super(STFTLoss, self).__init__()
+
+        self.fft_size = fft_size
+        self.shift_size = shift_size
+        self.win_length = win_length
+        self.register_buffer("window", window(win_length))
+        self.spectral_convergenge_loss = SpectralConvergengeLoss()
+        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+        Args:
+            x (Tensor): Predicted signal (B, T).
+            y (Tensor): Groundtruth signal (B, T).
+        Returns:
+            Tensor: Spectral convergence loss value.
+            Tensor: Log STFT magnitude loss value.
+        """
+
+        x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
+        y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
+        sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+        mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+        return sc_loss, mag_loss
+
+
+class MultiResolutionSTFTLoss(nn.Module):
+    """Multi resolution STFT loss module."""
+
+    def __init__(self, resolutions, window=torch.hann_window):
+        super(MultiResolutionSTFTLoss, self).__init__()
+
+        self.stft_losses = nn.ModuleList()
+        for fs, ss, wl in resolutions:
+            self.stft_losses += [STFTLoss(fs, ss, wl, window)]
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+        Args:
+            x (Tensor): Predicted signal (B, T).
+            y (Tensor): Groundtruth signal (B, T).
+        Returns:
+            Tensor: Multi resolution spectral convergence loss value.
+            Tensor: Multi resolution log STFT magnitude loss value.
+        """
+        sc_loss = 0.0
+        mag_loss = 0.0
+        for f in self.stft_losses:
+            sc_l, mag_l = f(x, y)
+            sc_loss += sc_l
+            mag_loss += mag_l
+
+        sc_loss /= len(self.stft_losses)
+        mag_loss /= len(self.stft_losses)
+
+        return sc_loss, mag_loss

+ 193 - 0
fish_speech/models/vqgan/modules/balancer.py

@@ -0,0 +1,193 @@
+import typing as tp
+from collections import defaultdict
+
+import torch
+from torch import autograd
+
+
+def rank():
+    if torch.distributed.is_initialized():
+        return torch.distributed.get_rank()
+    else:
+        return 0
+
+
+def world_size():
+    if torch.distributed.is_initialized():
+        return torch.distributed.get_world_size()
+    else:
+        return 1
+
+
+def is_distributed():
+    return world_size() > 1
+
+
+def average_metrics(metrics: tp.Dict[str, float], count=1.0):
+    """Average a dictionary of metrics across all workers, using the optional
+    `count` as unnormalized weight.
+    """
+    if not is_distributed():
+        return metrics
+    keys, values = zip(*metrics.items())
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
+    tensor *= count
+    all_reduce(tensor)
+    averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
+    return dict(zip(keys, averaged))
+
+
+def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
+    if is_distributed():
+        return torch.distributed.all_reduce(tensor, op)
+
+
+def averager(beta: float = 1):
+    """
+    Exponential Moving Average callback.
+    Returns a single function that can be called to repeatidly update the EMA
+    with a dict of metrics. The callback will return
+    the new averaged dict of metrics.
+
+    Note that for `beta=1`, this is just plain averaging.
+    """
+    fix: tp.Dict[str, float] = defaultdict(float)
+    total: tp.Dict[str, float] = defaultdict(float)
+
+    def _update(
+        metrics: tp.Dict[str, tp.Any], weight: float = 1
+    ) -> tp.Dict[str, float]:
+        nonlocal total, fix
+        for key, value in metrics.items():
+            total[key] = total[key] * beta + weight * float(value)
+            fix[key] = fix[key] * beta + weight
+        return {key: tot / fix[key] for key, tot in total.items()}
+
+    return _update
+
+
+class Balancer:
+    """Loss balancer.
+
+    The loss balancer combines losses together to compute gradients for the backward.
+    A call to the balancer will weight the losses according the specified weight coefficients.
+    A call to the backward method of the balancer will compute the gradients, combining all the losses and
+    potentially rescaling the gradients, which can help stabilize the training and reasonate
+    about multiple losses with varying scales.
+
+    Expected usage:
+        weights = {'loss_a': 1, 'loss_b': 4}
+        balancer = Balancer(weights, ...)
+        losses: dict = {}
+        losses['loss_a'] = compute_loss_a(x, y)
+        losses['loss_b'] = compute_loss_b(x, y)
+        if model.training():
+            balancer.backward(losses, x)
+
+    ..Warning:: It is unclear how this will interact with DistributedDataParallel,
+        in particular if you have some losses not handled by the balancer. In that case
+        you can use `encodec.distrib.sync_grad(model.parameters())` and
+        `encodec.distrib.sync_buffwers(model.buffers())` as a safe alternative.
+
+    Args:
+        weights (Dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
+            from the backward method to match the weights keys to assign weight to each of the provided loss.
+        rescale_grads (bool): Whether to rescale gradients or not, without. If False, this is just
+            a regular weighted sum of losses.
+        total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
+        emay_decay (float): EMA decay for averaging the norms when `rescale_grads` is True.
+        per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
+            when rescaling the gradients.
+        epsilon (float): Epsilon value for numerical stability.
+        monitor (bool): Whether to store additional ratio for each loss key in metrics.
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        weights: tp.Dict[str, float],
+        rescale_grads: bool = True,
+        total_norm: float = 1.0,
+        ema_decay: float = 0.999,
+        per_batch_item: bool = True,
+        epsilon: float = 1e-12,
+        monitor: bool = False,
+    ):
+        self.weights = weights
+        self.per_batch_item = per_batch_item
+        self.total_norm = total_norm
+        self.averager = averager(ema_decay)
+        self.epsilon = epsilon
+        self.monitor = monitor
+        self.rescale_grads = rescale_grads
+        self._metrics: tp.Dict[str, tp.Any] = {}
+
+    @property
+    def metrics(self):
+        return self._metrics
+
+    def compute(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor):
+        norms = {}
+        grads = {}
+        for name, loss in losses.items():
+            (grad,) = autograd.grad(loss, [input], retain_graph=True)
+            if self.per_batch_item:
+                dims = tuple(range(1, grad.dim()))
+                norm = grad.norm(dim=dims).mean()
+            else:
+                norm = grad.norm()
+            norms[name] = norm
+            grads[name] = grad
+
+        count = 1
+        if self.per_batch_item:
+            count = len(grad)
+        avg_norms = average_metrics(self.averager(norms), count)
+        total = sum(avg_norms.values())
+
+        self._metrics = {}
+        if self.monitor:
+            for k, v in avg_norms.items():
+                self._metrics[f"ratio_{k}"] = v / total
+
+        total_weights = sum([self.weights[k] for k in avg_norms])
+        ratios = {k: w / total_weights for k, w in self.weights.items()}
+
+        out_grad: tp.Any = 0
+        for name, avg_norm in avg_norms.items():
+            if self.rescale_grads:
+                scale = ratios[name] * self.total_norm / (self.epsilon + avg_norm)
+                grad = grads[name] * scale
+            else:
+                grad = self.weights[name] * grads[name]
+            out_grad += grad
+
+        return out_grad
+
+
+def test():
+    from torch.nn import functional as F
+
+    x = torch.zeros(1, requires_grad=True)
+    one = torch.ones_like(x)
+    loss_1 = F.l1_loss(x, one)
+    loss_2 = 100 * F.l1_loss(x, -one)
+    losses = {"1": loss_1, "2": loss_2}
+
+    balancer = Balancer(weights={"1": 1, "2": 1}, rescale_grads=False)
+    out_grad = balancer.compute(losses, x)
+    x.backward(out_grad)
+    assert torch.allclose(x.grad, torch.tensor(99.0)), x.grad
+
+    loss_1 = F.l1_loss(x, one)
+    loss_2 = 100 * F.l1_loss(x, -one)
+    losses = {"1": loss_1, "2": loss_2}
+    x.grad = None
+    balancer = Balancer(weights={"1": 1, "2": 1}, rescale_grads=True)
+    out_grad = balancer.compute({"1": loss_1, "2": loss_2}, x)
+    x.backward(out_grad)
+    assert torch.allclose(x.grad, torch.tensor(0.0)), x.grad
+
+
+if __name__ == "__main__":
+    test()

+ 270 - 0
fish_speech/models/vqgan/modules/decoder_v2.py

@@ -0,0 +1,270 @@
+from functools import partial
+from math import prod
+from typing import Callable
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Conv1d
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return (kernel_size * dilation - dilation) // 2
+
+
+class ResBlock(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super().__init__()
+
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.silu(x)
+            xt = c1(xt)
+            xt = F.silu(xt)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_parametrizations(self):
+        for conv in self.convs1:
+            remove_parametrizations(conv)
+        for conv in self.convs2:
+            remove_parametrizations(conv)
+
+
+class ParralelBlock(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        kernel_sizes: tuple[int] = (3, 7, 11),
+        dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+    ):
+        super().__init__()
+
+        assert len(kernel_sizes) == len(dilation_sizes)
+
+        self.blocks = nn.ModuleList()
+        for k, d in zip(kernel_sizes, dilation_sizes):
+            self.blocks.append(ResBlock(channels, k, d))
+
+    def forward(self, x):
+        xs = [block(x) for block in self.blocks]
+
+        return torch.stack(xs, dim=0).mean(dim=0)
+
+
+class HiFiGANGenerator(nn.Module):
+    def __init__(
+        self,
+        *,
+        hop_length: int = 512,
+        upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
+        upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
+        resblock_kernel_sizes: tuple[int] = (3, 7, 11),
+        resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+        num_mels: int = 128,
+        upsample_initial_channel: int = 512,
+        use_template: bool = True,
+        pre_conv_kernel_size: int = 7,
+        post_conv_kernel_size: int = 7,
+        post_activation: Callable = partial(nn.SiLU, inplace=True),
+        checkpointing: bool = False,
+    ):
+        super().__init__()
+
+        assert (
+            prod(upsample_rates) == hop_length
+        ), f"hop_length must be {prod(upsample_rates)}"
+
+        self.conv_pre = weight_norm(
+            nn.Conv1d(
+                num_mels,
+                upsample_initial_channel,
+                pre_conv_kernel_size,
+                1,
+                padding=get_padding(pre_conv_kernel_size),
+            )
+        )
+
+        self.hop_length = hop_length
+        self.num_upsamples = len(upsample_rates)
+        self.num_kernels = len(resblock_kernel_sizes)
+
+        self.noise_convs = nn.ModuleList()
+        self.use_template = use_template
+        self.ups = nn.ModuleList()
+
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            c_cur = upsample_initial_channel // (2 ** (i + 1))
+            self.ups.append(
+                weight_norm(
+                    nn.ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+            if not use_template:
+                continue
+
+            if i + 1 < len(upsample_rates):
+                stride_f0 = np.prod(upsample_rates[i + 1 :])
+                self.noise_convs.append(
+                    Conv1d(
+                        1,
+                        c_cur,
+                        kernel_size=stride_f0 * 2,
+                        stride=stride_f0,
+                        padding=stride_f0 // 2,
+                    )
+                )
+            else:
+                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            self.resblocks.append(
+                ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
+            )
+
+        self.activation_post = post_activation()
+        self.conv_post = weight_norm(
+            nn.Conv1d(
+                ch,
+                1,
+                post_conv_kernel_size,
+                1,
+                padding=get_padding(post_conv_kernel_size),
+            )
+        )
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+        # Gradient checkpointing
+        self.checkpointing = checkpointing
+
+    def forward(self, x, template=None):
+        if self.use_template and template is None:
+            length = x.shape[-1] * self.hop_length
+            template = (
+                torch.randn(x.shape[0], 1, length, device=x.device, dtype=x.dtype)
+                * 0.003
+            )
+
+        x = self.conv_pre(x)
+
+        for i in range(self.num_upsamples):
+            x = F.silu(x, inplace=True)
+            x = self.ups[i](x)
+
+            if self.use_template:
+                x = x + self.noise_convs[i](template)
+
+            if self.training and self.checkpointing:
+                x = torch.utils.checkpoint.checkpoint(
+                    self.resblocks[i],
+                    x,
+                    use_reentrant=False,
+                )
+            else:
+                x = self.resblocks[i](x)
+
+        x = self.activation_post(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_parametrizations(self):
+        for up in self.ups:
+            remove_parametrizations(up)
+        for block in self.resblocks:
+            block.remove_parametrizations()
+        remove_parametrizations(self.conv_pre)
+        remove_parametrizations(self.conv_post)

+ 0 - 166
fish_speech/models/vqgan/modules/discriminator.py

@@ -1,166 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn.utils import spectral_norm, weight_norm
-
-from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
-from fish_speech.models.vqgan.utils import get_padding
-
-
-class DiscriminatorP(nn.Module):
-    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
-        super(DiscriminatorP, self).__init__()
-        self.period = period
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(
-                    nn.Conv2d(
-                        1,
-                        32,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    nn.Conv2d(
-                        32,
-                        128,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    nn.Conv2d(
-                        128,
-                        512,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    nn.Conv2d(
-                        512,
-                        1024,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    nn.Conv2d(
-                        1024,
-                        1024,
-                        (kernel_size, 1),
-                        1,
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-            ]
-        )
-        self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
-
-    def forward(self, x):
-        fmap = []
-
-        # 1d to 2d
-        b, c, t = x.shape
-        if t % self.period != 0:  # pad first
-            n_pad = self.period - (t % self.period)
-            x = F.pad(x, (0, n_pad), "reflect")
-            t = t + n_pad
-        x = x.view(b, c, t // self.period, self.period)
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class DiscriminatorS(nn.Module):
-    def __init__(self, use_spectral_norm=False):
-        super(DiscriminatorS, self).__init__()
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
-                norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
-                norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
-                norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
-                norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
-                norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
-                norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
-            ]
-        )
-        self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
-
-    def forward(self, x):
-        fmap = []
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class EnsembleDiscriminator(nn.Module):
-    def __init__(self, ckpt_path=None, periods=(2, 3, 5, 7, 11)):
-        super(EnsembleDiscriminator, self).__init__()
-
-        discs = [DiscriminatorS(use_spectral_norm=True)]
-        discs = discs + [DiscriminatorP(i, use_spectral_norm=False) for i in periods]
-        self.discriminators = nn.ModuleList(discs)
-
-        if ckpt_path is not None:
-            self.restore_from_ckpt(ckpt_path)
-
-    def restore_from_ckpt(self, ckpt_path):
-        ckpt = torch.load(ckpt_path, map_location="cpu")
-        mpd, msd = ckpt["mpd"], ckpt["msd"]
-
-        all_keys = {}
-        for k, v in mpd.items():
-            keys = k.split(".")
-            keys[1] = str(int(keys[1]) + 1)
-            all_keys[".".join(keys)] = v
-
-        for k, v in msd.items():
-            if not k.startswith("discriminators.0"):
-                continue
-            all_keys[k] = v
-
-        self.load_state_dict(all_keys, strict=True)
-
-    def forward(self, y, y_hat):
-        y_d_rs = []
-        y_d_gs = []
-        fmap_rs = []
-        fmap_gs = []
-        for i, d in enumerate(self.discriminators):
-            y_d_r, fmap_r = d(y)
-            y_d_g, fmap_g = d(y_hat)
-            y_d_rs.append(y_d_r)
-            y_d_gs.append(y_d_g)
-            fmap_rs.append(fmap_r)
-            fmap_gs.append(fmap_g)
-
-        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
-
-
-if __name__ == "__main__":
-    m = EnsembleDiscriminator(
-        ckpt_path="checkpoints/hifigan-v1-universal-22050/do_02500000"
-    )

+ 80 - 0
fish_speech/models/vqgan/modules/discriminators/mpd.py

@@ -0,0 +1,80 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.parametrizations import weight_norm
+
+
+class DiscriminatorP(nn.Module):
+    def __init__(
+        self,
+        *,
+        period: int,
+        kernel_size: int = 5,
+        stride: int = 3,
+        channels: tuple[int] = (1, 64, 128, 256, 512, 1024),
+    ) -> None:
+        super(DiscriminatorP, self).__init__()
+
+        self.period = period
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    nn.Conv2d(
+                        in_channels,
+                        out_channels,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(kernel_size // 2, 0),
+                    )
+                )
+                for in_channels, out_channels in zip(channels[:-1], channels[1:])
+            ]
+        )
+
+        self.conv_post = weight_norm(
+            nn.Conv2d(channels[-1], 1, (3, 1), 1, padding=(1, 0))
+        )
+
+    def forward(self, x):
+        fmap = []
+
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), "constant")
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        for conv in self.convs:
+            x = conv(x)
+            x = F.silu(x, inplace=True)
+            fmap.append(x)
+
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class MultiPeriodDiscriminator(nn.Module):
+    def __init__(self, periods: tuple[int] = (2, 3, 5, 7, 11)) -> None:
+        super().__init__()
+
+        self.discriminators = nn.ModuleList(
+            [DiscriminatorP(period=period) for period in periods]
+        )
+
+    def forward(
+        self, x: torch.Tensor
+    ) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]:
+        scores, feature_map = [], []
+
+        for disc in self.discriminators:
+            res, fmap = disc(x)
+
+            scores.append(res)
+            feature_map.append(fmap)
+
+        return scores, feature_map

+ 100 - 0
fish_speech/models/vqgan/modules/discriminators/mrd.py

@@ -0,0 +1,100 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.parametrizations import weight_norm
+
+
+class DiscriminatorR(torch.nn.Module):
+    def __init__(
+        self,
+        *,
+        n_fft: int = 1024,
+        hop_length: int = 120,
+        win_length: int = 600,
+    ):
+        super(DiscriminatorR, self).__init__()
+
+        self.n_fft = n_fft
+        self.hop_length = hop_length
+        self.win_length = win_length
+
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
+                weight_norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+                weight_norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+                weight_norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+                weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
+            ]
+        )
+
+        self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
+
+    def forward(self, x):
+        fmap = []
+
+        x = self.spectrogram(x)
+        x = x.unsqueeze(1)
+
+        for conv in self.convs:
+            x = conv(x)
+            x = F.silu(x, inplace=True)
+            fmap.append(x)
+
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+    def spectrogram(self, x):
+        x = F.pad(
+            x,
+            (
+                (self.n_fft - self.hop_length) // 2,
+                (self.n_fft - self.hop_length + 1) // 2,
+            ),
+            mode="reflect",
+        )
+        x = x.squeeze(1)
+        x = torch.stft(
+            x,
+            n_fft=self.n_fft,
+            hop_length=self.hop_length,
+            win_length=self.win_length,
+            center=False,
+            return_complex=True,
+        )
+        x = torch.view_as_real(x)  # [B, F, TT, 2]
+        mag = torch.norm(x, p=2, dim=-1)  # [B, F, TT]
+
+        return mag
+
+
+class MultiResolutionDiscriminator(torch.nn.Module):
+    def __init__(self, resolutions: list[tuple[int]]):
+        super().__init__()
+
+        self.discriminators = nn.ModuleList(
+            [
+                DiscriminatorR(
+                    n_fft=n_fft,
+                    hop_length=hop_length,
+                    win_length=win_length,
+                )
+                for n_fft, hop_length, win_length in resolutions
+            ]
+        )
+
+    def forward(
+        self, x: torch.Tensor
+    ) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]:
+        scores, feature_map = [], []
+
+        for disc in self.discriminators:
+            res, fmap = disc(x)
+
+            scores.append(res)
+            feature_map.append(fmap)
+
+        return scores, feature_map

+ 188 - 0
fish_speech/models/vqgan/modules/discriminators/mssbcqtd.py

@@ -0,0 +1,188 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Monkey patching to fix a bug in nnAudio
+import numpy as np
+import torch
+import torchaudio.transforms as T
+from einops import rearrange
+from nnAudio import features
+from torch import nn
+
+from .msstftd import NormConv2d, get_2d_padding
+
+np.float = float
+
+LRELU_SLOPE = 0.1
+
+
+class DiscriminatorCQT(nn.Module):
+    def __init__(
+        self,
+        hop_length,
+        n_octaves,
+        bins_per_octave,
+        filters=32,
+        max_filters=1024,
+        filters_scale=1,
+        dilations=[1, 2, 4],
+        in_channels=1,
+        out_channels=1,
+        sample_rate=16000,
+    ):
+        super().__init__()
+
+        self.filters = filters
+        self.max_filters = max_filters
+        self.filters_scale = filters_scale
+        self.kernel_size = (3, 9)
+        self.dilations = dilations
+        self.stride = (1, 2)
+
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.fs = sample_rate
+        self.hop_length = hop_length
+        self.n_octaves = n_octaves
+        self.bins_per_octave = bins_per_octave
+
+        self.cqt_transform = features.cqt.CQT2010v2(
+            sr=self.fs * 2,
+            hop_length=self.hop_length,
+            n_bins=self.bins_per_octave * self.n_octaves,
+            bins_per_octave=self.bins_per_octave,
+            output_format="Complex",
+            pad_mode="constant",
+        )
+
+        self.conv_pres = nn.ModuleList()
+        for i in range(self.n_octaves):
+            self.conv_pres.append(
+                NormConv2d(
+                    self.in_channels * 2,
+                    self.in_channels * 2,
+                    kernel_size=self.kernel_size,
+                    padding=get_2d_padding(self.kernel_size),
+                )
+            )
+
+        self.convs = nn.ModuleList()
+
+        self.convs.append(
+            NormConv2d(
+                self.in_channels * 2,
+                self.filters,
+                kernel_size=self.kernel_size,
+                padding=get_2d_padding(self.kernel_size),
+            )
+        )
+
+        in_chs = min(self.filters_scale * self.filters, self.max_filters)
+        for i, dilation in enumerate(self.dilations):
+            out_chs = min(
+                (self.filters_scale ** (i + 1)) * self.filters, self.max_filters
+            )
+            self.convs.append(
+                NormConv2d(
+                    in_chs,
+                    out_chs,
+                    kernel_size=self.kernel_size,
+                    stride=self.stride,
+                    dilation=(dilation, 1),
+                    padding=get_2d_padding(self.kernel_size, (dilation, 1)),
+                    norm="weight_norm",
+                )
+            )
+            in_chs = out_chs
+        out_chs = min(
+            (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
+            self.max_filters,
+        )
+        self.convs.append(
+            NormConv2d(
+                in_chs,
+                out_chs,
+                kernel_size=(self.kernel_size[0], self.kernel_size[0]),
+                padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
+                norm="weight_norm",
+            )
+        )
+
+        self.conv_post = NormConv2d(
+            out_chs,
+            self.out_channels,
+            kernel_size=(self.kernel_size[0], self.kernel_size[0]),
+            padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
+            norm="weight_norm",
+        )
+
+        self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE)
+        self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2)
+
+    def forward(self, x):
+        fmap = []
+
+        x = self.resample(x)
+
+        z = self.cqt_transform(x)
+
+        z_amplitude = z[:, :, :, 0].unsqueeze(1)
+        z_phase = z[:, :, :, 1].unsqueeze(1)
+
+        z = torch.cat([z_amplitude, z_phase], dim=1)
+        z = rearrange(z, "b c w t -> b c t w")
+
+        latent_z = []
+        for i in range(self.n_octaves):
+            latent_z.append(
+                self.conv_pres[i](
+                    z[
+                        :,
+                        :,
+                        :,
+                        i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
+                    ]
+                )
+            )
+        latent_z = torch.cat(latent_z, dim=-1)
+
+        for i, layer in enumerate(self.convs):
+            latent_z = layer(latent_z)
+
+            latent_z = self.activation(latent_z)
+            fmap.append(latent_z)
+
+        latent_z = self.conv_post(latent_z)
+
+        return latent_z, fmap
+
+
+class MultiScaleSubbandCQTDiscriminator(nn.Module):
+    def __init__(self, hop_lengths, n_octaves, bins_per_octaves, **kwargs):
+        super().__init__()
+
+        self.discriminators = nn.ModuleList(
+            [
+                DiscriminatorCQT(
+                    hop_length=hop_length,
+                    n_octaves=n_octaves,
+                    bins_per_octave=bins_per_octave,
+                    **kwargs,
+                )
+                for hop_length, n_octaves, bins_per_octave in zip(
+                    hop_lengths, n_octaves, bins_per_octaves
+                )
+            ]
+        )
+
+    def forward(self, x: torch.Tensor):
+        logits = []
+        fmaps = []
+        for disc in self.discriminators:
+            logit, fmap = disc(x)
+            logits.append(logit)
+            fmaps.append(fmap)
+
+        return logits, fmaps

+ 303 - 0
fish_speech/models/vqgan/modules/discriminators/msstftd.py

@@ -0,0 +1,303 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""MS-STFT discriminator, provided here for reference."""
+
+import typing as tp
+
+import einops
+import torch
+import torchaudio
+from einops import rearrange
+from torch import nn
+from torch.nn.utils import spectral_norm, weight_norm
+
+FeatureMapType = tp.List[torch.Tensor]
+LogitsType = torch.Tensor
+DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
+
+
+class ConvLayerNorm(nn.LayerNorm):
+    """
+    Convolution-friendly LayerNorm that moves channels to last dimensions
+    before running the normalization and moves them back to original position right after.
+    """  # noqa: E501
+
+    def __init__(
+        self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
+    ):
+        super().__init__(normalized_shape, **kwargs)
+
+    def forward(self, x):
+        x = einops.rearrange(x, "b ... t -> b t ...")
+        x = super().forward(x)
+        x = einops.rearrange(x, "b t ... -> b ... t")
+        return
+
+
+CONV_NORMALIZATIONS = frozenset(
+    [
+        "none",
+        "weight_norm",
+        "spectral_norm",
+        "time_layer_norm",
+        "layer_norm",
+        "time_group_norm",
+    ]
+)
+
+
+def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
+    assert norm in CONV_NORMALIZATIONS
+    if norm == "weight_norm":
+        return weight_norm(module)
+    elif norm == "spectral_norm":
+        return spectral_norm(module)
+    else:
+        # We already check was in CONV_NORMALIZATION, so any other choice
+        # doesn't need reparametrization.
+        return module
+
+
+def get_norm_module(
+    module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
+) -> nn.Module:
+    """Return the proper normalization module. If causal is True, this will ensure the returned
+    module is causal, or return an error if the normalization doesn't support causal evaluation.
+    """  # noqa: E501
+    assert norm in CONV_NORMALIZATIONS
+    if norm == "layer_norm":
+        assert isinstance(module, nn.modules.conv._ConvNd)
+        return ConvLayerNorm(module.out_channels, **norm_kwargs)
+    elif norm == "time_group_norm":
+        if causal:
+            raise ValueError("GroupNorm doesn't support causal evaluation.")
+        assert isinstance(module, nn.modules.conv._ConvNd)
+        return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+    else:
+        return nn.Identity()
+
+
+class NormConv2d(nn.Module):
+    """Wrapper around Conv2d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+
+    def __init__(
+        self,
+        *args,
+        norm: str = "none",
+        norm_kwargs: tp.Dict[str, tp.Any] = {},
+        **kwargs,
+    ):
+        super().__init__()
+        self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.norm(x)
+        return x
+
+
+def get_2d_padding(
+    kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)
+):
+    return (
+        ((kernel_size[0] - 1) * dilation[0]) // 2,
+        ((kernel_size[1] - 1) * dilation[1]) // 2,
+    )
+
+
+class DiscriminatorSTFT(nn.Module):
+    """STFT sub-discriminator.
+    Args:
+        filters (int): Number of filters in convolutions
+        in_channels (int): Number of input channels. Default: 1
+        out_channels (int): Number of output channels. Default: 1
+        n_fft (int): Size of FFT for each scale. Default: 1024
+        hop_length (int): Length of hop between STFT windows for each scale. Default: 256
+        kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
+        stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
+        dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
+        win_length (int): Window size for each scale. Default: 1024
+        normalized (bool): Whether to normalize by magnitude after stft. Default: True
+        norm (str): Normalization method. Default: `'weight_norm'`
+        activation (str): Activation function. Default: `'LeakyReLU'`
+        activation_params (dict): Parameters to provide to the activation function.
+        growth (int): Growth factor for the filters. Default: 1
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        filters: int,
+        in_channels: int = 1,
+        out_channels: int = 1,
+        n_fft: int = 1024,
+        hop_length: int = 256,
+        win_length: int = 1024,
+        max_filters: int = 1024,
+        filters_scale: int = 1,
+        kernel_size: tp.Tuple[int, int] = (3, 9),
+        dilations: tp.List = [1, 2, 4],
+        stride: tp.Tuple[int, int] = (1, 2),
+        normalized: bool = True,
+        norm: str = "weight_norm",
+        activation: str = "LeakyReLU",
+        activation_params: dict = {"negative_slope": 0.2},
+    ):
+        super().__init__()
+        assert len(kernel_size) == 2
+        assert len(stride) == 2
+        self.filters = filters
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.n_fft = n_fft
+        self.hop_length = hop_length
+        self.win_length = win_length
+        self.normalized = normalized
+        self.activation = getattr(torch.nn, activation)(**activation_params)
+        self.spec_transform = torchaudio.transforms.Spectrogram(
+            n_fft=self.n_fft,
+            hop_length=self.hop_length,
+            win_length=self.win_length,
+            window_fn=torch.hann_window,
+            normalized=self.normalized,
+            center=False,
+            pad_mode=None,
+            power=None,
+        )
+        spec_channels = 2 * self.in_channels
+        self.convs = nn.ModuleList()
+        self.convs.append(
+            NormConv2d(
+                spec_channels,
+                self.filters,
+                kernel_size=kernel_size,
+                padding=get_2d_padding(kernel_size),
+            )
+        )
+        in_chs = min(filters_scale * self.filters, max_filters)
+        for i, dilation in enumerate(dilations):
+            out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
+            self.convs.append(
+                NormConv2d(
+                    in_chs,
+                    out_chs,
+                    kernel_size=kernel_size,
+                    stride=stride,
+                    dilation=(dilation, 1),
+                    padding=get_2d_padding(kernel_size, (dilation, 1)),
+                    norm=norm,
+                )
+            )
+            in_chs = out_chs
+        out_chs = min(
+            (filters_scale ** (len(dilations) + 1)) * self.filters, max_filters
+        )
+        self.convs.append(
+            NormConv2d(
+                in_chs,
+                out_chs,
+                kernel_size=(kernel_size[0], kernel_size[0]),
+                padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+                norm=norm,
+            )
+        )
+        self.conv_post = NormConv2d(
+            out_chs,
+            self.out_channels,
+            kernel_size=(kernel_size[0], kernel_size[0]),
+            padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+            norm=norm,
+        )
+
+    def forward(self, x: torch.Tensor):
+        fmap = []
+        z = self.spec_transform(x)  # [B, 2, Freq, Frames, 2]
+        z = torch.cat([z.real, z.imag], dim=1)
+        z = rearrange(z, "b c w t -> b c t w")
+        for i, layer in enumerate(self.convs):
+            z = layer(z)
+            z = self.activation(z)
+            fmap.append(z)
+        z = self.conv_post(z)
+        return z, fmap
+
+
+class MultiScaleSTFTDiscriminator(nn.Module):
+    """Multi-Scale STFT (MS-STFT) discriminator.
+    Args:
+        filters (int): Number of filters in convolutions
+        in_channels (int): Number of input channels. Default: 1
+        out_channels (int): Number of output channels. Default: 1
+        n_ffts (Sequence[int]): Size of FFT for each scale
+        hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
+        win_lengths (Sequence[int]): Window size for each scale
+        **kwargs: additional args for STFTDiscriminator
+    """
+
+    def __init__(
+        self,
+        filters: int,
+        in_channels: int = 1,
+        out_channels: int = 1,
+        n_ffts: tp.List[int] = [1024, 2048, 512],
+        hop_lengths: tp.List[int] = [256, 512, 128],
+        win_lengths: tp.List[int] = [1024, 2048, 512],
+        **kwargs,
+    ):
+        super().__init__()
+        assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
+        self.discriminators = nn.ModuleList(
+            [
+                DiscriminatorSTFT(
+                    filters,
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    n_fft=n_ffts[i],
+                    win_length=win_lengths[i],
+                    hop_length=hop_lengths[i],
+                    **kwargs,
+                )
+                for i in range(len(n_ffts))
+            ]
+        )
+        self.num_discriminators = len(self.discriminators)
+
+    def forward(self, x: torch.Tensor) -> DiscriminatorOutput:
+        logits = []
+        fmaps = []
+        for disc in self.discriminators:
+            logit, fmap = disc(x)
+            logits.append(logit)
+            fmaps.append(fmap)
+
+        return logits, fmaps
+
+
+def test():
+    disc = MultiScaleSTFTDiscriminator(filters=32)
+    y = torch.randn(1, 1, 24000)
+    y_hat = torch.randn(1, 1, 24000)
+
+    y_disc_r, fmap_r = disc(y)
+    y_disc_gen, fmap_gen = disc(y_hat)
+    assert (
+        len(y_disc_r)
+        == len(y_disc_gen)
+        == len(fmap_r)
+        == len(fmap_gen)
+        == disc.num_discriminators
+    )
+    assert all([len(fm) == 5 for fm in fmap_r + fmap_gen])
+    assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm])
+    assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen])
+
+
+if __name__ == "__main__":
+    test()

+ 63 - 4
fish_speech/models/vqgan/modules/encoders.py

@@ -5,6 +5,7 @@ import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from einops import rearrange
 from vector_quantize_pytorch import LFQ, GroupedResidualVQ, VectorQuantize
 
 from fish_speech.models.vqgan.modules.modules import WN
@@ -298,6 +299,7 @@ class VQEncoder(nn.Module):
             )
 
         self.codebook_groups = codebook_groups
+        self.codebook_layers = codebook_layers
         self.downsample = downsample
         self.conv_in = nn.Conv1d(
             in_channels, vq_channels, kernel_size=downsample, stride=downsample
@@ -309,6 +311,17 @@ class VQEncoder(nn.Module):
             nn.Conv1d(vq_channels, in_channels, kernel_size=1, stride=1),
         )
 
+    @property
+    def mode(self):
+        if self.codebook_groups > 1 and self.codebook_layers > 1:
+            return "grouped-residual"
+        elif self.codebook_groups > 1:
+            return "grouped"
+        elif self.codebook_layers > 1:
+            return "residual"
+        else:
+            return "single"
+
     def forward(self, x, x_mask):
         # x: [B, C, T], x_mask: [B, 1, T]
         x_len = x.shape[2]
@@ -327,15 +340,61 @@ class VQEncoder(nn.Module):
         x = self.conv_out(q) * x_mask
         x = x[:, :, :x_len]
 
+        # Post process indices
+        if self.mode == "grouped-residual":
+            indices = rearrange(indices, "g b t r -> b (g r) t")
+        elif self.mode == "grouped":
+            indices = rearrange(indices, "g b t 1 -> b g t")
+        elif self.mode == "residual":
+            indices = rearrange(indices, "1 b t r -> b r t")
+        else:
+            indices = rearrange(indices, "b t -> b 1 t")
+
         return x, indices, loss
 
     def decode(self, indices):
+        # Undo rearrange
+        if self.mode == "grouped-residual":
+            indices = rearrange(indices, "b (g r) t -> g b t r", g=self.codebook_groups)
+        elif self.mode == "grouped":
+            indices = rearrange(indices, "b g t -> g b t 1")
+        elif self.mode == "residual":
+            indices = rearrange(indices, "b r t -> 1 b t r")
+        else:
+            indices = rearrange(indices, "b 1 t -> b t")
+
         q = self.vq.get_output_from_indices(indices)
 
-        if q.shape[1] != indices.shape[1] and indices.ndim != 4:
-            q = q.view(q.shape[0], indices.shape[1], -1)
-        q = q.mT
+        # Edge case for single vq
+        if self.mode == "single":
+            q = rearrange(q, "b (t c) -> b t c", t=indices.shape[-1])
 
-        x = self.conv_out(q)
+        x = self.conv_out(q.mT)
 
         return x
+
+
+if __name__ == "__main__":
+    # Test VQEncoder
+    for group, layer in [
+        (1, 1),
+        (1, 2),
+        (2, 1),
+        (2, 2),
+        (4, 1),
+        (4, 2),
+    ]:
+        encoder = VQEncoder(
+            in_channels=1024,
+            vq_channels=1024,
+            codebook_size=2048,
+            downsample=1,
+            codebook_groups=group,
+            codebook_layers=layer,
+            threshold_ema_dead_code=2,
+        )
+        x = torch.randn(2, 1024, 100)
+        x_mask = torch.ones(2, 1, 100)
+        x, indices, loss = encoder(x, x_mask)
+        x = encoder.decode(indices)
+        assert x.shape == (2, 1024, 100)

+ 1 - 0
pyproject.toml

@@ -34,6 +34,7 @@ dependencies = [
     "zibai-server>=0.9.0",
     "loguru>=0.6.0",
     "WeTextProcessing>=0.1.10",
+    "nnAudio>=0.3.2",
     "loralib>=0.1.2",
     "natsort>=8.4.0",
     "cn2an>=0.5.22"

+ 5 - 57
tools/api_server.py

@@ -138,36 +138,11 @@ class VQGANModel:
     def sematic_to_wav(self, indices):
         model = self.model
         indices = indices.to(model.device).long()
-        indices = indices.unsqueeze(1).unsqueeze(-1)
-
-        mel_lengths = indices.shape[2] * (
-            model.downsample.total_strides if model.downsample is not None else 1
-        )
-        mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
-        mel_masks = torch.ones(
-            (1, 1, mel_lengths), device=model.device, dtype=torch.float32
-        )
-
-        text_features = model.vq_encoder.decode(indices)
-
-        logger.info(
-            f"VQ Encoded, indices: {indices.shape} equivalent to "
-            + f"{1 / (mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
-        )
-
-        text_features = F.interpolate(
-            text_features, size=mel_lengths[0], mode="nearest"
-        )
-
-        # Sample mels
-        decoded_mels = model.decoder(text_features, mel_masks)
-        fake_audios = model.generator(decoded_mels)
-        logger.info(
-            f"Generated audio of shape {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
-        )
+        feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
+        decoded = model.decode(indices=indices[None], feature_lengths=feature_lengths)
 
         # Save audio
-        fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
+        fake_audio = decoded.audios[0, 0].cpu().numpy().astype(np.float32)
 
         return fake_audio, model.sampling_rate
 
@@ -189,37 +164,10 @@ class VQGANModel:
         audio_lengths = torch.tensor(
             [audios.shape[2]], device=model.device, dtype=torch.long
         )
-
-        features = gt_mels = model.mel_transform(
-            audios, sample_rate=model.sampling_rate
-        )
-
-        if model.downsample is not None:
-            features = model.downsample(features)
-
-        mel_lengths = audio_lengths // model.hop_length
-        feature_lengths = (
-            audio_lengths
-            / model.hop_length
-            / (model.downsample.total_strides if model.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
-
-        # vq_features is 50 hz, need to convert to true mel size
-        text_features = model.mel_encoder(features, feature_masks)
-        _, indices, _ = model.vq_encoder(text_features, feature_masks)
-
-        if indices.ndim == 4 and indices.shape[1] == 1 and indices.shape[3] == 1:
-            indices = indices[:, 0, :, 0]
-        else:
-            logger.error(f"Unknown indices shape: {indices.shape}")
-            return
+        encoded = model.encode(audios, audio_lengths)
+        indices = encoded.indices[0]
 
         logger.info(f"Generated indices of shape {indices.shape}")
-
         return indices
 
 

+ 5 - 33
tools/vqgan/extract_vq.py

@@ -90,43 +90,15 @@ def process_batch(files: list[Path], model) -> float:
 
     # Calculate lengths
     with torch.no_grad():
-        # VQ Encoder
-        features = gt_mels = model.mel_transform(
-            audios, sample_rate=model.sampling_rate
-        )
-
-        if model.downsample is not None:
-            features = model.downsample(features)
-
-        feature_lengths = (
-            audio_lengths
-            / model.hop_length
-            / (model.downsample.total_strides if model.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
-
-        text_features = model.mel_encoder(features, feature_masks)
-        _, indices, _ = model.vq_encoder(text_features, feature_masks)
-
-        if indices.ndim == 4:
-            # Grouped vq
-            assert indices.shape[-1] == 1, f"Residual vq is not supported"
-            indices = indices.squeeze(-1)
-        elif indices.ndim == 2:
-            # Single vq
-            indices = indices.unsqueeze(0)
-        else:
-            raise ValueError(f"Invalid indices shape {indices.shape}")
-
-        indices = rearrange(indices, "c b t -> b c t")
+        out = model.encode(audios, audio_lengths)
+        indices, feature_lengths = out.indices, out.feature_lengths
 
     # Save to disk
     outputs = indices.cpu().numpy()
 
-    for file, length, feature, audio in zip(files, feature_lengths, outputs, audios):
+    for file, length, feature, audio_length in zip(
+        files, feature_lengths, outputs, audio_lengths
+    ):
         feature = feature[:, :length]
 
         # (T,)

+ 7 - 52
tools/vqgan/inference.py

@@ -67,37 +67,8 @@ def main(input_path, output_path, config_name, checkpoint_path):
         audio_lengths = torch.tensor(
             [audios.shape[2]], device=model.device, dtype=torch.long
         )
-
-        features = gt_mels = model.mel_transform(
-            audios, sample_rate=model.sampling_rate
-        )
-
-        if model.downsample is not None:
-            features = model.downsample(features)
-
-        mel_lengths = audio_lengths // model.hop_length
-        feature_lengths = (
-            audio_lengths
-            / model.hop_length
-            / (model.downsample.total_strides if model.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
-        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
-            gt_mels.dtype
-        )
-
-        # vq_features is 50 hz, need to convert to true mel size
-        text_features = model.mel_encoder(features, feature_masks)
-        _, indices, _ = model.vq_encoder(text_features, feature_masks)
-
-        if indices.ndim == 4 and indices.shape[1] == 1 and indices.shape[3] == 1:
-            indices = indices[:, 0, :, 0]
-        else:
-            logger.error(f"Unknown indices shape: {indices.shape}")
-            return
+        encoded = model.encode(audios, audio_lengths)
+        indices = encoded.indices[0]
 
         logger.info(f"Generated indices of shape {indices.shape}")
 
@@ -112,29 +83,13 @@ def main(input_path, output_path, config_name, checkpoint_path):
         raise ValueError(f"Unknown input type: {input_path}")
 
     # Restore
-    indices = indices.unsqueeze(1).unsqueeze(-1)
-    mel_lengths = indices.shape[2] * (
-        model.downsample.total_strides if model.downsample is not None else 1
-    )
-    mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
-    mel_masks = torch.ones(
-        (1, 1, mel_lengths), device=model.device, dtype=torch.float32
-    )
-
-    text_features = model.vq_encoder.decode(indices)
-
-    logger.info(
-        f"VQ Encoded, indices: {indices.shape} equivalent to "
-        + f"{1/(mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
-    )
-
-    text_features = F.interpolate(text_features, size=mel_lengths[0], mode="nearest")
+    feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
+    decoded = model.decode(indices=indices[None], feature_lengths=feature_lengths)
+    fake_audios = decoded.audios
+    audio_time = fake_audios.shape[-1] / model.sampling_rate
 
-    # Sample mels
-    decoded_mels = model.decoder(text_features, mel_masks)
-    fake_audios = model.generator(decoded_mels)
     logger.info(
-        f"Generated audio of shape {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
+        f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
     )
 
     # Save audio