Quellcode durchsuchen

Implement new vq + dit + reflow

Lengyue vor 2 Jahren
Ursprung
Commit
7b5fe470e4

+ 53 - 54
fish_speech/configs/vqgan_pretrain.yaml

@@ -2,36 +2,36 @@ defaults:
   - base
   - _self_
 
-project: vqgan_pretrain_lfq
-ckpt_path: checkpoints/gpt_sovits_488k.pth
-resume_weights_only: true
+project: vq_reflow_debug
 
 # Lightning Trainer
 trainer:
   accelerator: gpu
   devices: auto
   strategy: ddp_find_unused_parameters_true
-  precision: 32
+  precision: 16-mixed
   max_steps: 1_000_000
   val_check_interval: 2000
+  gradient_clip_algorithm: norm
+  gradient_clip_val: 1.0
 
-sample_rate: 32000
-hop_length: 640
-num_mels: 128
+sample_rate: 44100
+hop_length: 512
+num_mels: 160
 n_fft: 2048
 win_length: 2048
 
 # Dataset Configuration
 train_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_train_filelist.txt
+  filelist: /***REMOVED***/workspace/diffusion-test/data/HiFi-TTS/vq_train_filelist.txt
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
-  slice_frames: 128
+  slice_frames: 512
 
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_val_filelist.txt
+  filelist: /***REMOVED***/workspace/diffusion-test/data/HiFi-TTS/vq_val_filelist.txt
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
 
@@ -40,49 +40,48 @@ data:
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
-  batch_size: 16
+  batch_size: 32
   val_batch_size: 4
 
 # Model Configuration
 model:
   _target_: fish_speech.models.vqgan.VQGAN
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  freeze_discriminator: false
 
-  weight_mel: 45.0
-  weight_kl: 0.1
+  sampling_rate: ${sample_rate}
+  weight_reflow: 1.0
   weight_vq: 1.0
   weight_aux_mel: 1.0
 
-  generator:
-    _target_: fish_speech.models.vqgan.modules.models.SynthesizerTrn
-    spec_channels: 1025
-    segment_size: 32
-    inter_channels: 192
-    prior_hidden_channels: 192
-    posterior_hidden_channels: 192
-    prior_n_layers: 16
-    posterior_n_layers: 16
-    kernel_size: 5
-    p_dropout: 0.1
-    resblock: "1"
-    resblock_kernel_sizes: [3, 7, 11]
-    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    upsample_rates: [10, 8, 2, 2, 2]
-    upsample_initial_channel: 512
-    upsample_kernel_sizes: [16, 16, 8, 2, 2]
-    gin_channels: 512
-    freeze_quantizer: false
-    freeze_decoder: false
-    freeze_posterior_encoder: false
-    codebook_size: 1024
-    num_codebooks: 2
-    aux_spec_channels: ${num_mels}
+  encoder:
+    _target_: fish_speech.models.vqgan.modules.convnext.ConvNeXtEncoder
+    input_channels: ${num_mels}
+    depths: [3, 3, 9, 3]
+    dims: [128, 256, 384, 512]
+  
+  quantizer:
+    _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
+    input_dim: 512
+    n_codebooks: 8
+    levels: [8, 5, 5, 5]
+  
+  aux_decoder:
+    _target_: fish_speech.models.vqgan.modules.convnext.ConvNeXtEncoder
+    input_channels: 512
+    output_channels: ${num_mels}
+    depths: [6]
+    dims: [384]
+
+  reflow:
+    _target_: fish_speech.models.vqgan.modules.dit.DiT
+    hidden_size: 768
+    num_heads: 12
+    diffusion_num_layers: 12
+    channels: ${num_mels}
+    condition_dim: 512
 
-  discriminator:
-    _target_: fish_speech.models.vqgan.modules.models.EnsembledDiscriminator
-    periods: [2, 3, 5, 7, 11]
+  vocoder:
+    _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
+    ckpt_path: checkpoints/firefly-gan-base-002000000.ckpt
 
   mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
@@ -92,13 +91,6 @@ model:
     win_length: ${win_length}
     n_mels: ${num_mels}
 
-  spec_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LinearSpectrogram
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    mode: pow2_sqrt
-  
   optimizer:
     _target_: torch.optim.AdamW
     _partial_: true
@@ -107,12 +99,19 @@ model:
     eps: 1e-5
 
   lr_scheduler:
-    _target_: torch.optim.lr_scheduler.ExponentialLR
+    _target_: torch.optim.lr_scheduler.LambdaLR
     _partial_: true
-    gamma: 0.99999
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 100
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0
 
 callbacks:
   grad_norm_monitor:
     sub_module: 
-      - generator
-      - discriminator
+      - encoder
+      - aux_decoder
+      - quantizer
+      - reflow

+ 169 - 316
fish_speech/models/vqgan/lit_module.py

@@ -9,15 +9,7 @@ import wandb
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from torch import nn
-from torch.utils.checkpoint import checkpoint as gradient_checkpoint
-
-from fish_speech.models.vqgan.losses import (
-    MultiResolutionSTFTLoss,
-    discriminator_loss,
-    feature_loss,
-    generator_loss,
-    kl_loss,
-)
+
 from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
 
 
@@ -40,17 +32,16 @@ class VQGAN(L.LightningModule):
         self,
         optimizer: Callable,
         lr_scheduler: Callable,
-        generator: nn.Module,
-        discriminator: nn.Module,
+        encoder: nn.Module,
+        quantizer: nn.Module,
+        aux_decoder: nn.Module,
+        reflow: nn.Module,
+        vocoder: nn.Module,
         mel_transform: nn.Module,
-        spec_transform: nn.Module,
-        hop_length: int = 640,
-        sample_rate: int = 32000,
-        freeze_discriminator: bool = False,
-        weight_mel: float = 45,
-        weight_kl: float = 0.1,
+        weight_reflow: float = 1.0,
         weight_vq: float = 1.0,
-        weight_aux_mel: float = 20.0,
+        weight_aux_mel: float = 1.0,
+        sampling_rate: int = 44100,
     ):
         super().__init__()
 
@@ -58,62 +49,54 @@ class VQGAN(L.LightningModule):
         self.optimizer_builder = optimizer
         self.lr_scheduler_builder = lr_scheduler
 
-        # Generator and discriminator
-        self.generator = generator
-        self.discriminator = discriminator
+        # Modules
+        self.encoder = encoder
+        self.quantizer = quantizer
+        self.aux_decoder = aux_decoder
+        self.reflow = reflow
         self.mel_transform = mel_transform
-        self.spec_transform = spec_transform
-        self.freeze_discriminator = freeze_discriminator
+        self.vocoder = vocoder
+
+        # Freeze vocoder
+        for param in self.vocoder.parameters():
+            param.requires_grad = False
 
         # Loss weights
-        self.weight_mel = weight_mel
-        self.weight_kl = weight_kl
+        self.weight_reflow = weight_reflow
         self.weight_vq = weight_vq
         self.weight_aux_mel = weight_aux_mel
 
-        # Other parameters
-        self.hop_length = hop_length
-        self.sampling_rate = sample_rate
-
-        # Disable automatic optimization
-        self.automatic_optimization = False
+        self.spec_min = -12
+        self.spec_max = 3
+        self.sampling_rate = sampling_rate
 
-        if self.freeze_discriminator:
-            for p in self.discriminator.parameters():
-                p.requires_grad = False
+    def on_save_checkpoint(self, checkpoint):
+        # Do not save vocoder
+        state_dict = checkpoint["state_dict"]
+        for name in list(state_dict.keys()):
+            if "vocoder" in name:
+                state_dict.pop(name)
 
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
-        optimizer_generator = self.optimizer_builder(self.generator.parameters())
-        optimizer_discriminator = self.optimizer_builder(
-            self.discriminator.parameters()
-        )
-
-        lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
-        lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
-
-        return (
-            {
-                "optimizer": optimizer_generator,
-                "lr_scheduler": {
-                    "scheduler": lr_scheduler_generator,
-                    "interval": "step",
-                    "name": "optimizer/generator",
-                },
+        optimizer = self.optimizer_builder(self.parameters())
+        lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+        return {
+            "optimizer": optimizer,
+            "lr_scheduler": {
+                "scheduler": lr_scheduler,
+                "interval": "step",
             },
-            {
-                "optimizer": optimizer_discriminator,
-                "lr_scheduler": {
-                    "scheduler": lr_scheduler_discriminator,
-                    "interval": "step",
-                    "name": "optimizer/discriminator",
-                },
-            },
-        )
+        }
 
-    def training_step(self, batch, batch_idx):
-        optim_g, optim_d = self.optimizers()
+    def norm_spec(self, x):
+        return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
 
+    def denorm_spec(self, x):
+        return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
+
+    def training_step(self, batch, batch_idx):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
 
         audios = audios.float()
@@ -121,173 +104,84 @@ class VQGAN(L.LightningModule):
 
         with torch.no_grad():
             gt_mels = self.mel_transform(audios)
-            gt_specs = self.spec_transform(audios)
-
-        spec_lengths = audio_lengths // self.hop_length
-        spec_masks = torch.unsqueeze(
-            sequence_mask(spec_lengths, gt_mels.shape[2]), 1
-        ).to(gt_mels.dtype)
-        (
-            fake_audios,
-            ids_slice,
-            y_mask,
-            y_mask,
-            (z, z_p, m_p, logs_p, m_q, logs_q),
-            loss_vq,
-            decoded_aux_mels,
-        ) = self.generator(gt_specs, spec_lengths)
 
-        gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
-        decoded_aux_mels = slice_segments(
-            decoded_aux_mels, ids_slice, self.generator.segment_size
-        )
-        spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
-        audios = slice_segments(
-            audios,
-            ids_slice * self.hop_length,
-            self.generator.segment_size * self.hop_length,
-        )
-        fake_mels = self.mel_transform(fake_audios.squeeze(1))
+        mel_lengths = audio_lengths // self.mel_transform.hop_length
+        mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
+        mel_masks_float_conv = mel_masks[:, None, :].float()
 
-        assert (
-            audios.shape == fake_audios.shape
-        ), f"{audios.shape} != {fake_audios.shape}"
-
-        # Discriminator
-        if self.freeze_discriminator is False:
-            y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(
-                audios, fake_audios.detach()
-            )
+        # Encode
+        encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
 
-            with torch.autocast(device_type=audios.device.type, enabled=False):
-                loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
-
-            self.log(
-                f"train/discriminator/loss",
-                loss_disc,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
+        # Quantize
+        vq_result = self.quantizer(encoded_features)
+        loss_vq = getattr("vq_result", "loss", 0.0)
+        vq_recon_features = vq_result.z * mel_masks_float_conv
 
-            optim_d.zero_grad()
-            self.manual_backward(loss_disc)
-            self.clip_gradients(
-                optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
-            )
-            optim_d.step()
+        # VQ Decode
+        aux_mel = self.aux_decoder(vq_recon_features)
+        loss_aux_mel = F.l1_loss(
+            aux_mel * mel_masks_float_conv, gt_mels * mel_masks_float_conv
+        )
 
-        # Adv Loss
-        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios)
+        # Reflow
+        x_1 = self.norm_spec(gt_mels.mT)
+        t = torch.rand(gt_mels.shape[0], device=gt_mels.device)
+        x_0 = torch.randn_like(x_1)
 
-        # Adversarial Loss
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_adv, _ = generator_loss(y_d_hat_g)
+        # X_t = t * X_1 + (1 - t) * X_0
+        x_t = x_0 + t[:, None, None] * (x_1 - x_0)
 
-        self.log(
-            f"train/generator/adv",
-            loss_adv,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
+        v_pred = self.reflow(
+            x_t,
+            1000 * t,
+            condition=vq_recon_features.mT,
+            self_mask=mel_masks,
         )
 
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_fm = feature_loss(y_d_hat_r, y_d_hat_g)
-
-        self.log(
-            f"train/generator/adv_fm",
-            loss_fm,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
+        # Log L2 loss with
+        weights = 0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
+        loss_reflow = weights[:, None, None] * F.mse_loss(
+            x_1 - x_0, v_pred, reduction="none"
         )
+        loss_reflow = (loss_reflow * mel_masks_float_conv.mT).mean()
 
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_mel = F.l1_loss(gt_mels * spec_masks, fake_mels * spec_masks)
-            loss_aux_mel = F.l1_loss(
-                gt_mels * spec_masks, decoded_aux_mels * spec_masks
-            )
-
-        self.log(
-            "train/generator/loss_mel",
-            loss_mel,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
+        # Total loss
+        loss = (
+            self.weight_vq * loss_vq
+            + self.weight_aux_mel * loss_aux_mel
+            + self.weight_reflow * loss_reflow
         )
 
+        # Log losses
         self.log(
-            "train/generator/loss_aux_mel",
-            loss_aux_mel,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
+            "train/loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True
         )
-
         self.log(
-            "train/generator/loss_vq",
+            "train/loss_vq",
             loss_vq,
             on_step=True,
             on_epoch=False,
             prog_bar=False,
             logger=True,
-            sync_dist=True,
         )
-
-        loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
-
         self.log(
-            "train/generator/loss_kl",
-            loss_kl,
+            "train/loss_aux_mel",
+            loss_aux_mel,
             on_step=True,
             on_epoch=False,
             prog_bar=False,
             logger=True,
-            sync_dist=True,
-        )
-
-        loss = (
-            loss_mel * self.weight_mel
-            + loss_aux_mel * self.weight_aux_mel
-            + loss_vq * self.weight_vq
-            + loss_kl * self.weight_kl
-            + loss_adv
-            + loss_fm
         )
         self.log(
-            "train/generator/loss",
-            loss,
+            "train/loss_reflow",
+            loss_reflow,
             on_step=True,
             on_epoch=False,
             prog_bar=False,
             logger=True,
-            sync_dist=True,
-        )
-
-        # Backward
-        optim_g.zero_grad()
-
-        self.manual_backward(loss)
-        self.clip_gradients(
-            optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
         )
-        optim_g.step()
 
-        # Manual LR Scheduler
-        scheduler_g, scheduler_d = self.lr_schedulers()
-        scheduler_g.step()
-        scheduler_d.step()
+        return loss
 
     def validation_step(self, batch: Any, batch_idx: int):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
@@ -296,32 +190,25 @@ class VQGAN(L.LightningModule):
         audios = audios[:, None, :]
 
         gt_mels = self.mel_transform(audios)
-        gt_specs = self.spec_transform(audios)
-        spec_lengths = audio_lengths // self.hop_length
-        spec_masks = torch.unsqueeze(
-            sequence_mask(spec_lengths, gt_mels.shape[2]), 1
-        ).to(gt_mels.dtype)
-
-        prior_audios, _, _ = self.generator.infer(gt_specs, spec_lengths)
-        posterior_audios, _, _ = self.generator.infer_posterior(gt_specs, spec_lengths)
-        prior_mels = self.mel_transform(prior_audios.squeeze(1))
-        posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
-
-        min_mel_length = min(
-            gt_mels.shape[-1], prior_mels.shape[-1], posterior_mels.shape[-1]
-        )
-        gt_mels = gt_mels[:, :, :min_mel_length]
-        prior_mels = prior_mels[:, :, :min_mel_length]
-        posterior_mels = posterior_mels[:, :, :min_mel_length]
+        mel_lengths = audio_lengths // self.mel_transform.hop_length
+        mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+
+        # Encode
+        encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
 
-        prior_mel_loss = F.l1_loss(gt_mels * spec_masks, prior_mels * spec_masks)
-        posterior_mel_loss = F.l1_loss(
-            gt_mels * spec_masks, posterior_mels * spec_masks
+        # Quantize
+        vq_result = self.quantizer(encoded_features)
+
+        # VQ Decode
+        aux_mels = self.aux_decoder(vq_result.z)
+        loss_aux_mel = F.l1_loss(
+            aux_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
         )
 
         self.log(
-            "val/prior_mel_loss",
-            prior_mel_loss,
+            "val/loss_aux_mel",
+            loss_aux_mel,
             on_step=False,
             on_epoch=True,
             prog_bar=False,
@@ -329,9 +216,33 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
 
+        # Reflow inference
+        t_start = 0.0
+        infer_step = 20
+        gen_mels = torch.randn(gt_mels.shape, device=gt_mels.device).mT
+        t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
+        dt = (1.0 - t_start) / infer_step
+
+        for _ in range(infer_step):
+            gen_mels += (
+                self.reflow(
+                    gen_mels,
+                    1000 * t,
+                    condition=vq_result.z.mT,
+                    self_mask=mel_masks,
+                )
+                * dt
+            )
+            t += dt
+
+        gen_mels = self.denorm_spec(gen_mels).mT
+        loss_recon_reflow = F.l1_loss(
+            gen_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
+        )
+
         self.log(
-            "val/posterior_mel_loss",
-            posterior_mel_loss,
+            "val/loss_recon_reflow",
+            loss_recon_reflow,
             on_step=False,
             on_epoch=True,
             prog_bar=False,
@@ -339,41 +250,47 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
 
+        gen_audios = self.vocoder(gen_mels)
+        recon_audios = self.vocoder(gt_mels)
+        aux_audios = self.vocoder(aux_mels)
+
         # only log the first batch
         if batch_idx != 0:
             return
 
         for idx, (
-            mel,
-            prior_mel,
-            posterior_mel,
+            gt_mel,
+            reflow_mel,
+            aux_mel,
             audio,
-            prior_audio,
-            posterior_audio,
+            reflow_audio,
+            aux_audio,
+            recon_audio,
             audio_len,
         ) in enumerate(
             zip(
                 gt_mels,
-                prior_mels,
-                posterior_mels,
-                audios.detach().float(),
-                prior_audios.detach().float(),
-                posterior_audios.detach().float(),
+                gen_mels,
+                aux_mels,
+                audios.float(),
+                gen_audios.float(),
+                aux_audios.float(),
+                recon_audios.float(),
                 audio_lengths,
             )
         ):
-            mel_len = audio_len // self.hop_length
+            mel_len = audio_len // self.mel_transform.hop_length
 
             image_mels = plot_mel(
                 [
-                    prior_mel[:, :mel_len],
-                    posterior_mel[:, :mel_len],
-                    mel[:, :mel_len],
+                    gt_mel[:, :mel_len],
+                    reflow_mel[:, :mel_len],
+                    aux_mel[:, :mel_len],
                 ],
                 [
-                    "Prior (VQ)",
-                    "Posterior (Reconstruction)",
                     "Ground-Truth",
+                    "Reflow",
+                    "Aux",
                 ],
             )
 
@@ -388,14 +305,19 @@ class VQGAN(L.LightningModule):
                                 caption="gt",
                             ),
                             wandb.Audio(
-                                prior_audio[0, :audio_len],
+                                reflow_audio[0, :audio_len],
                                 sample_rate=self.sampling_rate,
-                                caption="prior",
+                                caption="reflow",
                             ),
                             wandb.Audio(
-                                posterior_audio[0, :audio_len],
+                                aux_audio[0, :audio_len],
                                 sample_rate=self.sampling_rate,
-                                caption="posterior",
+                                caption="aux",
+                            ),
+                            wandb.Audio(
+                                recon_audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="recon",
                             ),
                         ],
                     },
@@ -414,91 +336,22 @@ class VQGAN(L.LightningModule):
                     sample_rate=self.sampling_rate,
                 )
                 self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/prior",
-                    prior_audio[0, :audio_len],
+                    f"sample-{idx}/wavs/reflow",
+                    reflow_audio[0, :audio_len],
                     self.global_step,
                     sample_rate=self.sampling_rate,
                 )
                 self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/posterior",
-                    posterior_audio[0, :audio_len],
+                    f"sample-{idx}/wavs/aux",
+                    aux_audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/recon",
+                    recon_audio[0, :audio_len],
                     self.global_step,
                     sample_rate=self.sampling_rate,
                 )
 
             plt.close(image_mels)
-
-    # def encode(self, audios, audio_lengths=None):
-    #     if audio_lengths is None:
-    #         audio_lengths = torch.tensor(
-    #             [audios.shape[-1]] * audios.shape[0],
-    #             device=audios.device,
-    #             dtype=torch.long,
-    #         )
-
-    #     with torch.no_grad():
-    #         features = self.mel_transform(audios, sample_rate=self.sampling_rate)
-
-    #     feature_lengths = (
-    #         audio_lengths
-    #         / self.hop_length
-    #         # / self.vq.downsample
-    #     ).long()
-
-    #     # print(features.shape, feature_lengths.shape, torch.max(feature_lengths))
-
-    #     feature_masks = torch.unsqueeze(
-    #         sequence_mask(feature_lengths, features.shape[2]), 1
-    #     ).to(features.dtype)
-
-    #     features = (
-    #         gradient_checkpoint(
-    #             self.encoder, features, feature_masks, use_reentrant=False
-    #         )
-    #         * feature_masks
-    #     )
-    #     vq_features, indices, loss = self.vq(features, feature_masks)
-
-    #     return VQEncodeResult(
-    #         features=vq_features,
-    #         indices=indices,
-    #         loss=loss,
-    #         feature_lengths=feature_lengths,
-    #     )
-
-    # def calculate_audio_lengths(self, feature_lengths):
-    #     return feature_lengths * self.hop_length * self.vq.downsample
-
-    # def decode(
-    #     self,
-    #     indices=None,
-    #     features=None,
-    #     audio_lengths=None,
-    #     feature_lengths=None,
-    #     return_audios=False,
-    # ):
-    #     assert (
-    #         indices is not None or features is not None
-    #     ), "indices or features must be provided"
-    #     assert (
-    #         feature_lengths is not None or audio_lengths is not None
-    #     ), "feature_lengths or audio_lengths must be provided"
-
-    #     if audio_lengths is None:
-    #         audio_lengths = self.calculate_audio_lengths(feature_lengths)
-
-    #     mel_lengths = audio_lengths // self.hop_length
-    #     mel_masks = torch.unsqueeze(
-    #         sequence_mask(mel_lengths, torch.max(mel_lengths)), 1
-    #     ).float()
-
-    #     if indices is not None:
-    #         features = self.vq.decode(indices)
-
-    #     # Sample mels
-    #     decoded = gradient_checkpoint(self.decoder, features, use_reentrant=False)
-
-    #     return VQDecodeResult(
-    #         mels=decoded,
-    #         audios=self.generator(decoded) if return_audios else None,
-    #     )

+ 249 - 0
fish_speech/models/vqgan/modules/convnext.py

@@ -0,0 +1,249 @@
+from functools import partial
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+def drop_path(
+    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+    'survival rate' as the argument.
+
+    """  # noqa: E501
+
+    if drop_prob == 0.0 or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (
+        x.ndim - 1
+    )  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+    if keep_prob > 0.0 and scale_by_keep:
+        random_tensor.div_(keep_prob)
+    return x * random_tensor
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""  # noqa: E501
+
+    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+        self.scale_by_keep = scale_by_keep
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+    def extra_repr(self):
+        return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class LayerNorm(nn.Module):
+    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+    with shape (batch_size, channels, height, width).
+    """  # noqa: E501
+
+    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(normalized_shape))
+        self.bias = nn.Parameter(torch.zeros(normalized_shape))
+        self.eps = eps
+        self.data_format = data_format
+        if self.data_format not in ["channels_last", "channels_first"]:
+            raise NotImplementedError
+        self.normalized_shape = (normalized_shape,)
+
+    def forward(self, x):
+        if self.data_format == "channels_last":
+            return F.layer_norm(
+                x, self.normalized_shape, self.weight, self.bias, self.eps
+            )
+        elif self.data_format == "channels_first":
+            u = x.mean(1, keepdim=True)
+            s = (x - u).pow(2).mean(1, keepdim=True)
+            x = (x - u) / torch.sqrt(s + self.eps)
+            x = self.weight[:, None] * x + self.bias[:, None]
+            return x
+
+
+class ConvNeXtBlock(nn.Module):
+    r"""ConvNeXt Block. There are two equivalent implementations:
+    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+    We use (2) as we find it slightly faster in PyTorch
+
+    Args:
+        dim (int): Number of input channels.
+        drop_path (float): Stochastic depth rate. Default: 0.0
+        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+        kernel_size (int): Kernel size for depthwise conv. Default: 7.
+        dilation (int): Dilation for depthwise conv. Default: 1.
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        dim: int,
+        drop_path: float = 0.0,
+        layer_scale_init_value: float = 1e-6,
+        mlp_ratio: float = 4.0,
+        kernel_size: int = 7,
+        dilation: int = 1,
+    ):
+        super().__init__()
+
+        self.dwconv = nn.Conv1d(
+            dim,
+            dim,
+            kernel_size=kernel_size,
+            padding=int(dilation * (kernel_size - 1) / 2),
+            groups=dim,
+        )  # depthwise conv
+        self.norm = LayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(
+            dim, int(mlp_ratio * dim)
+        )  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
+        self.gamma = (
+            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+            if layer_scale_init_value > 0
+            else None
+        )
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+    def forward(self, x, apply_residual: bool = True):
+        input = x
+
+        x = self.dwconv(x)
+        x = x.permute(0, 2, 1)  # (N, C, L) -> (N, L, C)
+        x = self.norm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+
+        if self.gamma is not None:
+            x = self.gamma * x
+
+        x = x.permute(0, 2, 1)  # (N, L, C) -> (N, C, L)
+        x = self.drop_path(x)
+
+        if apply_residual:
+            x = input + x
+
+        return x
+
+
+class ParallelConvNeXtBlock(nn.Module):
+    def __init__(self, kernel_sizes: list[int], *args, **kwargs):
+        super().__init__()
+        self.blocks = nn.ModuleList(
+            [
+                ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
+                for kernel_size in kernel_sizes
+            ]
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return torch.stack(
+            [block(x, apply_residual=False) for block in self.blocks] + [x],
+            dim=1,
+        ).sum(dim=1)
+
+
+class ConvNeXtEncoder(nn.Module):
+    def __init__(
+        self,
+        input_channels: int = 3,
+        output_channels: Optional[int] = None,
+        depths: list[int] = [3, 3, 9, 3],
+        dims: list[int] = [96, 192, 384, 768],
+        drop_path_rate: float = 0.0,
+        layer_scale_init_value: float = 1e-6,
+        kernel_sizes: tuple[int] = (7,),
+    ):
+        super().__init__()
+        assert len(depths) == len(dims)
+
+        self.channel_layers = nn.ModuleList()
+        stem = nn.Sequential(
+            nn.Conv1d(
+                input_channels,
+                dims[0],
+                kernel_size=7,
+                padding=3,
+                padding_mode="zeros",
+            ),
+            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+        )
+        self.channel_layers.append(stem)
+
+        for i in range(len(depths) - 1):
+            mid_layer = nn.Sequential(
+                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+                nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
+            )
+            self.channel_layers.append(mid_layer)
+
+        block_fn = (
+            partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
+            if len(kernel_sizes) == 1
+            else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
+        )
+
+        self.stages = nn.ModuleList()
+        drop_path_rates = [
+            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+        ]
+
+        cur = 0
+        for i in range(len(depths)):
+            stage = nn.Sequential(
+                *[
+                    block_fn(
+                        dim=dims[i],
+                        drop_path=drop_path_rates[cur + j],
+                        layer_scale_init_value=layer_scale_init_value,
+                    )
+                    for j in range(depths[i])
+                ]
+            )
+            self.stages.append(stage)
+            cur += depths[i]
+
+        self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
+
+        if output_channels is not None:
+            self.output_projection = nn.Conv1d(dims[-1], output_channels, kernel_size=1)
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Conv1d, nn.Linear)):
+            nn.init.trunc_normal_(m.weight, std=0.02)
+            nn.init.constant_(m.bias, 0)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+    ) -> torch.Tensor:
+        for channel_layer, stage in zip(self.channel_layers, self.stages):
+            x = channel_layer(x)
+            x = stage(x)
+
+        x = self.norm(x)
+
+        if hasattr(self, "output_projection"):
+            x = self.output_projection(x)
+
+        return x

+ 419 - 0
fish_speech/models/vqgan/modules/dit.py

@@ -0,0 +1,419 @@
+import math
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def modulate(x, shift, scale):
+    return x * (1 + scale) + shift
+
+
+def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+    x_out2 = torch.stack(
+        [
+            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+        ],
+        -1,
+    )
+
+    x_out2 = x_out2.flatten(3)
+    return x_out2.type_as(x)
+
+
+class TimestepEmbedder(nn.Module):
+    """
+    Embeds scalar timesteps into vector representations.
+    """
+
+    def __init__(self, hidden_size, frequency_embedding_size=256):
+        super().__init__()
+        self.mlp = FeedForward(
+            frequency_embedding_size, hidden_size, out_dim=hidden_size
+        )
+        self.frequency_embedding_size = frequency_embedding_size
+
+    @staticmethod
+    def timestep_embedding(t, dim, max_period=10000):
+        """
+        Create sinusoidal timestep embeddings.
+        :param t: a 1-D Tensor of N indices, one per batch element.
+                          These may be fractional.
+        :param dim: the dimension of the output.
+        :param max_period: controls the minimum frequency of the embeddings.
+        :return: an (N, D) Tensor of positional embeddings.
+        """
+        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+        half = dim // 2
+        freqs = torch.exp(
+            -math.log(max_period)
+            * torch.arange(start=0, end=half, dtype=torch.float32)
+            / half
+        ).to(device=t.device)
+        args = t[:, None].float() * freqs[None]
+        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+        if dim % 2:
+            embedding = torch.cat(
+                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+            )
+        return embedding
+
+    def forward(self, t):
+        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+        t_emb = self.mlp(t_freq)
+        return t_emb
+
+
+def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> torch.Tensor:
+    freqs = 1.0 / (
+        base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+    )
+    t = torch.arange(seq_len, device=freqs.device)
+    freqs = torch.outer(t, freqs)
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+    return cache.to(dtype=torch.bfloat16)
+
+
+class Attention(nn.Module):
+    def __init__(
+        self,
+        dim,
+        n_head,
+    ):
+        super().__init__()
+        assert dim % n_head == 0
+
+        self.dim = dim
+        self.n_head = n_head
+        self.head_dim = dim // n_head
+
+        self.wq = nn.Linear(dim, dim)
+        self.wk = nn.Linear(dim, dim)
+        self.wv = nn.Linear(dim, dim)
+        self.wo = nn.Linear(dim, dim)
+
+    def forward(self, q, freqs_cis, kv=None, mask=None):
+        bsz, seqlen, _ = q.shape
+
+        if kv is None:
+            kv = q
+
+        kv_seqlen = kv.shape[1]
+
+        q = self.wq(q).view(bsz, seqlen, self.n_head, self.head_dim)
+        k = self.wk(kv).view(bsz, kv_seqlen, self.n_head, self.head_dim)
+        v = self.wv(kv).view(bsz, kv_seqlen, self.n_head, self.head_dim)
+
+        q = apply_rotary_emb(q, freqs_cis[:seqlen])
+        k = apply_rotary_emb(k, freqs_cis[:kv_seqlen])
+
+        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+        y = F.scaled_dot_product_attention(
+            q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
+        )
+
+        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+        y = self.wo(y)
+        return y
+
+
+class FeedForward(nn.Module):
+    def __init__(self, in_dim, intermediate_size, out_dim=None):
+        super().__init__()
+        self.w1 = nn.Linear(in_dim, intermediate_size)
+        self.w3 = nn.Linear(in_dim, intermediate_size)
+        self.w2 = nn.Linear(intermediate_size, out_dim or in_dim)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class DiTBlock(nn.Module):
+    def __init__(
+        self,
+        hidden_size,
+        num_heads,
+        mlp_ratio=4.0,
+        use_self_attention=True,
+        use_cross_attention=False,
+    ):
+        super().__init__()
+
+        self.use_self_attention = use_self_attention
+        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+        if use_self_attention:
+            self.mix = Attention(hidden_size, num_heads)
+        else:
+            self.mix = nn.Conv1d(
+                hidden_size,
+                hidden_size,
+                kernel_size=7,
+                padding=3,
+                bias=True,
+                groups=hidden_size,
+            )
+
+        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+        self.mlp = FeedForward(hidden_size, int(hidden_size * mlp_ratio))
+        self.adaLN_modulation = nn.Sequential(
+            nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+        )
+
+        self.use_cross_attention = use_cross_attention
+        if self.use_cross_attention:
+            self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+            self.norm4 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+            self.cross_attn = Attention(hidden_size, num_heads)
+            self.adaLN_modulation_cross = nn.Sequential(
+                nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size, bias=True)
+            )
+            self.adaLN_modulation_cross_condition = nn.Sequential(
+                nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+            )
+
+    def forward(
+        self,
+        x,
+        condition,
+        freqs_cis,
+        self_mask=None,
+        cross_condition=None,
+        cross_mask=None,
+    ):
+        (
+            shift_msa,
+            scale_msa,
+            gate_msa,
+            shift_mlp,
+            scale_mlp,
+            gate_mlp,
+        ) = self.adaLN_modulation(condition).chunk(6, dim=-1)
+
+        # Self-attention
+        inp = modulate(self.norm1(x), shift_msa, scale_msa)
+        if self.use_self_attention:
+            inp = self.mix(inp, freqs_cis=freqs_cis, mask=self_mask)
+        else:
+            inp = self.mix(inp.mT).mT
+        x = x + gate_msa * inp
+
+        # Cross-attention
+        if self.use_cross_attention:
+            (
+                shift_cross,
+                scale_cross,
+                gate_cross,
+            ) = self.adaLN_modulation_cross(
+                condition
+            ).chunk(3, dim=-1)
+
+            (
+                shift_cross_condition,
+                scale_cross_condition,
+            ) = self.adaLN_modulation_cross_condition(cross_condition).chunk(2, dim=-1)
+
+            inp = modulate(self.norm3(x), shift_cross, scale_cross)
+            inp = self.cross_attn(
+                inp,
+                freqs_cis=freqs_cis,
+                kv=modulate(
+                    self.norm4(cross_condition),
+                    shift_cross_condition,
+                    scale_cross_condition,
+                ),
+                mask=cross_mask,
+            )
+            x = x + gate_cross * inp
+
+        # MLP
+        x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+
+        return x
+
+
+class FinalLayer(nn.Module):
+    """
+    The final layer of DiT.
+    """
+
+    def __init__(self, hidden_size, out_channels):
+        super().__init__()
+        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+        self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+        self.adaLN_modulation = nn.Sequential(
+            nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+        )
+
+    def forward(self, x, c):
+        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+        x = modulate(self.norm_final(x), shift, scale)
+        return self.linear(x)
+
+
+class DiT(nn.Module):
+    def __init__(
+        self,
+        hidden_size,
+        num_heads,
+        diffusion_num_layers,
+        channels=160,
+        mlp_ratio=4.0,
+        max_seq_len=16384,
+        condition_dim=512,
+        style_dim=None,
+        cross_condition_dim=None,
+    ):
+        super().__init__()
+
+        self.max_seq_len = max_seq_len
+
+        self.time_embedder = TimestepEmbedder(hidden_size)
+        self.condition_embedder = FeedForward(
+            condition_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
+        )
+
+        if cross_condition_dim is not None:
+            self.cross_condition_embedder = FeedForward(
+                cross_condition_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
+            )
+
+        self.use_style = style_dim is not None
+        if self.use_style:
+            self.style_embedder = FeedForward(
+                style_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
+            )
+
+        self.diffusion_blocks = nn.ModuleList(
+            [
+                DiTBlock(
+                    hidden_size,
+                    num_heads,
+                    mlp_ratio,
+                    use_self_attention=i % 4 == 0,
+                    use_cross_attention=cross_condition_dim is not None,
+                )
+                for i in range(diffusion_num_layers)
+            ]
+        )
+
+        # Downsample & upsample blocks
+        self.input_embedder = FeedForward(
+            channels, int(hidden_size * mlp_ratio), out_dim=hidden_size
+        )
+        self.final_layer = FinalLayer(hidden_size, channels)
+
+        self.register_buffer(
+            "freqs_cis", precompute_freqs_cis(max_seq_len, hidden_size // num_heads)
+        )
+
+        self.initialize_weights()
+
+    def initialize_weights(self):
+        # Initialize input embedding:
+        self.input_embedder.apply(self.init_weight)
+        self.time_embedder.mlp.apply(self.init_weight)
+        self.condition_embedder.apply(self.init_weight)
+
+        if self.use_style:
+            self.style_embedder.apply(self.init_weight)
+
+        if hasattr(self, "cross_condition_embedder"):
+            self.cross_condition_embedder.apply(self.init_weight)
+
+        for block in self.diffusion_blocks:
+            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+            block.mix.apply(self.init_weight)
+
+        # Zero-out output layers:
+        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+        self.final_layer.linear.apply(self.init_weight)
+
+    def init_weight(self, m):
+        if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, nn.Linear)):
+            nn.init.normal_(m.weight, 0, 0.02)
+            if m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+
+    def forward(
+        self,
+        x,
+        time,
+        condition,
+        style=None,
+        self_mask=None,
+        cross_condition=None,
+        cross_mask=None,
+    ):
+        # Embed inputs
+        x = self.input_embedder(x)
+        t = self.time_embedder(time)
+
+        condition = self.condition_embedder(condition)
+
+        if self.use_style:
+            style = self.style_embedder(style)
+
+        if cross_condition is not None:
+            cross_condition = self.cross_condition_embedder(cross_condition)
+            cross_condition = t[:, None, :] + cross_condition
+
+        # Merge t, condition, and style
+        condition = t[:, None, :] + condition
+        if self.use_style:
+            condition = condition + style[:, None, :]
+
+        if self_mask is not None:
+            self_mask = self_mask[:, None, None, :]
+
+        if cross_mask is not None:
+            cross_mask = cross_mask[:, None, None, :]
+
+        # DiT
+        for block in self.diffusion_blocks:
+            x = block(
+                x,
+                condition,
+                self.freqs_cis,
+                self_mask=self_mask,
+                cross_condition=cross_condition,
+                cross_mask=cross_mask,
+            )
+
+        x = self.final_layer(x, condition)
+
+        return x
+
+
+if __name__ == "__main__":
+    model = DiT(
+        hidden_size=384,
+        num_heads=6,
+        diffusion_num_layers=12,
+        channels=160,
+        condition_dim=512,
+        style_dim=256,
+    )
+    bs, seq_len = 8, 1024
+    x = torch.randn(bs, seq_len, 160)
+    condition = torch.randn(bs, seq_len, 512)
+    style = torch.randn(bs, 256)
+    mask = torch.ones(bs, seq_len, dtype=torch.bool)
+    mask[0, 5:] = False
+    time = torch.arange(bs)
+    print(time)
+    out = model(x, time, condition, style, self_mask=mask)
+    print(out.shape)  # torch.Size([2, 100, 160])
+
+    # Print model size
+    num_params = sum(p.numel() for p in model.parameters())
+    print(f"Number of parameters: {num_params / 1e6:.1f}M")

+ 63 - 0
fish_speech/models/vqgan/modules/firefly.py

@@ -0,0 +1,63 @@
+import torch
+from torch import nn
+
+from .convnext import ConvNeXtEncoder
+from .hifigan import HiFiGANGenerator
+
+
+class FireflyBase(nn.Module):
+    def __init__(self, ckpt_path: str = None):
+        super().__init__()
+
+        self.backbone = ConvNeXtEncoder(
+            input_channels=160,
+            depths=[3, 3, 9, 3],
+            dims=[128, 256, 384, 512],
+            drop_path_rate=0.2,
+            kernel_sizes=[7],
+        )
+
+        self.head = HiFiGANGenerator(
+            hop_length=512,
+            upsample_rates=[8, 8, 2, 2, 2],
+            upsample_kernel_sizes=[16, 16, 4, 4, 4],
+            resblock_kernel_sizes=[3, 7, 11],
+            resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+            num_mels=512,
+            upsample_initial_channel=512,
+            use_template=True,
+            pre_conv_kernel_size=13,
+            post_conv_kernel_size=13,
+        )
+
+        if ckpt_path is None:
+            return
+
+        state_dict = torch.load(ckpt_path, map_location="cpu")
+
+        if "state_dict" in state_dict:
+            state_dict = state_dict["state_dict"]
+
+        if any("generator." in k for k in state_dict):
+            state_dict = {
+                k.replace("generator.", ""): v
+                for k, v in state_dict.items()
+                if "generator." in k
+            }
+
+        self.load_state_dict(state_dict, strict=True)
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.backbone(x)
+        return x
+
+    def decode(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.head(x)
+        if x.ndim == 2:
+            x = x[:, None, :]
+        return x
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.encode(x)
+        x = self.decode(x)
+        return x

+ 128 - 0
fish_speech/models/vqgan/modules/fsq.py

@@ -0,0 +1,128 @@
+from dataclasses import dataclass
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+from vector_quantize_pytorch import ResidualFSQ
+
+from .convnext import ConvNeXtBlock
+
+
+@dataclass
+class FSQResult:
+    z: torch.Tensor
+    codes: torch.Tensor
+    latents: torch.Tensor
+
+
+class DownsampleFiniteScalarQuantize(nn.Module):
+    def __init__(
+        self,
+        input_dim: int = 512,
+        n_codebooks: int = 9,
+        levels: tuple[int] = (8, 5, 5, 5),  # Approximate 2**10
+        downsample_factor: tuple[int] = (2, 2),
+        downsample_dims: tuple[int] | None = None,
+    ):
+        super().__init__()
+
+        if downsample_dims is None:
+            downsample_dims = [input_dim for _ in range(len(downsample_factor))]
+
+        all_dims = (input_dim,) + tuple(downsample_dims)
+
+        self.residual_fsq = ResidualFSQ(
+            dim=all_dims[-1],
+            levels=levels,
+            num_quantizers=n_codebooks,
+        )
+
+        self.downsample_factor = downsample_factor
+        self.downsample_dims = downsample_dims
+
+        self.downsample = nn.Sequential(
+            *[
+                nn.Sequential(
+                    nn.Conv1d(
+                        all_dims[idx],
+                        all_dims[idx + 1],
+                        kernel_size=factor,
+                        stride=factor,
+                    ),
+                    ConvNeXtBlock(dim=all_dims[idx + 1]),
+                    ConvNeXtBlock(dim=all_dims[idx + 1]),
+                )
+                for idx, factor in enumerate(downsample_factor)
+            ]
+        )
+
+        self.upsample = nn.Sequential(
+            *[
+                nn.Sequential(
+                    nn.ConvTranspose1d(
+                        all_dims[idx + 1],
+                        all_dims[idx],
+                        kernel_size=factor,
+                        stride=factor,
+                    ),
+                    ConvNeXtBlock(dim=all_dims[idx]),
+                    ConvNeXtBlock(dim=all_dims[idx]),
+                )
+                for idx, factor in reversed(list(enumerate(downsample_factor)))
+            ]
+        )
+
+    def forward(self, z) -> FSQResult:
+        original_shape = z.shape
+        z = self.downsample(z)
+        quantized, indices = self.residual_fsq(z.mT)
+        result = FSQResult(
+            z=quantized.mT,
+            codes=indices.mT,
+            latents=z,
+        )
+        result.z = self.upsample(result.z)
+
+        # Pad or crop z to match original shape
+        diff = original_shape[-1] - result.z.shape[-1]
+        left = diff // 2
+        right = diff - left
+
+        if diff > 0:
+            result.z = F.pad(result.z, (left, right))
+        elif diff < 0:
+            result.z = result.z[..., left:-right]
+
+        return result
+
+    # def from_codes(self, codes: torch.Tensor):
+    #     z_q, z_p, codes = self.residual_fsq.get_output_from_indices(codes)
+    #     z_q = self.upsample(z_q)
+    #     return z_q, z_p, codes
+
+    # def from_latents(self, latents: torch.Tensor):
+    #     z_q, z_p, codes = super().from_latents(latents)
+    #     z_q = self.upsample(z_q)
+    #     return z_q, z_p, codes
+
+
+if __name__ == "__main__":
+    rvq = DownsampleFiniteScalarQuantize(
+        n_codebooks=1,
+        downsample_factor=(2, 2),
+    )
+    x = torch.randn(16, 512, 80)
+
+    result = rvq(x)
+    print(rvq)
+    print(result.latents.shape, result.codes.shape, result.z.shape)
+
+    # y = rvq.from_codes(result.codes)
+    # print(y[0].shape)
+
+    # y = rvq.from_latents(result.latents)
+    # print(y[0].shape)

+ 278 - 0
fish_speech/models/vqgan/modules/hifigan.py

@@ -0,0 +1,278 @@
+from functools import partial
+from math import prod
+from typing import Callable, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Conv1d
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return (kernel_size * dilation - dilation) // 2
+
+
+class ResBlock(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super().__init__()
+
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.silu(x)
+            xt = c1(xt)
+            xt = F.silu(xt)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_parametrizations(self):
+        for conv in self.convs1:
+            remove_parametrizations(conv)
+        for conv in self.convs2:
+            remove_parametrizations(conv)
+
+
+class ParralelBlock(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        kernel_sizes: tuple[int] = (3, 7, 11),
+        dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+    ):
+        super().__init__()
+
+        assert len(kernel_sizes) == len(dilation_sizes)
+
+        self.blocks = nn.ModuleList()
+        for k, d in zip(kernel_sizes, dilation_sizes):
+            self.blocks.append(ResBlock(channels, k, d))
+
+    def forward(self, x):
+        xs = [block(x) for block in self.blocks]
+
+        return torch.stack(xs, dim=0).mean(dim=0)
+
+
+class HiFiGANGenerator(nn.Module):
+    def __init__(
+        self,
+        *,
+        hop_length: int = 512,
+        upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
+        upsample_kernel_sizes: tuple[int] = (16, 16, 4, 4, 4),
+        resblock_kernel_sizes: tuple[int] = (3, 7, 11),
+        resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+        num_mels: int = 160,
+        upsample_initial_channel: int = 512,
+        use_template: bool = True,
+        pre_conv_kernel_size: int = 7,
+        post_conv_kernel_size: int = 7,
+        post_activation: Callable = partial(nn.SiLU, inplace=True),
+        checkpointing: bool = False,
+        condition_dim: Optional[int] = None,
+    ):
+        super().__init__()
+
+        assert (
+            prod(upsample_rates) == hop_length
+        ), f"hop_length must be {prod(upsample_rates)}"
+
+        self.conv_pre = weight_norm(
+            nn.Conv1d(
+                num_mels,
+                upsample_initial_channel,
+                pre_conv_kernel_size,
+                1,
+                padding=get_padding(pre_conv_kernel_size),
+            )
+        )
+
+        self.hop_length = hop_length
+        self.num_upsamples = len(upsample_rates)
+        self.num_kernels = len(resblock_kernel_sizes)
+
+        self.noise_convs = nn.ModuleList()
+        self.use_template = use_template
+        self.ups = nn.ModuleList()
+        self.condition_dim = condition_dim
+
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            c_cur = upsample_initial_channel // (2 ** (i + 1))
+            self.ups.append(
+                weight_norm(
+                    nn.ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+            if not use_template:
+                continue
+
+            if i + 1 < len(upsample_rates):
+                stride_f0 = np.prod(upsample_rates[i + 1 :])
+                self.noise_convs.append(
+                    Conv1d(
+                        1,
+                        c_cur,
+                        kernel_size=stride_f0 * 2,
+                        stride=stride_f0,
+                        padding=stride_f0 // 2,
+                    )
+                )
+            else:
+                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            self.resblocks.append(
+                ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
+            )
+
+        self.activation_post = post_activation()
+        self.conv_post = weight_norm(
+            nn.Conv1d(
+                ch,
+                1,
+                post_conv_kernel_size,
+                1,
+                padding=get_padding(post_conv_kernel_size),
+            )
+        )
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+        if condition_dim is not None:
+            self.condition = nn.Conv1d(condition_dim, upsample_initial_channel, 1)
+
+        # Gradient checkpointing
+        self.checkpointing = checkpointing
+
+    def forward(self, x, template=None, condition=None):
+        if self.use_template and template is None:
+            length = x.shape[-1] * self.hop_length
+            template = (
+                torch.randn(x.shape[0], 1, length, device=x.device, dtype=x.dtype)
+                * 0.003
+            )
+
+        if self.condition_dim is not None:
+            x = x + self.condition(condition)
+
+        x = self.conv_pre(x)
+
+        for i in range(self.num_upsamples):
+            x = F.silu(x, inplace=True)
+            x = self.ups[i](x)
+
+            if self.use_template:
+                x = x + self.noise_convs[i](template)
+
+            if self.training and self.checkpointing:
+                x = torch.utils.checkpoint.checkpoint(
+                    self.resblocks[i],
+                    x,
+                    use_reentrant=False,
+                )
+            else:
+                x = self.resblocks[i](x)
+
+        x = self.activation_post(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_parametrizations(self):
+        for up in self.ups:
+            remove_parametrizations(up)
+        for block in self.resblocks:
+            block.remove_parametrizations()
+        remove_parametrizations(self.conv_pre)
+        remove_parametrizations(self.conv_post)

+ 20 - 11
fish_speech/models/vqgan/spectrogram.py

@@ -21,7 +21,7 @@ class LinearSpectrogram(nn.Module):
         self.center = center
         self.mode = mode
 
-        self.register_buffer("window", torch.hann_window(win_length))
+        self.register_buffer("window", torch.hann_window(win_length), persistent=False)
 
     def forward(self, y: Tensor) -> Tensor:
         if y.ndim == 3:
@@ -78,17 +78,23 @@ class LogMelSpectrogram(nn.Module):
         self.center = center
         self.n_mels = n_mels
         self.f_min = f_min
-        self.f_max = f_max or sample_rate // 2
+        self.f_max = f_max or float(sample_rate // 2)
 
         self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
-        self.mel_scale = MelScale(
-            self.n_mels,
-            self.sample_rate,
-            self.f_min,
-            self.f_max,
-            self.n_fft // 2 + 1,
-            "slaney",
-            "slaney",
+
+        fb = F.melscale_fbanks(
+            n_freqs=self.n_fft // 2 + 1,
+            f_min=self.f_min,
+            f_max=self.f_max,
+            n_mels=self.n_mels,
+            sample_rate=self.sample_rate,
+            norm="slaney",
+            mel_scale="slaney",
+        )
+        self.register_buffer(
+            "fb",
+            fb,
+            persistent=False,
         )
 
     def compress(self, x: Tensor) -> Tensor:
@@ -97,6 +103,9 @@ class LogMelSpectrogram(nn.Module):
     def decompress(self, x: Tensor) -> Tensor:
         return torch.exp(x)
 
+    def apply_mel_scale(self, x: Tensor) -> Tensor:
+        return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
+
     def forward(
         self, x: Tensor, return_linear: bool = False, sample_rate: int = None
     ) -> Tensor:
@@ -104,7 +113,7 @@ class LogMelSpectrogram(nn.Module):
             x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
 
         linear = self.spectrogram(x)
-        x = self.mel_scale(linear)
+        x = self.apply_mel_scale(linear)
         x = self.compress(x)
 
         if return_linear: