Sfoglia il codice sorgente

Implement new vq + dit + reflow

Lengyue 2 anni fa
parent
commit
7b5fe470e4

+ 53 - 54
fish_speech/configs/vqgan_pretrain.yaml

@@ -2,36 +2,36 @@ defaults:
   - base
   - base
   - _self_
   - _self_
 
 
-project: vqgan_pretrain_lfq
-ckpt_path: checkpoints/gpt_sovits_488k.pth
-resume_weights_only: true
+project: vq_reflow_debug
 
 
 # Lightning Trainer
 # Lightning Trainer
 trainer:
 trainer:
   accelerator: gpu
   accelerator: gpu
   devices: auto
   devices: auto
   strategy: ddp_find_unused_parameters_true
   strategy: ddp_find_unused_parameters_true
-  precision: 32
+  precision: 16-mixed
   max_steps: 1_000_000
   max_steps: 1_000_000
   val_check_interval: 2000
   val_check_interval: 2000
+  gradient_clip_algorithm: norm
+  gradient_clip_val: 1.0
 
 
-sample_rate: 32000
-hop_length: 640
-num_mels: 128
+sample_rate: 44100
+hop_length: 512
+num_mels: 160
 n_fft: 2048
 n_fft: 2048
 win_length: 2048
 win_length: 2048
 
 
 # Dataset Configuration
 # Dataset Configuration
 train_dataset:
 train_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_train_filelist.txt
+  filelist: /***REMOVED***/workspace/diffusion-test/data/HiFi-TTS/vq_train_filelist.txt
   sample_rate: ${sample_rate}
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   hop_length: ${hop_length}
-  slice_frames: 128
+  slice_frames: 512
 
 
 val_dataset:
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vq_val_filelist.txt
+  filelist: /***REMOVED***/workspace/diffusion-test/data/HiFi-TTS/vq_val_filelist.txt
   sample_rate: ${sample_rate}
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   hop_length: ${hop_length}
 
 
@@ -40,49 +40,48 @@ data:
   train_dataset: ${train_dataset}
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
   num_workers: 4
-  batch_size: 16
+  batch_size: 32
   val_batch_size: 4
   val_batch_size: 4
 
 
 # Model Configuration
 # Model Configuration
 model:
 model:
   _target_: fish_speech.models.vqgan.VQGAN
   _target_: fish_speech.models.vqgan.VQGAN
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  freeze_discriminator: false
 
 
-  weight_mel: 45.0
-  weight_kl: 0.1
+  sampling_rate: ${sample_rate}
+  weight_reflow: 1.0
   weight_vq: 1.0
   weight_vq: 1.0
   weight_aux_mel: 1.0
   weight_aux_mel: 1.0
 
 
-  generator:
-    _target_: fish_speech.models.vqgan.modules.models.SynthesizerTrn
-    spec_channels: 1025
-    segment_size: 32
-    inter_channels: 192
-    prior_hidden_channels: 192
-    posterior_hidden_channels: 192
-    prior_n_layers: 16
-    posterior_n_layers: 16
-    kernel_size: 5
-    p_dropout: 0.1
-    resblock: "1"
-    resblock_kernel_sizes: [3, 7, 11]
-    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    upsample_rates: [10, 8, 2, 2, 2]
-    upsample_initial_channel: 512
-    upsample_kernel_sizes: [16, 16, 8, 2, 2]
-    gin_channels: 512
-    freeze_quantizer: false
-    freeze_decoder: false
-    freeze_posterior_encoder: false
-    codebook_size: 1024
-    num_codebooks: 2
-    aux_spec_channels: ${num_mels}
+  encoder:
+    _target_: fish_speech.models.vqgan.modules.convnext.ConvNeXtEncoder
+    input_channels: ${num_mels}
+    depths: [3, 3, 9, 3]
+    dims: [128, 256, 384, 512]
+  
+  quantizer:
+    _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
+    input_dim: 512
+    n_codebooks: 8
+    levels: [8, 5, 5, 5]
+  
+  aux_decoder:
+    _target_: fish_speech.models.vqgan.modules.convnext.ConvNeXtEncoder
+    input_channels: 512
+    output_channels: ${num_mels}
+    depths: [6]
+    dims: [384]
+
+  reflow:
+    _target_: fish_speech.models.vqgan.modules.dit.DiT
+    hidden_size: 768
+    num_heads: 12
+    diffusion_num_layers: 12
+    channels: ${num_mels}
+    condition_dim: 512
 
 
-  discriminator:
-    _target_: fish_speech.models.vqgan.modules.models.EnsembledDiscriminator
-    periods: [2, 3, 5, 7, 11]
+  vocoder:
+    _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
+    ckpt_path: checkpoints/firefly-gan-base-002000000.ckpt
 
 
   mel_transform:
   mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
@@ -92,13 +91,6 @@ model:
     win_length: ${win_length}
     win_length: ${win_length}
     n_mels: ${num_mels}
     n_mels: ${num_mels}
 
 
-  spec_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LinearSpectrogram
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    mode: pow2_sqrt
-  
   optimizer:
   optimizer:
     _target_: torch.optim.AdamW
     _target_: torch.optim.AdamW
     _partial_: true
     _partial_: true
@@ -107,12 +99,19 @@ model:
     eps: 1e-5
     eps: 1e-5
 
 
   lr_scheduler:
   lr_scheduler:
-    _target_: torch.optim.lr_scheduler.ExponentialLR
+    _target_: torch.optim.lr_scheduler.LambdaLR
     _partial_: true
     _partial_: true
-    gamma: 0.99999
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 100
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0
 
 
 callbacks:
 callbacks:
   grad_norm_monitor:
   grad_norm_monitor:
     sub_module: 
     sub_module: 
-      - generator
-      - discriminator
+      - encoder
+      - aux_decoder
+      - quantizer
+      - reflow

+ 169 - 316
fish_speech/models/vqgan/lit_module.py

@@ -9,15 +9,7 @@ import wandb
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from matplotlib import pyplot as plt
 from torch import nn
 from torch import nn
-from torch.utils.checkpoint import checkpoint as gradient_checkpoint
-
-from fish_speech.models.vqgan.losses import (
-    MultiResolutionSTFTLoss,
-    discriminator_loss,
-    feature_loss,
-    generator_loss,
-    kl_loss,
-)
+
 from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
 from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
 
 
 
 
@@ -40,17 +32,16 @@ class VQGAN(L.LightningModule):
         self,
         self,
         optimizer: Callable,
         optimizer: Callable,
         lr_scheduler: Callable,
         lr_scheduler: Callable,
-        generator: nn.Module,
-        discriminator: nn.Module,
+        encoder: nn.Module,
+        quantizer: nn.Module,
+        aux_decoder: nn.Module,
+        reflow: nn.Module,
+        vocoder: nn.Module,
         mel_transform: nn.Module,
         mel_transform: nn.Module,
-        spec_transform: nn.Module,
-        hop_length: int = 640,
-        sample_rate: int = 32000,
-        freeze_discriminator: bool = False,
-        weight_mel: float = 45,
-        weight_kl: float = 0.1,
+        weight_reflow: float = 1.0,
         weight_vq: float = 1.0,
         weight_vq: float = 1.0,
