Lengyue пре 2 година
родитељ
комит
8de4b48c02

+ 37 - 47
fish_speech/configs/vqgan_pretrain.yaml

@@ -2,36 +2,20 @@ defaults:
   - base
   - _self_
 
-project: vq_reflow_wavenet_group_fsq
-ckpt_path: results/vq_reflow_bf16/checkpoints/step_000248000.ckpt
-resume_weights_only: true
+project: vq_reflow_shallow_group_fsq_8x1024_wavenet
 
 # Lightning Trainer
 trainer:
   accelerator: gpu
   devices: auto
-  precision: 32
+  precision: bf16-mixed
   max_steps: 1_000_000
-  # max_steps: 100
-  val_check_interval: 2000
-  gradient_clip_algorithm: norm
-  gradient_clip_val: 1.0
-  # limit_val_batches: 0.0
-
-  strategy: ddp #_find_unused_parameters_true
-  # strategy:
-  #   _target_: lightning.pytorch.strategies.DeepSpeedStrategy
-  #   stage: 1
-  #   overlap_comm: true
-
-  # profiler:
-  #   _target_: lightning.pytorch.profilers.PyTorchProfiler
-  #   export_to_chrome: true
-  #   filename: prof.txt
+  val_check_interval: 1000
+  strategy: ddp_find_unused_parameters_true
 
 sample_rate: 44100
 hop_length: 512
-num_mels: 160
+num_mels: 128
 n_fft: 2048
 win_length: 2048
 
@@ -64,46 +48,48 @@ model:
   sampling_rate: ${sample_rate}
   weight_reflow: 1.0
   weight_vq: 1.0
-  weight_aux_mel: 1.0
+  weight_mel: 1.0
+  freeze_encoder: false
+
+  # Reflow configs
+  reflow_use_shallow: true
+  reflow_inference_steps: 10
+  reflow_inference_start_t: 0.5
 
   encoder:
-    _target_: fish_speech.models.vqgan.modules.convnext.ConvNeXtEncoder
+    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
     input_channels: ${num_mels}
-    depths: [3, 3, 9, 3]
-    dims: [128, 256, 384, 512]
+    residual_channels: 512
+    residual_layers: 20
+    dilation_cycle: 4
   
   quantizer:
     _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
     input_dim: 512
-    n_codebooks: 1
-    n_groups: 8
+    n_codebooks: 8
+    n_groups: 1
     levels: [8, 5, 5, 5]
   
-  aux_decoder:
-    _target_: fish_speech.models.vqgan.modules.convnext.ConvNeXtEncoder
-    input_channels: 512
+  decoder:
+    _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
     output_channels: ${num_mels}
-    depths: [6]
-    dims: [384]
-
-  # reflow:
-  #   _target_: fish_speech.models.vqgan.modules.dit.DiT
-  #   hidden_size: 768
-  #   num_heads: 12
-  #   diffusion_num_layers: 12
-  #   channels: ${num_mels}
-  #   condition_dim: 512
+    residual_channels: 512
+    residual_layers: 20
+    dilation_cycle: 4
 
   reflow:
     _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
-    mel_channels: ${num_mels}
-    d_encoder: 512
+    input_channels: ${num_mels}
+    output_channels: ${num_mels}
     residual_channels: 512
+    condition_channels: 512
     residual_layers: 20
+    dilation_cycle: 4
+    is_diffusion: true
 
   vocoder:
     _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
-    ckpt_path: checkpoints/firefly-gan-base-002000000.ckpt
+    ckpt_path: null # You may download the pretrained vocoder and set the path here
 
   mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
@@ -132,12 +118,16 @@ model:
       final_lr_ratio: 0
 
 callbacks:
+  model_summary:
+    _target_: lightning.pytorch.callbacks.ModelSummary
+    max_depth: 1
+
+  model_checkpoint:
+    every_n_train_steps: ${trainer.val_check_interval}
+
   grad_norm_monitor:
     sub_module: 
       - encoder
-      - aux_decoder
+      - decoder
       - quantizer
       - reflow
-
-  model_checkpoint:
-    every_n_train_steps: ${trainer.val_check_interval}

+ 99 - 84
fish_speech/models/vqgan/lit_module.py

@@ -10,21 +10,8 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from torch import nn
 
-from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
-
-
-@dataclass
-class VQEncodeResult:
-    features: torch.Tensor
-    indices: torch.Tensor
-    loss: torch.Tensor
-    feature_lengths: torch.Tensor
-
-
-@dataclass
-class VQDecodeResult:
-    mels: torch.Tensor
-    audios: Optional[torch.Tensor] = None
+from fish_speech.models.vqgan.modules.wavenet import WaveNet
+from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
 
 
 class VQGAN(L.LightningModule):
@@ -32,16 +19,20 @@ class VQGAN(L.LightningModule):
         self,
         optimizer: Callable,
         lr_scheduler: Callable,
-        encoder: nn.Module,
+        encoder: WaveNet,
         quantizer: nn.Module,
-        aux_decoder: nn.Module,
+        decoder: WaveNet,
         reflow: nn.Module,
         vocoder: nn.Module,
         mel_transform: nn.Module,
         weight_reflow: float = 1.0,
         weight_vq: float = 1.0,
-        weight_aux_mel: float = 1.0,
+        weight_mel: float = 1.0,
         sampling_rate: int = 44100,
+        freeze_encoder: bool = False,
+        reflow_use_shallow: bool = False,
+        reflow_inference_steps: int = 10,
+        reflow_inference_start_t: float = 0.5,
     ):
         super().__init__()
 
@@ -52,10 +43,10 @@ class VQGAN(L.LightningModule):
         # Modules
         self.encoder = encoder
         self.quantizer = quantizer
-        self.aux_decoder = aux_decoder
+        self.decoder = decoder
+        self.vocoder = vocoder
         self.reflow = reflow
         self.mel_transform = mel_transform
-        self.vocoder = vocoder
 
         # Freeze vocoder
         for param in self.vocoder.parameters():
@@ -64,13 +55,27 @@ class VQGAN(L.LightningModule):
         # Loss weights
         self.weight_reflow = weight_reflow
         self.weight_vq = weight_vq
-        self.weight_aux_mel = weight_aux_mel
+        self.weight_mel = weight_mel
 
+        # Other parameters
         self.spec_min = -12
         self.spec_max = 3
         self.sampling_rate = sampling_rate
+        self.reflow_use_shallow = reflow_use_shallow
+        self.reflow_inference_steps = reflow_inference_steps
+        self.reflow_inference_start_t = reflow_inference_start_t
+
+        # Disable strict loading
         self.strict_loading = False
 
+        # If encoder is frozen
+        if freeze_encoder:
+            for param in self.encoder.parameters():
+                param.requires_grad = False
+
+            for param in self.quantizer.parameters():
+                param.requires_grad = False
+
     def on_save_checkpoint(self, checkpoint):
         # Do not save vocoder
         state_dict = checkpoint["state_dict"]
@@ -79,7 +84,6 @@ class VQGAN(L.LightningModule):
                 state_dict.pop(name)
 
     def configure_optimizers(self):
-        # Need two optimizers and two schedulers
         optimizer = self.optimizer_builder(self.parameters())
         lr_scheduler = self.lr_scheduler_builder(optimizer)
 
@@ -97,7 +101,6 @@ class VQGAN(L.LightningModule):
     def denorm_spec(self, x):
         return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
 
-    # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
     def training_step(self, batch, batch_idx):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
 
@@ -110,6 +113,7 @@ class VQGAN(L.LightningModule):
         mel_lengths = audio_lengths // self.mel_transform.hop_length
         mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
         mel_masks_float_conv = mel_masks[:, None, :].float()
+        gt_mels = gt_mels * mel_masks_float_conv
 
         # Encode
         encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
@@ -120,25 +124,31 @@ class VQGAN(L.LightningModule):
         vq_recon_features = vq_result.z * mel_masks_float_conv
 
         # VQ Decode
-        aux_mel = self.aux_decoder(vq_recon_features)
-        loss_aux_mel = F.l1_loss(
-            aux_mel * mel_masks_float_conv, gt_mels * mel_masks_float_conv
-        )
+        gen_mel = self.decoder(vq_recon_features) * mel_masks_float_conv
 
-        # Reflow
+        # Mel Loss
+        loss_mel = (gen_mel - gt_mels).abs().mean(
+            dim=1, keepdim=True
+        ).sum() / mel_masks_float_conv.sum()
+
+        # Reflow, given x_1_aux, we want to reconstruct x_1
         x_1 = self.norm_spec(gt_mels)
+
+        if self.reflow_use_shallow:
+            x_1_aux = self.norm_spec(gen_mel)
+        else:
+            x_1_aux = x_1
+
         t = torch.rand(gt_mels.shape[0], device=gt_mels.device)
         x_0 = torch.randn_like(x_1)
 
         # X_t = t * X_1 + (1 - t) * X_0
-        x_t = x_0 + t[:, None, None] * (x_1 - x_0)
+        x_t = x_0 + t[:, None, None] * (x_1_aux - x_0)
 
         v_pred = self.reflow(
             x_t,
             1000 * t,
-            vq_recon_features,  # .detach()
-            x_masks=mel_masks_float_conv,
-            cond_masks=mel_masks_float_conv,
+            vq_recon_features,
         )
 
         # Log L2 loss with
