Lengyue 2 سال پیش
والد
کامیت
0855b454e4

+ 1 - 0
README.md

@@ -22,3 +22,4 @@ We do not hold any responsibility for any illegal usage of the codebase. Please
 - [GPT VITS](https://github.com/innnky/gpt-vits)
 - [MQTTS](https://github.com/b04901014/MQTTS)
 - [GPT Fast](https://github.com/pytorch-labs/gpt-fast)
+- [GPT-SoVITS](https://github.com/RVC-Boss/GPT-SoVITS)

+ 115 - 0
fish_speech/configs/vqgan_pretrain.yaml

@@ -0,0 +1,115 @@
+defaults:
+  - base
+  - _self_
+
+project: vqgan_pretrain
+ckpt_path: checkpoints/gpt_sovits_488k.pth
+resume_weights_only: true
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: auto
+  strategy: ddp_find_unused_parameters_true
+  precision: 32
+  max_steps: 1_000_000
+  val_check_interval: 5000
+
+sample_rate: 32000
+hop_length: 640
+num_mels: 128
+n_fft: 2048
+win_length: 2048
+segment_size: 128
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/vq_train_filelist.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: ${segment_size}
+
+val_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/vq_val_filelist.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+data:
+  _target_: fish_speech.datasets.vqgan.VQGANDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 16
+  val_batch_size: 4
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vqgan.VQGAN
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  freeze_discriminator: true
+
+  weight_mel: 45
+  weight_kl: 0.1
+  weight_vq: 1.0
+
+  generator:
+    _target_: fish_speech.models.vqgan.modules.models.SynthesizerTrn
+    spec_channels: 1025
+    segment_size: 20480
+    inter_channels: 192
+    hidden_channels: 192
+    filter_channels: 768
+    n_heads: 2
+    n_layers: 6
+    kernel_size: 3
+    p_dropout: 0.1
+    resblock: "1"
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    upsample_rates: [10, 8, 2, 2, 2]
+    upsample_initial_channel: 512
+    upsample_kernel_sizes: [16, 16, 8, 2, 2]
+    gin_channels: 512
+    freeze_quantizer: false
+    codebook_size: 1024
+    num_codebooks: 2
+
+  discriminator:
+    _target_: fish_speech.models.vqgan.modules.models.EnsembledDiscriminator
+    periods: [2, 3, 5, 7, 11]
+
+  mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+
+  spec_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LinearSpectrogram
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    mode: pow2_sqrt
+  
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 1e-4
+    betas: [0.8, 0.99]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.ExponentialLR
+    _partial_: true
+    gamma: 0.99999
+
+callbacks:
+  grad_norm_monitor:
+    sub_module: 
+      - generator
+      - discriminator

+ 142 - 0
fish_speech/datasets/vqgan.py

@@ -0,0 +1,142 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import librosa
+import numpy as np
+import torch
+from lightning import LightningDataModule
+from torch.utils.data import DataLoader, Dataset
+
+from fish_speech.utils import RankedLogger
+
+logger = RankedLogger(__name__, rank_zero_only=False)
+
+
+class VQGANDataset(Dataset):
+    def __init__(
+        self,
+        filelist: str,
+        sample_rate: int = 32000,
+        hop_length: int = 640,
+        slice_frames: Optional[int] = None,
+    ):
+        super().__init__()
+
+        filelist = Path(filelist)
+        root = filelist.parent
+
+        self.files = [
+            root / line.strip()
+            for line in filelist.read_text().splitlines()
+            if line.strip()
+        ]
+        self.sample_rate = sample_rate
+        self.hop_length = hop_length
+        self.slice_frames = slice_frames
+
+    def __len__(self):
+        return len(self.files)
+
+    def get_item(self, idx):
+        file = self.files[idx]
+
+        audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
+
+        # Slice audio and features
+        if (
+            self.slice_frames is not None
+            and audio.shape[0] > self.slice_frames * self.hop_length
+        ):
+            start = np.random.randint(
+                0, audio.shape[0] - self.slice_frames * self.hop_length
+            )
+            audio = audio[start : start + self.slice_frames * self.hop_length]
+
+        if len(audio) == 0:
+            return None
+
+        max_value = np.abs(audio).max()
+        if max_value > 1.0:
+            audio = audio / max_value
+
+        return {
+            "audio": torch.from_numpy(audio),
+        }
+
+    def __getitem__(self, idx):
+        try:
+            return self.get_item(idx)
+        except Exception as e:
+            logger.error(f"Error loading {self.files[idx]}: {e}")
+            return None
+
+
+@dataclass
+class VQGANCollator:
+    def __call__(self, batch):
+        batch = [x for x in batch if x is not None]
+
+        audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
+        audio_maxlen = audio_lengths.max()
+
+        # Rounds up to nearest multiple of 2 (audio_lengths)
+        audios = []
+        for x in batch:
+            audios.append(
+                torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
+            )
+
+        return {
+            "audios": torch.stack(audios),
+            "audio_lengths": audio_lengths,
+        }
+
+
+class VQGANDataModule(LightningDataModule):
+    def __init__(
+        self,
+        train_dataset: VQGANDataset,
+        val_dataset: VQGANDataset,
+        batch_size: int = 32,
+        num_workers: int = 4,
+        val_batch_size: Optional[int] = None,
+    ):
+        super().__init__()
+
+        self.train_dataset = train_dataset
+        self.val_dataset = val_dataset
+        self.batch_size = batch_size
+        self.val_batch_size = val_batch_size or batch_size
+        self.num_workers = num_workers
+
+    def train_dataloader(self):
+        return DataLoader(
+            self.train_dataset,
+            batch_size=self.batch_size,
+            collate_fn=VQGANCollator(),
+            num_workers=self.num_workers,
+            shuffle=True,
+        )
+
+    def val_dataloader(self):
+        return DataLoader(
+            self.val_dataset,
+            batch_size=self.batch_size,
+            collate_fn=VQGANCollator(),
+            num_workers=self.num_workers,
+        )
+
+
+if __name__ == "__main__":
+    dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
+    dataloader = DataLoader(
+        dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
+    )
+
+    for batch in dataloader:
+        print(batch["audios"].shape)
+        print(batch["features"].shape)
+        print(batch["audio_lengths"])
+        print(batch["feature_lengths"])
+        break

+ 2 - 0
fish_speech/models/text2semantic/llama.py

@@ -165,6 +165,8 @@ class Transformer(nn.Module):
             )
             x += torch.rand_like(x) * scaled_alpha
 
+            print("NEFT alpha:", scaled_alpha)
+
         return x
 
     def compute(

+ 195 - 164
fish_speech/models/vqgan/lit_module.py

@@ -16,9 +16,8 @@ from fish_speech.models.vqgan.losses import (
     discriminator_loss,
     feature_loss,
     generator_loss,
+    kl_loss,
 )
-from fish_speech.models.vqgan.modules.convnext import ConvNeXt
-from fish_speech.models.vqgan.modules.encoders import VQEncoder
 from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
 
 
@@ -41,35 +40,36 @@ class VQGAN(L.LightningModule):
         self,
         optimizer: Callable,
         lr_scheduler: Callable,
-        encoder: ConvNeXt,
-        vq: VQEncoder,
-        decoder: ConvNeXt,
         generator: nn.Module,
-        discriminator: ConvNeXt,
+        discriminator: nn.Module,
         mel_transform: nn.Module,
+        spec_transform: nn.Module,
         hop_length: int = 640,
         sample_rate: int = 32000,
         freeze_discriminator: bool = False,
+        weight_mel: float = 45,
+        weight_kl: float = 0.1,
+        weight_vq: float = 1.0,
     ):
         super().__init__()
 
-        # pretrain: vq use gt mel as target, hifigan use gt mel as input
-        # finetune: end-to-end training, use gt mel as hifi gan target but freeze vq
-
         # Model parameters
         self.optimizer_builder = optimizer
         self.lr_scheduler_builder = lr_scheduler
 
         # Generator and discriminator
-        self.encoder = encoder
-        self.vq = vq
-        self.decoder = decoder
         self.generator = generator
         self.discriminator = discriminator
         self.mel_transform = mel_transform
+        self.spec_transform = spec_transform
         self.freeze_discriminator = freeze_discriminator
 
-        # Crop length for saving memory
+        # Loss weights
+        self.weight_mel = weight_mel
+        self.weight_kl = weight_kl
+        self.weight_vq = weight_vq
+
+        # Other parameters
         self.hop_length = hop_length
         self.sampling_rate = sample_rate
 
@@ -80,19 +80,9 @@ class VQGAN(L.LightningModule):
             for p in self.discriminator.parameters():
                 p.requires_grad = False
 
-        # Freeze generator
-        for p in self.generator.parameters():
-            p.requires_grad = False
-
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
-        optimizer_generator = self.optimizer_builder(
-            itertools.chain(
-                self.encoder.parameters(),
-                self.vq.parameters(),
-                self.decoder.parameters(),
-            )
-        )
+        optimizer_generator = self.optimizer_builder(self.generator.parameters())
         optimizer_discriminator = self.optimizer_builder(
             self.discriminator.parameters()
         )
@@ -128,28 +118,23 @@ class VQGAN(L.LightningModule):
         audios = audios[:, None, :]
 
         with torch.no_grad():
-            gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
-
-        mel_lengths = audio_lengths // self.hop_length
-        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
-            gt_mels.dtype
-        )
-
-        vq_result = self.encode(audios, audio_lengths)
-        loss_vq = vq_result.loss
-
-        if loss_vq.ndim > 1:
-            loss_vq = loss_vq.mean()
+            gt_mels = self.mel_transform(audios)
+            gt_specs = self.spec_transform(audios)
+
+        spec_lengths = audio_lengths // self.hop_length
+        spec_masks = torch.unsqueeze(
+            sequence_mask(spec_lengths, gt_mels.shape[2]), 1
+        ).to(gt_mels.dtype)
+        (
+            fake_audios,
+            ids_slice,
+            y_mask,
+            y_mask,
+            (z, z_p, m_p, logs_p, m_q, logs_q),
+            quantized,
+        ) = self.generator(gt_specs, spec_lengths)
 
-        decoded_mels = self.decode(
-            indices=None,
-            features=vq_result.features,
-            audio_lengths=audio_lengths,
-        ).mels
-
-        with torch.no_grad():
-            with torch.autocast(device_type=audios.device.type, enabled=False):
-                fake_audios = self.generator(decoded_mels.float())
+        fake_mels = self.mel_transform(fake_audios.squeeze(1))
 
         assert (
             audios.shape == fake_audios.shape
@@ -157,11 +142,12 @@ class VQGAN(L.LightningModule):
 
         # Discriminator
         if self.freeze_discriminator is False:
-            scores = self.discriminator(gt_mels)
-            score_fakes = self.discriminator(decoded_mels.detach())
+            y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(
+                audios, fake_audios.detach()
+            )
 
             with torch.autocast(device_type=audios.device.type, enabled=False):
-                loss_disc, _, _ = discriminator_loss([scores], [score_fakes])
+                loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
 
             self.log(
                 f"train/discriminator/loss",
@@ -181,11 +167,11 @@ class VQGAN(L.LightningModule):
             optim_d.step()
 
         # Adv Loss
-        score_fakes = self.discriminator(decoded_mels)
+        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios)
 
         # Adversarial Loss
         with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_adv, _ = generator_loss([score_fakes])
+            loss_adv, _ = generator_loss(y_d_hat_g)
 
         self.log(
             f"train/generator/adv",
@@ -197,11 +183,8 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
 
-        # Feature Matching Loss
-        score_gts = self.discriminator(gt_mels)
-
         with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_fm = feature_loss([score_gts], [score_fakes])
+            loss_fm = feature_loss(y_d_hat_r, y_d_hat_g)
 
         self.log(
             f"train/generator/adv_fm",
@@ -214,7 +197,7 @@ class VQGAN(L.LightningModule):
         )
 
         with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
+            loss_mel = F.l1_loss(gt_mels * spec_masks, fake_mels * spec_masks)
 
         self.log(
             "train/generator/loss_mel",
@@ -226,6 +209,7 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
 
+        loss_vq = quantized.commitment_loss + quantized.codebook_loss
         self.log(
             "train/generator/loss_vq",
             loss_vq,
@@ -236,7 +220,25 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
 
-        loss = loss_mel * 20 + loss_vq + loss_adv + loss_fm
+        loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
+
+        self.log(
+            "train/generator/loss",
+            loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+
+        loss = (
+            loss_mel * self.weight_mel
+            + loss_vq * self.weight_vq
+            + loss_kl * self.weight_kl
+            + loss_adv
+            + loss_fm
+        )
         self.log(
             "train/generator/loss",
             loss,
@@ -247,6 +249,7 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
 
+        # Backward
         optim_g.zero_grad()
 
         self.manual_backward(loss)
@@ -266,53 +269,70 @@ class VQGAN(L.LightningModule):
         audios = audios.float()
         audios = audios[:, None, :]
 
-        gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
-        mel_lengths = audio_lengths // self.hop_length
-        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
-            gt_mels.dtype
-        )
-
-        vq_result = self.encode(audios, audio_lengths)
-        decoded_mels = self.decode(
-            indices=vq_result.indices,
-            audio_lengths=audio_lengths,
-        ).mels
-        fake_audios = self.generator(decoded_mels)
+        gt_mels = self.mel_transform(audios)
+        gt_specs = self.spec_transform(audios)
+        spec_lengths = audio_lengths // self.hop_length
+        spec_masks = torch.unsqueeze(
+            sequence_mask(spec_lengths, gt_mels.shape[2]), 1
+        ).to(gt_mels.dtype)
 
-        fake_mels = self.mel_transform(fake_audios.squeeze(1))
+        prior_audios, _, _ = self.generator.infer(gt_specs, spec_lengths)
+        posterior_audios, _, _ = self.generator.infer_posterior(gt_specs, spec_lengths)
+        prior_mels = self.mel_transform(prior_audios.squeeze(1))
+        posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
 
         min_mel_length = min(
-            decoded_mels.shape[-1], gt_mels.shape[-1], fake_mels.shape[-1]
+            gt_mels.shape[-1], prior_mels.shape[-1], posterior_mels.shape[-1]
         )
-        decoded_mels = decoded_mels[:, :, :min_mel_length]
         gt_mels = gt_mels[:, :, :min_mel_length]
-        fake_mels = fake_mels[:, :, :min_mel_length]
+        prior_mels = prior_mels[:, :, :min_mel_length]
+        posterior_mels = posterior_mels[:, :, :min_mel_length]
+
+        prior_mel_loss = F.l1_loss(gt_mels * spec_masks, prior_mels * spec_masks)
+        posterior_mel_loss = F.l1_loss(
+            gt_mels * spec_masks, posterior_mels * spec_masks
+        )
+
+        self.log(
+            "val/prior_mel_loss",
+            prior_mel_loss,
+            on_step=False,
+            on_epoch=True,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
 
-        mel_loss = F.l1_loss(gt_mels * mel_masks, fake_mels * mel_masks)
         self.log(
-            "val/mel_loss",
-            mel_loss,
+            "val/posterior_mel_loss",
+            posterior_mel_loss,
             on_step=False,
             on_epoch=True,
-            prog_bar=True,
+            prog_bar=False,
             logger=True,
             sync_dist=True,
         )
 
+        # only log the first batch
+        if batch_idx != 0:
+            return
+
         for idx, (
             mel,
-            gen_mel,
-            decode_mel,
+            prior_mel,
+            posterior_mel,
             audio,
-            gen_audio,
+            prior_audio,
+            posterior_audio,
             audio_len,
         ) in enumerate(
             zip(
                 gt_mels,
-                fake_mels,
-                decoded_mels,
+                prior_mels,
+                posterior_mels,
                 audios.detach().float(),
-                fake_audios.detach().float(),
+                prior_audios.detach().float(),
+                posterior_audios.detach().float(),
                 audio_lengths,
             )
         ):
@@ -320,13 +340,13 @@ class VQGAN(L.LightningModule):
 
             image_mels = plot_mel(
                 [
-                    gen_mel[:, :mel_len],
-                    decode_mel[:, :mel_len],
+                    prior_mel[:, :mel_len],
+                    posterior_mel[:, :mel_len],
                     mel[:, :mel_len],
                 ],
                 [
-                    "Generated",
-                    "Decoded",
+                    "Prior (VQ)",
+                    "Posterior (Reconstruction)",
                     "Ground-Truth",
                 ],
             )
@@ -342,9 +362,14 @@ class VQGAN(L.LightningModule):
                                 caption="gt",
                             ),
                             wandb.Audio(
-                                gen_audio[0, :audio_len],
+                                prior_audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="prior",
+                            ),
+                            wandb.Audio(
+                                posterior_audio[0, :audio_len],
                                 sample_rate=self.sampling_rate,
-                                caption="prediction",
+                                caption="posterior",
                             ),
                         ],
                     },
@@ -363,85 +388,91 @@ class VQGAN(L.LightningModule):
                     sample_rate=self.sampling_rate,
                 )
                 self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/prediction",
-                    gen_audio[0, :audio_len],
+                    f"sample-{idx}/wavs/prior",
+                    prior_audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/posterior",
+                    posterior_audio[0, :audio_len],
                     self.global_step,
                     sample_rate=self.sampling_rate,
                 )
 
             plt.close(image_mels)
 
-    def encode(self, audios, audio_lengths=None):
-        if audio_lengths is None:
-            audio_lengths = torch.tensor(
-                [audios.shape[-1]] * audios.shape[0],
-                device=audios.device,
-                dtype=torch.long,
-            )
-
-        with torch.no_grad():
-            features = self.mel_transform(audios, sample_rate=self.sampling_rate)
-
-        feature_lengths = (
-            audio_lengths
-            / self.hop_length
-            # / self.vq.downsample
-        ).long()
-
-        # print(features.shape, feature_lengths.shape, torch.max(feature_lengths))
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(features.dtype)
-
-        features = (
-            gradient_checkpoint(
-                self.encoder, features, feature_masks, use_reentrant=False
-            )
-            * feature_masks
-        )
-        vq_features, indices, loss = self.vq(features, feature_masks)
-
-        return VQEncodeResult(
-            features=vq_features,
-            indices=indices,
-            loss=loss,
-            feature_lengths=feature_lengths,
-        )
-
-    def calculate_audio_lengths(self, feature_lengths):
-        return feature_lengths * self.hop_length * self.vq.downsample
-
-    def decode(
-        self,
-        indices=None,
-        features=None,
-        audio_lengths=None,
-        feature_lengths=None,
-        return_audios=False,
-    ):
-        assert (
-            indices is not None or features is not None
-        ), "indices or features must be provided"
-        assert (
-            feature_lengths is not None or audio_lengths is not None
-        ), "feature_lengths or audio_lengths must be provided"
-
-        if audio_lengths is None:
-            audio_lengths = self.calculate_audio_lengths(feature_lengths)
-
-        mel_lengths = audio_lengths // self.hop_length
-        mel_masks = torch.unsqueeze(
-            sequence_mask(mel_lengths, torch.max(mel_lengths)), 1
-        ).float()
-
-        if indices is not None:
-            features = self.vq.decode(indices)
-
-        # Sample mels
-        decoded = gradient_checkpoint(self.decoder, features, use_reentrant=False)
-
-        return VQDecodeResult(
-            mels=decoded,
-            audios=self.generator(decoded) if return_audios else None,
-        )
+    # def encode(self, audios, audio_lengths=None):
+    #     if audio_lengths is None:
+    #         audio_lengths = torch.tensor(
+    #             [audios.shape[-1]] * audios.shape[0],
+    #             device=audios.device,
+    #             dtype=torch.long,
+    #         )
+
+    #     with torch.no_grad():
+    #         features = self.mel_transform(audios, sample_rate=self.sampling_rate)
+
+    #     feature_lengths = (
+    #         audio_lengths
+    #         / self.hop_length
+    #         # / self.vq.downsample
+    #     ).long()
+
+    #     # print(features.shape, feature_lengths.shape, torch.max(feature_lengths))
+
+    #     feature_masks = torch.unsqueeze(
+    #         sequence_mask(feature_lengths, features.shape[2]), 1
+    #     ).to(features.dtype)
+
+    #     features = (
+    #         gradient_checkpoint(
+    #             self.encoder, features, feature_masks, use_reentrant=False
+    #         )
+    #         * feature_masks
+    #     )
+    #     vq_features, indices, loss = self.vq(features, feature_masks)
+
+    #     return VQEncodeResult(
+    #         features=vq_features,
+    #         indices=indices,
+    #         loss=loss,
+    #         feature_lengths=feature_lengths,
+    #     )
+
+    # def calculate_audio_lengths(self, feature_lengths):
+    #     return feature_lengths * self.hop_length * self.vq.downsample
+
+    # def decode(
+    #     self,
+    #     indices=None,
+    #     features=None,
+    #     audio_lengths=None,
+    #     feature_lengths=None,
+    #     return_audios=False,
+    # ):
+    #     assert (
+    #         indices is not None or features is not None
+    #     ), "indices or features must be provided"
+    #     assert (
+    #         feature_lengths is not None or audio_lengths is not None
+    #     ), "feature_lengths or audio_lengths must be provided"
+
+    #     if audio_lengths is None:
+    #         audio_lengths = self.calculate_audio_lengths(feature_lengths)
+
+    #     mel_lengths = audio_lengths // self.hop_length
+    #     mel_masks = torch.unsqueeze(
+    #         sequence_mask(mel_lengths, torch.max(mel_lengths)), 1
+    #     ).float()
+
+    #     if indices is not None:
+    #         features = self.vq.decode(indices)
+
+    #     # Sample mels
+    #     decoded = gradient_checkpoint(self.decoder, features, use_reentrant=False)
+
+    #     return VQDecodeResult(
+    #         mels=decoded,
+    #         audios=self.generator(decoded) if return_audios else None,
+    #     )

+ 349 - 0
fish_speech/models/vqgan/modules/attentions.py

@@ -0,0 +1,349 @@
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils import remove_weight_norm, weight_norm
+
+from fish_speech.models.vqgan.modules import commons
+from fish_speech.models.vqgan.modules.modules import LayerNorm
+
+
+class Encoder(nn.Module):
+    def __init__(
+        self,
+        hidden_channels,
+        filter_channels,
+        n_heads,
+        n_layers,
+        kernel_size=1,
+        p_dropout=0.0,
+        window_size=4,
+        isflow=False,
+        gin_channels=0,
+    ):
+        super().__init__()
+        self.hidden_channels = hidden_channels
+        self.filter_channels = filter_channels
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+        self.window_size = window_size
+
+        self.drop = nn.Dropout(p_dropout)
+        self.attn_layers = nn.ModuleList()
+        self.norm_layers_1 = nn.ModuleList()
+        self.ffn_layers = nn.ModuleList()
+        self.norm_layers_2 = nn.ModuleList()
+        for i in range(self.n_layers):
+            self.attn_layers.append(
+                MultiHeadAttention(
+                    hidden_channels,
+                    hidden_channels,
+                    n_heads,
+                    p_dropout=p_dropout,
+                    window_size=window_size,
+                )
+            )
+            self.norm_layers_1.append(LayerNorm(hidden_channels))
+            self.ffn_layers.append(
+                FFN(
+                    hidden_channels,
+                    hidden_channels,
+                    filter_channels,
+                    kernel_size,
+                    p_dropout=p_dropout,
+                )
+            )
+            self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+        if isflow:
+            cond_layer = torch.nn.Conv1d(
+                gin_channels, 2 * hidden_channels * n_layers, 1
+            )
+            self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
+            self.cond_layer = weight_norm(cond_layer, "weight")
+            self.gin_channels = gin_channels
+
+    def forward(self, x, x_mask, g=None):
+        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+        x = x * x_mask
+        if g is not None:
+            g = self.cond_layer(g)
+
+        for i in range(self.n_layers):
+            if g is not None:
+                x = self.cond_pre(x)
+                cond_offset = i * 2 * self.hidden_channels
+                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
+                x = commons.fused_add_tanh_sigmoid_multiply(
+                    x, g_l, torch.IntTensor([self.hidden_channels])
+                )
+            y = self.attn_layers[i](x, x, attn_mask)
+            y = self.drop(y)
+            x = self.norm_layers_1[i](x + y)
+
+            y = self.ffn_layers[i](x, x_mask)
+            y = self.drop(y)
+            x = self.norm_layers_2[i](x + y)
+        x = x * x_mask
+        return x
+
+
+class MultiHeadAttention(nn.Module):
+    def __init__(
+        self,
+        channels,
+        out_channels,
+        n_heads,
+        p_dropout=0.0,
+        window_size=None,
+        heads_share=True,
+        block_length=None,
+        proximal_bias=False,
+        proximal_init=False,
+    ):
+        super().__init__()
+        assert channels % n_heads == 0
+
+        self.channels = channels
+        self.out_channels = out_channels
+        self.n_heads = n_heads
+        self.p_dropout = p_dropout
+        self.window_size = window_size
+        self.heads_share = heads_share
+        self.block_length = block_length
+        self.proximal_bias = proximal_bias
+        self.proximal_init = proximal_init
+        self.attn = None
+
+        self.k_channels = channels // n_heads
+        self.conv_q = nn.Conv1d(channels, channels, 1)
+        self.conv_k = nn.Conv1d(channels, channels, 1)
+        self.conv_v = nn.Conv1d(channels, channels, 1)
+        self.conv_o = nn.Conv1d(channels, out_channels, 1)
+        self.drop = nn.Dropout(p_dropout)
+
+        if window_size is not None:
+            n_heads_rel = 1 if heads_share else n_heads
+            rel_stddev = self.k_channels**-0.5
+            self.emb_rel_k = nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+            self.emb_rel_v = nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+
+        nn.init.xavier_uniform_(self.conv_q.weight)
+        nn.init.xavier_uniform_(self.conv_k.weight)
+        nn.init.xavier_uniform_(self.conv_v.weight)
+        if proximal_init:
+            with torch.no_grad():
+                self.conv_k.weight.copy_(self.conv_q.weight)
+                self.conv_k.bias.copy_(self.conv_q.bias)
+
+    def forward(self, x, c, attn_mask=None):
+        q = self.conv_q(x)
+        k = self.conv_k(c)
+        v = self.conv_v(c)
+
+        x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+        x = self.conv_o(x)
+        return x
+
+    def attention(self, query, key, value, mask=None):
+        # reshape [b, d, t] -> [b, n_h, t, d_k]
+        b, d, t_s, t_t = (*key.size(), query.size(2))
+        query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+        key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+        value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+        scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
+        if self.window_size is not None:
+            assert (
+                t_s == t_t
+            ), "Relative attention is only available for self-attention."
+            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+            rel_logits = self._matmul_with_relative_keys(
+                query / math.sqrt(self.k_channels), key_relative_embeddings
+            )
+            scores_local = self._relative_position_to_absolute_position(rel_logits)
+            scores = scores + scores_local
+        if self.proximal_bias:
+            assert t_s == t_t, "Proximal bias is only available for self-attention."
+            scores = scores + self._attention_bias_proximal(t_s).to(
+                device=scores.device, dtype=scores.dtype
+            )
+        if mask is not None:
+            scores = scores.masked_fill(mask == 0, -1e4)
+            if self.block_length is not None:
+                assert (
+                    t_s == t_t
+                ), "Local attention is only available for self-attention."
+                block_mask = (
+                    torch.ones_like(scores)
+                    .triu(-self.block_length)
+                    .tril(self.block_length)
+                )
+                scores = scores.masked_fill(block_mask == 0, -1e4)
+        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
+        p_attn = self.drop(p_attn)
+        output = torch.matmul(p_attn, value)
+        if self.window_size is not None:
+            relative_weights = self._absolute_position_to_relative_position(p_attn)
+            value_relative_embeddings = self._get_relative_embeddings(
+                self.emb_rel_v, t_s
+            )
+            output = output + self._matmul_with_relative_values(
+                relative_weights, value_relative_embeddings
+            )
+        output = (
+            output.transpose(2, 3).contiguous().view(b, d, t_t)
+        )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
+        return output, p_attn
+
+    def _matmul_with_relative_values(self, x, y):
+        """
+        x: [b, h, l, m]
+        y: [h or 1, m, d]
+        ret: [b, h, l, d]
+        """
+        ret = torch.matmul(x, y.unsqueeze(0))
+        return ret
+
+    def _matmul_with_relative_keys(self, x, y):
+        """
+        x: [b, h, l, d]
+        y: [h or 1, m, d]
+        ret: [b, h, l, m]
+        """
+        ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+        return ret
+
+    def _get_relative_embeddings(self, relative_embeddings, length):
+        max_relative_position = 2 * self.window_size + 1
+        # Pad first before slice to avoid using cond ops.
+        pad_length = max(length - (self.window_size + 1), 0)
+        slice_start_position = max((self.window_size + 1) - length, 0)
+        slice_end_position = slice_start_position + 2 * length - 1
+        if pad_length > 0:
+            padded_relative_embeddings = F.pad(
+                relative_embeddings,
+                commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+            )
+        else:
+            padded_relative_embeddings = relative_embeddings
+        used_relative_embeddings = padded_relative_embeddings[
+            :, slice_start_position:slice_end_position
+        ]
+        return used_relative_embeddings
+
+    def _relative_position_to_absolute_position(self, x):
+        """
+        x: [b, h, l, 2*l-1]
+        ret: [b, h, l, l]
+        """
+        batch, heads, length, _ = x.size()
+        # Concat columns of pad to shift from relative to absolute indexing.
+        x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+        # Concat extra elements so to add up to shape (len+1, 2*len-1).
+        x_flat = x.view([batch, heads, length * 2 * length])
+        x_flat = F.pad(
+            x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
+        )
+
+        # Reshape and slice out the padded elements.
+        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+            :, :, :length, length - 1 :
+        ]
+        return x_final
+
+    def _absolute_position_to_relative_position(self, x):
+        """
+        x: [b, h, l, l]
+        ret: [b, h, l, 2*l-1]
+        """
+        batch, heads, length, _ = x.size()
+        # pad along column
+        x = F.pad(
+            x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
+        )
+        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+        # add 0's in the beginning that will skew the elements after reshape
+        x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+        return x_final
+
+    def _attention_bias_proximal(self, length):
+        """Bias for self-attention to encourage attention to close positions.
+        Args:
+          length: an integer scalar.
+        Returns:
+          a Tensor with shape [1, 1, length, length]
+        """
+        r = torch.arange(length, dtype=torch.float32)
+        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        filter_channels,
+        kernel_size,
+        p_dropout=0.0,
+        activation=None,
+        causal=False,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.filter_channels = filter_channels
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+        self.activation = activation
+        self.causal = causal
+
+        if causal:
+            self.padding = self._causal_padding
+        else:
+            self.padding = self._same_padding
+
+        self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
+        self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
+        self.drop = nn.Dropout(p_dropout)
+
+    def forward(self, x, x_mask):
+        x = self.conv_1(self.padding(x * x_mask))
+        if self.activation == "gelu":
+            x = x * torch.sigmoid(1.702 * x)
+        else:
+            x = torch.relu(x)
+        x = self.drop(x)
+        x = self.conv_2(self.padding(x * x_mask))
+        return x * x_mask
+
+    def _causal_padding(self, x):
+        if self.kernel_size == 1:
+            return x
+        pad_l = self.kernel_size - 1
+        pad_r = 0
+        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+        x = F.pad(x, commons.convert_pad_shape(padding))
+        return x
+
+    def _same_padding(self, x):
+        if self.kernel_size == 1:
+            return x
+        pad_l = (self.kernel_size - 1) // 2
+        pad_r = self.kernel_size // 2
+        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+        x = F.pad(x, commons.convert_pad_shape(padding))
+        return x

+ 0 - 193
fish_speech/models/vqgan/modules/balancer.py

@@ -1,193 +0,0 @@
-import typing as tp
-from collections import defaultdict
-
-import torch
-from torch import autograd
-
-
-def rank():
-    if torch.distributed.is_initialized():
-        return torch.distributed.get_rank()
-    else:
-        return 0
-
-
-def world_size():
-    if torch.distributed.is_initialized():
-        return torch.distributed.get_world_size()
-    else:
-        return 1
-
-
-def is_distributed():
-    return world_size() > 1
-
-
-def average_metrics(metrics: tp.Dict[str, float], count=1.0):
-    """Average a dictionary of metrics across all workers, using the optional
-    `count` as unnormalized weight.
-    """
-    if not is_distributed():
-        return metrics
-    keys, values = zip(*metrics.items())
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-    tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
-    tensor *= count
-    all_reduce(tensor)
-    averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
-    return dict(zip(keys, averaged))
-
-
-def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
-    if is_distributed():
-        return torch.distributed.all_reduce(tensor, op)
-
-
-def averager(beta: float = 1):
-    """
-    Exponential Moving Average callback.
-    Returns a single function that can be called to repeatidly update the EMA
-    with a dict of metrics. The callback will return
-    the new averaged dict of metrics.
-
-    Note that for `beta=1`, this is just plain averaging.
-    """
-    fix: tp.Dict[str, float] = defaultdict(float)
-    total: tp.Dict[str, float] = defaultdict(float)
-
-    def _update(
-        metrics: tp.Dict[str, tp.Any], weight: float = 1
-    ) -> tp.Dict[str, float]:
-        nonlocal total, fix
-        for key, value in metrics.items():
-            total[key] = total[key] * beta + weight * float(value)
-            fix[key] = fix[key] * beta + weight
-        return {key: tot / fix[key] for key, tot in total.items()}
-
-    return _update
-
-
-class Balancer:
-    """Loss balancer.
-
-    The loss balancer combines losses together to compute gradients for the backward.
-    A call to the balancer will weight the losses according the specified weight coefficients.
-    A call to the backward method of the balancer will compute the gradients, combining all the losses and
-    potentially rescaling the gradients, which can help stabilize the training and reasonate
-    about multiple losses with varying scales.
-
-    Expected usage:
-        weights = {'loss_a': 1, 'loss_b': 4}
-        balancer = Balancer(weights, ...)
-        losses: dict = {}
-        losses['loss_a'] = compute_loss_a(x, y)
-        losses['loss_b'] = compute_loss_b(x, y)
-        if model.training():
-            balancer.backward(losses, x)
-
-    ..Warning:: It is unclear how this will interact with DistributedDataParallel,
-        in particular if you have some losses not handled by the balancer. In that case
-        you can use `encodec.distrib.sync_grad(model.parameters())` and
-        `encodec.distrib.sync_buffwers(model.buffers())` as a safe alternative.
-
-    Args:
-        weights (Dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
-            from the backward method to match the weights keys to assign weight to each of the provided loss.
-        rescale_grads (bool): Whether to rescale gradients or not, without. If False, this is just
-            a regular weighted sum of losses.
-        total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
-        emay_decay (float): EMA decay for averaging the norms when `rescale_grads` is True.
-        per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
-            when rescaling the gradients.
-        epsilon (float): Epsilon value for numerical stability.
-        monitor (bool): Whether to store additional ratio for each loss key in metrics.
-    """  # noqa: E501
-
-    def __init__(
-        self,
-        weights: tp.Dict[str, float],
-        rescale_grads: bool = True,
-        total_norm: float = 1.0,
-        ema_decay: float = 0.999,
-        per_batch_item: bool = True,
-        epsilon: float = 1e-12,
-        monitor: bool = False,
-    ):
-        self.weights = weights
-        self.per_batch_item = per_batch_item
-        self.total_norm = total_norm
-        self.averager = averager(ema_decay)
-        self.epsilon = epsilon
-        self.monitor = monitor
-        self.rescale_grads = rescale_grads
-        self._metrics: tp.Dict[str, tp.Any] = {}
-
-    @property
-    def metrics(self):
-        return self._metrics
-
-    def compute(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor):
-        norms = {}
-        grads = {}
-        for name, loss in losses.items():
-            (grad,) = autograd.grad(loss, [input], retain_graph=True)
-            if self.per_batch_item:
-                dims = tuple(range(1, grad.dim()))
-                norm = grad.norm(dim=dims).mean()
-            else:
-                norm = grad.norm()
-            norms[name] = norm
-            grads[name] = grad
-
-        count = 1
-        if self.per_batch_item:
-            count = len(grad)
-        avg_norms = average_metrics(self.averager(norms), count)
-        total = sum(avg_norms.values())
-
-        self._metrics = {}
-        if self.monitor:
-            for k, v in avg_norms.items():
-                self._metrics[f"ratio_{k}"] = v / total
-
-        total_weights = sum([self.weights[k] for k in avg_norms])
-        ratios = {k: w / total_weights for k, w in self.weights.items()}
-
-        out_grad: tp.Any = 0
-        for name, avg_norm in avg_norms.items():
-            if self.rescale_grads:
-                scale = ratios[name] * self.total_norm / (self.epsilon + avg_norm)
-                grad = grads[name] * scale
-            else:
-                grad = self.weights[name] * grads[name]
-            out_grad += grad
-
-        return out_grad
-
-
-def test():
-    from torch.nn import functional as F
-
-    x = torch.zeros(1, requires_grad=True)
-    one = torch.ones_like(x)
-    loss_1 = F.l1_loss(x, one)
-    loss_2 = 100 * F.l1_loss(x, -one)
-    losses = {"1": loss_1, "2": loss_2}
-
-    balancer = Balancer(weights={"1": 1, "2": 1}, rescale_grads=False)
-    out_grad = balancer.compute(losses, x)
-    x.backward(out_grad)
-    assert torch.allclose(x.grad, torch.tensor(99.0)), x.grad
-
-    loss_1 = F.l1_loss(x, one)
-    loss_2 = 100 * F.l1_loss(x, -one)
-    losses = {"1": loss_1, "2": loss_2}
-    x.grad = None
-    balancer = Balancer(weights={"1": 1, "2": 1}, rescale_grads=True)
-    out_grad = balancer.compute({"1": loss_1, "2": loss_2}, x)
-    x.backward(out_grad)
-    assert torch.allclose(x.grad, torch.tensor(0.0)), x.grad
-
-
-if __name__ == "__main__":
-    test()

+ 192 - 0
fish_speech/models/vqgan/modules/commons.py

@@ -0,0 +1,192 @@
+import math
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+def convert_pad_shape(pad_shape):
+    l = pad_shape[::-1]
+    pad_shape = [item for sublist in l for item in sublist]
+    return pad_shape
+
+
+def intersperse(lst, item):
+    result = [item] * (len(lst) * 2 + 1)
+    result[1::2] = lst
+    return result
+
+
+def kl_divergence(m_p, logs_p, m_q, logs_q):
+    """KL(P||Q)"""
+    kl = (logs_q - logs_p) - 0.5
+    kl += (
+        0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
+    )
+    return kl
+
+
+def rand_gumbel(shape):
+    """Sample from the Gumbel distribution, protect from overflows."""
+    uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
+    return -torch.log(-torch.log(uniform_samples))
+
+
+def rand_gumbel_like(x):
+    g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
+    return g
+
+
+def slice_segments(x, ids_str, segment_size=4):
+    ret = torch.zeros_like(x[:, :, :segment_size])
+    for i in range(x.size(0)):
+        idx_str = ids_str[i]
+        idx_end = idx_str + segment_size
+        ret[i] = x[i, :, idx_str:idx_end]
+    return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+    b, d, t = x.size()
+    if x_lengths is None:
+        x_lengths = t
+    ids_str_max = x_lengths - segment_size + 1
+    ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+    ret = slice_segments(x, ids_str, segment_size)
+    return ret, ids_str
+
+
+def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
+    position = torch.arange(length, dtype=torch.float)
+    num_timescales = channels // 2
+    log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
+        num_timescales - 1
+    )
+    inv_timescales = min_timescale * torch.exp(
+        torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
+    )
+    scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
+    signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
+    signal = F.pad(signal, [0, 0, 0, channels % 2])
+    signal = signal.view(1, channels, length)
+    return signal
+
+
+def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
+    b, channels, length = x.size()
+    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+    return x + signal.to(dtype=x.dtype, device=x.device)
+
+
+def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
+    b, channels, length = x.size()
+    signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+    return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
+
+
+def subsequent_mask(length):
+    mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+    return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+    n_channels_int = n_channels[0]
+    in_act = input_a + input_b
+    t_act = torch.tanh(in_act[:, :n_channels_int, :])
+    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+    acts = t_act * s_act
+    return acts
+
+
+def convert_pad_shape(pad_shape):
+    l = pad_shape[::-1]
+    pad_shape = [item for sublist in l for item in sublist]
+    return pad_shape
+
+
+def shift_1d(x):
+    x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+    return x
+
+
+def sequence_mask(length, max_length=None):
+    if max_length is None:
+        max_length = length.max()
+    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+    return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def generate_path(duration, mask):
+    """
+    duration: [b, 1, t_x]
+    mask: [b, 1, t_y, t_x]
+    """
+    device = duration.device
+
+    b, _, t_y, t_x = mask.shape
+    cum_duration = torch.cumsum(duration, -1)
+
+    cum_duration_flat = cum_duration.view(b * t_x)
+    path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
+    path = path.view(b, t_x, t_y)
+    path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
+    path = path.unsqueeze(1).transpose(2, 3) * mask
+    return path
+
+
+def clip_grad_value_(parameters, clip_value, norm_type=2):
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = list(filter(lambda p: p.grad is not None, parameters))
+    norm_type = float(norm_type)
+    if clip_value is not None:
+        clip_value = float(clip_value)
+
+    total_norm = 0
+    for p in parameters:
+        param_norm = p.grad.data.norm(norm_type)
+        total_norm += param_norm.item() ** norm_type
+        if clip_value is not None:
+            p.grad.data.clamp_(min=-clip_value, max=clip_value)
+    total_norm = total_norm ** (1.0 / norm_type)
+    return total_norm
+
+
+def squeeze(x, x_mask=None, n_sqz=2):
+    b, c, t = x.size()
+
+    t = (t // n_sqz) * n_sqz
+    x = x[:, :, :t]
+    x_sqz = x.view(b, c, t // n_sqz, n_sqz)
+    x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
+
+    if x_mask is not None:
+        x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
+    else:
+        x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
+    return x_sqz * x_mask, x_mask
+
+
+def unsqueeze(x, x_mask=None, n_sqz=2):
+    b, c, t = x.size()
+
+    x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
+    x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
+
+    if x_mask is not None:
+        x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
+    else:
+        x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
+    return x_unsqz * x_mask, x_mask

+ 0 - 315
fish_speech/models/vqgan/modules/convnext.py

@@ -1,315 +0,0 @@
-from functools import partial
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-from vocos.spectral_ops import ISTFT
-
-
-def drop_path(
-    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
-):
-    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
-
-    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
-    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
-    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
-    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
-    'survival rate' as the argument.
-
-    """  # noqa: E501
-
-    if drop_prob == 0.0 or not training:
-        return x
-    keep_prob = 1 - drop_prob
-    shape = (x.shape[0],) + (1,) * (
-        x.ndim - 1
-    )  # work with diff dim tensors, not just 2D ConvNets
-    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
-    if keep_prob > 0.0 and scale_by_keep:
-        random_tensor.div_(keep_prob)
-    return x * random_tensor
-
-
-class DropPath(nn.Module):
-    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""  # noqa: E501
-
-    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
-        super(DropPath, self).__init__()
-        self.drop_prob = drop_prob
-        self.scale_by_keep = scale_by_keep
-
-    def forward(self, x):
-        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
-
-    def extra_repr(self):
-        return f"drop_prob={round(self.drop_prob,3):0.3f}"
-
-
-class LayerNorm(nn.Module):
-    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
-    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
-    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
-    with shape (batch_size, channels, height, width).
-    """  # noqa: E501
-
-    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
-        super().__init__()
-        self.weight = nn.Parameter(torch.ones(normalized_shape))
-        self.bias = nn.Parameter(torch.zeros(normalized_shape))
-        self.eps = eps
-        self.data_format = data_format
-        if self.data_format not in ["channels_last", "channels_first"]:
-            raise NotImplementedError
-        self.normalized_shape = (normalized_shape,)
-
-    def forward(self, x):
-        if self.data_format == "channels_last":
-            return F.layer_norm(
-                x, self.normalized_shape, self.weight, self.bias, self.eps
-            )
-        elif self.data_format == "channels_first":
-            u = x.mean(1, keepdim=True)
-            s = (x - u).pow(2).mean(1, keepdim=True)
-            x = (x - u) / torch.sqrt(s + self.eps)
-            x = self.weight[:, None] * x + self.bias[:, None]
-            return x
-
-
-class ConvNeXtBlock(nn.Module):
-    r"""ConvNeXt Block. There are two equivalent implementations:
-    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
-    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
-    We use (2) as we find it slightly faster in PyTorch
-
-    Args:
-        dim (int): Number of input channels.
-        drop_path (float): Stochastic depth rate. Default: 0.0
-        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
-        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
-        kernel_size (int): Kernel size for depthwise conv. Default: 7.
-        dilation (int): Dilation for depthwise conv. Default: 1.
-    """  # noqa: E501
-
-    def __init__(
-        self,
-        dim: int,
-        drop_path: float = 0.0,
-        layer_scale_init_value: float = 1e-6,
-        mlp_ratio: float = 4.0,
-        kernel_size: int = 7,
-        dilation: int = 1,
-    ):
-        super().__init__()
-
-        self.dwconv = nn.Conv1d(
-            dim,
-            dim,
-            kernel_size=kernel_size,
-            padding=int(dilation * (kernel_size - 1) / 2),
-            groups=dim,
-        )  # depthwise conv
-        self.norm = LayerNorm(dim, eps=1e-6)
-        self.pwconv1 = nn.Linear(
-            dim, int(mlp_ratio * dim)
-        )  # pointwise/1x1 convs, implemented with linear layers
-        self.act = nn.GELU()
-        self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
-        self.gamma = (
-            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
-            if layer_scale_init_value > 0
-            else None
-        )
-        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
-
-    def forward(self, x, apply_residual: bool = True):
-        input = x
-
-        x = self.dwconv(x)
-        x = x.permute(0, 2, 1)  # (N, C, L) -> (N, L, C)
-        x = self.norm(x)
-        x = self.pwconv1(x)
-        x = self.act(x)
-        x = self.pwconv2(x)
-
-        if self.gamma is not None:
-            x = self.gamma * x
-
-        x = x.permute(0, 2, 1)  # (N, L, C) -> (N, C, L)
-        x = self.drop_path(x)
-
-        if apply_residual:
-            x = input + x
-
-        return x
-
-
-class ParallelConvNeXtBlock(nn.Module):
-    def __init__(self, kernel_sizes: list[int], *args, **kwargs):
-        super().__init__()
-        self.blocks = nn.ModuleList(
-            [
-                ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
-                for kernel_size in kernel_sizes
-            ]
-        )
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        return torch.stack(
-            [block(x, apply_residual=False) for block in self.blocks] + [x],
-            dim=1,
-        ).sum(dim=1)
-
-
-class ConvNeXt(nn.Module):
-    def __init__(
-        self,
-        input_channels: int = 3,
-        depths: list[int] = [3, 3, 9, 3],
-        dims: list[int] = [96, 192, 384, 768],
-        drop_path_rate: float = 0.0,
-        layer_scale_init_value: float = 1e-6,
-        kernel_sizes: tuple[int] = (7,),
-    ):
-        super().__init__()
-        assert len(depths) == len(dims)
-
-        self.channel_layers = nn.ModuleList()
-        stem = nn.Sequential(
-            nn.Conv1d(
-                input_channels,
-                dims[0],
-                kernel_size=7,
-                padding=3,
-                # padding_mode="replicate",
-                padding_mode="zeros",
-            ),
-            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
-        )
-        self.channel_layers.append(stem)
-
-        for i in range(len(depths) - 1):
-            mid_layer = nn.Sequential(
-                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
-                nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
-            )
-            self.channel_layers.append(mid_layer)
-
-        block_fn = (
-            partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
-            if len(kernel_sizes) == 1
-            else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
-        )
-
-        self.stages = nn.ModuleList()
-        drop_path_rates = [
-            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
-        ]
-
-        cur = 0
-        for i in range(len(depths)):
-            stage = nn.Sequential(
-                *[
-                    block_fn(
-                        dim=dims[i],
-                        drop_path=drop_path_rates[cur + j],
-                        layer_scale_init_value=layer_scale_init_value,
-                    )
-                    for j in range(depths[i])
-                ]
-            )
-            self.stages.append(stage)
-            cur += depths[i]
-
-        self.apply(self._init_weights)
-
-    def _init_weights(self, m):
-        if isinstance(m, (nn.Conv1d, nn.Linear)):
-            nn.init.trunc_normal_(m.weight, std=0.02)
-            nn.init.constant_(m.bias, 0)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        return_features: bool = False,
-    ) -> torch.Tensor:
-        features = []
-
-        for channel_layer, stage in zip(self.channel_layers, self.stages):
-            x = channel_layer(x)
-            x = stage(x)
-
-            if return_features:
-                features.append(x)
-
-        if return_features:
-            return features
-
-        return x
-
-
-class ISTFTHead(nn.Module):
-    """
-    ISTFT Head module for predicting STFT complex coefficients.
-
-    Args:
-        dim (int): Hidden dimension of the model.
-        n_fft (int): Size of Fourier transform.
-        hop_length (int): The distance between neighboring sliding window frames, which should align with
-                          the resolution of the input features.
-        win_length (int): The size of window frame and STFT filter.
-        padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
-    """  # noqa: E501
-
-    def __init__(
-        self,
-        dim: int,
-        n_fft: int,
-        hop_length: int,
-        win_length: int,
-        padding: str = "same",
-    ):
-        super().__init__()
-
-        self.n_fft = n_fft
-        self.hop_length = hop_length
-        self.win_length = win_length
-
-        self.istft = ISTFT(
-            n_fft=n_fft,
-            hop_length=hop_length,
-            win_length=win_length,
-            padding=padding,
-        )
-
-        out_dim = n_fft * 2
-        self.out = nn.Conv1d(dim, out_dim, 1)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        """
-        Forward pass of the ISTFTHead module.
-
-        Args:
-            x (Tensor): Input tensor of shape (B, H, L), where B is the batch size,
-                        L is the sequence length, and H denotes the model dimension.
-
-        Returns:
-            Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
-        """  # noqa: E501
-
-        x = self.out(x)
-
-        mag, p = x.chunk(2, dim=1)
-        mag = torch.exp(mag)
-        mag = torch.clip(
-            mag, max=1e2
-        )  # safeguard to prevent excessively large magnitudes
-
-        # wrapping happens here. These two lines produce real and imaginary value
-        x = torch.cos(p)
-        y = torch.sin(p)
-
-        S = mag * (x + 1j * y)
-
-        x = self.istft(S)
-        return x.unsqueeze(1)

+ 0 - 315
fish_speech/models/vqgan/modules/decoder_v2.py

@@ -1,315 +0,0 @@
-from functools import partial
-from math import prod
-from typing import Callable
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn import Conv1d
-from torch.nn.utils.parametrizations import weight_norm
-from torch.nn.utils.parametrize import remove_parametrizations
-
-
-def init_weights(m, mean=0.0, std=0.01):
-    classname = m.__class__.__name__
-    if classname.find("Conv") != -1:
-        m.weight.data.normal_(mean, std)
-
-
-def get_padding(kernel_size, dilation=1):
-    return (kernel_size * dilation - dilation) // 2
-
-
-class ResBlock(torch.nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
-        super().__init__()
-
-        self.convs1 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[2],
-                        padding=get_padding(kernel_size, dilation[2]),
-                    )
-                ),
-            ]
-        )
-        self.convs1.apply(init_weights)
-
-        self.convs2 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-            ]
-        )
-        self.convs2.apply(init_weights)
-
-    def forward(self, x):
-        for c1, c2 in zip(self.convs1, self.convs2):
-            xt = F.silu(x)
-            xt = c1(xt)
-            xt = F.silu(xt)
-            xt = c2(xt)
-            x = xt + x
-        return x
-
-    def remove_parametrizations(self):
-        for conv in self.convs1:
-            remove_parametrizations(conv)
-        for conv in self.convs2:
-            remove_parametrizations(conv)
-
-
-class ParralelBlock(nn.Module):
-    def __init__(
-        self,
-        channels: int,
-        kernel_sizes: tuple[int] = (3, 7, 11),
-        dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
-    ):
-        super().__init__()
-
-        assert len(kernel_sizes) == len(dilation_sizes)
-
-        self.blocks = nn.ModuleList()
-        for k, d in zip(kernel_sizes, dilation_sizes):
-            self.blocks.append(ResBlock(channels, k, d))
-
-    def forward(self, x):
-        xs = [block(x) for block in self.blocks]
-
-        return torch.stack(xs, dim=0).mean(dim=0)
-
-
-class HiFiGANGenerator(nn.Module):
-    def __init__(
-        self,
-        *,
-        hop_length: int = 512,
-        upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
-        upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
-        resblock_kernel_sizes: tuple[int] = (3, 7, 11),
-        resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
-        num_mels: int = 128,
-        upsample_initial_channel: int = 512,
-        use_template: bool = True,
-        pre_conv_kernel_size: int = 7,
-        post_conv_kernel_size: int = 7,
-        post_activation: Callable = partial(nn.SiLU, inplace=True),
-        checkpointing: bool = False,
-        ckpt_path: str = None,
-    ):
-        super().__init__()
-
-        assert (
-            prod(upsample_rates) == hop_length
-        ), f"hop_length must be {prod(upsample_rates)}"
-
-        self.conv_pre = weight_norm(
-            nn.Conv1d(
-                num_mels,
-                upsample_initial_channel,
-                pre_conv_kernel_size,
-                1,
-                padding=get_padding(pre_conv_kernel_size),
-            )
-        )
-
-        self.hop_length = hop_length
-        self.num_upsamples = len(upsample_rates)
-        self.num_kernels = len(resblock_kernel_sizes)
-
-        self.noise_convs = nn.ModuleList()
-        self.use_template = use_template
-        self.ups = nn.ModuleList()
-
-        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
-            c_cur = upsample_initial_channel // (2 ** (i + 1))
-            self.ups.append(
-                weight_norm(
-                    nn.ConvTranspose1d(
-                        upsample_initial_channel // (2**i),
-                        upsample_initial_channel // (2 ** (i + 1)),
-                        k,
-                        u,
-                        padding=(k - u) // 2,
-                    )
-                )
-            )
-
-            if not use_template:
-                continue
-
-            if i + 1 < len(upsample_rates):
-                stride_f0 = np.prod(upsample_rates[i + 1 :])
-                self.noise_convs.append(
-                    Conv1d(
-                        1,
-                        c_cur,
-                        kernel_size=stride_f0 * 2,
-                        stride=stride_f0,
-                        padding=stride_f0 // 2,
-                    )
-                )
-            else:
-                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
-
-        self.resblocks = nn.ModuleList()
-        for i in range(len(self.ups)):
-            ch = upsample_initial_channel // (2 ** (i + 1))
-            self.resblocks.append(
-                ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
-            )
-
-        self.activation_post = post_activation()
-        self.conv_post = weight_norm(
-            nn.Conv1d(
-                ch,
-                1,
-                post_conv_kernel_size,
-                1,
-                padding=get_padding(post_conv_kernel_size),
-            )
-        )
-        self.ups.apply(init_weights)
-        self.conv_post.apply(init_weights)
-
-        # Gradient checkpointing
-        self.checkpointing = checkpointing
-
-        if ckpt_path is not None:
-            states = torch.load(ckpt_path, map_location="cpu")
-            if "state_dict" in states:
-                states = states["state_dict"]
-            states = {
-                k.replace("generator.", ""): v
-                for k, v in states.items()
-                if k.startswith("generator")
-            }
-            self.load_state_dict(states, strict=True)
-
-    def forward(self, x, template=None):
-        if self.use_template and template is None:
-            length = x.shape[-1] * self.hop_length
-            template = (
-                torch.randn(x.shape[0], 1, length, device=x.device, dtype=x.dtype)
-                * 0.003
-            )
-
-        x = self.conv_pre(x)
-
-        for i in range(self.num_upsamples):
-            x = F.silu(x, inplace=True)
-            x = self.ups[i](x)
-
-            if self.use_template:
-                x = x + self.noise_convs[i](template)
-
-            if self.training and self.checkpointing:
-                x = torch.utils.checkpoint.checkpoint(
-                    self.resblocks[i],
-                    x,
-                    use_reentrant=False,
-                )
-            else:
-                x = self.resblocks[i](x)
-
-        x = self.activation_post(x)
-        x = self.conv_post(x)
-        x = torch.tanh(x)
-
-        return x
-
-    def remove_parametrizations(self):
-        for up in self.ups:
-            remove_parametrizations(up)
-        for block in self.resblocks:
-            block.remove_parametrizations()
-        remove_parametrizations(self.conv_pre)
-        remove_parametrizations(self.conv_post)
-
-
-if __name__ == "__main__":
-    import torchaudio
-
-    from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
-
-    spec = LogMelSpectrogram(n_mels=160)
-    audio, sr = torchaudio.load("test.wav")
-    audio = audio[None, :]
-    spec = spec(audio, sample_rate=sr)
-
-    model = HiFiGANGenerator(
-        hop_length=512,
-        upsample_rates=(8, 8, 2, 2, 2),
-        upsample_kernel_sizes=(16, 16, 4, 4, 4),
-        resblock_kernel_sizes=(3, 7, 11),
-        resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)),
-        num_mels=160,
-        upsample_initial_channel=512,
-        use_template=True,
-        pre_conv_kernel_size=7,
-        post_conv_kernel_size=7,
-        post_activation=partial(nn.SiLU, inplace=True),
-        ckpt_path="checkpoints/hifigan-base-comb-mix-lb-020/step_001200000_weights_only.ckpt",
-    )
-
-    print(model)
-
-    out = model(spec)
-    print(out.shape)
-
-    torchaudio.save("out.wav", out[0], 44100)

