Просмотр исходного кода

Merge VQGAN v2 to dev (#56)

* squash vqgan v2 changes

* Merge pretrain stage 1 and 2

* Optimize vqgan inference (remove redundant code)

* Implement data mixing

* Optimize vqgan v2 config

* Add support to freeze discriminator

* Add stft loss & larger segement size
Leng Yue 2 лет назад
Родитель
Сommit
1609e9bad4

+ 1 - 0
.pre-commit-config.yaml

@@ -18,6 +18,7 @@ repos:
     hooks:
     hooks:
       - id: codespell
       - id: codespell
         files: ^.*\.(py|md|rst|yml)$
         files: ^.*\.(py|md|rst|yml)$
+        args: [-L=fro]
 
 
   - repo: https://github.com/pre-commit/pre-commit-hooks
   - repo: https://github.com/pre-commit/pre-commit-hooks
     rev: v4.5.0
     rev: v4.5.0

+ 59 - 25
fish_speech/configs/vqgan_pretrain_v2.yaml

@@ -3,6 +3,8 @@ defaults:
   - _self_
   - _self_
 
 
 project: vqgan_pretrain_v2
 project: vqgan_pretrain_v2
+ckpt_path: checkpoints/hifigan-base-comb-mix-lb-020/step_001200000_weights_only.ckpt
+resume_weights_only: true
 
 
 # Lightning Trainer
 # Lightning Trainer
 trainer:
 trainer:
@@ -15,22 +17,36 @@ trainer:
 
 
 sample_rate: 44100
 sample_rate: 44100
 hop_length: 512
 hop_length: 512
-num_mels: 128
+num_mels: 160
 n_fft: 2048
 n_fft: 2048
 win_length: 2048
 win_length: 2048
 segment_size: 256
 segment_size: 256
 
 
 # Dataset Configuration
 # Dataset Configuration
 train_dataset:
 train_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/Genshin/vq_train_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  slice_frames: ${segment_size}
+  _target_: fish_speech.datasets.vqgan.MixDatast
+  datasets:
+    high-quality-441:
+      prob: 0.5
+      dataset:
+        _target_: fish_speech.datasets.vqgan.VQGANDataset
+        filelist: data/vocoder_data_441/vq_train_filelist.txt
+        sample_rate: ${sample_rate}
+        hop_length: ${hop_length}
+        slice_frames: ${segment_size}
+    
+    common-voice:
+      prob: 0.5
+      dataset:
+        _target_: fish_speech.datasets.vqgan.VQGANDataset
+        filelist: data/cv-corpus-16.0-2023-12-06/vq_train_filelist.txt
+        sample_rate: ${sample_rate}
+        hop_length: ${hop_length}
+        slice_frames: ${segment_size}
 
 
 val_dataset:
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/Genshin/vq_val_filelist.txt
+  filelist: data/vocoder_data_441/vq_val_filelist.txt
   sample_rate: ${sample_rate}
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   hop_length: ${hop_length}
 
 
@@ -47,8 +63,9 @@ model:
   _target_: fish_speech.models.vqgan.VQGAN
   _target_: fish_speech.models.vqgan.VQGAN
   sample_rate: ${sample_rate}
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   hop_length: ${hop_length}
-  segment_size: 8192
-  mode: pretrain-stage1
+  segment_size: 32768
+  mode: pretrain
+  freeze_discriminator: true
 
 
   downsample:
   downsample:
     _target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
     _target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
@@ -67,8 +84,8 @@ model:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
     in_channels: 256
     in_channels: 256
     vq_channels: 256
     vq_channels: 256
-    codebook_size: 1024
-    codebook_layers: 4
+    codebook_size: 256
+    codebook_groups: 4
     downsample: 1
     downsample: 1
 
 
   decoder:
   decoder:
@@ -80,19 +97,38 @@ model:
     n_layers: 6
     n_layers: 6
 
 
   generator:
   generator:
-    _target_: fish_speech.models.vqgan.modules.decoder.Generator
-    initial_channel: ${num_mels}
-    resblock: "1"
+    _target_: fish_speech.models.vqgan.modules.decoder_v2.HiFiGANGenerator
+    hop_length: ${hop_length}
+    upsample_rates: [8, 8, 2, 2, 2]  # aka. strides
+    upsample_kernel_sizes: [16, 16, 4, 4, 4]
     resblock_kernel_sizes: [3, 7, 11]
     resblock_kernel_sizes: [3, 7, 11]
     resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
     resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    upsample_rates: [8, 8, 2, 2, 2]
+    num_mels: ${num_mels}
     upsample_initial_channel: 512
     upsample_initial_channel: 512
-    upsample_kernel_sizes: [16, 16, 4, 4, 4]
-
-  discriminator:
-    _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
-    periods: [2, 3, 5, 7, 11, 17, 23, 37]
-
+    use_template: true
+    pre_conv_kernel_size: 7
+    post_conv_kernel_size: 7
+
+  discriminators:
+    _target_: torch.nn.ModuleDict
+    modules:
+      mpd:
+        _target_: fish_speech.models.vqgan.modules.discriminators.mpd.MultiPeriodDiscriminator
+        periods: [2, 3, 5, 7, 11, 17, 23, 37]
+
+      mrd:
+        _target_: fish_speech.models.vqgan.modules.discriminators.mrd.MultiResolutionDiscriminator
+        resolutions:
+          - ["${n_fft}", "${hop_length}", "${win_length}"]
+          - [1024, 120, 600]
+          - [2048, 240, 1200]
+          - [4096, 480, 2400]
+          - [512, 50, 240]
+
+  multi_resolution_stft_loss:
+    _target_: fish_speech.models.vqgan.losses.MultiResolutionSTFTLoss
+    resolutions: ${model.discriminators.modules.mrd.resolutions}
+  
   mel_transform:
   mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     sample_rate: ${sample_rate}
     sample_rate: ${sample_rate}
@@ -100,13 +136,11 @@ model:
     hop_length: ${hop_length}
     hop_length: ${hop_length}
     win_length: ${win_length}
     win_length: ${win_length}
     n_mels: ${num_mels}
     n_mels: ${num_mels}
-    f_min: 0
-    f_max: 16000
 
 
   optimizer:
   optimizer:
     _target_: torch.optim.AdamW
     _target_: torch.optim.AdamW
     _partial_: true
     _partial_: true
-    lr: 2e-4
+    lr: 1e-4
     betas: [0.8, 0.99]
     betas: [0.8, 0.99]
     eps: 1e-5
     eps: 1e-5
 
 
@@ -119,7 +153,7 @@ callbacks:
   grad_norm_monitor:
   grad_norm_monitor:
     sub_module: 
     sub_module: 
       - generator
       - generator
-      - discriminator
+      - discriminators
       - mel_encoder
       - mel_encoder
       - vq_encoder
       - vq_encoder
       - decoder
       - decoder

+ 29 - 2
fish_speech/datasets/vqgan.py

@@ -6,7 +6,7 @@ import librosa
 import numpy as np
 import numpy as np
 import torch
 import torch
 from lightning import LightningDataModule
 from lightning import LightningDataModule
-from torch.utils.data import DataLoader, Dataset
+from torch.utils.data import DataLoader, Dataset, IterableDataset
 
 
 from fish_speech.utils import RankedLogger
 from fish_speech.utils import RankedLogger
 
 
@@ -72,6 +72,33 @@ class VQGANDataset(Dataset):
             return None
             return None
 
 
 
 
+class MixDatast(IterableDataset):
+    def __init__(self, datasets: dict[str, dict], seed: int = 42) -> None:
+        values = list(datasets.values())
+        probs = [v["prob"] for v in values]
+        self.datasets = [v["dataset"] for v in values]
+
+        total_probs = sum(probs)
+        self.probs = [p / total_probs for p in probs]
+        self.seed = seed
+
+    def __iter__(self):
+        rng = np.random.default_rng(self.seed)
+        dataset_iterators = [iter(dataset) for dataset in self.datasets]
+
+        while True:
+            # Random choice one
+            dataset_idx = rng.choice(len(self.datasets), p=self.probs)
+            dataset_iterator = dataset_iterators[dataset_idx]
+
+            try:
+                yield next(dataset_iterator)
+            except StopIteration:
+                # Exhausted, create a new iterator
+                dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
+                yield next(dataset_iterators[dataset_idx])
+
+
 @dataclass
 @dataclass
 class VQGANCollator:
 class VQGANCollator:
     def __call__(self, batch):
     def __call__(self, batch):
@@ -116,7 +143,7 @@ class VQGANDataModule(LightningDataModule):
             batch_size=self.batch_size,
             batch_size=self.batch_size,
             collate_fn=VQGANCollator(),
             collate_fn=VQGANCollator(),
             num_workers=self.num_workers,
             num_workers=self.num_workers,
-            shuffle=True,
+            shuffle=not isinstance(self.train_dataset, IterableDataset),
         )
         )
 
 
     def val_dataloader(self):
     def val_dataloader(self):

+ 298 - 151
fish_speech/models/vqgan/lit_module.py

@@ -1,5 +1,6 @@
 import itertools
 import itertools
-from typing import Any, Callable, Literal
+from dataclasses import dataclass
+from typing import Any, Callable, Literal, Optional
 
 
 import lightning as L
 import lightning as L
 import torch
 import torch
@@ -8,19 +9,17 @@ import wandb
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from matplotlib import pyplot as plt
 from torch import nn
 from torch import nn
-from vector_quantize_pytorch import VectorQuantize
 
 
 from fish_speech.models.vqgan.losses import (
 from fish_speech.models.vqgan.losses import (
+    MultiResolutionSTFTLoss,
     discriminator_loss,
     discriminator_loss,
     feature_loss,
     feature_loss,
     generator_loss,
     generator_loss,
-    kl_loss,
 )
 )
+from fish_speech.models.vqgan.modules.balancer import Balancer
 from fish_speech.models.vqgan.modules.decoder import Generator
 from fish_speech.models.vqgan.modules.decoder import Generator