@@ -146,21 +156,28 @@ class VQGAN(L.LightningModule):
         loss_reflow = weights[:, None, None] * F.mse_loss(
             x_1 - x_0, v_pred, reduction="none"
         )
-        loss_reflow = (loss_reflow * mel_masks_float_conv).mean()
+        loss_reflow = (loss_reflow * mel_masks_float_conv).mean(
+            dim=1
+        ).sum() / mel_masks_float_conv.sum()
 
         # Total loss
         loss = (
             self.weight_vq * loss_vq
-            + self.weight_aux_mel * loss_aux_mel
+            + self.weight_mel * loss_mel
             + self.weight_reflow * loss_reflow
         )
 
         # Log losses
         self.log(
-            "train/loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True
+            "train/generator/loss",
+            loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
         )
         self.log(
-            "train/loss_vq",
+            "train/generator/loss_vq",
             loss_vq,
             on_step=True,
             on_epoch=False,
@@ -168,15 +185,15 @@ class VQGAN(L.LightningModule):
             logger=True,
         )
         self.log(
-            "train/loss_aux_mel",
-            loss_aux_mel,
+            "train/generator/loss_mel",
+            loss_mel,
             on_step=True,
             on_epoch=False,
             prog_bar=False,
             logger=True,
         )
         self.log(
-            "train/loss_reflow",
+            "train/generator/loss_reflow",
             loss_reflow,
             on_step=True,
             on_epoch=False,
@@ -196,22 +213,23 @@ class VQGAN(L.LightningModule):
         mel_lengths = audio_lengths // self.mel_transform.hop_length
         mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
         mel_masks_float_conv = mel_masks[:, None, :].float()
+        gt_mels = gt_mels * mel_masks_float_conv
 
         # Encode
         encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
 
         # Quantize
-        vq_result = self.quantizer(encoded_features)
+        vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
 
         # VQ Decode
-        aux_mels = self.aux_decoder(vq_result.z)
-        loss_aux_mel = F.l1_loss(
-            aux_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
-        )
+        gen_aux_mels = self.decoder(vq_recon_features) * mel_masks_float_conv
+        loss_mel = (gen_aux_mels - gt_mels).abs().mean(
+            dim=1, keepdim=True
+        ).sum() / mel_masks_float_conv.sum()
 
         self.log(
-            "val/loss_aux_mel",
-            loss_aux_mel,
+            "val/loss_mel",
+            loss_mel,
             on_step=False,
             on_epoch=True,
             prog_bar=False,
@@ -220,37 +238,34 @@ class VQGAN(L.LightningModule):
         )
 
         # Reflow inference
-        t_start = 0.0
-        infer_step = 10
+        t_start = self.reflow_inference_start_t if self.reflow_use_shallow else 0.0
 
-        x_1 = self.norm_spec(aux_mels)
+        x_1 = self.norm_spec(gen_aux_mels)
         x_0 = torch.randn_like(x_1)
-        gen_mels = (1 - t_start) * x_0 + t_start * x_1
+        gen_reflow_mels = (1 - t_start) * x_0 + t_start * x_1
 
         t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
-        dt = (1.0 - t_start) / infer_step
+        dt = (1.0 - t_start) / self.reflow_inference_steps
 
-        for _ in range(infer_step):
-            gen_mels += (
+        for _ in range(self.reflow_inference_steps):
+            gen_reflow_mels += (
                 self.reflow(
-                    gen_mels,
+                    gen_reflow_mels,
                     1000 * t,
-                    vq_result.z,
-                    x_masks=mel_masks_float_conv,
-                    cond_masks=mel_masks_float_conv,
+                    vq_recon_features,
                 )
                 * dt
             )
             t += dt
 
-        gen_mels = self.denorm_spec(gen_mels)
-        loss_recon_reflow = F.l1_loss(
-            gen_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
-        )
+        gen_reflow_mels = self.denorm_spec(gen_reflow_mels) * mel_masks_float_conv
+        loss_reflow_mel = (gen_reflow_mels - gt_mels).abs().mean(
+            dim=1, keepdim=True
+        ).sum() / mel_masks_float_conv.sum()
 
         self.log(
-            "val/loss_recon_reflow",
-            loss_recon_reflow,
+            "val/loss_reflow_mel",
+            loss_reflow_mel,
             on_step=False,
             on_epoch=True,
             prog_bar=False,
@@ -258,9 +273,9 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
 
-        gen_audios = self.vocoder(gen_mels)
         recon_audios = self.vocoder(gt_mels)
-        aux_audios = self.vocoder(aux_mels)
+        gen_aux_audios = self.vocoder(gen_aux_mels)
+        gen_reflow_audios = self.vocoder(gen_reflow_mels)
 
         # only log the first batch
         if batch_idx != 0:
@@ -268,21 +283,21 @@ class VQGAN(L.LightningModule):
 
         for idx, (
             gt_mel,
-            reflow_mel,
-            aux_mel,
+            gen_aux_mel,
+            gen_reflow_mel,
             audio,
-            reflow_audio,
-            aux_audio,
+            gen_aux_audio,
+            gen_reflow_audio,
             recon_audio,
             audio_len,
         ) in enumerate(
             zip(
                 gt_mels,
-                gen_mels,
-                aux_mels,
+                gen_aux_mels,
+                gen_reflow_mels,
                 audios.float(),
-                gen_audios.float(),
-                aux_audios.float(),
+                gen_aux_audios.float(),
+                gen_reflow_audios.float(),
                 recon_audios.float(),
                 audio_lengths,
             )
@@ -292,13 +307,13 @@ class VQGAN(L.LightningModule):
             image_mels = plot_mel(
                 [
                     gt_mel[:, :mel_len],
-                    reflow_mel[:, :mel_len],
-                    aux_mel[:, :mel_len],
+                    gen_aux_mel[:, :mel_len],
+                    gen_reflow_mel[:, :mel_len],
                 ],
                 [
                     "Ground-Truth",
+                    "Auxiliary",
                     "Reflow",
-                    "Aux",
                 ],
             )
 
@@ -313,14 +328,14 @@ class VQGAN(L.LightningModule):
                                 caption="gt",
                             ),
                             wandb.Audio(
-                                reflow_audio[0, :audio_len],
+                                gen_aux_audio[0, :audio_len],
                                 sample_rate=self.sampling_rate,
-                                caption="reflow",
+                                caption="aux",
                             ),
                             wandb.Audio(
-                                aux_audio[0, :audio_len],
+                                gen_reflow_audio[0, :audio_len],
                                 sample_rate=self.sampling_rate,
-                                caption="aux",
+                                caption="reflow",
                             ),
                             wandb.Audio(
                                 recon_audio[0, :audio_len],
@@ -344,14 +359,14 @@ class VQGAN(L.LightningModule):
                     sample_rate=self.sampling_rate,
                 )
                 self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/reflow",
-                    reflow_audio[0, :audio_len],
+                    f"sample-{idx}/wavs/gen",
+                    gen_aux_audio[0, :audio_len],
                     self.global_step,
                     sample_rate=self.sampling_rate,
                 )
                 self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/aux",
-                    aux_audio[0, :audio_len],
+                    f"sample-{idx}/wavs/reflow",
+                    gen_reflow_audio[0, :audio_len],
                     self.global_step,
                     sample_rate=self.sampling_rate,
                 )

+ 0 - 419
fish_speech/models/vqgan/modules/dit.py

@@ -1,419 +0,0 @@
-import math
-from typing import Callable, Optional, Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-def modulate(x, shift, scale):
-    return x * (1 + scale) + shift
-
-
-def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
-    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
-    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
-    x_out2 = torch.stack(
-        [
-            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
-            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
-        ],
-        -1,
-    )
-
-    x_out2 = x_out2.flatten(3)
-    return x_out2.type_as(x)
-
-
-class TimestepEmbedder(nn.Module):
-    """
-    Embeds scalar timesteps into vector representations.
-    """
-
-    def __init__(self, hidden_size, frequency_embedding_size=256):
-        super().__init__()
-        self.mlp = FeedForward(
-            frequency_embedding_size, hidden_size, out_dim=hidden_size
-        )
-        self.frequency_embedding_size = frequency_embedding_size
-
-    @staticmethod
-    def timestep_embedding(t, dim, max_period=10000):
-        """
-        Create sinusoidal timestep embeddings.
-        :param t: a 1-D Tensor of N indices, one per batch element.
-                          These may be fractional.
-        :param dim: the dimension of the output.
-        :param max_period: controls the minimum frequency of the embeddings.
-        :return: an (N, D) Tensor of positional embeddings.
-        """
-        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
-        half = dim // 2
-        freqs = torch.exp(
-            -math.log(max_period)
-            * torch.arange(start=0, end=half, dtype=torch.float32)
-            / half
-        ).to(device=t.device)
-        args = t[:, None].float() * freqs[None]
-        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
-        if dim % 2:
-            embedding = torch.cat(
-                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
-            )
-        return embedding
-
-    def forward(self, t):
-        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
-        t_emb = self.mlp(t_freq)
-        return t_emb
-
-
-def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> torch.Tensor:
-    freqs = 1.0 / (
-        base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
-    )
-    t = torch.arange(seq_len, device=freqs.device)
-    freqs = torch.outer(t, freqs)
-    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
-    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
-    return cache.to(dtype=torch.bfloat16)
-
-
-class Attention(nn.Module):
-    def __init__(
-        self,
-        dim,
-        n_head,
-    ):
-        super().__init__()
-        assert dim % n_head == 0
-
-        self.dim = dim
-        self.n_head = n_head
-        self.head_dim = dim // n_head
-
-        self.wq = nn.Linear(dim, dim)
-        self.wk = nn.Linear(dim, dim)
-        self.wv = nn.Linear(dim, dim)
-        self.wo = nn.Linear(dim, dim)
-
-    def forward(self, q, freqs_cis, kv=None, mask=None):
-        bsz, seqlen, _ = q.shape
-
-        if kv is None:
-            kv = q
-
-        kv_seqlen = kv.shape[1]
-
-        q = self.wq(q).view(bsz, seqlen, self.n_head, self.head_dim)
-        k = self.wk(kv).view(bsz, kv_seqlen, self.n_head, self.head_dim)
-        v = self.wv(kv).view(bsz, kv_seqlen, self.n_head, self.head_dim)
-
-        q = apply_rotary_emb(q, freqs_cis[:seqlen])
-        k = apply_rotary_emb(k, freqs_cis[:kv_seqlen])
-
-        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
-        y = F.scaled_dot_product_attention(
-            q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
-        )
-
-        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
-
-        y = self.wo(y)
-        return y
-
-
-class FeedForward(nn.Module):
-    def __init__(self, in_dim, intermediate_size, out_dim=None):
-        super().__init__()
-        self.w1 = nn.Linear(in_dim, intermediate_size)
-        self.w3 = nn.Linear(in_dim, intermediate_size)
-        self.w2 = nn.Linear(intermediate_size, out_dim or in_dim)
-
-    def forward(self, x: torch.Tensor) -> torch.Tensor:
-        return self.w2(F.silu(self.w1(x)) * self.w3(x))
-
-
-class DiTBlock(nn.Module):
-    def __init__(
-        self,
-        hidden_size,
-        num_heads,
-        mlp_ratio=4.0,
-        use_self_attention=True,
-        use_cross_attention=False,
-    ):
-        super().__init__()
-
-        self.use_self_attention = use_self_attention
-        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
-
-        if use_self_attention:
-            self.mix = Attention(hidden_size, num_heads)
-        else:
-            self.mix = nn.Conv1d(
-                hidden_size,
-                hidden_size,
-                kernel_size=7,
-                padding=3,
-                bias=True,
-                groups=hidden_size,
-            )
-
-        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
-        self.mlp = FeedForward(hidden_size, int(hidden_size * mlp_ratio))
-        self.adaLN_modulation = nn.Sequential(
-            nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
-        )
-
-        self.use_cross_attention = use_cross_attention
-        if self.use_cross_attention:
-            self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
-            self.norm4 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
-            self.cross_attn = Attention(hidden_size, num_heads)
-            self.adaLN_modulation_cross = nn.Sequential(
-                nn.SiLU(), nn.Linear(hidden_size, 3 * hidden_size, bias=True)
-            )
-            self.adaLN_modulation_cross_condition = nn.Sequential(
-                nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
-            )
-
-    def forward(
-        self,
-        x,
-        condition,
-        freqs_cis,
-        self_mask=None,
-        cross_condition=None,
-        cross_mask=None,
-    ):
-        (
-            shift_msa,
-            scale_msa,
-            gate_msa,
-            shift_mlp,
-            scale_mlp,
-            gate_mlp,
-        ) = self.adaLN_modulation(condition).chunk(6, dim=-1)
-
-        # Self-attention
-        inp = modulate(self.norm1(x), shift_msa, scale_msa)
-        if self.use_self_attention:
-            inp = self.mix(inp, freqs_cis=freqs_cis, mask=self_mask)
-        else:
-            inp = self.mix(inp.mT).mT
-        x = x + gate_msa * inp
-
-        # Cross-attention
-        if self.use_cross_attention:
-            (
-                shift_cross,
-                scale_cross,
-                gate_cross,
-            ) = self.adaLN_modulation_cross(
-                condition
-            ).chunk(3, dim=-1)
-
-            (
-                shift_cross_condition,
-                scale_cross_condition,
-            ) = self.adaLN_modulation_cross_condition(cross_condition).chunk(2, dim=-1)
-
-            inp = modulate(self.norm3(x), shift_cross, scale_cross)
-            inp = self.cross_attn(
-                inp,
-                freqs_cis=freqs_cis,
-                kv=modulate(
-                    self.norm4(cross_condition),
-                    shift_cross_condition,
-                    scale_cross_condition,
-                ),
-                mask=cross_mask,
-            )
-            x = x + gate_cross * inp
-
-        # MLP
-        x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
-
-        return x
-
-
-class FinalLayer(nn.Module):
-    """
-    The final layer of DiT.
-    """
-
-    def __init__(self, hidden_size, out_channels):
-        super().__init__()
-        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
-        self.linear = nn.Linear(hidden_size, out_channels, bias=True)
-        self.adaLN_modulation = nn.Sequential(
-            nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
-        )
-
-    def forward(self, x, c):
-        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
-        x = modulate(self.norm_final(x), shift, scale)
-        return self.linear(x)
-
-
-class DiT(nn.Module):
-    def __init__(
-        self,
-        hidden_size,
-        num_heads,
-        diffusion_num_layers,
-        channels=160,
-        mlp_ratio=4.0,
-        max_seq_len=16384,
-        condition_dim=512,
-        style_dim=None,
-        cross_condition_dim=None,
-    ):
-        super().__init__()
-
-        self.max_seq_len = max_seq_len
-
-        self.time_embedder = TimestepEmbedder(hidden_size)
-        self.condition_embedder = FeedForward(
-            condition_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
-        )
-
-        if cross_condition_dim is not None:
-            self.cross_condition_embedder = FeedForward(
-                cross_condition_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
-            )
-
-        self.use_style = style_dim is not None
-        if self.use_style:
-            self.style_embedder = FeedForward(
-                style_dim, int(hidden_size * mlp_ratio), out_dim=hidden_size
-            )
-
-        self.diffusion_blocks = nn.ModuleList(
-            [
-                DiTBlock(
-                    hidden_size,
-                    num_heads,
-                    mlp_ratio,
-                    use_self_attention=i % 4 == 0,
-                    use_cross_attention=cross_condition_dim is not None,
-                )
-                for i in range(diffusion_num_layers)
-            ]
-        )
-
-        # Downsample & upsample blocks
-        self.input_embedder = FeedForward(
-            channels, int(hidden_size * mlp_ratio), out_dim=hidden_size
-        )
-        self.final_layer = FinalLayer(hidden_size, channels)
-
-        self.register_buffer(
-            "freqs_cis", precompute_freqs_cis(max_seq_len, hidden_size // num_heads)
-        )
-
-        self.initialize_weights()
-
-    def initialize_weights(self):
-        # Initialize input embedding:
-        self.input_embedder.apply(self.init_weight)
-        self.time_embedder.mlp.apply(self.init_weight)
-        self.condition_embedder.apply(self.init_weight)
-
-        if self.use_style:
-            self.style_embedder.apply(self.init_weight)
-
-        if hasattr(self, "cross_condition_embedder"):
-            self.cross_condition_embedder.apply(self.init_weight)
-
-        for block in self.diffusion_blocks:
-            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
-            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-            block.mix.apply(self.init_weight)
-
-        # Zero-out output layers:
-        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
-        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
-        self.final_layer.linear.apply(self.init_weight)
-
-    def init_weight(self, m):
-        if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d, nn.Linear)):
-            nn.init.normal_(m.weight, 0, 0.02)
-            if m.bias is not None:
-                nn.init.constant_(m.bias, 0)
-
-    def forward(
-        self,
-        x,
-        time,
-        condition,
-        style=None,
-        self_mask=None,
-        cross_condition=None,
-        cross_mask=None,
-    ):
-        # Embed inputs
-        x = self.input_embedder(x)
-        t = self.time_embedder(time)
-
-        condition = self.condition_embedder(condition)
-
-        if self.use_style:
-            style = self.style_embedder(style)
-
-        if cross_condition is not None:
-            cross_condition = self.cross_condition_embedder(cross_condition)
-            cross_condition = t[:, None, :] + cross_condition
-
-        # Merge t, condition, and style
-        condition = t[:, None, :] + condition
-        if self.use_style:
-            condition = condition + style[:, None, :]
-
-        if self_mask is not None:
-            self_mask = self_mask[:, None, None, :]
-
-        if cross_mask is not None:
-            cross_mask = cross_mask[:, None, None, :]
-
-        # DiT
-        for block in self.diffusion_blocks:
-            x = block(
-                x,
-                condition,
-                self.freqs_cis,
-                self_mask=self_mask,
-                cross_condition=cross_condition,
-                cross_mask=cross_mask,
-            )
-
-        x = self.final_layer(x, condition)
-
-        return x
-
-
-if __name__ == "__main__":
-    model = DiT(
-        hidden_size=384,
-        num_heads=6,
-        diffusion_num_layers=12,
-        channels=160,
-        condition_dim=512,
-        style_dim=256,
-    )
-    bs, seq_len = 8, 1024
-    x = torch.randn(bs, seq_len, 160)
-    condition = torch.randn(bs, seq_len, 512)
-    style = torch.randn(bs, 256)
-    mask = torch.ones(bs, seq_len, dtype=torch.bool)
-    mask[0, 5:] = False
-    time = torch.arange(bs)
-    print(time)
-    out = model(x, time, condition, style, self_mask=mask)
-    print(out.shape)  # torch.Size([2, 100, 160])
-
-    # Print model size
-    num_params = sum(p.numel() for p in model.parameters())
-    print(f"Number of parameters: {num_params / 1e6:.1f}M")

+ 8 - 3
fish_speech/models/vqgan/modules/fsq.py

@@ -9,7 +9,7 @@ from einops import rearrange
 from torch.nn.utils import weight_norm
 from vector_quantize_pytorch import GroupedResidualFSQ
 
-from .convnext import ConvNeXtBlock
+from .firefly import ConvNeXtBlock
 
 
 @dataclass
@@ -56,7 +56,6 @@ class DownsampleFiniteScalarQuantize(nn.Module):
                         stride=factor,
                     ),
                     ConvNeXtBlock(dim=all_dims[idx + 1]),
-                    ConvNeXtBlock(dim=all_dims[idx + 1]),
                 )
                 for idx, factor in enumerate(downsample_factor)
             ]
@@ -72,12 +71,18 @@ class DownsampleFiniteScalarQuantize(nn.Module):
                         stride=factor,
                     ),
                     ConvNeXtBlock(dim=all_dims[idx]),
-                    ConvNeXtBlock(dim=all_dims[idx]),
                 )
                 for idx, factor in reversed(list(enumerate(downsample_factor)))
             ]
         )
 
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Conv1d, nn.Linear)):
+            nn.init.trunc_normal_(m.weight, std=0.02)
+            nn.init.constant_(m.bias, 0)
+
     def forward(self, z) -> FSQResult:
         original_shape = z.shape
         z = self.downsample(z)

+ 84 - 95
fish_speech/models/vqgan/modules/wavenet.py

@@ -1,4 +1,5 @@
 import math
+from typing import Optional
 
 import torch
 import torch.nn.functional as F
@@ -83,7 +84,14 @@ class ConvNorm(nn.Module):
 class ResidualBlock(nn.Module):
     """Residual Block"""
 
-    def __init__(self, d_encoder, residual_channels, use_linear_bias=False, dilation=1):
+    def __init__(
+        self,
+        residual_channels,
+        use_linear_bias=False,
+        dilation=1,
+        has_condition=True,
+        condition_channels=None,
+    ):
         super(ResidualBlock, self).__init__()
         self.conv_layer = ConvNorm(
             residual_channels,
@@ -93,23 +101,31 @@ class ResidualBlock(nn.Module):
             padding=dilation,
             dilation=dilation,
         )
-        self.diffusion_projection = LinearNorm(
-            residual_channels, residual_channels, use_linear_bias
-        )
-        self.conditioner_projection = ConvNorm(
-            d_encoder, 2 * residual_channels, kernel_size=1
-        )
+
+        if has_condition:
+            self.diffusion_projection = LinearNorm(
+                residual_channels, residual_channels, use_linear_bias
+            )
+            self.condition_projection = ConvNorm(
+                condition_channels, 2 * residual_channels, kernel_size=1
+            )
+
         self.output_projection = ConvNorm(
             residual_channels, 2 * residual_channels, kernel_size=1
         )
 