+ 0 - 141
fish_speech/models/vqgan/modules/encoders.py

@@ -1,141 +0,0 @@
-from math import log2
-from typing import Optional
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import rearrange
-from vector_quantize_pytorch import LFQ, GroupedResidualVQ, VectorQuantize
-
-
-class VQEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int = 1024,
-        vq_channels: int = 1024,
-        codebook_size: int = 2048,
-        downsample: int = 1,
-        codebook_groups: int = 1,
-        codebook_layers: int = 1,
-        threshold_ema_dead_code: int = 2,
-    ):
-        super().__init__()
-
-        if codebook_groups > 1 or codebook_layers > 1:
-            self.vq = GroupedResidualVQ(
-                dim=vq_channels,
-                codebook_size=codebook_size,
-                threshold_ema_dead_code=threshold_ema_dead_code,
-                kmeans_init=True,
-                groups=codebook_groups,
-                num_quantizers=codebook_layers,
-            )
-        else:
-            self.vq = VectorQuantize(
-                dim=vq_channels,
-                codebook_size=codebook_size,
-                threshold_ema_dead_code=threshold_ema_dead_code,
-                kmeans_init=True,
-            )
-
-        self.codebook_groups = codebook_groups
-        self.codebook_layers = codebook_layers
-        self.downsample = downsample
-        self.conv_in = nn.Conv1d(
-            in_channels, vq_channels, kernel_size=downsample, stride=downsample
-        )
-        self.conv_out = nn.Sequential(
-            nn.Upsample(scale_factor=downsample, mode="nearest")
-            if downsample > 1
-            else nn.Identity(),
-            nn.Conv1d(vq_channels, in_channels, kernel_size=1, stride=1),
-        )
-
-    @property
-    def mode(self):
-        if self.codebook_groups > 1 and self.codebook_layers > 1:
-            return "grouped-residual"
-        elif self.codebook_groups > 1:
-            return "grouped"
-        elif self.codebook_layers > 1:
-            return "residual"
-        else:
-            return "single"
-
-    def forward(self, x, x_mask):
-        # x: [B, C, T], x_mask: [B, 1, T]
-        x_len = x.shape[2]
-
-        if x_len % self.downsample != 0:
-            x = F.pad(x, (0, self.downsample - x_len % self.downsample))
-            x_mask = F.pad(x_mask, (0, self.downsample - x_len % self.downsample))
-
-        x = self.conv_in(x)
-        q, indices, loss = self.vq(x.mT)
-        q = q.mT
-
-        if self.codebook_groups > 1:
-            loss = loss.mean()
-
-        x = self.conv_out(q) * x_mask
-        x = x[:, :, :x_len]
-
-        # Post process indices
-        if self.mode == "grouped-residual":
-            indices = rearrange(indices, "g b t r -> b (g r) t")
-        elif self.mode == "grouped":
-            indices = rearrange(indices, "g b t 1 -> b g t")
-        elif self.mode == "residual":
-            indices = rearrange(indices, "1 b t r -> b r t")
-        else:
-            indices = rearrange(indices, "b t -> b 1 t")
-
-        return x, indices, loss
-
-    def decode(self, indices):
-        # Undo rearrange
-        if self.mode == "grouped-residual":
-            indices = rearrange(indices, "b (g r) t -> g b t r", g=self.codebook_groups)
-        elif self.mode == "grouped":
-            indices = rearrange(indices, "b g t -> g b t 1")
-        elif self.mode == "residual":
-            indices = rearrange(indices, "b r t -> 1 b t r")
-        else:
-            indices = rearrange(indices, "b 1 t -> b t")
-
-        q = self.vq.get_output_from_indices(indices)
-
-        # Edge case for single vq
-        if self.mode == "single":
-            q = rearrange(q, "b (t c) -> b t c", t=indices.shape[-1])
-
-        x = self.conv_out(q.mT)
-
-        return x
-
-
-if __name__ == "__main__":
-    # Test VQEncoder
-    for group, layer in [
-        (1, 1),
-        (1, 2),
-        (2, 1),
-        (2, 2),
-        (4, 1),
-        (4, 2),
-    ]:
-        encoder = VQEncoder(
-            in_channels=1024,
-            vq_channels=1024,
-            codebook_size=2048,
-            downsample=1,
-            codebook_groups=group,
-            codebook_layers=layer,
-            threshold_ema_dead_code=2,
-        )
-        x = torch.randn(2, 1024, 100)
-        x_mask = torch.ones(2, 1, 100)
-        x, indices, loss = encoder(x, x_mask)
-        x = encoder.decode(indices)
-        assert x.shape == (2, 1024, 100)