-from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
 from fish_speech.models.vqgan.modules.encoders import (
 from fish_speech.models.vqgan.modules.encoders import (
     ConvDownSampler,
     ConvDownSampler,
-    SpeakerEncoder,
     TextEncoder,
     TextEncoder,
     VQEncoder,
     VQEncoder,
 )
 )
@@ -32,6 +31,21 @@ from fish_speech.models.vqgan.utils import (
 )
 )
 
 
 
 
+@dataclass
+class VQEncodeResult:
+    features: torch.Tensor
+    indices: torch.Tensor
+    loss: torch.Tensor
+    feature_lengths: torch.Tensor
+
+
+@dataclass
+class VQDecodeResult:
+    audios: torch.Tensor
+    mels: torch.Tensor
+    mel_lengths: torch.Tensor
+
+
 class VQGAN(L.LightningModule):
 class VQGAN(L.LightningModule):
     def __init__(
     def __init__(
         self,
         self,
@@ -42,18 +56,18 @@ class VQGAN(L.LightningModule):
         mel_encoder: TextEncoder,
         mel_encoder: TextEncoder,
         decoder: TextEncoder,
         decoder: TextEncoder,
         generator: Generator,
         generator: Generator,
-        discriminator: EnsembleDiscriminator,
+        discriminators: nn.ModuleDict,
         mel_transform: nn.Module,
         mel_transform: nn.Module,
         segment_size: int = 20480,
         segment_size: int = 20480,
         hop_length: int = 640,
         hop_length: int = 640,
         sample_rate: int = 32000,
         sample_rate: int = 32000,
-        mode: Literal["pretrain-stage1", "pretrain-stage2", "finetune"] = "finetune",
-        speaker_encoder: SpeakerEncoder = None,
+        mode: Literal["pretrain", "finetune"] = "finetune",
+        freeze_discriminator: bool = False,
+        multi_resolution_stft_loss: Optional[MultiResolutionSTFTLoss] = None,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
-        # pretrain-stage1: vq use gt mel as target, hifigan use gt mel as input
-        # pretrain-stage2: end-to-end training, use gt mel as hifi gan target
+        # pretrain: vq use gt mel as target, hifigan use gt mel as input
         # finetune: end-to-end training, use gt mel as hifi gan target but freeze vq
         # finetune: end-to-end training, use gt mel as hifi gan target but freeze vq
 
 
         # Model parameters
         # Model parameters
@@ -64,11 +78,11 @@ class VQGAN(L.LightningModule):
         self.downsample = downsample
         self.downsample = downsample
         self.vq_encoder = vq_encoder
         self.vq_encoder = vq_encoder
         self.mel_encoder = mel_encoder
         self.mel_encoder = mel_encoder
-        self.speaker_encoder = speaker_encoder
         self.decoder = decoder
         self.decoder = decoder
         self.generator = generator
         self.generator = generator
-        self.discriminator = discriminator
+        self.discriminators = discriminators
         self.mel_transform = mel_transform
         self.mel_transform = mel_transform
+        self.freeze_discriminator = freeze_discriminator
 
 
         # Crop length for saving memory
         # Crop length for saving memory
         self.segment_size = segment_size
         self.segment_size = segment_size
@@ -90,20 +104,30 @@ class VQGAN(L.LightningModule):
             for p in self.downsample.parameters():
             for p in self.downsample.parameters():
                 p.requires_grad = False
                 p.requires_grad = False
 
 
+        if self.freeze_discriminator:
+            for p in self.discriminators.parameters():
+                p.requires_grad = False
+
+        # Losses
+        self.multi_resolution_stft_loss = multi_resolution_stft_loss
+        loss_dict = {
+            "mel": 1,
+            "adv": 1,
+            "fm": 1,
+        }
+
+        if self.multi_resolution_stft_loss is not None:
+            loss_dict["stft"] = 1
+
+        self.balancer = Balancer(loss_dict)
+
     def configure_optimizers(self):
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
         # Need two optimizers and two schedulers
-        components = []
-        if self.mode != "finetune":
-            components.extend(
-                [
-                    self.downsample.parameters(),
-                    self.vq_encoder.parameters(),
-                    self.mel_encoder.parameters(),
-                ]
-            )
-
-        if self.speaker_encoder is not None:
-            components.append(self.speaker_encoder.parameters())
+        components = [
+            self.downsample.parameters(),
+            self.vq_encoder.parameters(),
+            self.mel_encoder.parameters(),
+        ]
 
 
         if self.decoder is not None:
         if self.decoder is not None:
             components.append(self.decoder.parameters())
             components.append(self.decoder.parameters())
@@ -111,7 +135,7 @@ class VQGAN(L.LightningModule):
         components.append(self.generator.parameters())
         components.append(self.generator.parameters())
         optimizer_generator = self.optimizer_builder(itertools.chain(*components))
         optimizer_generator = self.optimizer_builder(itertools.chain(*components))
         optimizer_discriminator = self.optimizer_builder(
         optimizer_discriminator = self.optimizer_builder(
-            self.discriminator.parameters()
+            self.discriminators.parameters()
         )
         )
 
 
         lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
         lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
@@ -145,9 +169,7 @@ class VQGAN(L.LightningModule):
         audios = audios[:, None, :]
         audios = audios[:, None, :]
 
 
         with torch.no_grad():
         with torch.no_grad():
-            features = gt_mels = self.mel_transform(
-                audios, sample_rate=self.sampling_rate
-            )
+            gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
 
 
         if self.mode == "finetune":
         if self.mode == "finetune":
             # Disable gradient computation for VQ
             # Disable gradient computation for VQ
@@ -156,29 +178,13 @@ class VQGAN(L.LightningModule):
             self.mel_encoder.eval()
             self.mel_encoder.eval()
             self.downsample.eval()
             self.downsample.eval()
 
 
-        if self.downsample is not None:
-            features = self.downsample(features)
-
         mel_lengths = audio_lengths // self.hop_length
         mel_lengths = audio_lengths // self.hop_length
-        feature_lengths = (
-            audio_lengths
-            / self.hop_length
-            / (self.downsample.total_strides if self.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
             gt_mels.dtype
             gt_mels.dtype
         )
         )
 
 
-        # vq_features is 50 hz, need to convert to true mel size
-        text_features = self.mel_encoder(features, feature_masks)
-        text_features, _, loss_vq = self.vq_encoder(text_features, feature_masks)
-        text_features = F.interpolate(
-            text_features, size=gt_mels.shape[2], mode="nearest"
-        )
+        vq_result = self.encode(audios, audio_lengths)
+        loss_vq = vq_result.loss
 
 
         if loss_vq.ndim > 1:
         if loss_vq.ndim > 1:
             loss_vq = loss_vq.mean()
             loss_vq = loss_vq.mean()
@@ -187,18 +193,15 @@ class VQGAN(L.LightningModule):
             # Enable gradient computation
             # Enable gradient computation
             torch.set_grad_enabled(True)
             torch.set_grad_enabled(True)
 
 
-        # Sample mels
-        if self.decoder is not None:
-            speaker_features = (
-                self.speaker_encoder(gt_mels, mel_masks)
-                if self.speaker_encoder is not None
-                else None
-            )
-            decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
-        else:
-            decoded_mels = text_features
+        decoded = self.decode(
+            indices=vq_result.indices if self.mode == "finetune" else None,
+            features=vq_result.features if self.mode == "pretrain" else None,
+            audio_lengths=audio_lengths,
+            mel_only=True,
+        )
+        decoded_mels = decoded.mels
+        input_mels = gt_mels if self.mode == "pretrain" else decoded_mels
 
 
-        input_mels = gt_mels if self.mode == "pretrain-stage1" else decoded_mels
         if self.segment_size is not None:
         if self.segment_size is not None:
             audios, ids_slice = rand_slice_segments(
             audios, ids_slice = rand_slice_segments(
                 audios, audio_lengths, self.segment_size
                 audios, audio_lengths, self.segment_size
@@ -228,75 +231,145 @@ class VQGAN(L.LightningModule):
             audios.shape == fake_audios.shape
             audios.shape == fake_audios.shape
         ), f"{audios.shape} != {fake_audios.shape}"
         ), f"{audios.shape} != {fake_audios.shape}"
 
 
+        # Multi-Resolution STFT Loss
+        if self.multi_resolution_stft_loss is not None:
+            with torch.autocast(device_type=audios.device.type, enabled=False):
+                sc_loss, mag_loss = self.multi_resolution_stft_loss(
+                    fake_audios.squeeze(1).float(), audios.squeeze(1).float()
+                )
+                loss_stft = sc_loss + mag_loss
+
         # Discriminator
         # Discriminator
-        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios.detach())
+        if self.freeze_discriminator is False:
+            loss_disc_all = []
+
+            for key, disc in self.discriminators.items():
+                scores, _ = disc(audios)
+                score_fakes, _ = disc(fake_audios.detach())
+
+                with torch.autocast(device_type=audios.device.type, enabled=False):
+                    loss_disc, _, _ = discriminator_loss(scores, score_fakes)
+
+                self.log(
+                    f"train/discriminator/{key}",
+                    loss_disc,
+                    on_step=True,
+                    on_epoch=False,
+                    prog_bar=False,
+                    logger=True,
+                    sync_dist=True,
+                )
 
 
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
+                loss_disc_all.append(loss_disc)
 
 
-        self.log(
-            "train/discriminator/loss",
-            loss_disc_all,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
+            loss_disc_all = torch.stack(loss_disc_all).mean()
 
 
-        optim_d.zero_grad()
-        self.manual_backward(loss_disc_all)
-        self.clip_gradients(
-            optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
-        )
-        optim_d.step()
+            self.log(
+                "train/discriminator/loss",
+                loss_disc_all,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=True,
+                logger=True,
+                sync_dist=True,
+            )
+
+            optim_d.zero_grad()
+            self.manual_backward(loss_disc_all)
+            self.clip_gradients(
+                optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+            )
+            optim_d.step()
+
+        # Adv Loss
+        loss_adv_all = []
+        loss_fm_all = []
+
+        for key, disc in self.discriminators.items():
+            score_fakes, feat_fake = disc(fake_audios)
+
+            # Adversarial Loss
+            with torch.autocast(device_type=audios.device.type, enabled=False):
+                loss_fake, _ = generator_loss(score_fakes)
+
+            self.log(
+                f"train/generator/adv_{key}",
+                loss_fake,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
+
+            loss_adv_all.append(loss_fake)
+
+            # Feature Matching Loss
+            _, feat_real = disc(audios)
+
+            with torch.autocast(device_type=audios.device.type, enabled=False):
+                loss_fm = feature_loss(feat_real, feat_fake)
 
 
-        y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(audios, fake_audios)
+            self.log(
+                f"train/generator/adv_fm_{key}",
+                loss_fm,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
+
+            loss_fm_all.append(loss_fm)
+
+        loss_adv_all = torch.stack(loss_adv_all).mean()
+        loss_fm_all = torch.stack(loss_fm_all).mean()
 
 
         with torch.autocast(device_type=audios.device.type, enabled=False):
         with torch.autocast(device_type=audios.device.type, enabled=False):
             loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
             loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
             loss_mel = F.l1_loss(
             loss_mel = F.l1_loss(
                 sliced_gt_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
                 sliced_gt_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
             )
             )
-            loss_adv, _ = generator_loss(y_d_hat_g)
-            loss_fm = feature_loss(fmap_r, fmap_g)
 
 
-            if self.mode == "pretrain-stage1":
+            loss_dict = {
+                "mel": loss_mel,
+                "adv": loss_adv_all,
+                "fm": loss_fm_all,
+            }
+
+            if self.multi_resolution_stft_loss is not None:
+                loss_dict["stft"] = loss_stft
+
+            generator_out_grad = self.balancer.compute(
+                loss_dict,
+                fake_audios,
+            )
+
+            if self.mode == "pretrain":
                 loss_vq_all = loss_decoded_mel + loss_vq
                 loss_vq_all = loss_decoded_mel + loss_vq
-                loss_gen_all = loss_mel * 45 + loss_fm + loss_adv
-            else:
-                loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
 
 
-        self.log(
-            "train/generator/loss_gen_all",
-            loss_gen_all,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
+        # Loss vq and loss decoded mel are only used in pretrain stage
+        if self.mode == "pretrain":
+            self.log(
+                "train/generator/loss_vq",
+                loss_vq,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
 
 
-        if self.mode == "pretrain-stage1":
             self.log(
             self.log(
-                "train/generator/loss_vq_all",
-                loss_vq_all,
+                "train/generator/loss_decoded_mel",
+                loss_decoded_mel,
                 on_step=True,
                 on_step=True,
                 on_epoch=False,
                 on_epoch=False,
-                prog_bar=True,
+                prog_bar=False,
                 logger=True,
                 logger=True,
                 sync_dist=True,
                 sync_dist=True,
             )
             )
 
 
-        self.log(
-            "train/generator/loss_decoded_mel",
-            loss_decoded_mel,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
         self.log(
         self.log(
             "train/generator/loss_mel",
             "train/generator/loss_mel",
             loss_mel,
             loss_mel,
@@ -306,18 +379,21 @@ class VQGAN(L.LightningModule):
             logger=True,
             logger=True,
             sync_dist=True,
             sync_dist=True,
         )
         )
+
+        if self.multi_resolution_stft_loss is not None:
+            self.log(
+                "train/generator/loss_stft",
+                loss_stft,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
+
         self.log(
         self.log(
-            "train/generator/loss_fm",
-            loss_fm,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-        self.log(
-            "train/generator/loss_adv",
-            loss_adv,
+            "train/generator/loss_fm_all",
+            loss_fm_all,
             on_step=True,
             on_step=True,
             on_epoch=False,
             on_epoch=False,
             prog_bar=False,
             prog_bar=False,
@@ -325,8 +401,8 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
             sync_dist=True,
         )
         )
         self.log(
         self.log(
-            "train/generator/loss_vq",
-            loss_vq,
+            "train/generator/loss_adv_all",
+            loss_adv_all,
             on_step=True,
             on_step=True,
             on_epoch=False,
             on_epoch=False,
             prog_bar=False,
             prog_bar=False,
@@ -336,11 +412,11 @@ class VQGAN(L.LightningModule):
 
 
         optim_g.zero_grad()
         optim_g.zero_grad()
 
 
-        # Only backpropagate loss_vq_all in pretrain-stage1
-        if self.mode == "pretrain-stage1":
-            self.manual_backward(loss_vq_all)
+        # Only backpropagate loss_vq_all in pretrain stage
+        if self.mode == "pretrain":
+            self.manual_backward(loss_vq_all, retain_graph=True)
 
 
-        self.manual_backward(loss_gen_all)
+        self.manual_backward(fake_audios, gradient=generator_out_grad)
         self.clip_gradients(
         self.clip_gradients(
             optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
             optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
         )
         )
@@ -357,44 +433,26 @@ class VQGAN(L.LightningModule):
         audios = audios.float()
         audios = audios.float()
         audios = audios[:, None, :]
         audios = audios[:, None, :]
 
 
-        features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
-
-        if self.downsample is not None:
-            features = self.downsample(features)
-
+        gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
         mel_lengths = audio_lengths // self.hop_length
         mel_lengths = audio_lengths // self.hop_length
-        feature_lengths = (
-            audio_lengths
-            / self.hop_length
-            / (self.downsample.total_strides if self.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
             gt_mels.dtype
             gt_mels.dtype
         )
         )
 
 
-        # vq_features is 50 hz, need to convert to true mel size
-        text_features = self.mel_encoder(features, feature_masks)
-        text_features, _, _ = self.vq_encoder(text_features, feature_masks)
-        text_features = F.interpolate(
-            text_features, size=gt_mels.shape[2], mode="nearest"
+        vq_result = self.encode(audios, audio_lengths)
+        decoded = self.decode(
+            indices=vq_result.indices,
+            audio_lengths=audio_lengths,
+            mel_only=self.mode == "pretrain",
         )
         )
 
 
-        # Sample mels
-        if self.decoder is not None:
-            speaker_features = (
-                self.speaker_encoder(gt_mels, mel_masks)
-                if self.speaker_encoder is not None
-                else None
-            )
-            decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
-        else:
-            decoded_mels = text_features
+        decoded_mels = decoded.mels
 
 
-        fake_audios = self.generator(decoded_mels)
+        # Use gt mel as input for pretrain
+        if self.mode == "pretrain":
+            fake_audios = self.generator(gt_mels)
+        else:
+            fake_audios = decoded.audios
 
 
         fake_mels = self.mel_transform(fake_audios.squeeze(1))
         fake_mels = self.mel_transform(fake_audios.squeeze(1))
 
 
@@ -487,3 +545,92 @@ class VQGAN(L.LightningModule):
                 )
                 )
 
 
             plt.close(image_mels)
             plt.close(image_mels)
+
+    def encode(self, audios, audio_lengths=None):
+        if audio_lengths is None:
+            audio_lengths = torch.tensor(
+                [audios.shape[-1]] * audios.shape[0],
+                device=audios.device,
+                dtype=torch.long,
+            )
+
+        with torch.no_grad():
+            features = self.mel_transform(audios, sample_rate=self.sampling_rate)
+
+        if self.downsample is not None:
+            features = self.downsample(features)
+
+        feature_lengths = (
+            audio_lengths
+            / self.hop_length
+            / (self.downsample.total_strides if self.downsample is not None else 1)
+        ).long()
+
+        feature_masks = torch.unsqueeze(
+            sequence_mask(feature_lengths, features.shape[2]), 1
+        ).to(features.dtype)
+
+        text_features = self.mel_encoder(features, feature_masks)
+        vq_features, indices, loss = self.vq_encoder(text_features, feature_masks)
+
+        return VQEncodeResult(
+            features=vq_features,
+            indices=indices,
+            loss=loss,
+            feature_lengths=feature_lengths,
+        )
+
+    def calculate_audio_lengths(self, feature_lengths):
+        return (
+            feature_lengths
+            * self.hop_length
+            * (self.downsample.total_strides if self.downsample is not None else 1)
+        )
+
+    def decode(
+        self,
+        indices=None,
+        features=None,
+        audio_lengths=None,
+        mel_only=False,
+        feature_lengths=None,
+    ):
+        assert (
+            indices is not None or features is not None
+        ), "indices or features must be provided"
+        assert (
+            feature_lengths is not None or audio_lengths is not None
+        ), "feature_lengths or audio_lengths must be provided"
+
+        if audio_lengths is None:
+            audio_lengths = self.calculate_audio_lengths(feature_lengths)
+
+        mel_lengths = audio_lengths // self.hop_length
+        mel_masks = torch.unsqueeze(
+            sequence_mask(mel_lengths, torch.max(mel_lengths)), 1
+        ).float()
+
+        if indices is not None:
+            features = self.vq_encoder.decode(indices)
+
+        features = F.interpolate(features, size=mel_masks.shape[2], mode="nearest")
+
+        # Sample mels
+        if self.decoder is not None:
+            decoded_mels = self.decoder(features, mel_masks)
+        else:
+            decoded_mels = features
+
+        if mel_only:
+            return VQDecodeResult(
+                audios=None,
+                mels=decoded_mels,
+                mel_lengths=mel_lengths,
+            )
+
+        fake_audios = self.generator(decoded_mels)
+        return VQDecodeResult(
+            audios=fake_audios,
+            mels=decoded_mels,
+            mel_lengths=mel_lengths,
+        )

+ 135 - 5
fish_speech/models/vqgan/losses.py

@@ -1,9 +1,9 @@
-from typing import List
-
 import torch
 import torch
+import torch.nn.functional as F
+from torch import nn
 
 
 
 
-def feature_loss(fmap_r: List[torch.Tensor], fmap_g: List[torch.Tensor]):
+def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]):
     loss = 0
     loss = 0
     for dr, dg in zip(fmap_r, fmap_g):
     for dr, dg in zip(fmap_r, fmap_g):
         for rl, gl in zip(dr, dg):
         for rl, gl in zip(dr, dg):
@@ -15,7 +15,7 @@ def feature_loss(fmap_r: List[torch.Tensor], fmap_g: List[torch.Tensor]):
 
 
 
 
 def discriminator_loss(
 def discriminator_loss(
-    disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
+    disc_real_outputs: list[torch.Tensor], disc_generated_outputs: list[torch.Tensor]
 ):
 ):
     loss = 0
     loss = 0
     r_losses = []
     r_losses = []
@@ -32,7 +32,7 @@ def discriminator_loss(
     return loss, r_losses, g_losses
     return loss, r_losses, g_losses
 
 
 
 
-def generator_loss(disc_outputs: List[torch.Tensor]):
+def generator_loss(disc_outputs: list[torch.Tensor]):
     loss = 0
     loss = 0
     gen_losses = []
     gen_losses = []
     for dg in disc_outputs:
     for dg in disc_outputs:
@@ -66,3 +66,133 @@ def kl_loss(
     kl = torch.sum(kl * z_mask)
     kl = torch.sum(kl * z_mask)
     l = kl / torch.sum(z_mask)
     l = kl / torch.sum(z_mask)
     return l
     return l
+
+
+def stft(x, fft_size, hop_size, win_length, window):
+    """Perform STFT and convert to magnitude spectrogram.
+    Args:
+        x (Tensor): Input signal tensor (B, T).
+        fft_size (int): FFT size.
+        hop_size (int): Hop size.
+        win_length (int): Window length.
+        window (str): Window function type.
+    Returns:
+        Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
+    """
+    spec = torch.stft(
+        x,
+        fft_size,
+        hop_size,
+        win_length,
+        window,
+        return_complex=True,
+        pad_mode="reflect",
+    )
+    spec = torch.view_as_real(spec)
+
+    # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
+    return torch.sqrt(torch.clamp(spec.pow(2).sum(-1), min=1e-6)).transpose(2, 1)
+
+
+class SpectralConvergengeLoss(nn.Module):
+    """Spectral convergence loss module."""
+
+    def __init__(self):
+        """Initialize spectral convergence loss module."""
+        super(SpectralConvergengeLoss, self).__init__()
+
+    def forward(self, x_mag, y_mag):
+        """Calculate forward propagation.
+        Args:
+            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+        Returns:
+            Tensor: Spectral convergence loss value.
+        """  # noqa: E501
+
+        return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
+
+
+class LogSTFTMagnitudeLoss(nn.Module):
+    """Log STFT magnitude loss module."""
+
+    def __init__(self):
+        """Initialize los STFT magnitude loss module."""
+        super(LogSTFTMagnitudeLoss, self).__init__()
+
+    def forward(self, x_mag, y_mag):
+        """Calculate forward propagation.
+        Args:
+            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+        Returns:
+            Tensor: Log STFT magnitude loss value.
+        """  # noqa: E501
+
+        return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
+
+
+class STFTLoss(nn.Module):
+    """STFT loss module."""
+
+    def __init__(
+        self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window
+    ):
+        """Initialize STFT loss module."""
+        super(STFTLoss, self).__init__()
+
+        self.fft_size = fft_size
+        self.shift_size = shift_size
+        self.win_length = win_length
+        self.register_buffer("window", window(win_length))
+        self.spectral_convergenge_loss = SpectralConvergengeLoss()
+        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+        Args:
+            x (Tensor): Predicted signal (B, T).
+            y (Tensor): Groundtruth signal (B, T).
+        Returns:
+            Tensor: Spectral convergence loss value.
+            Tensor: Log STFT magnitude loss value.
+        """
+
+        x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
+        y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
+        sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+        mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+        return sc_loss, mag_loss
+
+
+class MultiResolutionSTFTLoss(nn.Module):
+    """Multi resolution STFT loss module."""
+
+    def __init__(self, resolutions, window=torch.hann_window):
+        super(MultiResolutionSTFTLoss, self).__init__()
+
+        self.stft_losses = nn.ModuleList()
+        for fs, ss, wl in resolutions:
+            self.stft_losses += [STFTLoss(fs, ss, wl, window)]
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+        Args:
+            x (Tensor): Predicted signal (B, T).
+            y (Tensor): Groundtruth signal (B, T).
+        Returns:
+            Tensor: Multi resolution spectral convergence loss value.
+            Tensor: Multi resolution log STFT magnitude loss value.
+        """
+        sc_loss = 0.0
+        mag_loss = 0.0
+        for f in self.stft_losses:
+            sc_l, mag_l = f(x, y)
+            sc_loss += sc_l
+            mag_loss += mag_l
+
+        sc_loss /= len(self.stft_losses)
+        mag_loss /= len(self.stft_losses)
+
+        return sc_loss, mag_loss

+ 193 - 0
fish_speech/models/vqgan/modules/balancer.py

@@ -0,0 +1,193 @@
+import typing as tp
+from collections import defaultdict
+
+import torch
+from torch import autograd
+
+
+def rank():
+    if torch.distributed.is_initialized():
+        return torch.distributed.get_rank()
+    else:
+        return 0
+
+
+def world_size():
+    if torch.distributed.is_initialized():
+        return torch.distributed.get_world_size()
+    else:
+        return 1
+
+
+def is_distributed():
+    return world_size() > 1
+
+
+def average_metrics(metrics: tp.Dict[str, float], count=1.0):
+    """Average a dictionary of metrics across all workers, using the optional
+    `count` as unnormalized weight.
+    """
+    if not is_distributed():
+        return metrics
+    keys, values = zip(*metrics.items())
+    device = "cuda" if torch.cuda.is_available() else "cpu"
+    tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
+    tensor *= count
+    all_reduce(tensor)
+    averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
+    return dict(zip(keys, averaged))
+
+
+def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
+    if is_distributed():
+        return torch.distributed.all_reduce(tensor, op)
+
+
+def averager(beta: float = 1):
+    """
+    Exponential Moving Average callback.
+    Returns a single function that can be called to repeatidly update the EMA
+    with a dict of metrics. The callback will return
+    the new averaged dict of metrics.
+
+    Note that for `beta=1`, this is just plain averaging.
+    """
+    fix: tp.Dict[str, float] = defaultdict(float)
+    total: tp.Dict[str, float] = defaultdict(float)
+
+    def _update(
+        metrics: tp.Dict[str, tp.Any], weight: float = 1
+    ) -> tp.Dict[str, float]:
+        nonlocal total, fix
+        for key, value in metrics.items():
+            total[key] = total[key] * beta + weight * float(value)
+            fix[key] = fix[key] * beta + weight
+        return {key: tot / fix[key] for key, tot in total.items()}
+
+    return _update
+
+
+class Balancer:
+    """Loss balancer.
+
+    The loss balancer combines losses together to compute gradients for the backward.
+    A call to the balancer will weight the losses according the specified weight coefficients.
+    A call to the backward method of the balancer will compute the gradients, combining all the losses and
+    potentially rescaling the gradients, which can help stabilize the training and reasonate
+    about multiple losses with varying scales.
+
+    Expected usage:
+        weights = {'loss_a': 1, 'loss_b': 4}
+        balancer = Balancer(weights, ...)
+        losses: dict = {}
+        losses['loss_a'] = compute_loss_a(x, y)
+        losses['loss_b'] = compute_loss_b(x, y)
+        if model.training():
+            balancer.backward(losses, x)
+
+    ..Warning:: It is unclear how this will interact with DistributedDataParallel,
+        in particular if you have some losses not handled by the balancer. In that case
+        you can use `encodec.distrib.sync_grad(model.parameters())` and
+        `encodec.distrib.sync_buffwers(model.buffers())` as a safe alternative.
+
+    Args:
+        weights (Dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
+            from the backward method to match the weights keys to assign weight to each of the provided loss.
+        rescale_grads (bool): Whether to rescale gradients or not, without. If False, this is just
+            a regular weighted sum of losses.
+        total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
+        emay_decay (float): EMA decay for averaging the norms when `rescale_grads` is True.
+        per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
+            when rescaling the gradients.
+        epsilon (float): Epsilon value for numerical stability.
+        monitor (bool): Whether to store additional ratio for each loss key in metrics.
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        weights: tp.Dict[str, float],
+        rescale_grads: bool = True,
+        total_norm: float = 1.0,
+        ema_decay: float = 0.999,
+        per_batch_item: bool = True,
+        epsilon: float = 1e-12,
+        monitor: bool = False,
+    ):
+        self.weights = weights
+        self.per_batch_item = per_batch_item
+        self.total_norm = total_norm
+        self.averager = averager(ema_decay)
+        self.epsilon = epsilon
+        self.monitor = monitor
+        self.rescale_grads = rescale_grads
+        self._metrics: tp.Dict[str, tp.Any] = {}
+
+    @property
+    def metrics(self):
+        return self._metrics
+
+    def compute(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor):
+        norms = {}
+        grads = {}
+        for name, loss in losses.items():
+            (grad,) = autograd.grad(loss, [input], retain_graph=True)
+            if self.per_batch_item:
+                dims = tuple(range(1, grad.dim()))
+                norm = grad.norm(dim=dims).mean()
+            else:
+                norm = grad.norm()
+            norms[name] = norm
+            grads[name] = grad
+
+        count = 1
+        if self.per_batch_item:
+            count = len(grad)
+        avg_norms = average_metrics(self.averager(norms), count)
+        total = sum(avg_norms.values())
+
+        self._metrics = {}
+        if self.monitor:
+            for k, v in avg_norms.items():
+                self._metrics[f"ratio_{k}"] = v / total
+
+        total_weights = sum([self.weights[k] for k in avg_norms])
+        ratios = {k: w / total_weights for k, w in self.weights.items()}
+
+        out_grad: tp.Any = 0
+        for name, avg_norm in avg_norms.items():
+            if self.rescale_grads:
+                scale = ratios[name] * self.total_norm / (self.epsilon + avg_norm)
+                grad = grads[name] * scale
+            else:
+                grad = self.weights[name] * grads[name]
+            out_grad += grad
+
+        return out_grad
+
+
+def test():
+    from torch.nn import functional as F
+
+    x = torch.zeros(1, requires_grad=True)
+    one = torch.ones_like(x)
+    loss_1 = F.l1_loss(x, one)
+    loss_2 = 100 * F.l1_loss(x, -one)
+    losses = {"1": loss_1, "2": loss_2}
+
+    balancer = Balancer(weights={"1": 1, "2": 1}, rescale_grads=False)
+    out_grad = balancer.compute(losses, x)
+    x.backward(out_grad)
+    assert torch.allclose(x.grad, torch.tensor(99.0)), x.grad
+
+    loss_1 = F.l1_loss(x, one)
+    loss_2 = 100 * F.l1_loss(x, -one)
+    losses = {"1": loss_1, "2": loss_2}
+    x.grad = None
+    balancer = Balancer(weights={"1": 1, "2": 1}, rescale_grads=True)
+    out_grad = balancer.compute({"1": loss_1, "2": loss_2}, x)
+    x.backward(out_grad)
+    assert torch.allclose(x.grad, torch.tensor(0.0)), x.grad
+
+
+if __name__ == "__main__":
+    test()

+ 270 - 0
fish_speech/models/vqgan/modules/decoder_v2.py

@@ -0,0 +1,270 @@
+from functools import partial
+from math import prod
+from typing import Callable
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Conv1d
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return (kernel_size * dilation - dilation) // 2
+
+
+class ResBlock(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super().__init__()
+
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.silu(x)
+            xt = c1(xt)
+            xt = F.silu(xt)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_parametrizations(self):
+        for conv in self.convs1:
+            remove_parametrizations(conv)
+        for conv in self.convs2:
+            remove_parametrizations(conv)
+
+
+class ParralelBlock(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        kernel_sizes: tuple[int] = (3, 7, 11),
+        dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+    ):
+        super().__init__()
+
+        assert len(kernel_sizes) == len(dilation_sizes)
+
+        self.blocks = nn.ModuleList()
+        for k, d in zip(kernel_sizes, dilation_sizes):
+            self.blocks.append(ResBlock(channels, k, d))
+
+    def forward(self, x):
+        xs = [block(x) for block in self.blocks]
+
+        return torch.stack(xs, dim=0).mean(dim=0)
+
+
+class HiFiGANGenerator(nn.Module):
+    def __init__(
+        self,
+        *,
+        hop_length: int = 512,
+        upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
+        upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
+        resblock_kernel_sizes: tuple[int] = (3, 7, 11),
+        resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+        num_mels: int = 128,
+        upsample_initial_channel: int = 512,
+        use_template: bool = True,
+        pre_conv_kernel_size: int = 7,
+        post_conv_kernel_size: int = 7,
+        post_activation: Callable = partial(nn.SiLU, inplace=True),
+        checkpointing: bool = False,
+    ):
+        super().__init__()
+
+        assert (
+            prod(upsample_rates) == hop_length
+        ), f"hop_length must be {prod(upsample_rates)}"
+
+        self.conv_pre = weight_norm(
+            nn.Conv1d(
+                num_mels,
+                upsample_initial_channel,
+                pre_conv_kernel_size,
+                1,
+                padding=get_padding(pre_conv_kernel_size),
+            )
+        )
+
+        self.hop_length = hop_length
+        self.num_upsamples = len(upsample_rates)
+        self.num_kernels = len(resblock_kernel_sizes)
+
+        self.noise_convs = nn.ModuleList()
+        self.use_template = use_template
+        self.ups = nn.ModuleList()
+
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            c_cur = upsample_initial_channel // (2 ** (i + 1))
+            self.ups.append(
+                weight_norm(
+                    nn.ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+            if not use_template:
+                continue
+
+            if i + 1 < len(upsample_rates):
+                stride_f0 = np.prod(upsample_rates[i + 1 :])
+                self.noise_convs.append(
+                    Conv1d(
+                        1,
+                        c_cur,
+                        kernel_size=stride_f0 * 2,
+                        stride=stride_f0,
+                        padding=stride_f0 // 2,
+                    )
+                )
+            else:
+                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            self.resblocks.append(
+                ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
+            )
+
+        self.activation_post = post_activation()
+        self.conv_post = weight_norm(
+            nn.Conv1d(
+                ch,
+                1,
+                post_conv_kernel_size,
+                1,
+                padding=get_padding(post_conv_kernel_size),
+            )
+        )
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+        # Gradient checkpointing
+        self.checkpointing = checkpointing
+
+    def forward(self, x, template=None):
+        if self.use_template and template is None:
+            length = x.shape[-1] * self.hop_length
+            template = (
+                torch.randn(x.shape[0], 1, length, device=x.device, dtype=x.dtype)
+                * 0.003
+            )
+
+        x = self.conv_pre(x)
+
+        for i in range(self.num_upsamples):
+            x = F.silu(x, inplace=True)
+            x = self.ups[i](x)
+
+            if self.use_template:
+                x = x + self.noise_convs[i](template)
+
+            if self.training and self.checkpointing:
+                x = torch.utils.checkpoint.checkpoint(
+                    self.resblocks[i],
+                    x,
+                    use_reentrant=False,
+                )
+            else:
+                x = self.resblocks[i](x)
+
+        x = self.activation_post(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_parametrizations(self):
+        for up in self.ups:
+            remove_parametrizations(up)
+        for block in self.resblocks:
+            block.remove_parametrizations()
+        remove_parametrizations(self.conv_pre)
+        remove_parametrizations(self.conv_post)

+ 0 - 166
fish_speech/models/vqgan/modules/discriminator.py

@@ -1,166 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn.utils import spectral_norm, weight_norm
-
-from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
-from fish_speech.models.vqgan.utils import get_padding
-
-
-class DiscriminatorP(nn.Module):
-    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
-        super(DiscriminatorP, self).__init__()
-        self.period = period
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(
-                    nn.Conv2d(
-                        1,
-                        32,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    nn.Conv2d(
-                        32,
-                        128,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    nn.Conv2d(
-                        128,
-                        512,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    nn.Conv2d(
-                        512,
-                        1024,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    nn.Conv2d(
-                        1024,
-                        1024,
-                        (kernel_size, 1),
-                        1,
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-            ]
-        )
-        self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
-
-    def forward(self, x):
-        fmap = []
-
-        # 1d to 2d
-        b, c, t = x.shape
-        if t % self.period != 0:  # pad first
-            n_pad = self.period - (t % self.period)
-            x = F.pad(x, (0, n_pad), "reflect")
-            t = t + n_pad
-        x = x.view(b, c, t // self.period, self.period)
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class DiscriminatorS(nn.Module):
-    def __init__(self, use_spectral_norm=False):
-        super(DiscriminatorS, self).__init__()
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
-                norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
-                norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
-                norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
-                norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
-                norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
-                norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
-            ]
-        )
-        self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
-
-    def forward(self, x):
-        fmap = []
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class EnsembleDiscriminator(nn.Module):
-    def __init__(self, ckpt_path=None, periods=(2, 3, 5, 7, 11)):
-        super(EnsembleDiscriminator, self).__init__()
-
-        discs = [DiscriminatorS(use_spectral_norm=True)]
-        discs = discs + [DiscriminatorP(i, use_spectral_norm=False) for i in periods]
-        self.discriminators = nn.ModuleList(discs)
-
-        if ckpt_path is not None:
-            self.restore_from_ckpt(ckpt_path)
-
-    def restore_from_ckpt(self, ckpt_path):
-        ckpt = torch.load(ckpt_path, map_location="cpu")
-        mpd, msd = ckpt["mpd"], ckpt["msd"]
-
-        all_keys = {}
-        for k, v in mpd.items():
-            keys = k.split(".")
-            keys[1] = str(int(keys[1]) + 1)
-            all_keys[".".join(keys)] = v
-
-        for k, v in msd.items():
-            if not k.startswith("discriminators.0"):
-                continue
-            all_keys[k] = v
-
-        self.load_state_dict(all_keys, strict=True)
-
-    def forward(self, y, y_hat):
-        y_d_rs = []
-        y_d_gs = []
-        fmap_rs = []
-        fmap_gs = []
-        for i, d in enumerate(self.discriminators):
-            y_d_r, fmap_r = d(y)
-            y_d_g, fmap_g = d(y_hat)
-            y_d_rs.append(y_d_r)
-            y_d_gs.append(y_d_g)
-            fmap_rs.append(fmap_r)
-            fmap_gs.append(fmap_g)
-
-        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
-
-
-if __name__ == "__main__":
-    m = EnsembleDiscriminator(
-        ckpt_path="checkpoints/hifigan-v1-universal-22050/do_02500000"
-    )

+ 80 - 0
fish_speech/models/vqgan/modules/discriminators/mpd.py

@@ -0,0 +1,80 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.parametrizations import weight_norm
+
+
+class DiscriminatorP(nn.Module):
+    def __init__(
+        self,
+        *,
+        period: int,
+        kernel_size: int = 5,
+        stride: int = 3,
+        channels: tuple[int] = (1, 64, 128, 256, 512, 1024),
+    ) -> None:
+        super(DiscriminatorP, self).__init__()
+
+        self.period = period
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    nn.Conv2d(
+                        in_channels,
+                        out_channels,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(kernel_size // 2, 0),
+                    )
+                )
+                for in_channels, out_channels in zip(channels[:-1], channels[1:])
+            ]
+        )
+
+        self.conv_post = weight_norm(
+            nn.Conv2d(channels[-1], 1, (3, 1), 1, padding=(1, 0))
+        )
+
+    def forward(self, x):
+        fmap = []
+
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), "constant")
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        for conv in self.convs:
+            x = conv(x)
+            x = F.silu(x, inplace=True)
+            fmap.append(x)
+
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class MultiPeriodDiscriminator(nn.Module):
+    def __init__(self, periods: tuple[int] = (2, 3, 5, 7, 11)) -> None:
+        super().__init__()
+
+        self.discriminators = nn.ModuleList(
+            [DiscriminatorP(period=period) for period in periods]
+        )
+
+    def forward(
+        self, x: torch.Tensor
+    ) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]:
+        scores, feature_map = [], []
+
+        for disc in self.discriminators:
+            res, fmap = disc(x)
+
+            scores.append(res)
+            feature_map.append(fmap)
+
+        return scores, feature_map

+ 100 - 0
fish_speech/models/vqgan/modules/discriminators/mrd.py

@@ -0,0 +1,100 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.parametrizations import weight_norm
+
+
+class DiscriminatorR(torch.nn.Module):
+    def __init__(
+        self,
+        *,
+        n_fft: int = 1024,
+        hop_length: int = 120,
+        win_length: int = 600,
+    ):
+        super(DiscriminatorR, self).__init__()
+
+        self.n_fft = n_fft
+        self.hop_length = hop_length
+        self.win_length = win_length
+
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
+                weight_norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+                weight_norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+                weight_norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+                weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
+            ]
+        )
+
+        self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
+
+    def forward(self, x):
+        fmap = []
+
+        x = self.spectrogram(x)
+        x = x.unsqueeze(1)
+
+        for conv in self.convs:
+            x = conv(x)
+            x = F.silu(x, inplace=True)
+            fmap.append(x)
+
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+    def spectrogram(self, x):
+        x = F.pad(
+            x,
+            (
+                (self.n_fft - self.hop_length) // 2,
+                (self.n_fft - self.hop_length + 1) // 2,
+            ),
+            mode="reflect",
+        )
+        x = x.squeeze(1)
+        x = torch.stft(
+            x,
+            n_fft=self.n_fft,
+            hop_length=self.hop_length,
+            win_length=self.win_length,
+            center=False,
+            return_complex=True,
+        )
+        x = torch.view_as_real(x)  # [B, F, TT, 2]
+        mag = torch.norm(x, p=2, dim=-1)  # [B, F, TT]
+
+        return mag
+
+
+class MultiResolutionDiscriminator(torch.nn.Module):
+    def __init__(self, resolutions: list[tuple[int]]):
+        super().__init__()
+
+        self.discriminators = nn.ModuleList(
+            [
+                DiscriminatorR(
+                    n_fft=n_fft,
+                    hop_length=hop_length,
+                    win_length=win_length,
+                )
+                for n_fft, hop_length, win_length in resolutions
+            ]
+        )
+
+    def forward(
+        self, x: torch.Tensor
+    ) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]:
+        scores, feature_map = [], []
+
+        for disc in self.discriminators:
+            res, fmap = disc(x)
+
+            scores.append(res)
+            feature_map.append(fmap)
+
+        return scores, feature_map

+ 188 - 0
fish_speech/models/vqgan/modules/discriminators/mssbcqtd.py

@@ -0,0 +1,188 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Monkey patching to fix a bug in nnAudio
+import numpy as np
+import torch
+import torchaudio.transforms as T
+from einops import rearrange
+from nnAudio import features
+from torch import nn
+
+from .msstftd import NormConv2d, get_2d_padding
+
+np.float = float
+
+LRELU_SLOPE = 0.1
+
+
+class DiscriminatorCQT(nn.Module):
+    def __init__(
+        self,
+        hop_length,
+        n_octaves,
+        bins_per_octave,
+        filters=32,
+        max_filters=1024,
+        filters_scale=1,
+        dilations=[1, 2, 4],
+        in_channels=1,
+        out_channels=1,
+        sample_rate=16000,
+    ):
+        super().__init__()
+
+        self.filters = filters
+        self.max_filters = max_filters
+        self.filters_scale = filters_scale
+        self.kernel_size = (3, 9)
+        self.dilations = dilations
+        self.stride = (1, 2)
+
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.fs = sample_rate
+        self.hop_length = hop_length
+        self.n_octaves = n_octaves
+        self.bins_per_octave = bins_per_octave
+
+        self.cqt_transform = features.cqt.CQT2010v2(
+            sr=self.fs * 2,
+            hop_length=self.hop_length,
+            n_bins=self.bins_per_octave * self.n_octaves,
+            bins_per_octave=self.bins_per_octave,
+            output_format="Complex",
+            pad_mode="constant",
+        )
+
+        self.conv_pres = nn.ModuleList()
+        for i in range(self.n_octaves):
+            self.conv_pres.append(
+                NormConv2d(
+                    self.in_channels * 2,
+                    self.in_channels * 2,
+                    kernel_size=self.kernel_size,
+                    padding=get_2d_padding(self.kernel_size),
+                )
+            )
+
+        self.convs = nn.ModuleList()
+
+        self.convs.append(
+            NormConv2d(
+                self.in_channels * 2,
+                self.filters,
+                kernel_size=self.kernel_size,
+                padding=get_2d_padding(self.kernel_size),
+            )
+        )
+
+        in_chs = min(self.filters_scale * self.filters, self.max_filters)
+        for i, dilation in enumerate(self.dilations):
+            out_chs = min(
+                (self.filters_scale ** (i + 1)) * self.filters, self.max_filters
+            )
+            self.convs.append(
+                NormConv2d(
+                    in_chs,
+                    out_chs,
+                    kernel_size=self.kernel_size,
+                    stride=self.stride,
+                    dilation=(dilation, 1),
+                    padding=get_2d_padding(self.kernel_size, (dilation, 1)),
+                    norm="weight_norm",
+                )
+            )
+            in_chs = out_chs
+        out_chs = min(
+            (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
+            self.max_filters,
+        )
+        self.convs.append(
+            NormConv2d(
+                in_chs,
+                out_chs,
+                kernel_size=(self.kernel_size[0], self.kernel_size[0]),
+                padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
+                norm="weight_norm",
+            )
+        )
+
+        self.conv_post = NormConv2d(
+            out_chs,
+            self.out_channels,
+            kernel_size=(self.kernel_size[0], self.kernel_size[0]),
+            padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
+            norm="weight_norm",
+        )
+
+        self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE)
+        self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2)
+
+    def forward(self, x):
+        fmap = []
+
+        x = self.resample(x)
+
+        z = self.cqt_transform(x)
+
+        z_amplitude = z[:, :, :, 0].unsqueeze(1)
+        z_phase = z[:, :, :, 1].unsqueeze(1)
+
+        z = torch.cat([z_amplitude, z_phase], dim=1)
+        z = rearrange(z, "b c w t -> b c t w")
+
+        latent_z = []
+        for i in range(self.n_octaves):
+            latent_z.append(
+                self.conv_pres[i](
+                    z[
+                        :,
+                        :,
+                        :,
+                        i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
+                    ]
+                )
+            )
+        latent_z = torch.cat(latent_z, dim=-1)
+
+        for i, layer in enumerate(self.convs):
+            latent_z = layer(latent_z)
+
+            latent_z = self.activation(latent_z)
+            fmap.append(latent_z)
+
+        latent_z = self.conv_post(latent_z)
+
+        return latent_z, fmap
+
+
+class MultiScaleSubbandCQTDiscriminator(nn.Module):
+    def __init__(self, hop_lengths, n_octaves, bins_per_octaves, **kwargs):
+        super().__init__()
+
+        self.discriminators = nn.ModuleList(
+            [
+                DiscriminatorCQT(
+                    hop_length=hop_length,
+                    n_octaves=n_octaves,
+                    bins_per_octave=bins_per_octave,
+                    **kwargs,
+                )
+                for hop_length, n_octaves, bins_per_octave in zip(
+                    hop_lengths, n_octaves, bins_per_octaves
+                )
+            ]
+        )
+
+    def forward(self, x: torch.Tensor):
+        logits = []
+        fmaps = []
+        for disc in self.discriminators:
+            logit, fmap = disc(x)
+            logits.append(logit)
+            fmaps.append(fmap)
+
+        return logits, fmaps

+ 303 - 0
fish_speech/models/vqgan/modules/discriminators/msstftd.py

@@ -0,0 +1,303 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""MS-STFT discriminator, provided here for reference."""
+
+import typing as tp
+
+import einops
+import torch
+import torchaudio
+from einops import rearrange
+from torch import nn
+from torch.nn.utils import spectral_norm, weight_norm
+
+FeatureMapType = tp.List[torch.Tensor]
+LogitsType = torch.Tensor
+DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
+
+
+class ConvLayerNorm(nn.LayerNorm):
+    """
+    Convolution-friendly LayerNorm that moves channels to last dimensions
+    before running the normalization and moves them back to original position right after.
+    """  # noqa: E501
+
+    def __init__(
+        self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
+    ):
+        super().__init__(normalized_shape, **kwargs)
+
+    def forward(self, x):
+        x = einops.rearrange(x, "b ... t -> b t ...")
+        x = super().forward(x)
+        x = einops.rearrange(x, "b t ... -> b ... t")
+        return
+
+
+CONV_NORMALIZATIONS = frozenset(
+    [
+        "none",
+        "weight_norm",
+        "spectral_norm",
+        "time_layer_norm",
+        "layer_norm",
+        "time_group_norm",
+    ]
+)
+
+
+def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
+    assert norm in CONV_NORMALIZATIONS
+    if norm == "weight_norm":
+        return weight_norm(module)
+    elif norm == "spectral_norm":
+        return spectral_norm(module)
+    else:
+        # We already check was in CONV_NORMALIZATION, so any other choice
+        # doesn't need reparametrization.
+        return module
+
+
+def get_norm_module(
+    module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
+) -> nn.Module:
+    """Return the proper normalization module. If causal is True, this will ensure the returned
+    module is causal, or return an error if the normalization doesn't support causal evaluation.
+    """  # noqa: E501
+    assert norm in CONV_NORMALIZATIONS
+    if norm == "layer_norm":
+        assert isinstance(module, nn.modules.conv._ConvNd)
+        return ConvLayerNorm(module.out_channels, **norm_kwargs)
+    elif norm == "time_group_norm":
+        if causal:
+            raise ValueError("GroupNorm doesn't support causal evaluation.")
+        assert isinstance(module, nn.modules.conv._ConvNd)
+        return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+    else:
+        return nn.Identity()
+
+
+class NormConv2d(nn.Module):
+    """Wrapper around Conv2d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+
+    def __init__(
+        self,
+        *args,
+        norm: str = "none",
+        norm_kwargs: tp.Dict[str, tp.Any] = {},
+        **kwargs,
+    ):
+        super().__init__()
+        self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+        self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.norm(x)
+        return x
+
+
+def get_2d_padding(
+    kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)
+):
+    return (
+        ((kernel_size[0] - 1) * dilation[0]) // 2,
+        ((kernel_size[1] - 1) * dilation[1]) // 2,
+    )
+
+
+class DiscriminatorSTFT(nn.Module):
+    """STFT sub-discriminator.
+    Args:
+        filters (int): Number of filters in convolutions
+        in_channels (int): Number of input channels. Default: 1
+        out_channels (int): Number of output channels. Default: 1
+        n_fft (int): Size of FFT for each scale. Default: 1024
+        hop_length (int): Length of hop between STFT windows for each scale. Default: 256
+        kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
+        stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
+        dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
+        win_length (int): Window size for each scale. Default: 1024
+        normalized (bool): Whether to normalize by magnitude after stft. Default: True
+        norm (str): Normalization method. Default: `'weight_norm'`
+        activation (str): Activation function. Default: `'LeakyReLU'`
+        activation_params (dict): Parameters to provide to the activation function.
+        growth (int): Growth factor for the filters. Default: 1
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        filters: int,
+        in_channels: int = 1,
+        out_channels: int = 1,
+        n_fft: int = 1024,
+        hop_length: int = 256,
+        win_length: int = 1024,
+        max_filters: int = 1024,
+        filters_scale: int = 1,
+        kernel_size: tp.Tuple[int, int] = (3, 9),
+        dilations: tp.List = [1, 2, 4],
+        stride: tp.Tuple[int, int] = (1, 2),
+        normalized: bool = True,
+        norm: str = "weight_norm",
+        activation: str = "LeakyReLU",
+        activation_params: dict = {"negative_slope": 0.2},
+    ):
+        super().__init__()
+        assert len(kernel_size) == 2
+        assert len(stride) == 2
+        self.filters = filters
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.n_fft = n_fft
+        self.hop_length = hop_length
+        self.win_length = win_length
+        self.normalized = normalized
+        self.activation = getattr(torch.nn, activation)(**activation_params)
+        self.spec_transform = torchaudio.transforms.Spectrogram(
+            n_fft=self.n_fft,
+            hop_length=self.hop_length,
+            win_length=self.win_length,
+            window_fn=torch.hann_window,
+            normalized=self.normalized,
+            center=False,
+            pad_mode=None,
+            power=None,
+        )
+        spec_channels = 2 * self.in_channels
+        self.convs = nn.ModuleList()
+        self.convs.append(
+            NormConv2d(
+                spec_channels,
+                self.filters,
+                kernel_size=kernel_size,
+                padding=get_2d_padding(kernel_size),
+            )
+        )
+        in_chs = min(filters_scale * self.filters, max_filters)
+        for i, dilation in enumerate(dilations):
+            out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
+            self.convs.append(
+                NormConv2d(
+                    in_chs,
+                    out_chs,
+                    kernel_size=kernel_size,
+                    stride=stride,
+                    dilation=(dilation, 1),
+                    padding=get_2d_padding(kernel_size, (dilation, 1)),
+                    norm=norm,
+                )
+            )
+            in_chs = out_chs
+        out_chs = min(
+            (filters_scale ** (len(dilations) + 1)) * self.filters, max_filters
+        )
+        self.convs.append(
+            NormConv2d(
+                in_chs,
+                out_chs,
+                kernel_size=(kernel_size[0], kernel_size[0]),
+                padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+                norm=norm,
+            )
+        )
+        self.conv_post = NormConv2d(
+            out_chs,
+            self.out_channels,
+            kernel_size=(kernel_size[0], kernel_size[0]),
+            padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+            norm=norm,
+        )
+
+    def forward(self, x: torch.Tensor):
+        fmap = []
+        z = self.spec_transform(x)  # [B, 2, Freq, Frames, 2]
+        z = torch.cat([z.real, z.imag], dim=1)
+        z = rearrange(z, "b c w t -> b c t w")
+        for i, layer in enumerate(self.convs):
+            z = layer(z)
+            z = self.activation(z)
+            fmap.append(z)
+        z = self.conv_post(z)
+        return z, fmap
+
+
+class MultiScaleSTFTDiscriminator(nn.Module):
+    """Multi-Scale STFT (MS-STFT) discriminator.
+    Args:
+        filters (int): Number of filters in convolutions
+        in_channels (int): Number of input channels. Default: 1
+        out_channels (int): Number of output channels. Default: 1
+        n_ffts (Sequence[int]): Size of FFT for each scale
+        hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
+        win_lengths (Sequence[int]): Window size for each scale
+        **kwargs: additional args for STFTDiscriminator
+    """
+
+    def __init__(
+        self,
+        filters: int,
+        in_channels: int = 1,
+        out_channels: int = 1,
+        n_ffts: tp.List[int] = [1024, 2048, 512],
+        hop_lengths: tp.List[int] = [256, 512, 128],
+        win_lengths: tp.List[int] = [1024, 2048, 512],
+        **kwargs,
+    ):
+        super().__init__()
+        assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
+        self.discriminators = nn.ModuleList(
+            [
+                DiscriminatorSTFT(
+                    filters,
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    n_fft=n_ffts[i],
+                    win_length=win_lengths[i],
+                    hop_length=hop_lengths[i],
+                    **kwargs,
+                )
+                for i in range(len(n_ffts))
+            ]
+        )
+        self.num_discriminators = len(self.discriminators)
+
+    def forward(self, x: torch.Tensor) -> DiscriminatorOutput:
+        logits = []
+        fmaps = []
+        for disc in self.discriminators:
+            logit, fmap = disc(x)
+            logits.append(logit)
+            fmaps.append(fmap)
+
+        return logits, fmaps
+
+
+def test():
+    disc = MultiScaleSTFTDiscriminator(filters=32)
+    y = torch.randn(1, 1, 24000)
+    y_hat = torch.randn(1, 1, 24000)
+
+    y_disc_r, fmap_r = disc(y)
+    y_disc_gen, fmap_gen = disc(y_hat)
+    assert (
+        len(y_disc_r)
+        == len(y_disc_gen)
+        == len(fmap_r)
+        == len(fmap_gen)
+        == disc.num_discriminators
+    )
+    assert all([len(fm) == 5 for fm in fmap_r + fmap_gen])
+    assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm])
+    assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen])
+
+
+if __name__ == "__main__":
+    test()

+ 63 - 4
fish_speech/models/vqgan/modules/encoders.py

@@ -5,6 +5,7 @@ import numpy as np
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
+from einops import rearrange
 from vector_quantize_pytorch import LFQ, GroupedResidualVQ, VectorQuantize
 from vector_quantize_pytorch import LFQ, GroupedResidualVQ, VectorQuantize
 
 
 from fish_speech.models.vqgan.modules.modules import WN
 from fish_speech.models.vqgan.modules.modules import WN
@@ -298,6 +299,7 @@ class VQEncoder(nn.Module):
             )
             )
 
 
         self.codebook_groups = codebook_groups
         self.codebook_groups = codebook_groups
+        self.codebook_layers = codebook_layers
         self.downsample = downsample
         self.downsample = downsample
         self.conv_in = nn.Conv1d(
         self.conv_in = nn.Conv1d(
             in_channels, vq_channels, kernel_size=downsample, stride=downsample
             in_channels, vq_channels, kernel_size=downsample, stride=downsample
@@ -309,6 +311,17 @@ class VQEncoder(nn.Module):
             nn.Conv1d(vq_channels, in_channels, kernel_size=1, stride=1),
             nn.Conv1d(vq_channels, in_channels, kernel_size=1, stride=1),
         )
         )
 
 
+    @property
+    def mode(self):
+        if self.codebook_groups > 1 and self.codebook_layers > 1:
+            return "grouped-residual"
+        elif self.codebook_groups > 1:
+            return "grouped"
+        elif self.codebook_layers > 1:
+            return "residual"
+        else:
+            return "single"
+
     def forward(self, x, x_mask):
     def forward(self, x, x_mask):
         # x: [B, C, T], x_mask: [B, 1, T]
         # x: [B, C, T], x_mask: [B, 1, T]
         x_len = x.shape[2]
         x_len = x.shape[2]
@@ -327,15 +340,61 @@ class VQEncoder(nn.Module):
         x = self.conv_out(q) * x_mask
         x = self.conv_out(q) * x_mask
         x = x[:, :, :x_len]
         x = x[:, :, :x_len]
 
 
+        # Post process indices
+        if self.mode == "grouped-residual":
+            indices = rearrange(indices, "g b t r -> b (g r) t")
+        elif self.mode == "grouped":
+            indices = rearrange(indices, "g b t 1 -> b g t")
+        elif self.mode == "residual":
+            indices = rearrange(indices, "1 b t r -> b r t")
+        else:
+            indices = rearrange(indices, "b t -> b 1 t")
+
         return x, indices, loss
         return x, indices, loss
 
 
     def decode(self, indices):
     def decode(self, indices):
+        # Undo rearrange
+        if self.mode == "grouped-residual":
+            indices = rearrange(indices, "b (g r) t -> g b t r", g=self.codebook_groups)
+        elif self.mode == "grouped":
+            indices = rearrange(indices, "b g t -> g b t 1")
+        elif self.mode == "residual":
+            indices = rearrange(indices, "b r t -> 1 b t r")
+        else:
+            indices = rearrange(indices, "b 1 t -> b t")
+
         q = self.vq.get_output_from_indices(indices)
         q = self.vq.get_output_from_indices(indices)
 
 
-        if q.shape[1] != indices.shape[1] and indices.ndim != 4:
-            q = q.view(q.shape[0], indices.shape[1], -1)
-        q = q.mT
+        # Edge case for single vq
+        if self.mode == "single":
+            q = rearrange(q, "b (t c) -> b t c", t=indices.shape[-1])
 
 
-        x = self.conv_out(q)
+        x = self.conv_out(q.mT)
 
 
         return x
         return x
+
+
+if __name__ == "__main__":
+    # Test VQEncoder
+    for group, layer in [
+        (1, 1),
+        (1, 2),
+        (2, 1),
+        (2, 2),
+        (4, 1),
+        (4, 2),
+    ]:
+        encoder = VQEncoder(
+            in_channels=1024,
+            vq_channels=1024,
+            codebook_size=2048,
+            downsample=1,
+            codebook_groups=group,
+            codebook_layers=layer,
+            threshold_ema_dead_code=2,
+        )
+        x = torch.randn(2, 1024, 100)
+        x_mask = torch.ones(2, 1, 100)
+        x, indices, loss = encoder(x, x_mask)
+        x = encoder.decode(indices)
+        assert x.shape == (2, 1024, 100)

+ 1 - 0
pyproject.toml

@@ -34,6 +34,7 @@ dependencies = [
     "zibai-server>=0.9.0",
     "zibai-server>=0.9.0",
     "loguru>=0.6.0",
     "loguru>=0.6.0",
     "WeTextProcessing>=0.1.10",
     "WeTextProcessing>=0.1.10",
+    "nnAudio>=0.3.2",
     "loralib>=0.1.2",
     "loralib>=0.1.2",
     "natsort>=8.4.0",
     "natsort>=8.4.0",
     "cn2an>=0.5.22"
     "cn2an>=0.5.22"

+ 5 - 57
tools/api_server.py

@@ -138,36 +138,11 @@ class VQGANModel:
     def sematic_to_wav(self, indices):
     def sematic_to_wav(self, indices):
         model = self.model
         model = self.model
         indices = indices.to(model.device).long()
         indices = indices.to(model.device).long()
-        indices = indices.unsqueeze(1).unsqueeze(-1)
-
-        mel_lengths = indices.shape[2] * (
-            model.downsample.total_strides if model.downsample is not None else 1
-        )
-        mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
-        mel_masks = torch.ones(
-            (1, 1, mel_lengths), device=model.device, dtype=torch.float32
-        )
-
-        text_features = model.vq_encoder.decode(indices)
-
-        logger.info(
-            f"VQ Encoded, indices: {indices.shape} equivalent to "
-            + f"{1 / (mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
-        )
-
-        text_features = F.interpolate(
-            text_features, size=mel_lengths[0], mode="nearest"
-        )
-
-        # Sample mels
-        decoded_mels = model.decoder(text_features, mel_masks)
-        fake_audios = model.generator(decoded_mels)
-        logger.info(
-            f"Generated audio of shape {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
-        )
+        feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
+        decoded = model.decode(indices=indices[None], feature_lengths=feature_lengths)
 
 
         # Save audio
         # Save audio
-        fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
+        fake_audio = decoded.audios[0, 0].cpu().numpy().astype(np.float32)
 
 
         return fake_audio, model.sampling_rate
         return fake_audio, model.sampling_rate
 
 
@@ -189,37 +164,10 @@ class VQGANModel:
         audio_lengths = torch.tensor(
         audio_lengths = torch.tensor(
             [audios.shape[2]], device=model.device, dtype=torch.long
             [audios.shape[2]], device=model.device, dtype=torch.long
         )
         )
-
-        features = gt_mels = model.mel_transform(
-            audios, sample_rate=model.sampling_rate
-        )
-
-        if model.downsample is not None:
-            features = model.downsample(features)
-
-        mel_lengths = audio_lengths // model.hop_length
-        feature_lengths = (
-            audio_lengths
-            / model.hop_length
-            / (model.downsample.total_strides if model.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
-
-        # vq_features is 50 hz, need to convert to true mel size
-        text_features = model.mel_encoder(features, feature_masks)
-        _, indices, _ = model.vq_encoder(text_features, feature_masks)
-
-        if indices.ndim == 4 and indices.shape[1] == 1 and indices.shape[3] == 1:
-            indices = indices[:, 0, :, 0]
-        else:
-            logger.error(f"Unknown indices shape: {indices.shape}")
-            return
+        encoded = model.encode(audios, audio_lengths)
+        indices = encoded.indices[0]
 
 
         logger.info(f"Generated indices of shape {indices.shape}")
         logger.info(f"Generated indices of shape {indices.shape}")
-
         return indices
         return indices
 
 
 
 

+ 5 - 33
tools/vqgan/extract_vq.py

@@ -90,43 +90,15 @@ def process_batch(files: list[Path], model) -> float:
 
 
     # Calculate lengths
     # Calculate lengths
     with torch.no_grad():
     with torch.no_grad():
-        # VQ Encoder
-        features = gt_mels = model.mel_transform(
-            audios, sample_rate=model.sampling_rate
-        )
-
-        if model.downsample is not None:
-            features = model.downsample(features)
-
-        feature_lengths = (
-            audio_lengths
-            / model.hop_length
-            / (model.downsample.total_strides if model.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
-
-        text_features = model.mel_encoder(features, feature_masks)
-        _, indices, _ = model.vq_encoder(text_features, feature_masks)
-
-        if indices.ndim == 4:
-            # Grouped vq
-            assert indices.shape[-1] == 1, f"Residual vq is not supported"
-            indices = indices.squeeze(-1)
-        elif indices.ndim == 2:
-            # Single vq
-            indices = indices.unsqueeze(0)
-        else:
-            raise ValueError(f"Invalid indices shape {indices.shape}")
-
-        indices = rearrange(indices, "c b t -> b c t")
+        out = model.encode(audios, audio_lengths)
+        indices, feature_lengths = out.indices, out.feature_lengths
 
 
     # Save to disk
     # Save to disk
     outputs = indices.cpu().numpy()
     outputs = indices.cpu().numpy()
 
 
-    for file, length, feature, audio in zip(files, feature_lengths, outputs, audios):
+    for file, length, feature, audio_length in zip(
+        files, feature_lengths, outputs, audio_lengths
+    ):
         feature = feature[:, :length]
         feature = feature[:, :length]
 
 
         # (T,)
         # (T,)

+ 7 - 52
tools/vqgan/inference.py

@@ -67,37 +67,8 @@ def main(input_path, output_path, config_name, checkpoint_path):
         audio_lengths = torch.tensor(
         audio_lengths = torch.tensor(
             [audios.shape[2]], device=model.device, dtype=torch.long
             [audios.shape[2]], device=model.device, dtype=torch.long
         )
         )
-
-        features = gt_mels = model.mel_transform(
-            audios, sample_rate=model.sampling_rate
-        )
-
-        if model.downsample is not None:
-            features = model.downsample(features)
-
-        mel_lengths = audio_lengths // model.hop_length
-        feature_lengths = (
-            audio_lengths
-            / model.hop_length
-            / (model.downsample.total_strides if model.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
-        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
-            gt_mels.dtype
-        )
-
-        # vq_features is 50 hz, need to convert to true mel size
-        text_features = model.mel_encoder(features, feature_masks)
-        _, indices, _ = model.vq_encoder(text_features, feature_masks)
-
-        if indices.ndim == 4 and indices.shape[1] == 1 and indices.shape[3] == 1:
-            indices = indices[:, 0, :, 0]
-        else:
-            logger.error(f"Unknown indices shape: {indices.shape}")
-            return
+        encoded = model.encode(audios, audio_lengths)
+        indices = encoded.indices[0]
 
 
         logger.info(f"Generated indices of shape {indices.shape}")
         logger.info(f"Generated indices of shape {indices.shape}")
 
 
@@ -112,29 +83,13 @@ def main(input_path, output_path, config_name, checkpoint_path):
         raise ValueError(f"Unknown input type: {input_path}")
         raise ValueError(f"Unknown input type: {input_path}")
 
 
     # Restore
     # Restore
-    indices = indices.unsqueeze(1).unsqueeze(-1)
-    mel_lengths = indices.shape[2] * (
-        model.downsample.total_strides if model.downsample is not None else 1
-    )
-    mel_lengths = torch.tensor([mel_lengths], device=model.device, dtype=torch.long)
-    mel_masks = torch.ones(
-        (1, 1, mel_lengths), device=model.device, dtype=torch.float32
-    )
-
-    text_features = model.vq_encoder.decode(indices)
-
-    logger.info(
-        f"VQ Encoded, indices: {indices.shape} equivalent to "
-        + f"{1/(mel_lengths[0] * model.hop_length / model.sampling_rate / indices.shape[2]):.2f} Hz"
-    )
-
-    text_features = F.interpolate(text_features, size=mel_lengths[0], mode="nearest")
+    feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
+    decoded = model.decode(indices=indices[None], feature_lengths=feature_lengths)
+    fake_audios = decoded.audios
+    audio_time = fake_audios.shape[-1] / model.sampling_rate
 
 
-    # Sample mels
-    decoded_mels = model.decoder(text_features, mel_masks)
-    fake_audios = model.generator(decoded_mels)
     logger.info(
     logger.info(
-        f"Generated audio of shape {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
+        f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
     )
     )
 
 
     # Save audio
     # Save audio