-    def forward(self, x, conditioner, diffusion_step):
-        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
-        conditioner = self.conditioner_projection(conditioner)
+    def forward(self, x, condition=None, diffusion_step=None):
+        y = x
+
+        if diffusion_step is not None:
+            diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
+            y = y + diffusion_step
 
-        y = x + diffusion_step
+        y = self.conv_layer(y)
 
-        y = self.conv_layer(y) + conditioner
+        if condition is not None:
+            condition = self.condition_projection(condition)
+            y = y + condition
 
         gate, filter = torch.chunk(y, 2, dim=1)
         y = torch.sigmoid(gate) * torch.tanh(filter)
@@ -120,117 +136,90 @@ class ResidualBlock(nn.Module):
         return (x + residual) / math.sqrt(2.0), skip
 
 
-class SpectrogramUpsampler(nn.Module):
-    def __init__(self, hop_size):
+class WaveNet(nn.Module):
+    def __init__(
+        self,
+        input_channels: Optional[int] = None,
+        output_channels: Optional[int] = None,
+        residual_channels: int = 512,
+        residual_layers: int = 20,
+        dilation_cycle: Optional[int] = 4,
+        is_diffusion: bool = False,
+        condition_channels: Optional[int] = None,
+    ):
         super().__init__()
 
-        if hop_size == 256:
-            self.conv1 = nn.ConvTranspose2d(
-                1, 1, [3, 32], stride=[1, 16], padding=[1, 8]
-            )
-        elif hop_size == 512:
-            self.conv1 = nn.ConvTranspose2d(
-                1, 1, [3, 64], stride=[1, 32], padding=[1, 16]
+        # Input projection
+        self.input_projection = None
+        if input_channels is not None and input_channels != residual_channels:
+            self.input_projection = ConvNorm(
+                input_channels, residual_channels, kernel_size=1
             )
-        else:
-            raise ValueError(f"Unsupported hop_size: {hop_size}")
 
-        self.conv2 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
-
-    def forward(self, x):
-        x = torch.unsqueeze(x, 1)
-        x = self.conv1(x)
-        x = F.leaky_relu(x, 0.4)
-        x = self.conv2(x)
-        x = F.leaky_relu(x, 0.4)
-        x = torch.squeeze(x, 1)
+        if input_channels is None:
+            input_channels = residual_channels
 
-        return x
-
-
-class WaveNet(nn.Module):
-    """
-    WaveNet
-    https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio
-    """
-
-    def __init__(
-        self,
-        mel_channels=128,
-        d_encoder=256,
-        residual_channels=512,
-        residual_layers=20,
-        use_linear_bias=False,
-        dilation_cycle=None,
-    ):
-        super(WaveNet, self).__init__()
-
-        self.input_projection = ConvNorm(mel_channels, residual_channels, kernel_size=1)
-        self.diffusion_embedding = DiffusionEmbedding(residual_channels)
-        self.mlp = nn.Sequential(
-            LinearNorm(residual_channels, residual_channels * 4, use_linear_bias),
-            Mish(),
-            LinearNorm(residual_channels * 4, residual_channels, use_linear_bias),
-        )
+        # Residual layers
         self.residual_layers = nn.ModuleList(
             [
                 ResidualBlock(
-                    d_encoder,
-                    residual_channels,
-                    use_linear_bias=use_linear_bias,
+                    residual_channels=residual_channels,
+                    use_linear_bias=False,
                     dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
+                    has_condition=is_diffusion,
+                    condition_channels=condition_channels,
                 )
                 for i in range(residual_layers)
             ]
         )
+
+        # Skip projection
         self.skip_projection = ConvNorm(
             residual_channels, residual_channels, kernel_size=1
         )
-        self.output_projection = ConvNorm(
-            residual_channels, mel_channels, kernel_size=1
-        )
-        nn.init.zeros_(self.output_projection.conv.weight)
-
-    def forward(self, x, diffusion_step, conditioner, x_masks=None, cond_masks=None):
-        """
 
-        :param x: [B, M, T]
-        :param diffusion_step: [B,]
-        :param conditioner: [B, M, T]
-        :return:
-        """
-
-        # To keep compatibility with DiffSVC, [B, 1, M, T]
-        use_4_dim = False
-        if x.dim() == 4:
-            x = x[:, 0]
-            use_4_dim = True
+        # Output projection
+        self.output_projection = None
+        if output_channels is not None and output_channels != residual_channels:
+            self.output_projection = ConvNorm(
+                residual_channels, output_channels, kernel_size=1
+            )
 
-        assert x.dim() == 3, f"mel must be 3 dim tensor, but got {x.dim()}"
+        if is_diffusion:
+            self.diffusion_embedding = DiffusionEmbedding(residual_channels)
+            self.mlp = nn.Sequential(
+                LinearNorm(residual_channels, residual_channels * 4, False),
+                Mish(),
+                LinearNorm(residual_channels * 4, residual_channels, False),
+            )
 
-        x = self.input_projection(x)  # x [B, residual_channel, T]
-        x = F.relu(x)
+        self.apply(self._init_weights)
 
-        diffusion_step = self.diffusion_embedding(diffusion_step)
-        diffusion_step = self.mlp(diffusion_step)
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Conv1d, nn.Linear)):
+            nn.init.trunc_normal_(m.weight, std=0.02)
+            if getattr(m, "bias", None) is not None:
+                nn.init.constant_(m.bias, 0)
 
-        if x_masks is not None:
-            x = x * x_masks
+    def forward(self, x, t=None, condition=None):
+        if self.input_projection is not None:
+            x = self.input_projection(x)
+            x = F.silu(x)
 
-        if cond_masks is not None:
-            conditioner = conditioner * cond_masks
+        if t is not None:
+            t = self.diffusion_embedding(t)
+            t = self.mlp(t)
 
         skip = []
         for layer in self.residual_layers:
-            x, skip_connection = layer(x, conditioner, diffusion_step)
+            x, skip_connection = layer(x, condition, t)
             skip.append(skip_connection)
 
         x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
         x = self.skip_projection(x)
-        x = F.relu(x)
-        x = self.output_projection(x)  # [B, 128, T]
 
-        if x_masks is not None:
-            x = x * x_masks
+        if self.output_projection is not None:
+            x = F.silu(x)
+            x = self.output_projection(x)
 
-        return x[:, None] if use_4_dim else x
+        return x