+ 668 - 0
fish_speech/models/vqgan/modules/models.py

@@ -0,0 +1,668 @@
+import copy
+import math
+
+import torch
+from torch import nn
+from torch.cuda.amp import autocast
+from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
+from torch.nn import functional as F
+from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
+
+from fish_speech.models.vqgan.modules import attentions, commons, modules
+from fish_speech.models.vqgan.modules.commons import get_padding, init_weights
+from fish_speech.models.vqgan.modules.rvq import DownsampleResidualVectorQuantizer
+
+
+class FeatureEncoder(nn.Module):
+    def __init__(
+        self,
+        spec_channels,
+        out_channels,
+        hidden_channels,
+        filter_channels,
+        n_heads,
+        n_layers,
+        kernel_size,
+        p_dropout,
+        codebook_size=1024,
+        num_codebooks=2,
+        gin_channels=0,
+    ):
+        super().__init__()
+        self.out_channels = out_channels
+        self.hidden_channels = hidden_channels
+        self.filter_channels = filter_channels
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+
+        self.spec_proj = nn.Conv1d(spec_channels, hidden_channels, 1)
+
+        self.encoder = attentions.Encoder(
+            hidden_channels,
+            filter_channels,
+            n_heads,
+            n_layers // 2,
+            kernel_size,
+            p_dropout,
+        )
+
+        self.vq = DownsampleResidualVectorQuantizer(
+            input_dim=hidden_channels,
+            n_codebooks=num_codebooks,
+            codebook_size=codebook_size,
+            min_quantizers=num_codebooks,
+            downsample_factor=(2,),
+        )
+
+        self.decoder = attentions.Encoder(
+            hidden_channels,
+            filter_channels,
+            n_heads,
+            n_layers // 2,
+            kernel_size,
+            p_dropout,
+            isflow=True,
+            gin_channels=gin_channels,
+        )
+
+        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+    def forward(self, y, y_lengths, ge):
+        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
+            y.dtype
+        )
+
+        y = self.spec_proj(y * y_mask) * y_mask
+        y = self.encoder(y * y_mask, y_mask)
+        quantized = self.vq(y)
+        y = self.decoder(quantized.z * y_mask, y_mask, g=ge)
+
+        stats = self.proj(y) * y_mask
+        m, logs = torch.split(stats, self.out_channels, dim=1)
+        return y, m, logs, y_mask, quantized
+
+
+class ResidualCouplingBlock(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        n_flows=4,
+        gin_channels=0,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.n_flows = n_flows
+        self.gin_channels = gin_channels
+
+        self.flows = nn.ModuleList()
+        for i in range(n_flows):
+            self.flows.append(
+                modules.ResidualCouplingLayer(
+                    channels,
+                    hidden_channels,
+                    kernel_size,
+                    dilation_rate,
+                    n_layers,
+                    gin_channels=gin_channels,
+                    mean_only=True,
+                )
+            )
+            self.flows.append(modules.Flip())
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        if not reverse:
+            for flow in self.flows:
+                x, _ = flow(x, x_mask, g=g, reverse=reverse)
+        else:
+            for flow in reversed(self.flows):
+                x = flow(x, x_mask, g=g, reverse=reverse)
+        return x
+
+
+class PosteriorEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        gin_channels=0,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.gin_channels = gin_channels
+
+        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
+        self.enc = modules.WN(
+            hidden_channels,
+            kernel_size,
+            dilation_rate,
+            n_layers,
+            gin_channels=gin_channels,
+        )
+        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
+
+    def forward(self, x, x_lengths, g=None):
+        g = g.detach()
+        x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
+            x.dtype
+        )
+        x = self.pre(x) * x_mask
+        x = self.enc(x, x_mask, g=g)
+        stats = self.proj(x) * x_mask
+        m, logs = torch.split(stats, self.out_channels, dim=1)
+        z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
+        return z, m, logs, x_mask
+
+
+class WNEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        gin_channels=0,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.gin_channels = gin_channels
+
+        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
+        self.enc = modules.WN(
+            hidden_channels,
+            kernel_size,
+            dilation_rate,
+            n_layers,
+            gin_channels=gin_channels,
+        )
+        self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+        self.norm = modules.LayerNorm(out_channels)
+
+    def forward(self, x, x_lengths, g=None):
+        x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
+            x.dtype
+        )
+        x = self.pre(x) * x_mask
+        x = self.enc(x, x_mask, g=g)
+        out = self.proj(x) * x_mask
+        out = self.norm(out)
+        return out
+
+
+class Generator(torch.nn.Module):
+    def __init__(
+        self,
+        initial_channel,
+        resblock,
+        resblock_kernel_sizes,
+        resblock_dilation_sizes,
+        upsample_rates,
+        upsample_initial_channel,
+        upsample_kernel_sizes,
+        gin_channels=0,
+    ):
+        super(Generator, self).__init__()
+        self.num_kernels = len(resblock_kernel_sizes)
+        self.num_upsamples = len(upsample_rates)
+        self.conv_pre = Conv1d(
+            initial_channel, upsample_initial_channel, 7, 1, padding=3
+        )
+        resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            self.ups.append(
+                weight_norm(
+                    ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            for j, (k, d) in enumerate(
+                zip(resblock_kernel_sizes, resblock_dilation_sizes)
+            ):
+                self.resblocks.append(resblock(ch, k, d))
+
+        self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+        self.ups.apply(init_weights)
+
+        if gin_channels != 0:
+            self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
+
+    def forward(self, x, g=None):
+        x = self.conv_pre(x)
+        if g is not None:
+            x = x + self.cond(g)
+
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, modules.LRELU_SLOPE)
+            x = self.ups[i](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        print("Removing weight norm...")
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+
+
+class DiscriminatorP(torch.nn.Module):
+    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+        super(DiscriminatorP, self).__init__()
+        self.period = period
+        self.use_spectral_norm = use_spectral_norm
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(
+                    Conv2d(
+                        1,
+                        32,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        32,
+                        128,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        128,
+                        512,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        512,
+                        1024,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        1024,
+                        1024,
+                        (kernel_size, 1),
+                        1,
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+            ]
+        )
+        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+    def forward(self, x):
+        fmap = []
+
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), "reflect")
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, modules.LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class DiscriminatorS(torch.nn.Module):
+    def __init__(self, use_spectral_norm=False):
+        super(DiscriminatorS, self).__init__()
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(Conv1d(1, 16, 15, 1, padding=7)),
+                norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
+                norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
+                norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
+                norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+                norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+            ]
+        )
+        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+    def forward(self, x):
+        fmap = []
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, modules.LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class EnsembledDiscriminator(torch.nn.Module):
+    def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
+        super(EnsembledDiscriminator, self).__init__()
+        discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
+        discs = discs + [
+            DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
+        ]
+        self.discriminators = nn.ModuleList(discs)
+
+    def forward(self, y, y_hat):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+        for i, d in enumerate(self.discriminators):
+            y_d_r, fmap_r = d(y)
+            y_d_g, fmap_g = d(y_hat)
+            y_d_rs.append(y_d_r)
+            y_d_gs.append(y_d_g)
+            fmap_rs.append(fmap_r)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class SynthesizerTrn(nn.Module):
+    """
+    Synthesizer for Training
+    """
+
+    def __init__(
+        self,
+        *,
+        spec_channels,
+        segment_size,
+        inter_channels,
+        hidden_channels,
+        filter_channels,
+        n_heads,
+        n_layers,
+        kernel_size,
+        p_dropout,
+        resblock,
+        resblock_kernel_sizes,
+        resblock_dilation_sizes,
+        upsample_rates,
+        upsample_initial_channel,
+        upsample_kernel_sizes,
+        gin_channels=0,
+        freeze_quantizer=False,
+        codebook_size=1024,
+        num_codebooks=2,
+    ):
+        super().__init__()
+        self.spec_channels = spec_channels
+        self.inter_channels = inter_channels
+        self.hidden_channels = hidden_channels
+        self.filter_channels = filter_channels
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+        self.resblock = resblock
+        self.resblock_kernel_sizes = resblock_kernel_sizes
+        self.resblock_dilation_sizes = resblock_dilation_sizes
+        self.upsample_rates = upsample_rates
+        self.upsample_initial_channel = upsample_initial_channel
+        self.upsample_kernel_sizes = upsample_kernel_sizes
+        self.segment_size = segment_size
+        self.gin_channels = gin_channels
+
+        self.enc_p = FeatureEncoder(
+            spec_channels,
+            inter_channels,
+            hidden_channels,
+            filter_channels,
+            n_heads,
+            n_layers,
+            kernel_size,
+            p_dropout,
+            codebook_size=codebook_size,
+            num_codebooks=num_codebooks,
+            gin_channels=gin_channels,
+        )
+        self.dec = Generator(
+            inter_channels,
+            resblock,
+            resblock_kernel_sizes,
+            resblock_dilation_sizes,
+            upsample_rates,
+            upsample_initial_channel,
+            upsample_kernel_sizes,
+            gin_channels=gin_channels,
+        )
+        self.enc_q = PosteriorEncoder(
+            spec_channels,
+            inter_channels,
+            hidden_channels,
+            5,
+            1,
+            16,
+            gin_channels=gin_channels,
+        )
+        self.flow = ResidualCouplingBlock(
+            inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
+        )
+
+        self.ref_enc = modules.MelStyleEncoder(
+            spec_channels, style_vector_dim=gin_channels
+        )
+
+        if freeze_quantizer:
+            self.enc_p.spec_proj.requires_grad_(False)
+            self.enc_p.encoder.requires_grad_(False)
+            self.enc_p.vq.requires_grad_(False)
+
+    def forward(self, y, y_lengths):
+        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
+            y.dtype
+        )
+        ge = self.ref_enc(y * y_mask, y_mask)
+
+        x, m_p, logs_p, y_mask, quantized = self.enc_p(y, y_lengths, ge)
+        z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
+        z_p = self.flow(z, y_mask, g=ge)
+
+        z_slice, ids_slice = commons.rand_slice_segments(
+            z, y_lengths, self.segment_size
+        )
+        o = self.dec(z_slice, g=ge)
+
+        return (
+            o,
+            ids_slice,
+            y_mask,
+            y_mask,
+            (z, z_p, m_p, logs_p, m_q, logs_q),
+            quantized,
+        )
+
+    def infer(self, y, y_lengths, noise_scale=0.5):
+        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
+            y.dtype
+        )
+        ge = self.ref_enc(y * y_mask, y_mask)
+        x, m_p, logs_p, y_mask, quantized = self.enc_p(y, y_lengths, ge)
+        z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
+
+        z = self.flow(z_p, y_mask, g=ge, reverse=True)
+
+        o = self.dec((z * y_mask)[:, :, :], g=ge)
+        return o, y_mask, (z, z_p, m_p, logs_p)
+
+    def infer_posterior(self, y, y_lengths):
+        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
+            y.dtype
+        )
+        ge = self.ref_enc(y * y_mask, y_mask)
+        z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
+        o = self.dec(z * y_mask, g=ge)
+        return o, y_mask, (z, m_q, logs_q)
+
+    # @torch.no_grad()
+    # def decode(self, codes, text, refer, noise_scale=0.5):
+    #     refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
+    #     refer_mask = torch.unsqueeze(
+    #         commons.sequence_mask(refer_lengths, refer.size(2)), 1
+    #     ).to(refer.dtype)
+    #     ge = self.ref_enc(refer * refer_mask, refer_mask)
+
+    #     y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
+    #     text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
+
+    #     quantized = self.quantizer.decode(codes)
+    #     if self.semantic_frame_rate == "25hz":
+    #         quantized = F.interpolate(
+    #             quantized, size=int(quantized.shape[-1] * 2), mode="nearest"
+    #         )
+
+    #     x, m_p, logs_p, y_mask = self.enc_p(
+    #         quantized, y_lengths, text, text_lengths, ge
+    #     )
+    #     z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
+
+    #     z = self.flow(z_p, y_mask, g=ge, reverse=True)
+
+    #     o = self.dec((z * y_mask)[:, :, :], g=ge)
+    #     return o
+
+    # def extract_latent(self, x):
+    #     ssl = self.ssl_proj(x)
+    #     quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
+    #     return codes.transpose(0, 1)
+
+
+if __name__ == "__main__":
+    model = SynthesizerTrn(
+        spec_channels=1025,
+        segment_size=20480,
+        inter_channels=192,
+        hidden_channels=192,
+        filter_channels=768,
+        n_heads=2,
+        n_layers=6,
+        kernel_size=3,
+        p_dropout=0.1,
+        resblock="1",
+        resblock_kernel_sizes=[3, 7, 11],
+        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+        upsample_rates=[10, 8, 2, 2, 2],
+        upsample_initial_channel=512,
+        upsample_kernel_sizes=[16, 16, 8, 2, 2],
+        gin_channels=512,
+        freeze_quantizer=True,
+    )
+
+    state_dict_g = torch.load("checkpoints/gpt_sovits_g_488k.pth", map_location="cpu")
+    # state_dict_d = torch.load("checkpoints/gpt_sovits_d_488k.pth", map_location="cpu")
+    # keys = set(model.state_dict().keys())
+    # state_dict_g = {k.replace("encoder2.", "decoder."): v for k, v in state_dict_g.items() if k in keys}
+
+    # new_state = {}
+    # for k, v in state_dict_g.items():
+    #     new_state["generator." + k] = v
+
+    # for k, v in state_dict_d.items():
+    #     new_state["discriminator." + k] = v
+
+    # torch.save(new_state, "checkpoints/gpt_sovits_488k.pth")
+
+    # print(EnsembledDiscriminator().load_state_dict(state_dict_d, strict=False))
+    print(model.load_state_dict(state_dict_g, strict=False))
+
+    # y = torch.randn(3, 1025, 20480)
+    # y_lengths = torch.tensor([20480, 19000, 18000])
+
+    import librosa
+    import soundfile as sf
+
+    from fish_speech.models.vqgan.spectrogram import LinearSpectrogram
+
+    spec = LinearSpectrogram(
+        n_fft=2048, win_length=2048, hop_length=640, mode="pow2_sqrt"
+    )
+
+    audio, _ = librosa.load(
+        "/***REMOVED***/workspace/llm-multimodal-test/data/Rail_ZH/星/dbc16cc114ca1700.wav",
+        sr=32000,
+    )
+
+    y = spec(torch.tensor(audio).unsqueeze(0))
+    y_lengths = torch.tensor([y.size(2)])
+
+    o, ids_slice, y_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q), quantized = model(
+        y, y_lengths
+    )
+    print(o.shape)
+
+    o, y_mask, (z, z_p, m_p, logs_p) = model.infer(y, y_lengths)
+    print(o.shape)
+
+    o, y_mask, (z, m_q, logs_q) = model.infer_posterior(y, y_lengths)
+    print(o.shape)
+
+    o = o.squeeze(0).T.detach().cpu().numpy()
+    sf.write("test.wav", o, 32000)