-        weight_aux_mel: float = 20.0,
+        weight_aux_mel: float = 1.0,
+        sampling_rate: int = 44100,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
@@ -58,62 +49,54 @@ class VQGAN(L.LightningModule):
         self.optimizer_builder = optimizer
         self.optimizer_builder = optimizer
         self.lr_scheduler_builder = lr_scheduler
         self.lr_scheduler_builder = lr_scheduler
 
 
-        # Generator and discriminator
-        self.generator = generator
-        self.discriminator = discriminator
+        # Modules
+        self.encoder = encoder
+        self.quantizer = quantizer
+        self.aux_decoder = aux_decoder
+        self.reflow = reflow
         self.mel_transform = mel_transform
         self.mel_transform = mel_transform
-        self.spec_transform = spec_transform
-        self.freeze_discriminator = freeze_discriminator
+        self.vocoder = vocoder
+
+        # Freeze vocoder
+        for param in self.vocoder.parameters():
+            param.requires_grad = False
 
 
         # Loss weights
         # Loss weights
-        self.weight_mel = weight_mel
-        self.weight_kl = weight_kl
+        self.weight_reflow = weight_reflow
         self.weight_vq = weight_vq
         self.weight_vq = weight_vq
         self.weight_aux_mel = weight_aux_mel
         self.weight_aux_mel = weight_aux_mel
 
 
-        # Other parameters
-        self.hop_length = hop_length
-        self.sampling_rate = sample_rate
-
-        # Disable automatic optimization
-        self.automatic_optimization = False
+        self.spec_min = -12
+        self.spec_max = 3
+        self.sampling_rate = sampling_rate
 
 
-        if self.freeze_discriminator:
-            for p in self.discriminator.parameters():
-                p.requires_grad = False
+    def on_save_checkpoint(self, checkpoint):
+        # Do not save vocoder
+        state_dict = checkpoint["state_dict"]
+        for name in list(state_dict.keys()):
+            if "vocoder" in name:
+                state_dict.pop(name)
 
 
     def configure_optimizers(self):
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
         # Need two optimizers and two schedulers
-        optimizer_generator = self.optimizer_builder(self.generator.parameters())
-        optimizer_discriminator = self.optimizer_builder(
-            self.discriminator.parameters()
-        )
-
-        lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
-        lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
-
-        return (
-            {
-                "optimizer": optimizer_generator,
-                "lr_scheduler": {
-                    "scheduler": lr_scheduler_generator,
-                    "interval": "step",
-                    "name": "optimizer/generator",
-                },
+        optimizer = self.optimizer_builder(self.parameters())
+        lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+        return {
+            "optimizer": optimizer,
+            "lr_scheduler": {
+                "scheduler": lr_scheduler,
+                "interval": "step",
             },
             },
-            {
-                "optimizer": optimizer_discriminator,
-                "lr_scheduler": {
-                    "scheduler": lr_scheduler_discriminator,
-                    "interval": "step",
-                    "name": "optimizer/discriminator",
-                },
-            },
-        )
+        }
 
 
-    def training_step(self, batch, batch_idx):
-        optim_g, optim_d = self.optimizers()
+    def norm_spec(self, x):
+        return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
 
 
+    def denorm_spec(self, x):
+        return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
+
+    def training_step(self, batch, batch_idx):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
 
 
         audios = audios.float()
         audios = audios.float()
@@ -121,173 +104,84 @@ class VQGAN(L.LightningModule):
 
 
         with torch.no_grad():
         with torch.no_grad():
             gt_mels = self.mel_transform(audios)
             gt_mels = self.mel_transform(audios)
-            gt_specs = self.spec_transform(audios)
-
-        spec_lengths = audio_lengths // self.hop_length
-        spec_masks = torch.unsqueeze(
-            sequence_mask(spec_lengths, gt_mels.shape[2]), 1
-        ).to(gt_mels.dtype)
-        (
-            fake_audios,
-            ids_slice,
-            y_mask,
-            y_mask,
-            (z, z_p, m_p, logs_p, m_q, logs_q),
-            loss_vq,
-            decoded_aux_mels,
-        ) = self.generator(gt_specs, spec_lengths)
 
 
-        gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
-        decoded_aux_mels = slice_segments(
-            decoded_aux_mels, ids_slice, self.generator.segment_size
-        )
-        spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
-        audios = slice_segments(
-            audios,
-            ids_slice * self.hop_length,
-            self.generator.segment_size * self.hop_length,
-        )
-        fake_mels = self.mel_transform(fake_audios.squeeze(1))
+        mel_lengths = audio_lengths // self.mel_transform.hop_length
+        mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
+        mel_masks_float_conv = mel_masks[:, None, :].float()
 
 
-        assert (
-            audios.shape == fake_audios.shape
-        ), f"{audios.shape} != {fake_audios.shape}"
-
-        # Discriminator
-        if self.freeze_discriminator is False:
-            y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(
-                audios, fake_audios.detach()
-            )
+        # Encode
+        encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
 
 
-            with torch.autocast(device_type=audios.device.type, enabled=False):
-                loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
-
-            self.log(
-                f"train/discriminator/loss",
-                loss_disc,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
+        # Quantize
+        vq_result = self.quantizer(encoded_features)
+        loss_vq = getattr("vq_result", "loss", 0.0)
+        vq_recon_features = vq_result.z * mel_masks_float_conv
 
 
-            optim_d.zero_grad()
-            self.manual_backward(loss_disc)
-            self.clip_gradients(
-                optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
-            )
-            optim_d.step()
+        # VQ Decode
+        aux_mel = self.aux_decoder(vq_recon_features)
+        loss_aux_mel = F.l1_loss(
+            aux_mel * mel_masks_float_conv, gt_mels * mel_masks_float_conv
+        )
 
 
-        # Adv Loss
-        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios)
+        # Reflow
+        x_1 = self.norm_spec(gt_mels.mT)
+        t = torch.rand(gt_mels.shape[0], device=gt_mels.device)
+        x_0 = torch.randn_like(x_1)
 
 
-        # Adversarial Loss
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_adv, _ = generator_loss(y_d_hat_g)
+        # X_t = t * X_1 + (1 - t) * X_0
+        x_t = x_0 + t[:, None, None] * (x_1 - x_0)
 
 
-        self.log(
-            f"train/generator/adv",
-            loss_adv,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
+        v_pred = self.reflow(
+            x_t,
+            1000 * t,
+            condition=vq_recon_features.mT,
+            self_mask=mel_masks,
         )
         )
 
 
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_fm = feature_loss(y_d_hat_r, y_d_hat_g)
-
-        self.log(
-            f"train/generator/adv_fm",
-            loss_fm,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
+        # Log L2 loss with
+        weights = 0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
+        loss_reflow = weights[:, None, None] * F.mse_loss(
+            x_1 - x_0, v_pred, reduction="none"
         )
         )