+ 616 - 44
fish_speech/models/vqgan/modules/modules.py

@@ -1,100 +1,672 @@
+import numpy as np
 import torch
-import torch.nn as nn
-from torch.nn.utils.parametrizations import weight_norm
-from torch.nn.utils.parametrize import remove_parametrizations
+from torch import nn
+from torch.nn import Conv1d
+from torch.nn import functional as F
+from torch.nn.utils import remove_weight_norm, weight_norm
 
-from fish_speech.models.vqgan.utils import fused_add_tanh_sigmoid_multiply
+from fish_speech.models.vqgan.modules.commons import (
+    fused_add_tanh_sigmoid_multiply,
+    get_padding,
+    init_weights,
+)
 
 LRELU_SLOPE = 0.1
 
 
-# ! PosteriorEncoder
-# ! ResidualCouplingLayer
-class WaveNet(nn.Module):
+class LayerNorm(nn.Module):
+    def __init__(self, channels, eps=1e-5):
+        super().__init__()
+        self.channels = channels
+        self.eps = eps
+
+        self.gamma = nn.Parameter(torch.ones(channels))
+        self.beta = nn.Parameter(torch.zeros(channels))
+
+    def forward(self, x):
+        x = x.transpose(1, -1)
+        x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+        return x.transpose(1, -1)
+
+
+class ConvReluNorm(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        hidden_channels,
+        out_channels,
+        kernel_size,
+        n_layers,
+        p_dropout,
+    ):
+        super().__init__()
+        self.in_channels = in_channels
+        self.hidden_channels = hidden_channels
+        self.out_channels = out_channels
+        self.kernel_size = kernel_size
+        self.n_layers = n_layers
+        self.p_dropout = p_dropout
+        assert n_layers > 1, "Number of layers should be larger than 0."
+
+        self.conv_layers = nn.ModuleList()
+        self.norm_layers = nn.ModuleList()
+        self.conv_layers.append(
+            nn.Conv1d(
+                in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
+            )
+        )
+        self.norm_layers.append(LayerNorm(hidden_channels))
+        self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
+        for _ in range(n_layers - 1):
+            self.conv_layers.append(
+                nn.Conv1d(
+                    hidden_channels,
+                    hidden_channels,
+                    kernel_size,
+                    padding=kernel_size // 2,
+                )
+            )
+            self.norm_layers.append(LayerNorm(hidden_channels))
+        self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+        self.proj.weight.data.zero_()
+        self.proj.bias.data.zero_()
+
+    def forward(self, x, x_mask):
+        x_org = x
+        for i in range(self.n_layers):
+            x = self.conv_layers[i](x * x_mask)
+            x = self.norm_layers[i](x)
+            x = self.relu_drop(x)
+        x = x_org + self.proj(x)
+        return x * x_mask
+
+
+class DDSConv(nn.Module):
+    """
+    Dialted and Depth-Separable Convolution
+    """
+
+    def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
+        super().__init__()
+        self.channels = channels
+        self.kernel_size = kernel_size
+        self.n_layers = n_layers
+        self.p_dropout = p_dropout
+
+        self.drop = nn.Dropout(p_dropout)
+        self.convs_sep = nn.ModuleList()
+        self.convs_1x1 = nn.ModuleList()
+        self.norms_1 = nn.ModuleList()
+        self.norms_2 = nn.ModuleList()
+        for i in range(n_layers):
+            dilation = kernel_size**i
+            padding = (kernel_size * dilation - dilation) // 2
+            self.convs_sep.append(
+                nn.Conv1d(
+                    channels,
+                    channels,
+                    kernel_size,
+                    groups=channels,
+                    dilation=dilation,
+                    padding=padding,
+                )
+            )
+            self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
+            self.norms_1.append(LayerNorm(channels))
+            self.norms_2.append(LayerNorm(channels))
+
+    def forward(self, x, x_mask, g=None):
+        if g is not None:
+            x = x + g
+        for i in range(self.n_layers):
+            y = self.convs_sep[i](x * x_mask)
+            y = self.norms_1[i](y)
+            y = F.gelu(y)
+            y = self.convs_1x1[i](y)
+            y = self.norms_2[i](y)
+            y = F.gelu(y)
+            y = self.drop(y)
+            x = x + y
+        return x * x_mask
+
+
+class WN(torch.nn.Module):
     def __init__(
         self,
         hidden_channels,
         kernel_size,
         dilation_rate,
         n_layers,
+        gin_channels=0,
         p_dropout=0,
-        out_channels=None,
-        in_channels=None,
     ):
-        super(WaveNet, self).__init__()
+        super(WN, self).__init__()
         assert kernel_size % 2 == 1
         self.hidden_channels = hidden_channels
         self.kernel_size = (kernel_size,)
+        self.dilation_rate = dilation_rate
         self.n_layers = n_layers
+        self.gin_channels = gin_channels
+        self.p_dropout = p_dropout
 
-        self.in_layers = nn.ModuleList()
-        self.res_skip_layers = nn.ModuleList()
+        self.in_layers = torch.nn.ModuleList()
+        self.res_skip_layers = torch.nn.ModuleList()
         self.drop = nn.Dropout(p_dropout)
 
-        self.in_channels = in_channels
-        if in_channels is not None:
-            self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
+        if gin_channels != 0:
+            cond_layer = torch.nn.Conv1d(
+                gin_channels, 2 * hidden_channels * n_layers, 1
+            )
+            self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
 
         for i in range(n_layers):
             dilation = dilation_rate**i
             padding = int((kernel_size * dilation - dilation) / 2)
-            in_layer = nn.Conv1d(
+            in_layer = torch.nn.Conv1d(
                 hidden_channels,
                 2 * hidden_channels,
                 kernel_size,
                 dilation=dilation,
                 padding=padding,
             )
-            in_layer = weight_norm(in_layer, name="weight")
+            in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
             self.in_layers.append(in_layer)
 
             # last one is not necessary
-            res_skip_channels = (
-                2 * hidden_channels if i < n_layers - 1 else hidden_channels
-            )
-            res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1)
-            res_skip_layer = weight_norm(res_skip_layer, name="weight")
-            self.res_skip_layers.append(res_skip_layer)
+            if i < n_layers - 1:
+                res_skip_channels = 2 * hidden_channels
+            else:
+                res_skip_channels = hidden_channels
 
-        self.out_channels = out_channels
-        if out_channels is not None:
-            self.out_layer = nn.Conv1d(hidden_channels, out_channels, 1)
+            res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
+            res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
+            self.res_skip_layers.append(res_skip_layer)
 
-    def forward(self, x, x_mask=None):
+    def forward(self, x, x_mask, g=None, **kwargs):
+        output = torch.zeros_like(x)
         n_channels_tensor = torch.IntTensor([self.hidden_channels])
 
-        if self.in_channels is not None:
-            x = self.proj_in(x)
-
-        output = torch.zeros_like(x)
+        if g is not None:
+            g = self.cond_layer(g)
 
         for i in range(self.n_layers):
             x_in = self.in_layers[i](x)
-            acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
+            if g is not None:
+                cond_offset = i * 2 * self.hidden_channels
+                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
+            else:
+                g_l = torch.zeros_like(x_in)
+
+            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
             acts = self.drop(acts)
 
             res_skip_acts = self.res_skip_layers[i](acts)
             if i < self.n_layers - 1:
                 res_acts = res_skip_acts[:, : self.hidden_channels, :]
-                x = x + res_acts
-                if x_mask is not None:
-                    x = x * x_mask
+                x = (x + res_acts) * x_mask
                 output = output + res_skip_acts[:, self.hidden_channels :, :]
             else:
                 output = output + res_skip_acts