+        loss_reflow = (loss_reflow * mel_masks_float_conv.mT).mean()
 
 
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_mel = F.l1_loss(gt_mels * spec_masks, fake_mels * spec_masks)
-            loss_aux_mel = F.l1_loss(
-                gt_mels * spec_masks, decoded_aux_mels * spec_masks
-            )
-
-        self.log(
-            "train/generator/loss_mel",
-            loss_mel,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
+        # Total loss
+        loss = (
+            self.weight_vq * loss_vq
+            + self.weight_aux_mel * loss_aux_mel
+            + self.weight_reflow * loss_reflow
         )
         )
 
 
+        # Log losses
         self.log(
         self.log(
-            "train/generator/loss_aux_mel",
-            loss_aux_mel,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
+            "train/loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True
         )
         )
-
         self.log(
         self.log(
-            "train/generator/loss_vq",
+            "train/loss_vq",
             loss_vq,
             loss_vq,
             on_step=True,
             on_step=True,
             on_epoch=False,
             on_epoch=False,
             prog_bar=False,
             prog_bar=False,
             logger=True,
             logger=True,
-            sync_dist=True,
         )
         )
-
-        loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
-
         self.log(
         self.log(
-            "train/generator/loss_kl",
-            loss_kl,
+            "train/loss_aux_mel",
+            loss_aux_mel,
             on_step=True,
             on_step=True,
             on_epoch=False,
             on_epoch=False,
             prog_bar=False,
             prog_bar=False,
             logger=True,
             logger=True,
-            sync_dist=True,
-        )
-
-        loss = (
-            loss_mel * self.weight_mel
-            + loss_aux_mel * self.weight_aux_mel
-            + loss_vq * self.weight_vq
-            + loss_kl * self.weight_kl
-            + loss_adv
-            + loss_fm
         )
         )
         self.log(
         self.log(
-            "train/generator/loss",
-            loss,
+            "train/loss_reflow",
+            loss_reflow,
             on_step=True,
             on_step=True,
             on_epoch=False,
             on_epoch=False,
             prog_bar=False,
             prog_bar=False,
             logger=True,
             logger=True,
-            sync_dist=True,
-        )
-
-        # Backward
-        optim_g.zero_grad()
-
-        self.manual_backward(loss)
-        self.clip_gradients(
-            optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
         )
         )
-        optim_g.step()
 
 
-        # Manual LR Scheduler
-        scheduler_g, scheduler_d = self.lr_schedulers()
-        scheduler_g.step()
-        scheduler_d.step()
+        return loss
 
 
     def validation_step(self, batch: Any, batch_idx: int):
     def validation_step(self, batch: Any, batch_idx: int):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
@@ -296,32 +190,25 @@ class VQGAN(L.LightningModule):
         audios = audios[:, None, :]
         audios = audios[:, None, :]
 
 
         gt_mels = self.mel_transform(audios)
         gt_mels = self.mel_transform(audios)
-        gt_specs = self.spec_transform(audios)
-        spec_lengths = audio_lengths // self.hop_length
-        spec_masks = torch.unsqueeze(
-            sequence_mask(spec_lengths, gt_mels.shape[2]), 1
-        ).to(gt_mels.dtype)
-
-        prior_audios, _, _ = self.generator.infer(gt_specs, spec_lengths)
-        posterior_audios, _, _ = self.generator.infer_posterior(gt_specs, spec_lengths)
-        prior_mels = self.mel_transform(prior_audios.squeeze(1))
-        posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
-
-        min_mel_length = min(
-            gt_mels.shape[-1], prior_mels.shape[-1], posterior_mels.shape[-1]
-        )
-        gt_mels = gt_mels[:, :, :min_mel_length]
-        prior_mels = prior_mels[:, :, :min_mel_length]
-        posterior_mels = posterior_mels[:, :, :min_mel_length]
+        mel_lengths = audio_lengths // self.mel_transform.hop_length
+        mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
+        mel_masks_float_conv = mel_masks[:, None, :].float()
+
+        # Encode
+        encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
 
 
-        prior_mel_loss = F.l1_loss(gt_mels * spec_masks, prior_mels * spec_masks)
-        posterior_mel_loss = F.l1_loss(
-            gt_mels * spec_masks, posterior_mels * spec_masks
+        # Quantize
+        vq_result = self.quantizer(encoded_features)
+
+        # VQ Decode
+        aux_mels = self.aux_decoder(vq_result.z)
+        loss_aux_mel = F.l1_loss(
+            aux_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
         )
         )
 
 
         self.log(
         self.log(
-            "val/prior_mel_loss",
-            prior_mel_loss,
+            "val/loss_aux_mel",
+            loss_aux_mel,
             on_step=False,
             on_step=False,
             on_epoch=True,
             on_epoch=True,
             prog_bar=False,
             prog_bar=False,
@@ -329,9 +216,33 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
             sync_dist=True,
         )
         )
 
 
+        # Reflow inference
+        t_start = 0.0
+        infer_step = 20
+        gen_mels = torch.randn(gt_mels.shape, device=gt_mels.device).mT
+        t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
+        dt = (1.0 - t_start) / infer_step
+
+        for _ in range(infer_step):
+            gen_mels += (
+                self.reflow(
+                    gen_mels,
+                    1000 * t,
+                    condition=vq_result.z.mT,
+                    self_mask=mel_masks,
+                )
+                * dt
+            )
+            t += dt
+
+        gen_mels = self.denorm_spec(gen_mels).mT
+        loss_recon_reflow = F.l1_loss(
+            gen_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
+        )
+
         self.log(
         self.log(
-            "val/posterior_mel_loss",
-            posterior_mel_loss,
+            "val/loss_recon_reflow",
+            loss_recon_reflow,
             on_step=False,
             on_step=False,
             on_epoch=True,
             on_epoch=True,
             prog_bar=False,
             prog_bar=False,
@@ -339,41 +250,47 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
             sync_dist=True,
         )
         )
 
 
+        gen_audios = self.vocoder(gen_mels)
+        recon_audios = self.vocoder(gt_mels)
+        aux_audios = self.vocoder(aux_mels)
+
         # only log the first batch
         # only log the first batch
         if batch_idx != 0:
         if batch_idx != 0:
             return
             return
 
 
         for idx, (
         for idx, (
-            mel,
-            prior_mel,
-            posterior_mel,
+            gt_mel,
+            reflow_mel,
+            aux_mel,
             audio,
             audio,
-            prior_audio,
-            posterior_audio,
+            reflow_audio,
+            aux_audio,
+            recon_audio,
             audio_len,
             audio_len,
         ) in enumerate(
         ) in enumerate(
             zip(
             zip(
                 gt_mels,
                 gt_mels,
-                prior_mels,
-                posterior_mels,
-                audios.detach().float(),
-                prior_audios.detach().float(),
-                posterior_audios.detach().float(),
+                gen_mels,
+                aux_mels,
+                audios.float(),
+                gen_audios.float(),
+                aux_audios.float(),
+                recon_audios.float(),
                 audio_lengths,
                 audio_lengths,
             )
             )
         ):
         ):
-            mel_len = audio_len // self.hop_length
+            mel_len = audio_len // self.mel_transform.hop_length
 
 
             image_mels = plot_mel(
             image_mels = plot_mel(
                 [
                 [
-                    prior_mel[:, :mel_len],
-                    posterior_mel[:, :mel_len],
-                    mel[:, :mel_len],
+                    gt_mel[:, :mel_len],
+                    reflow_mel[:, :mel_len],
+                    aux_mel[:, :mel_len],
                 ],
                 ],
                 [
                 [
-                    "Prior (VQ)",
-                    "Posterior (Reconstruction)",
                     "Ground-Truth",
                     "Ground-Truth",
+                    "Reflow",
+                    "Aux",
                 ],
                 ],
             )
             )
 
 
@@ -388,14 +305,19 @@ class VQGAN(L.LightningModule):
                                 caption="gt",
                                 caption="gt",
                             ),
                             ),
                             wandb.Audio(
                             wandb.Audio(
-                                prior_audio[0, :audio_len],
+                                reflow_audio[0, :audio_len],
                                 sample_rate=self.sampling_rate,
                                 sample_rate=self.sampling_rate,
-                                caption="prior",
+                                caption="reflow",
                             ),
                             ),
                             wandb.Audio(
                             wandb.Audio(
-                                posterior_audio[0, :audio_len],
+                                aux_audio[0, :audio_len],
                                 sample_rate=self.sampling_rate,
                                 sample_rate=self.sampling_rate,
-                                caption="posterior",
+                                caption="aux",
+                            ),
+                            wandb.Audio(
+                                recon_audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="recon",
                             ),
                             ),
                         ],
                         ],
                     },
                     },
@@ -414,91 +336,22 @@ class VQGAN(L.LightningModule):
                     sample_rate=self.sampling_rate,
                     sample_rate=self.sampling_rate,
                 )
                 )
                 self.logger.experiment.add_audio(
                 self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/prior",
-                    prior_audio[0, :audio_len],
+                    f"sample-{idx}/wavs/reflow",
+                    reflow_audio[0, :audio_len],
                     self.global_step,
                     self.global_step,
                     sample_rate=self.sampling_rate,
                     sample_rate=self.sampling_rate,
                 )
                 )
                 self.logger.experiment.add_audio(
                 self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/posterior",
-                    posterior_audio[0, :audio_len],
+                    f"sample-{idx}/wavs/aux",
+                    aux_audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/recon",
+                    recon_audio[0, :audio_len],
                     self.global_step,
                     self.global_step,
                     sample_rate=self.sampling_rate,
                     sample_rate=self.sampling_rate,
                 )
                 )
 
 
             plt.close(image_mels)
             plt.close(image_mels)
-
-    # def encode(self, audios, audio_lengths=None):
-    #     if audio_lengths is None:
-    #         audio_lengths = torch.tensor(
-    #             [audios.shape[-1]] * audios.shape[0],
-    #             device=audios.device,
-    #             dtype=torch.long,
-    #         )
-
-    #     with torch.no_grad():
-    #         features = self.mel_transform(audios, sample_rate=self.sampling_rate)
-
-    #     feature_lengths = (
-    #         audio_lengths
-    #         / self.hop_length
-    #         # / self.vq.downsample
-    #     ).long()
-
-    #     # print(features.shape, feature_lengths.shape, torch.max(feature_lengths))
-
-    #     feature_masks = torch.unsqueeze(
-    #         sequence_mask(feature_lengths, features.shape[2]), 1
-    #     ).to(features.dtype)
-
-    #     features = (
-    #         gradient_checkpoint(
-    #             self.encoder, features, feature_masks, use_reentrant=False
-    #         )
-    #         * feature_masks
-    #     )
-    #     vq_features, indices, loss = self.vq(features, feature_masks)
-
-    #     return VQEncodeResult(
-    #         features=vq_features,
-    #         indices=indices,
-    #         loss=loss,
-    #         feature_lengths=feature_lengths,
-    #     )
-
-    # def calculate_audio_lengths(self, feature_lengths):
-    #     return feature_lengths * self.hop_length * self.vq.downsample
-
-    # def decode(
-    #     self,
-    #     indices=None,
-    #     features=None,
-    #     audio_lengths=None,
-    #     feature_lengths=None,
-    #     return_audios=False,
-    # ):
-    #     assert (
-    #         indices is not None or features is not None
-    #     ), "indices or features must be provided"
-    #     assert (
-    #         feature_lengths is not None or audio_lengths is not None
-    #     ), "feature_lengths or audio_lengths must be provided"
-
-    #     if audio_lengths is None:
-    #         audio_lengths = self.calculate_audio_lengths(feature_lengths)
-
-    #     mel_lengths = audio_lengths // self.hop_length
-    #     mel_masks = torch.unsqueeze(
-    #         sequence_mask(mel_lengths, torch.max(mel_lengths)), 1
-    #     ).float()
-
-    #     if indices is not None:
-    #         features = self.vq.decode(indices)
-
-    #     # Sample mels
-    #     decoded = gradient_checkpoint(self.decoder, features, use_reentrant=False)
-
-    #     return VQDecodeResult(
-    #         mels=decoded,
-    #         audios=self.generator(decoded) if return_audios else None,
-    #     )

+ 249 - 0
fish_speech/models/vqgan/modules/convnext.py