+        return output * x_mask
 
+    def remove_weight_norm(self):
+        if self.gin_channels != 0:
+            torch.nn.utils.remove_weight_norm(self.cond_layer)
+        for l in self.in_layers:
+            torch.nn.utils.remove_weight_norm(l)
+        for l in self.res_skip_layers:
+            torch.nn.utils.remove_weight_norm(l)
+
+
+class ResBlock1(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock1, self).__init__()
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x, x_mask=None):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c2(xt)
+            x = xt + x
         if x_mask is not None:
-            x = output * x_mask
+            x = x * x_mask
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
 
-        if self.out_channels is not None:
-            x = self.out_layer(x)
 
+class ResBlock2(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+        super(ResBlock2, self).__init__()
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+            ]
+        )
+        self.convs.apply(init_weights)
+
+    def forward(self, x, x_mask=None):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c(xt)
+            x = xt + x
+        if x_mask is not None:
+            x = x * x_mask
         return x
 
     def remove_weight_norm(self):
-        if self.gin_channels != 0:
-            remove_parametrizations(self.cond_layer)
-        for l in self.in_layers:
-            remove_parametrizations(l)
-        for l in self.res_skip_layers:
-            remove_parametrizations(l)
+        for l in self.convs:
+            remove_weight_norm(l)
+
+
+class Flip(nn.Module):
+    def forward(self, x, *args, reverse=False, **kwargs):
+        x = torch.flip(x, [1])
+        if not reverse:
+            logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+            return x, logdet
+        else:
+            return x
+
+
+class ResidualCouplingLayer(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        p_dropout=0,
+        gin_channels=0,
+        mean_only=False,
+    ):
+        assert channels % 2 == 0, "channels should be divisible by 2"
+        super().__init__()
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.half_channels = channels // 2
+        self.mean_only = mean_only
+
+        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+        self.enc = WN(
+            hidden_channels,
+            kernel_size,
+            dilation_rate,
+            n_layers,
+            p_dropout=p_dropout,
+            gin_channels=gin_channels,
+        )
+        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+        self.post.weight.data.zero_()
+        self.post.bias.data.zero_()
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+        h = self.pre(x0) * x_mask
+        h = self.enc(h, x_mask, g=g)
+        stats = self.post(h) * x_mask
+        if not self.mean_only:
+            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+        else:
+            m = stats
+            logs = torch.zeros_like(m)
+
+        if not reverse:
+            x1 = m + x1 * torch.exp(logs) * x_mask
+            x = torch.cat([x0, x1], 1)
+            logdet = torch.sum(logs, [1, 2])
+            return x, logdet
+        else:
+            x1 = (x1 - m) * torch.exp(-logs) * x_mask
+            x = torch.cat([x0, x1], 1)
+            return x
+
+
+class LinearNorm(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        bias=True,
+        spectral_norm=False,
+    ):
+        super(LinearNorm, self).__init__()
+        self.fc = nn.Linear(in_channels, out_channels, bias)
+
+        if spectral_norm:
+            self.fc = nn.utils.spectral_norm(self.fc)
+
+    def forward(self, input):
+        out = self.fc(input)
+        return out
+
+
+class Mish(nn.Module):
+    def __init__(self):
+        super(Mish, self).__init__()
+
+    def forward(self, x):
+        return x * torch.tanh(F.softplus(x))
+
+
+class Conv1dGLU(nn.Module):
+    """
+    Conv1d + GLU(Gated Linear Unit) with residual connection.
+    For GLU refer to https://arxiv.org/abs/1612.08083 paper.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size, dropout):
+        super(Conv1dGLU, self).__init__()
+        self.out_channels = out_channels
+        self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, x):
+        residual = x
+        x = self.conv1(x)
+        x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
+        x = x1 * torch.sigmoid(x2)
+        x = residual + self.dropout(x)
+        return x
+
+
+class ConvNorm(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size=1,
+        stride=1,
+        padding=None,
+        dilation=1,
+        bias=True,
+        spectral_norm=False,
+    ):
+        super(ConvNorm, self).__init__()
+
+        if padding is None:
+            assert kernel_size % 2 == 1
+            padding = int(dilation * (kernel_size - 1) / 2)
+
+        self.conv = torch.nn.Conv1d(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            bias=bias,
+        )
+
+        if spectral_norm:
+            self.conv = nn.utils.spectral_norm(self.conv)
+
+    def forward(self, input):
+        out = self.conv(input)
+        return out
+
+
+class MultiHeadAttention(nn.Module):
+    """Multi-Head Attention module"""
+
+    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
+        super().__init__()
+
+        self.n_head = n_head
+        self.d_k = d_k
+        self.d_v = d_v
+
+        self.w_qs = nn.Linear(d_model, n_head * d_k)
+        self.w_ks = nn.Linear(d_model, n_head * d_k)
+        self.w_vs = nn.Linear(d_model, n_head * d_v)
+
+        self.attention = ScaledDotProductAttention(
+            temperature=np.power(d_model, 0.5), dropout=dropout
+        )
+
+        self.fc = nn.Linear(n_head * d_v, d_model)
+        self.dropout = nn.Dropout(dropout)
+
+        if spectral_norm:
+            self.w_qs = nn.utils.spectral_norm(self.w_qs)
+            self.w_ks = nn.utils.spectral_norm(self.w_ks)
+            self.w_vs = nn.utils.spectral_norm(self.w_vs)
+            self.fc = nn.utils.spectral_norm(self.fc)
+
+    def forward(self, x, mask=None):
+        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+        sz_b, len_x, _ = x.size()
+
+        residual = x
+
+        q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
+        k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
+        v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
+        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k)  # (n*b) x lq x dk
+        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k)  # (n*b) x lk x dk
+        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v)  # (n*b) x lv x dv
+
+        if mask is not None:
+            slf_mask = mask.repeat(n_head, 1, 1)  # (n*b) x .. x ..
+        else:
+            slf_mask = None
+        output, attn = self.attention(q, k, v, mask=slf_mask)
+
+        output = output.view(n_head, sz_b, len_x, d_v)
+        output = (
+            output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
+        )  # b x lq x (n*dv)
+
+        output = self.fc(output)
+
+        output = self.dropout(output) + residual
+        return output, attn
+
+
+class ScaledDotProductAttention(nn.Module):
+    """Scaled Dot-Product Attention"""
+
+    def __init__(self, temperature, dropout):
+        super().__init__()
+        self.temperature = temperature
+        self.softmax = nn.Softmax(dim=2)
+        self.dropout = nn.Dropout(dropout)
+
+    def forward(self, q, k, v, mask=None):
+        attn = torch.bmm(q, k.transpose(1, 2))
+        attn = attn / self.temperature
+
+        if mask is not None:
+            attn = attn.masked_fill(mask, -np.inf)
+
+        attn = self.softmax(attn)
+        p_attn = self.dropout(attn)
+
+        output = torch.bmm(p_attn, v)
+        return output, attn
+
+
+class MelStyleEncoder(nn.Module):
+    """MelStyleEncoder"""
+
+    def __init__(
+        self,
+        n_mel_channels=80,
+        style_hidden=128,
+        style_vector_dim=256,
+        style_kernel_size=5,
+        style_head=2,
+        dropout=0.1,
+    ):
+        super(MelStyleEncoder, self).__init__()
+        self.in_dim = n_mel_channels
+        self.hidden_dim = style_hidden
+        self.out_dim = style_vector_dim
+        self.kernel_size = style_kernel_size
+        self.n_head = style_head
+        self.dropout = dropout
+
+        self.spectral = nn.Sequential(
+            LinearNorm(self.in_dim, self.hidden_dim),
+            Mish(),
+            nn.Dropout(self.dropout),
+            LinearNorm(self.hidden_dim, self.hidden_dim),
+            Mish(),
+            nn.Dropout(self.dropout),
+        )
+
+        self.temporal = nn.Sequential(
+            Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
+            Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
+        )
+
+        self.slf_attn = MultiHeadAttention(
+            self.n_head,
+            self.hidden_dim,
+            self.hidden_dim // self.n_head,
+            self.hidden_dim // self.n_head,
+            self.dropout,
+        )
+
+        self.fc = LinearNorm(self.hidden_dim, self.out_dim)
+
+    def temporal_avg_pool(self, x, mask=None):
+        if mask is None:
+            out = torch.mean(x, dim=1)
+        else:
+            len_ = (~mask).sum(dim=1).unsqueeze(1)
+            x = x.masked_fill(mask.unsqueeze(-1), 0)
+            x = x.sum(dim=1)
+            out = torch.div(x, len_)
+        return out
+
+    def forward(self, x, mask=None):
+        x = x.transpose(1, 2)
+        if mask is not None:
+            mask = (mask.int() == 0).squeeze(1)
+        max_len = x.shape[1]
+        slf_attn_mask = (
+            mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
+        )
+
+        # spectral
+        x = self.spectral(x)
+        # temporal
+        x = x.transpose(1, 2)
+        x = self.temporal(x)
+        x = x.transpose(1, 2)
+        # self-attention
+        if mask is not None:
+            x = x.masked_fill(mask.unsqueeze(-1), 0)
+        x, _ = self.slf_attn(x, mask=slf_attn_mask)
+        # fc
+        x = self.fc(x)
+        # temoral average pooling
+        w = self.temporal_avg_pool(x, mask=mask)
+
+        return w.unsqueeze(-1)

+ 358 - 0
fish_speech/models/vqgan/modules/rvq.py

@@ -0,0 +1,358 @@
+from dataclasses import dataclass
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+
+
+class VectorQuantize(nn.Module):
+    """
+    Implementation of VQ similar to Karpathy's repo:
+    https://github.com/karpathy/deep-vector-quantization
+    Additionally uses following tricks from Improved VQGAN
+    (https://arxiv.org/pdf/2110.04627.pdf):
+        1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
+            for improved codebook usage
+        2. l2-normalized codes: Converts euclidean distance to cosine similarity which
+            improves training stability
+    """
+
+    def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
+        super().__init__()
+        self.codebook_size = codebook_size
+        self.codebook_dim = codebook_dim
+
+        self.in_proj = weight_norm(nn.Conv1d(input_dim, codebook_dim, kernel_size=1))
+        self.out_proj = weight_norm(nn.Conv1d(codebook_dim, input_dim, kernel_size=1))
+        self.codebook = nn.Embedding(codebook_size, codebook_dim)
+
+    def forward(self, z):
+        """Quantized the input tensor using a fixed codebook and returns
+        the corresponding codebook vectors
+
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized continuous representation of input
+        Tensor[1]
+            Commitment loss to train encoder to predict vectors closer to codebook
+            entries
+        Tensor[1]
+            Codebook loss to update the codebook
+        Tensor[B x T]
+            Codebook indices (quantized discrete representation of input)
+        Tensor[B x D x T]
+            Projected latents (continuous representation of input before quantization)
+        """
+
+        # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
+        z_e = self.in_proj(z)  # z_e : (B x D x T)
+        z_q, indices = self.decode_latents(z_e)
+
+        commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
+        codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
+
+        z_q = (
+            z_e + (z_q - z_e).detach()
+        )  # noop in forward pass, straight-through gradient estimator in backward pass
+
+        z_q = self.out_proj(z_q)
+
+        return z_q, commitment_loss, codebook_loss, indices, z_e
+
+    def embed_code(self, embed_id):
+        return F.embedding(embed_id, self.codebook.weight)
+
+    def decode_code(self, embed_id):
+        return self.embed_code(embed_id).transpose(1, 2)
+
+    def decode_latents(self, latents):
+        encodings = rearrange(latents, "b d t -> (b t) d")
+        codebook = self.codebook.weight  # codebook: (N x D)
+
+        # L2 normalize encodings and codebook (ViT-VQGAN)
+        encodings = F.normalize(encodings)
+        codebook = F.normalize(codebook)
+
+        # Compute euclidean distance with codebook
+        dist = (
+            encodings.pow(2).sum(1, keepdim=True)
+            - 2 * encodings @ codebook.t()
+            + codebook.pow(2).sum(1, keepdim=True).t()
+        )
+        indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
+        z_q = self.decode_code(indices)
+        return z_q, indices
+
+
+@dataclass
+class VQResult:
+    z: torch.Tensor
+    codes: torch.Tensor
+    latents: torch.Tensor
+    commitment_loss: torch.Tensor
+    codebook_loss: torch.Tensor
+
+
+class ResidualVectorQuantize(nn.Module):
+    """
+    Introduced in SoundStream: An end2end neural audio codec
+    https://arxiv.org/abs/2107.03312
+    """
+
+    def __init__(
+        self,
+        input_dim: int = 512,
+        n_codebooks: int = 9,
+        codebook_size: int = 1024,
+        codebook_dim: Union[int, list] = 8,
+        quantizer_dropout: float = 0.0,
+        min_quantizers: int = 4,
+    ):
+        super().__init__()
+        if isinstance(codebook_dim, int):
+            codebook_dim = [codebook_dim for _ in range(n_codebooks)]
+
+        self.n_codebooks = n_codebooks
+        self.codebook_dim = codebook_dim
+        self.codebook_size = codebook_size
+
+        self.quantizers = nn.ModuleList(
+            [
+                VectorQuantize(input_dim, codebook_size, codebook_dim[i])
+                for i in range(n_codebooks)
+            ]
+        )
+        self.quantizer_dropout = quantizer_dropout
+        self.min_quantizers = min_quantizers
+
+    def forward(self, z, n_quantizers: int = None) -> VQResult:
+        """Quantized the input tensor using a fixed set of `n` codebooks and returns
+        the corresponding codebook vectors
+        Parameters
+        ----------
+        z : Tensor[B x D x T]
+        n_quantizers : int, optional
+            No. of quantizers to use
+            (n_quantizers < self.n_codebooks ex: for quantizer dropout)
+            Note: if `self.quantizer_dropout` is True, this argument is ignored
+                when in training mode, and a random number of quantizers is used.
+        Returns
+        -------
+        """
+        z_q = 0
+        residual = z
+        commitment_loss = 0
+        codebook_loss = 0
+
+        codebook_indices = []
+        latents = []
+
+        if n_quantizers is None:
+            n_quantizers = self.n_codebooks
+
+        if self.training:
+            n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
+            dropout = torch.randint(
+                self.min_quantizers, self.n_codebooks + 1, (z.shape[0],)
+            )
+            n_dropout = int(z.shape[0] * self.quantizer_dropout)
+            n_quantizers[:n_dropout] = dropout[:n_dropout]
+            n_quantizers = n_quantizers.to(z.device)
+
+        for i, quantizer in enumerate(self.quantizers):
+            if self.training is False and i >= n_quantizers:
+                break
+
+            z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
+                residual
+            )
+
+            # Create mask to apply quantizer dropout
+            mask = (
+                torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
+            )
+            z_q = z_q + z_q_i * mask[:, None, None]
+            residual = residual - z_q_i
+
+            # Sum losses
+            commitment_loss += (commitment_loss_i * mask).mean()
+            codebook_loss += (codebook_loss_i * mask).mean()
+
+            codebook_indices.append(indices_i)
+            latents.append(z_e_i)
+
+        codes = torch.stack(codebook_indices, dim=1)
+        latents = torch.cat(latents, dim=1)
+
+        return VQResult(z_q, codes, latents, commitment_loss, codebook_loss)
+
+    def from_codes(self, codes: torch.Tensor):
+        """Given the quantized codes, reconstruct the continuous representation
+        Parameters
+        ----------
+        codes : Tensor[B x N x T]
+            Quantized discrete representation of input
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized continuous representation of input
+        """
+        z_q = 0.0
+        z_p = []
+        n_codebooks = codes.shape[1]
+        for i in range(n_codebooks):
+            z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
+            z_p.append(z_p_i)
+
+            z_q_i = self.quantizers[i].out_proj(z_p_i)
+            z_q = z_q + z_q_i
+        return z_q, torch.cat(z_p, dim=1), codes
+
+    def from_latents(self, latents: torch.Tensor):
+        """Given the unquantized latents, reconstruct the
+        continuous representation after quantization.
+
+        Parameters
+        ----------
+        latents : Tensor[B x N x T]
+            Continuous representation of input after projection
+
+        Returns
+        -------
+        Tensor[B x D x T]
+            Quantized representation of full-projected space
+        Tensor[B x D x T]
+            Quantized representation of latent space
+        """
+        z_q = 0
+        z_p = []
+        codes = []
+        dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
+
+        n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
+            0
+        ]
+        for i in range(n_codebooks):
+            j, k = dims[i], dims[i + 1]
+            z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
+            z_p.append(z_p_i)
+            codes.append(codes_i)
+
+            z_q_i = self.quantizers[i].out_proj(z_p_i)
+            z_q = z_q + z_q_i
+
+        return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
+
+
+class DownsampleResidualVectorQuantizer(ResidualVectorQuantize):
+    """
+    Downsampled version of ResidualVectorQuantize
+    """
+
+    def __init__(
+        self,
+        input_dim: int = 512,
+        n_codebooks: int = 9,
+        codebook_size: int = 1024,
+        codebook_dim: Union[int, list] = 8,
+        quantizer_dropout: float = 0.0,
+        min_quantizers: int = 4,
+        downsample_factor: tuple[int] = (2, 2),
+        downsample_dims: tuple[int] | None = None,
+    ):
+        if downsample_dims is None:
+            downsample_dims = [input_dim for _ in range(len(downsample_factor))]
+
+        all_dims = (input_dim,) + tuple(downsample_dims)
+
+        super().__init__(
+            all_dims[-1],
+            n_codebooks,
+            codebook_size,
+            codebook_dim,
+            quantizer_dropout,
+            min_quantizers,
+        )
+
+        self.downsample_factor = downsample_factor
+        self.downsample_dims = downsample_dims
+
+        self.downsample = nn.Sequential(
+            *[
+                nn.Conv1d(
+                    all_dims[idx],
+                    all_dims[idx + 1],
+                    kernel_size=factor,
+                    stride=factor,
+                )
+                for idx, factor in enumerate(downsample_factor)
+            ]
+        )
+
+        self.upsample = nn.Sequential(
+            *[
+                nn.ConvTranspose1d(
+                    all_dims[idx + 1],
+                    all_dims[idx],
+                    kernel_size=factor,
+                    stride=factor,
+                )
+                for idx, factor in reversed(list(enumerate(downsample_factor)))
+            ]
+        )
+
+    def forward(self, z, n_quantizers: int = None) -> VQResult:
+        original_shape = z.shape
+        z = self.downsample(z)
+        result = super().forward(z, n_quantizers)
+        result.z = self.upsample(result.z)
+
+        # Pad or crop z to match original shape
+        diff = original_shape[-1] - result.z.shape[-1]
+        left = diff // 2
+        right = diff - left
+
+        if diff > 0:
+            result.z = F.pad(result.z, (left, right))
+        elif diff < 0:
+            result.z = result.z[..., left:-right]
+
+        return result
+
+    def from_codes(self, codes: torch.Tensor):
+        z_q, z_p, codes = super().from_codes(codes)
+        z_q = self.upsample(z_q)
+        return z_q, z_p, codes
+
+    def from_latents(self, latents: torch.Tensor):
+        z_q, z_p, codes = super().from_latents(latents)
+        z_q = self.upsample(z_q)
+        return z_q, z_p, codes
+
+
+if __name__ == "__main__":
+    rvq = DownsampleResidualVectorQuantizer(
+        quantizer_dropout=1.0,
+        min_quantizers=1,
+        codebook_size=256,
+        downsample_factor=(2, 2),
+    )
+    x = torch.randn(16, 512, 80)
+
+    result = rvq(x)
+    print(result.latents.shape, result.codes.shape, result.z.shape)
+
+    y = rvq.from_codes(result.codes)
+    print(y[0].shape)
+
+    y = rvq.from_latents(result.latents)
+    print(y[0].shape)

+ 3 - 2
fish_speech/train.py

@@ -83,8 +83,9 @@ def train(cfg: DictConfig) -> tuple[dict, dict]:
         ckpt_path = cfg.get("ckpt_path")
         auto_resume = False
 
-        if ckpt_path is None:
-            ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
+        resume_ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
+        if resume_ckpt_path is not None:
+            ckpt_path = resume_ckpt_path
             auto_resume = True
 
         if ckpt_path is not None:

+ 0 - 1
pyproject.toml

@@ -21,7 +21,6 @@ dependencies = [
     "natsort>=8.4.0",
     "einops>=0.7.0",
     "librosa>=0.10.1",
-    "vector-quantize-pytorch>=1.10.0",
     "rich>=13.5.3",
     "gradio>=4.0.0",
     "pypinyin>=0.49.0",

+ 112 - 73
tools/llama/generate.py

@@ -283,17 +283,23 @@ def encode_tokens(
 
     # Handle English less frequent words
     # TODO: update tokenizer to handle this
-    sub_strings = string.split(" ")
-    new_tokens = []
-    for i, string in enumerate(sub_strings):
-        tokens = tokenizer.encode(
-            string,
-            add_special_tokens=i == 0 and bos,
-            max_length=10**6,
-            truncation=False,
-        )
-        new_tokens.extend(tokens)
-
+    # sub_strings = string.split(" ")
+    # new_tokens = []
+    # for i, string in enumerate(sub_strings):
+    #     tokens = tokenizer.encode(
+    #         string,
+    #         add_special_tokens=i == 0 and bos,
+    #         max_length=10**6,
+    #         truncation=False,
+    #     )
+    #     new_tokens.extend(tokens)
+
+    new_tokens = tokenizer.encode(
+        string,
+        add_special_tokens=bos,
+        max_length=10**6,
+        truncation=False,
+    )
     tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
 
     # Codebooks
@@ -373,6 +379,25 @@ def load_model(config_name, checkpoint_path, device, precision):
     return model.eval()
 
 
+def split_text(text, min_length):
+    text = clean_text(text)
+    segments = []
+    curr = ""
+    for char in text:
+        curr += char
+        if char not in [".", ",", "!", "?"]:
+            continue
+
+        if len(curr) >= min_length:
+            segments.append(curr)
+            curr = ""
+
+    if curr:
+        segments.append(curr)
+
+    return segments
+
+
 @click.command()
 @click.option("--text", type=str, default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.")
 @click.option("--prompt-text", type=str, default=None)
@@ -440,18 +465,23 @@ def main(
     )
 
     use_prompt = prompt_text is not None and prompt_tokens is not None
+    encoded = []
+    for text in split_text(text, 20):
+        encoded.append(
+            encode_tokens(
+                tokenizer,
+                text,
+                bos=False,
+                device=device,
+                use_g2p=use_g2p,
+                speaker=None,
+                order=order,
+                num_codebooks=model.config.num_codebooks,
+            )
+        )
+        print(f"Encoded text: {text}")
 
     if use_prompt and iterative_prompt:
-        encoded = encode_tokens(
-            tokenizer,
-            text,
-            bos=False,
-            device=device,
-            use_g2p=use_g2p,
-            speaker=None,
-            order=order,
-            num_codebooks=model.config.num_codebooks,
-        )
         encoded_prompt = encode_tokens(
             tokenizer,
             prompt_text,
@@ -463,25 +493,11 @@ def main(
             order=order,
             num_codebooks=model.config.num_codebooks,
         )
-        encoded = torch.cat((encoded_prompt, encoded), dim=1)
-    else:
-        if prompt_text:
-            text = prompt_text + text
 
-        encoded = encode_tokens(
-            tokenizer,
-            text,
-            bos=True,
-            device=device,
-            use_g2p=use_g2p,
-            speaker=speaker,
-            order=order,
-            prompt_tokens=prompt_tokens,
-            num_codebooks=model.config.num_codebooks,
-        )
+        encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
 
-    prompt_length = encoded.size(1)
-    logger.info(f"Encoded prompt shape: {encoded.shape}")
+    # prompt_length = encoded.size(1)
+    # logger.info(f"Encoded prompt shape: {encoded.shape}")
 
     torch.manual_seed(seed)
     torch.cuda.manual_seed(seed)
@@ -494,46 +510,69 @@ def main(
 
     for idx in range(num_samples):
         torch.cuda.synchronize()
+        global_encoded = []
+        all_codes = []
+        seg_idx = 0
+
+        while seg_idx < len(encoded):
+            seg = encoded[seg_idx]
+            global_encoded.append(seg)
+            cat_encoded = torch.cat(global_encoded, dim=1)
+            prompt_length = cat_encoded.size(1)
+
+            t0 = time.perf_counter()
+            y = generate(
+                model=model,
+                prompt=cat_encoded,
+                max_new_tokens=max_new_tokens,
+                eos_token_id=tokenizer.eos_token_id,
+                precision=precision,
+                temperature=temperature,
+                top_k=top_k,
+                top_p=top_p,
+                repetition_penalty=repetition_penalty,
+            )
 
-        t0 = time.perf_counter()
-        y = generate(
-            model=model,
-            prompt=encoded,
-            max_new_tokens=max_new_tokens,
-            eos_token_id=tokenizer.eos_token_id,
-            precision=precision,
-            temperature=temperature,
-            top_k=top_k,
-            top_p=top_p,
-            repetition_penalty=repetition_penalty,
-        )
-
-        if idx == 0 and compile:
-            logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
+            if idx == 0 and compile:
+                logger.info(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
 
-        torch.cuda.synchronize()
-        t = time.perf_counter() - t0
+            torch.cuda.synchronize()
+            t = time.perf_counter() - t0
 
-        tokens_generated = y.size(1) - prompt_length
-        tokens_sec = tokens_generated / t
-        logger.info(
-            f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
-        )
-        logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
-        logger.info(
-            f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
-        )
-
-        codes = y[1:, prompt_length:-1]
-        new_codes = []
-        for j, code in enumerate(codes):
-            new_codes.append(
-                code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
+            tokens_generated = y.size(1) - prompt_length
+            tokens_sec = tokens_generated / t
+            logger.info(
+                f"Generated {tokens_generated} tokens in {t:.02f} seconds, {tokens_sec:.02f} tokens/sec"
+            )
+            logger.info(
+                f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s"
+            )
+            logger.info(
+                f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB"
             )
 
-        codes = torch.stack(new_codes, dim=0)
-        codes = codes - 2
-        assert (codes >= 0).all(), "Codes should be >= 0"
+            # Put the generated tokens
+            codes = y[1:, prompt_length:-1].clone()
+            new_codes = []
+            for j, code in enumerate(codes):
+                new_codes.append(
+                    code[j : codes.shape[1] - (model.config.num_codebooks - j - 1)]
+                )
+
+            codes = torch.stack(new_codes, dim=0)
+            codes = codes - 2
+            if not (codes >= 0).all():
+                global_encoded.pop()
+                logger.warning(f"Negative code found: {codes}, retrying ...")
+                continue
+
+            global_encoded.append(y[:, prompt_length:-1].clone())
+            all_codes.append(codes)
+            seg_idx += 1
+
+        codes = torch.cat(all_codes, dim=1)
+        assert (codes >= 0).all(), f"Negative code found: {codes}"
+        print(codes)
 
         np.save(f"codes_{idx}.npy", codes.cpu().numpy())
         logger.info(f"Saved codes to codes_{idx}.npy")