@@ -0,0 +1,249 @@
+from functools import partial
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+def drop_path(
+    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+    'survival rate' as the argument.
+
+    """  # noqa: E501
+
+    if drop_prob == 0.0 or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (
+        x.ndim - 1
+    )  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+    if keep_prob > 0.0 and scale_by_keep:
+        random_tensor.div_(keep_prob)
+    return x * random_tensor
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""  # noqa: E501
+
+    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+        self.scale_by_keep = scale_by_keep
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+    def extra_repr(self):
+        return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class LayerNorm(nn.Module):
+    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+    with shape (batch_size, channels, height, width).
+    """  # noqa: E501
+
+    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(normalized_shape))
+        self.bias = nn.Parameter(torch.zeros(normalized_shape))
+        self.eps = eps
+        self.data_format = data_format
+        if self.data_format not in ["channels_last", "channels_first"]:
+            raise NotImplementedError
+        self.normalized_shape = (normalized_shape,)
+
+    def forward(self, x):
+        if self.data_format == "channels_last":
+            return F.layer_norm(
+                x, self.normalized_shape, self.weight, self.bias, self.eps
+            )
+        elif self.data_format == "channels_first":
+            u = x.mean(1, keepdim=True)
+            s = (x - u).pow(2).mean(1, keepdim=True)
+            x = (x - u) / torch.sqrt(s + self.eps)
+            x = self.weight[:, None] * x + self.bias[:, None]
+            return x
+
+
+class ConvNeXtBlock(nn.Module):
+    r"""ConvNeXt Block. There are two equivalent implementations:
+    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+    We use (2) as we find it slightly faster in PyTorch
+
+    Args:
+        dim (int): Number of input channels.
+        drop_path (float): Stochastic depth rate. Default: 0.0
+        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+        kernel_size (int): Kernel size for depthwise conv. Default: 7.
+        dilation (int): Dilation for depthwise conv. Default: 1.
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        dim: int,
+        drop_path: float = 0.0,
+        layer_scale_init_value: float = 1e-6,
+        mlp_ratio: float = 4.0,
+        kernel_size: int = 7,
+        dilation: int = 1,
+    ):
+        super().__init__()
+
+        self.dwconv = nn.Conv1d(
+            dim,
+            dim,
+            kernel_size=kernel_size,
+            padding=int(dilation * (kernel_size - 1) / 2),
+            groups=dim,
+        )  # depthwise conv
+        self.norm = LayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(
+            dim, int(mlp_ratio * dim)
+        )  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
+        self.gamma = (
+            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+            if layer_scale_init_value > 0
+            else None
+        )
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+    def forward(self, x, apply_residual: bool = True):
+        input = x
+
+        x = self.dwconv(x)
+        x = x.permute(0, 2, 1)  # (N, C, L) -> (N, L, C)
+        x = self.norm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+
+        if self.gamma is not None:
+            x = self.gamma * x
+
+        x = x.permute(0, 2, 1)  # (N, L, C) -> (N, C, L)
+        x = self.drop_path(x)
+
+        if apply_residual:
+            x = input + x
+
+        return x
+
+
+class ParallelConvNeXtBlock(nn.Module):
+    def __init__(self, kernel_sizes: list[int], *args, **kwargs):
+        super().__init__()
+        self.blocks = nn.ModuleList(
+            [
+                ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
+                for kernel_size in kernel_sizes
+            ]
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return torch.stack(
+            [block(x, apply_residual=False) for block in self.blocks] + [x],
+            dim=1,
+        ).sum(dim=1)
+
+
+class ConvNeXtEncoder(nn.Module):
+    def __init__(
+        self,
+        input_channels: int = 3,
+        output_channels: Optional[int] = None,
+        depths: list[int] = [3, 3, 9, 3],
+        dims: list[int] = [96, 192, 384, 768],
+        drop_path_rate: float = 0.0,
+        layer_scale_init_value: float = 1e-6,
+        kernel_sizes: tuple[int] = (7,),
+    ):
+        super().__init__()
+        assert len(depths) == len(dims)
+
+        self.channel_layers = nn.ModuleList()
+        stem = nn.Sequential(
+            nn.Conv1d(
+                input_channels,
+                dims[0],
+                kernel_size=7,
+                padding=3,
+                padding_mode="zeros",
+            ),
+            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+        )
+        self.channel_layers.append(stem)
+
+        for i in range(len(depths) - 1):
+            mid_layer = nn.Sequential(
+                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+                nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
+            )
+            self.channel_layers.append(mid_layer)
+
+        block_fn = (
+            partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
+            if len(kernel_sizes) == 1
+            else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
+        )
+
+        self.stages = nn.ModuleList()
+        drop_path_rates = [
+            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+        ]
+
+        cur = 0
+        for i in range(len(depths)):
+            stage = nn.Sequential(
+                *[
+                    block_fn(
+                        dim=dims[i],
+                        drop_path=drop_path_rates[cur + j],
+                        layer_scale_init_value=layer_scale_init_value,
+                    )
+                    for j in range(depths[i])
+                ]
+            )
+            self.stages.append(stage)
+            cur += depths[i]
+
+        self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
+
+        if output_channels is not None:
+            self.output_projection = nn.Conv1d(dims[-1], output_channels, kernel_size=1)
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Conv1d, nn.Linear)):
+            nn.init.trunc_normal_(m.weight, std=0.02)
+            nn.init.constant_(m.bias, 0)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+    ) -> torch.Tensor:
+        for channel_layer, stage in zip(self.channel_layers, self.stages):
+            x = channel_layer(x)
+            x = stage(x)
+
+        x = self.norm(x)
+
+        if hasattr(self, "output_projection"):
+            x = self.output_projection(x)
+
+        return x

+ 419 - 0
fish_speech/models/vqgan/modules/dit.py

@@ -0,0 +1,419 @@
+import math
+from typing import Callable, Optional, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def modulate(x, shift, scale):
+    return x * (1 + scale) + shift
+
+
+def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+    x_out2 = torch.stack(
+        [
+            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+        ],
+        -1,
+    )
+
+    x_out2 = x_out2.flatten(3)
+    return x_out2.type_as(x)
+
+
+class TimestepEmbedder(nn.Module):
+    """
+    Embeds scalar timesteps into vector representations.
+    """
+
+    def __init__(self, hidden_size, frequency_embedding_size=256):
+        super().__init__()
+        self.mlp = FeedForward(
+            frequency_embedding_size, hidden_size, out_dim=hidden_size
+        )
+        self.frequency_embedding_size = frequency_embedding_size
+
+    @staticmethod
+    def timestep_embedding(t, dim, max_period=10000):
+        """
+        Create sinusoidal timestep embeddings.
+        :param t: a 1-D Tensor of N indices, one per batch element.
+                          These may be fractional.
+        :param dim: the dimension of the output.
+        :param max_period: controls the minimum frequency of the embeddings.
+        :return: an (N, D) Tensor of positional embeddings.
+        """
+        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+        half = dim // 2
+        freqs = torch.exp(
+            -math.log(max_period)
+            * torch.arange(start=0, end=half, dtype=torch.float32)
+            / half
+        ).to(device=t.device)
+        args = t[:, None].float() * freqs[None]
+        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+        if dim % 2:
+            embedding = torch.cat(
+                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+            )
+        return embedding
+
+    def forward(self, t):
+        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+        t_emb = self.mlp(t_freq)
+        return t_emb
+
+
+def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> torch.Tensor:
+    freqs = 1.0 / (
+        base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+    )
+    t = torch.arange(seq_len, device=freqs.device)
+    freqs = torch.outer(t, freqs)
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+    return cache.to(dtype=torch.bfloat16)
+
+
+class Attention(nn.Module):
+    def __init__(
+        self,
+        dim,
+        n_head,
+    ):
+        super().__init__()
+        assert dim % n_head == 0
+
+        self.dim = dim
+        self.n_head = n_head
+        self.head_dim = dim // n_head
+
+        self.wq = nn.Linear(dim, dim)
+        self.wk = nn.Linear(dim, dim)
+        self.wv = nn.Linear(dim, dim)
+        self.wo = nn.Linear(dim, dim)
+
+    def forward(self, q, freqs_cis, kv=None, mask=None):
+        bsz, seqlen, _ = q.shape
+
+        if kv is None:
+            kv = q
+
+        kv_seqlen = kv.shape[1]
+
+        q = self.wq(q).view(bsz, seqlen, self.n_head, self.head_dim)
+        k = self.wk(kv).view(bsz, kv_seqlen, self.n_head, self.head_dim)
+        v = self.wv(kv).view(bsz, kv_seqlen, self.n_head, self.head_dim)
+
+        q = apply_rotary_emb(q, freqs_cis[:seqlen])
+        k = apply_rotary_emb(k, freqs_cis[:kv_seqlen])
+
+        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+        y = F.scaled_dot_product_attention(
+            q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
+        )
+
+        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+        y = self.wo(y)
+        return y
+
+
+class FeedForward(nn.Module):
+    def __init__(self, in_dim, intermediate_size, out_dim=None):
+        super().__init__()
+        self.w1 = nn.Linear(in_dim, intermediate_size)
+        self.w3 = nn.Linear(in_dim, intermediate_size)
+        self.w2 = nn.Linear(intermediate_size, out_dim or in_dim)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class DiTBlock(nn.Module):
+    def __init__(
+        self,
+        hidden_size,
+        num_heads,
+        mlp_ratio=4.0,
+        use_self_attention=True,
+        use_cross_attention=False,
+    ):
+        super().__init__()
+
+        self.use_self_attention = use_self_attention
+        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+        if use_self_attention:
+            self.mix = Attention(hidden_size, num_heads)
+        else:
+            self.mix = nn.Conv1d(
+                hidden_size,
+                hidden_size,
+                kernel_size=7,
+                padding=3,
+                bias=True,
+                groups=hidden_size,
+            )
+
+        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+        self.mlp = FeedForward(hidden_size, int(hidden_size * mlp_ratio))
+        self.adaLN_modulation = nn.Sequential(
+            nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+        )
+
+        self.use_cross_attention = use_cross_attention
+        if self.use_cross_attention:
+            self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+            self.norm4 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+            self.cross_attn = Attention(hidden_size, num_heads)
+            self.adaLN_modulation_cross = nn.Sequential(
+                nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size, bias=True)
+            )
+            self.adaLN_modulation_cross_condition = nn.Sequential(
+                nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+            )
+
+    def forward(
+        self,
+        x,
+        condition,
+        freqs_cis,
+        self_mask=None,
+        cross_condition=None,
+        cross_mask=None,
+    ):
+        (
+            shift_msa,
+            scale_msa,
+            gate_msa,
+            shift_mlp,
+            scale_mlp,
+            gate_mlp,
+        ) = self.adaLN_modulation(condition).chunk(6, dim=-1)
+
+        # Self-attention
+        inp = modulate(self.norm1(x), shift_msa, scale_msa)
+        if self.use_self_attention:
+            inp = self.mix(inp, freqs_cis=freqs_cis, mask=self_mask)
+        else:
+            inp = self.mix(inp.mT).mT
+        x = x + gate_msa * inp
+
+        # Cross-attention
+        if self.use_cross_attention:
+            (
+                shift_cross,
+                scale_cross,
+                gate_cross,
+            ) = self.adaLN_modulation_cross(
+                condition
+            ).chunk(3, dim=-1)
+
+            (
+                shift_cross_condition,
+                scale_cross_condition,
+            ) = self.adaLN_modulation_cross_condition(cross_condition).chunk(2, dim=-1)
+
+            inp = modulate(self.norm3(x), shift_cross, scale_cross)
+            inp = self.cross_attn(
+                inp,
+                freqs_cis=freqs_cis,
+                kv=modulate(
+                    self.norm4(cross_condition),
+                    shift_cross_condition,
+                    scale_cross_condition,
+                ),
+                mask=cross_mask,
+            )
+            x = x + gate_cross * inp
+
+        # MLP
+        x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+
+        return x
+
+
+class FinalLayer(nn.Module):
+    """
+    The final layer of DiT.
+    """
+
+    def __init__(self, hidden_size, out_channels):
+        super().__init__()
+        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+        self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+        self.adaLN_modulation = nn.Sequential(
+            nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+        )
+
+    def forward(self, x, c):
+        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+        x = modulate(self.norm_final(x), shift, scale)
+        return self.linear(x)
+
+
+class DiT(nn.Module):
+    def __init__(
+        self,
+        hidden_size,
+        num_heads,
+        diffusion_num_layers,
+        channels=160,
+        mlp_ratio=4.0,
+        max_seq_len=16384,
+        condition_dim=512,
+        style_dim=None,
+        cross_condition_dim=None,
+    ):
+        super().__init__()
+
+        self.max_seq_len = max_seq_len
+
+        self.time_embedder = TimestepEmbedder(hidden_size)
+        self.condition_embedder = FeedForward(
+            condition_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
+        )
+
+        if cross_condition_dim is not None:
+            self.cross_condition_embedder = FeedForward(
+                cross_condition_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
+            )
+
+        self.use_style = style_dim is not None
+        if self.use_style:
+            self.style_embedder = FeedForward(
+                style_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
+            )
+
+        self.diffusion_blocks = nn.ModuleList(
+            [
+                DiTBlock(
+                    hidden_size,
+                    num_heads,
+                    mlp_ratio,
+                    use_self_attention=i % 4 == 0,
+                    use_cross_attention=cross_condition_dim is not None,
+                )
+                for i in range(diffusion_num_layers)
+            ]
+        )
+
+        # Downsample & upsample blocks
+        self.input_embedder = FeedForward(
+            channels, int(hidden_size * mlp_ratio), out_dim=hidden_size
+        )
+        self.final_layer = FinalLayer(hidden_size, channels)
+
+        self.register_buffer(
+            "freqs_cis", precompute_freqs_cis(max_seq_len, hidden_size // num_heads)
+        )
+
+        self.initialize_weights()
+
+    def initialize_weights(self):
+        # Initialize input embedding:
+        self.input_embedder.apply(self.init_weight)
+        self.time_embedder.mlp.apply(self.init_weight)
+        self.condition_embedder.apply(self.init_weight)
+
+        if self.use_style:
+            self.style_embedder.apply(self.init_weight)
+
+        if hasattr(self, "cross_condition_embedder"):
+            self.cross_condition_embedder.apply(self.init_weight)
+
+        for block in self.diffusion_blocks:
+            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+            block.mix.apply(self.init_weight)
+
+        # Zero-out output layers:
+        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+        self.final_layer.linear.apply(self.init_weight)
+
+    def init_weight(self, m):
+        if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, nn.Linear)):
+            nn.init.normal_(m.weight, 0, 0.02)
+            if m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+
+    def forward(
+        self,
+        x,
+        time,
+        condition,
+        style=None,
+        self_mask=None,
+        cross_condition=None,
+        cross_mask=None,
+    ):
+        # Embed inputs
+        x = self.input_embedder(x)
+        t = self.time_embedder(time)
+
+        condition = self.condition_embedder(condition)
+
+        if self.use_style:
+            style = self.style_embedder(style)
+
+        if cross_condition is not None:
+            cross_condition = self.cross_condition_embedder(cross_condition)
+            cross_condition = t[:, None, :] + cross_condition
+
+        # Merge t, condition, and style
+        condition = t[:, None, :] + condition
+        if self.use_style:
+            condition = condition + style[:, None, :]
+
+        if self_mask is not None:
+            self_mask = self_mask[:, None, None, :]
+
+        if cross_mask is not None:
+            cross_mask = cross_mask[:, None, None, :]
+
+        # DiT
+        for block in self.diffusion_blocks:
+            x = block(
+                x,
+                condition,
+                self.freqs_cis,
+                self_mask=self_mask,
+                cross_condition=cross_condition,
+                cross_mask=cross_mask,
+            )
+
+        x = self.final_layer(x, condition)
+
+        return x
+
+
+if __name__ == "__main__":
+    model = DiT(
+        hidden_size=384,
+        num_heads=6,
+        diffusion_num_layers=12,
+        channels=160,
+        condition_dim=512,
+        style_dim=256,
+    )
+    bs, seq_len = 8, 1024
+    x = torch.randn(bs, seq_len, 160)
+    condition = torch.randn(bs, seq_len, 512)
+    style = torch.randn(bs, 256)
+    mask = torch.ones(bs, seq_len, dtype=torch.bool)
+    mask[0, 5:] = False
+    time = torch.arange(bs)
+    print(time)
+    out = model(x, time, condition, style, self_mask=mask)
+    print(out.shape)  # torch.Size([2, 100, 160])
+
+    # Print model size
+    num_params = sum(p.numel() for p in model.parameters())
+    print(f"Number of parameters: {num_params / 1e6:.1f}M")

+ 63 - 0
fish_speech/models/vqgan/modules/firefly.py

@@ -0,0 +1,63 @@
+import torch
+from torch import nn
+
+from .convnext import ConvNeXtEncoder
+from .hifigan import HiFiGANGenerator
+
+
+class FireflyBase(nn.Module):
+    def __init__(self, ckpt_path: str = None):
+        super().__init__()
+
+        self.backbone = ConvNeXtEncoder(
+            input_channels=160,
+            depths=[3, 3, 9, 3],
+            dims=[128, 256, 384, 512],
+            drop_path_rate=0.2,
+            kernel_sizes=[7],
+        )
+
+        self.head = HiFiGANGenerator(
+            hop_length=512,
+            upsample_rates=[8, 8, 2, 2, 2],
+            upsample_kernel_sizes=[16, 16, 4, 4, 4],
+            resblock_kernel_sizes=[3, 7, 11],
+            resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+            num_mels=512,
+            upsample_initial_channel=512,
+            use_template=True,
+            pre_conv_kernel_size=13,
+            post_conv_kernel_size=13,
+        )
+
+        if ckpt_path is None:
+            return
+
+        state_dict = torch.load(ckpt_path, map_location="cpu")
+
+        if "state_dict" in state_dict:
+            state_dict = state_dict["state_dict"]
+
+        if any("generator." in k for k in state_dict):
+            state_dict = {
+                k.replace("generator.", ""): v
+                for k, v in state_dict.items()
+                if "generator." in k
+            }
+
+        self.load_state_dict(state_dict, strict=True)
+
+    def encode(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.backbone(x)
+        return x
+
+    def decode(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.head(x)
+        if x.ndim == 2:
+            x = x[:, None, :]
+        return x
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.encode(x)
+        x = self.decode(x)
+        return x

+ 128 - 0
fish_speech/models/vqgan/modules/fsq.py

@@ -0,0 +1,128 @@
+from dataclasses import dataclass
+from typing import Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.nn.utils import weight_norm
+from vector_quantize_pytorch import ResidualFSQ
+
+from .convnext import ConvNeXtBlock
+
+
+@dataclass
+class FSQResult:
+    z: torch.Tensor
+    codes: torch.Tensor
+    latents: torch.Tensor
+
+
+class DownsampleFiniteScalarQuantize(nn.Module):
+    def __init__(
+        self,
+        input_dim: int = 512,
+        n_codebooks: int = 9,
+        levels: tuple[int] = (8, 5, 5, 5),  # Approximate 2**10
+        downsample_factor: tuple[int] = (2, 2),
+        downsample_dims: tuple[int] | None = None,
+    ):
+        super().__init__()
+
+        if downsample_dims is None:
+            downsample_dims = [input_dim for _ in range(len(downsample_factor))]
+
+        all_dims = (input_dim,) + tuple(downsample_dims)
+
+        self.residual_fsq = ResidualFSQ(
+            dim=all_dims[-1],
+            levels=levels,
+            num_quantizers=n_codebooks,
+        )
+
+        self.downsample_factor = downsample_factor
+        self.downsample_dims = downsample_dims
+
+        self.downsample = nn.Sequential(
+            *[
+                nn.Sequential(
+                    nn.Conv1d(
+                        all_dims[idx],
+                        all_dims[idx + 1],
+                        kernel_size=factor,
+                        stride=factor,
+                    ),
+                    ConvNeXtBlock(dim=all_dims[idx + 1]),
+                    ConvNeXtBlock(dim=all_dims[idx + 1]),
+                )
+                for idx, factor in enumerate(downsample_factor)
+            ]
+        )
+
+        self.upsample = nn.Sequential(
+            *[
+                nn.Sequential(
+                    nn.ConvTranspose1d(
+                        all_dims[idx + 1],
+                        all_dims[idx],
+                        kernel_size=factor,
+                        stride=factor,
+                    ),
+                    ConvNeXtBlock(dim=all_dims[idx]),
+                    ConvNeXtBlock(dim=all_dims[idx]),
+                )
+                for idx, factor in reversed(list(enumerate(downsample_factor)))
+            ]
+        )
+
+    def forward(self, z) -> FSQResult:
+        original_shape = z.shape
+        z = self.downsample(z)
+        quantized, indices = self.residual_fsq(z.mT)
+        result = FSQResult(
+            z=quantized.mT,
+            codes=indices.mT,
+            latents=z,
+        )
+        result.z = self.upsample(result.z)
+
+        # Pad or crop z to match original shape
+        diff = original_shape[-1] - result.z.shape[-1]
+        left = diff // 2
+        right = diff - left
+
+        if diff > 0:
+            result.z = F.pad(result.z, (left, right))
+        elif diff < 0:
+            result.z = result.z[..., left:-right]
+
+        return result
+
+    # def from_codes(self, codes: torch.Tensor):
+    #     z_q, z_p, codes = self.residual_fsq.get_output_from_indices(codes)
+    #     z_q = self.upsample(z_q)
+    #     return z_q, z_p, codes
+
+    # def from_latents(self, latents: torch.Tensor):
+    #     z_q, z_p, codes = super().from_latents(latents)
+    #     z_q = self.upsample(z_q)
+    #     return z_q, z_p, codes
+
+
+if __name__ == "__main__":
+    rvq = DownsampleFiniteScalarQuantize(
+        n_codebooks=1,
+        downsample_factor=(2, 2),
+    )
+    x = torch.randn(16, 512, 80)
+
+    result = rvq(x)
+    print(rvq)
+    print(result.latents.shape, result.codes.shape, result.z.shape)
+
+    # y = rvq.from_codes(result.codes)
+    # print(y[0].shape)
+
+    # y = rvq.from_latents(result.latents)
+    # print(y[0].shape)

+ 278 - 0
fish_speech/models/vqgan/modules/hifigan.py

@@ -0,0 +1,278 @@
+from functools import partial
+from math import prod
+from typing import Callable, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Conv1d
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return (kernel_size * dilation - dilation) // 2
+
+
+class ResBlock(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super().__init__()
+
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.silu(x)
+            xt = c1(xt)
+            xt = F.silu(xt)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_parametrizations(self):
+        for conv in self.convs1:
+            remove_parametrizations(conv)
+        for conv in self.convs2:
+            remove_parametrizations(conv)
+
+
+class ParralelBlock(nn.Module):
+    def __init__(
+        self,
+        channels: int,
+        kernel_sizes: tuple[int] = (3, 7, 11),
+        dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+    ):
+        super().__init__()
+
+        assert len(kernel_sizes) == len(dilation_sizes)
+
+        self.blocks = nn.ModuleList()
+        for k, d in zip(kernel_sizes, dilation_sizes):
+            self.blocks.append(ResBlock(channels, k, d))
+
+    def forward(self, x):
+        xs = [block(x) for block in self.blocks]
+
+        return torch.stack(xs, dim=0).mean(dim=0)
+
+
+class HiFiGANGenerator(nn.Module):
+    def __init__(
+        self,
+        *,
+        hop_length: int = 512,
+        upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
+        upsample_kernel_sizes: tuple[int] = (16, 16, 4, 4, 4),
+        resblock_kernel_sizes: tuple[int] = (3, 7, 11),
+        resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+        num_mels: int = 160,
+        upsample_initial_channel: int = 512,
+        use_template: bool = True,
+        pre_conv_kernel_size: int = 7,
+        post_conv_kernel_size: int = 7,
+        post_activation: Callable = partial(nn.SiLU, inplace=True),
+        checkpointing: bool = False,
+        condition_dim: Optional[int] = None,
+    ):
+        super().__init__()
+
+        assert (
+            prod(upsample_rates) == hop_length
+        ), f"hop_length must be {prod(upsample_rates)}"
+
+        self.conv_pre = weight_norm(
+            nn.Conv1d(
+                num_mels,
+                upsample_initial_channel,
+                pre_conv_kernel_size,
+                1,
+                padding=get_padding(pre_conv_kernel_size),
+            )
+        )
+
+        self.hop_length = hop_length
+        self.num_upsamples = len(upsample_rates)
+        self.num_kernels = len(resblock_kernel_sizes)
+
+        self.noise_convs = nn.ModuleList()
+        self.use_template = use_template
+        self.ups = nn.ModuleList()
+        self.condition_dim = condition_dim
+
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            c_cur = upsample_initial_channel // (2 ** (i + 1))
+            self.ups.append(
+                weight_norm(
+                    nn.ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+            if not use_template:
+                continue
+
+            if i + 1 < len(upsample_rates):
+                stride_f0 = np.prod(upsample_rates[i + 1 :])
+                self.noise_convs.append(
+                    Conv1d(
+                        1,
+                        c_cur,
+                        kernel_size=stride_f0 * 2,
+                        stride=stride_f0,
+                        padding=stride_f0 // 2,
+                    )
+                )
+            else:
+                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            self.resblocks.append(
+                ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
+            )
+
+        self.activation_post = post_activation()
+        self.conv_post = weight_norm(
+            nn.Conv1d(
+                ch,
+                1,
+                post_conv_kernel_size,
+                1,
+                padding=get_padding(post_conv_kernel_size),
+            )
+        )
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+        if condition_dim is not None:
+            self.condition = nn.Conv1d(condition_dim, upsample_initial_channel, 1)
+
+        # Gradient checkpointing
+        self.checkpointing = checkpointing
+
+    def forward(self, x, template=None, condition=None):
+        if self.use_template and template is None:
+            length = x.shape[-1] * self.hop_length
+            template = (
+                torch.randn(x.shape[0], 1, length, device=x.device, dtype=x.dtype)
+                * 0.003
+            )
+
+        if self.condition_dim is not None:
+            x = x + self.condition(condition)
+
+        x = self.conv_pre(x)
+
+        for i in range(self.num_upsamples):
+            x = F.silu(x, inplace=True)
+            x = self.ups[i](x)
+
+            if self.use_template:
+                x = x + self.noise_convs[i](template)
+
+            if self.training and self.checkpointing:
+                x = torch.utils.checkpoint.checkpoint(
+                    self.resblocks[i],
+                    x,
+                    use_reentrant=False,
+                )
+            else:
+                x = self.resblocks[i](x)
+
+        x = self.activation_post(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_parametrizations(self):
+        for up in self.ups:
+            remove_parametrizations(up)
+        for block in self.resblocks:
+            block.remove_parametrizations()
+        remove_parametrizations(self.conv_pre)
+        remove_parametrizations(self.conv_post)

+ 20 - 11
fish_speech/models/vqgan/spectrogram.py

@@ -21,7 +21,7 @@ class LinearSpectrogram(nn.Module):
         self.center = center
         self.center = center
         self.mode = mode
         self.mode = mode
 
 
-        self.register_buffer("window", torch.hann_window(win_length))
+        self.register_buffer("window", torch.hann_window(win_length), persistent=False)
 
 
     def forward(self, y: Tensor) -> Tensor:
     def forward(self, y: Tensor) -> Tensor:
         if y.ndim == 3:
         if y.ndim == 3:
@@ -78,17 +78,23 @@ class LogMelSpectrogram(nn.Module):
         self.center = center
         self.center = center
         self.n_mels = n_mels
         self.n_mels = n_mels
         self.f_min = f_min
         self.f_min = f_min
-        self.f_max = f_max or sample_rate // 2
+        self.f_max = f_max or float(sample_rate // 2)
 
 
         self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
         self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
-        self.mel_scale = MelScale(
-            self.n_mels,
-            self.sample_rate,
-            self.f_min,
-            self.f_max,
-            self.n_fft // 2 + 1,
-            "slaney",
-            "slaney",
+
+        fb = F.melscale_fbanks(
+            n_freqs=self.n_fft // 2 + 1,
+            f_min=self.f_min,
+            f_max=self.f_max,
+            n_mels=self.n_mels,
+            sample_rate=self.sample_rate,
+            norm="slaney",
+            mel_scale="slaney",
+        )
+        self.register_buffer(
+            "fb",
+            fb,
+            persistent=False,
         )
         )
 
 
     def compress(self, x: Tensor) -> Tensor:
     def compress(self, x: Tensor) -> Tensor:
@@ -97,6 +103,9 @@ class LogMelSpectrogram(nn.Module):
     def decompress(self, x: Tensor) -> Tensor:
     def decompress(self, x: Tensor) -> Tensor:
         return torch.exp(x)
         return torch.exp(x)
 
 
+    def apply_mel_scale(self, x: Tensor) -> Tensor:
+        return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
+
     def forward(
     def forward(
         self, x: Tensor, return_linear: bool = False, sample_rate: int = None
         self, x: Tensor, return_linear: bool = False, sample_rate: int = None
     ) -> Tensor:
     ) -> Tensor:
@@ -104,7 +113,7 @@ class LogMelSpectrogram(nn.Module):
             x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
             x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
 
 
         linear = self.spectrogram(x)
         linear = self.spectrogram(x)
-        x = self.mel_scale(linear)
+        x = self.apply_mel_scale(linear)
         x = self.compress(x)
         x = self.compress(x)
 
 
         if return_linear:
         if return_linear: