Bladeren bron

borrow VITS components from other repo

Lengyue 2 jaren geleden
bovenliggende
commit
e0d8c72b24

+ 22 - 49
fish_speech/configs/hubert_vq.yaml

@@ -13,7 +13,7 @@ trainer:
     static_graph: true
   precision: 32
   max_steps: 1_000_000
-  val_check_interval: 1000
+  val_check_interval: 5000
 
 sample_rate: 32000
 hop_length: 640
@@ -27,14 +27,13 @@ train_dataset:
   filelist: data/vq_train_filelist.txt
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
-  slice_frames: 32
+  slice_frames: 512
 
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
   filelist: data/vq_val_filelist.txt
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
-  slice_frames: 256
 
 data:
   _target_: fish_speech.datasets.vqgan.VQGANDataModule
@@ -51,57 +50,31 @@ model:
   hop_length: ${hop_length}
   segment_size: 20480
 
-  semantic_encoder:
-    _target_: fish_speech.models.vqgan.modules.SemanticEncoder
+  generator:
+    _target_: fish_speech.models.vqgan.modules.models.SynthesizerTrn
     in_channels: 1024
-    hidden_channels: 384
-    out_channels: 192
-    num_heads: 2
-    num_layers: 8
-    input_downsample: true
-    code_book_size: 2048
-    freeze_vq: false
-    gin_channels: ${model.speaker_encoder.out_channels}
-
-  posterior_encoder:
-    _target_: fish_speech.models.vqgan.modules.PosteriorEncoder
-    in_channels: "${eval: '${n_fft} // 2 + 1'}"
-    hidden_channels: 192
-    out_channels: 192
-    gin_channels: ${model.speaker_encoder.out_channels}
-  
-  speaker_encoder:
-    _target_: fish_speech.models.vqgan.modules.SpeakerEncoder
-    in_channels: ${num_mels}
+    spec_channels: ${num_mels}
+    segment_size: "${eval: '${model.segment_size} // ${hop_length}'}"
+    inter_channels: 192
     hidden_channels: 192
-    out_channels: 512
-
-  # flow:
-  #   _target_: fish_speech.models.vqgan.modules.ResidualCouplingBlock
-  #   channels: 192
-  #   hidden_channels: 192
-  #   kernel_size: 5
-  #   dilation_rate: 1
-  #   n_layers: 4
-  #   n_flows: 4
-  #   gin_channels: ${model.speaker_encoder.out_channels}
-
-  generator:
-    _target_: fish_speech.models.vqgan.modules.Generator
-    initial_channel: 192
-    gin_channels: ${model.speaker_encoder.out_channels}
+    filter_channels: 768
+    n_heads: 4
+    n_layers: 8
+    n_layers_q: 16
+    n_layers_spk: 6
+    kernel_size: 3
+    p_dropout: 0.1
+    speaker_cond_layer: 0
     resblock: "1"
     resblock_kernel_sizes: [3, 7, 11]
-    resblock_dilation_sizes: 
-      - [1, 3, 5]
-      - [1, 3, 5]
-      - [1, 3, 5]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
     upsample_rates: [10, 8, 2, 2, 2]
     upsample_initial_channel: 512
     upsample_kernel_sizes: [16, 16, 8, 2, 2]
+    gin_channels: 512 # basically the speaker embedding size
 
   discriminator:
-    _target_: fish_speech.models.vqgan.modules.EnsembleDiscriminator
+    _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
 
   mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
@@ -114,7 +87,7 @@ model:
   optimizer:
     _target_: torch.optim.AdamW
     _partial_: true
-    lr: 1e-4
+    lr: 2e-4
     betas: [0.8, 0.99]
     eps: 1e-5
 
@@ -124,7 +97,7 @@ model:
     lr_lambda:
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _partial_: true
-      num_warmup_steps: 0
+      num_warmup_steps: 1000
       num_training_steps: ${trainer.max_steps}
       final_lr_ratio: 0.05
 
@@ -133,5 +106,5 @@ callbacks:
     sub_module: generator
 
 # Resume from rcell's checkpoint
-ckpt_path: results/hubert-vq-pretrain/rcell/ckpt_23000_pl.pth
-resume_weights_only: true
+# ckpt_path: results/hubert-vq-pretrain/rcell/ckpt_23000_pl.pth
+# resume_weights_only: true

+ 5 - 1
fish_speech/datasets/vqgan.py

@@ -26,7 +26,11 @@ class VQGANDataset(Dataset):
         filelist = Path(filelist)
         root = filelist.parent
 
-        self.files = [root / line.strip() for line in filelist.read_text().splitlines()]
+        self.files = [
+            root / line.strip()
+            for line in filelist.read_text().splitlines()
+            if ("Genshin" in line or "StarRail" in line)
+        ]
         self.sample_rate = sample_rate
         self.hop_length = hop_length
         self.slice_frames = slice_frames

+ 73 - 217
fish_speech/models/vqgan/lit_module.py

@@ -8,15 +8,17 @@ import wandb
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from torch import nn
+from vector_quantize_pytorch import ResidualLFQ
 
-from fish_speech.models.vqgan.modules import (
-    EnsembleDiscriminator,
-    Generator,
-    PosteriorEncoder,
-    SemanticEncoder,
-    SpeakerEncoder,
+from fish_speech.models.vqgan.losses import (
+    discriminator_loss,
+    feature_loss,
+    generator_loss,
+    kl_loss_normal,
 )
-from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
+from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
+from fish_speech.models.vqgan.modules.models import SynthesizerTrn
+from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
 
 
 class VQGAN(L.LightningModule):
@@ -24,11 +26,7 @@ class VQGAN(L.LightningModule):
         self,
         optimizer: Callable,
         lr_scheduler: Callable,
-        semantic_encoder: SemanticEncoder,
-        posterior_encoder: PosteriorEncoder,
-        speaker_encoder: SpeakerEncoder,
-        # flow: nn.Module,
-        generator: Generator,
+        generator: SynthesizerTrn,
         discriminator: EnsembleDiscriminator,
         mel_transform: nn.Module,
         segment_size: int = 20480,
@@ -42,11 +40,6 @@ class VQGAN(L.LightningModule):
         self.lr_scheduler_builder = lr_scheduler
 
         # Generator and discriminators
-        # Compile generator so that snake can save memory
-        self.semantic_encoder = semantic_encoder
-        self.posterior_encoder = posterior_encoder
-        self.speaker_encoder = speaker_encoder
-        # self.flow = flow
         self.generator = generator
         self.discriminator = discriminator
         self.mel_transform = mel_transform
@@ -61,15 +54,7 @@ class VQGAN(L.LightningModule):
 
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
-        optimizer_generator = self.optimizer_builder(
-            itertools.chain(
-                self.semantic_encoder.parameters(),
-                self.posterior_encoder.parameters(),
-                self.speaker_encoder.parameters(),
-                self.generator.parameters(),
-                # self.flow.parameters(),
-            )
-        )
+        optimizer_generator = self.optimizer_builder(self.generator.parameters())
         optimizer_discriminator = self.optimizer_builder(
             self.discriminator.parameters()
         )
@@ -96,127 +81,41 @@ class VQGAN(L.LightningModule):
             },
         )
 
-    @staticmethod
-    def discriminator_loss(disc_real_outputs, disc_generated_outputs):
-        loss = 0
-        r_losses = []
-        g_losses = []
-        for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
-            dr = dr.float()
-            dg = dg.float()
-            r_loss = torch.mean((1 - dr) ** 2)
-            g_loss = torch.mean(dg**2)
-            loss += r_loss + g_loss
-            r_losses.append(r_loss.item())
-            g_losses.append(g_loss.item())
-
-        return loss, r_losses, g_losses
-
-    @staticmethod
-    def generator_loss(disc_outputs):
-        loss = 0
-        gen_losses = []
-        for dg in disc_outputs:
-            dg = dg.float()
-            l = torch.mean((1 - dg) ** 2)
-            gen_losses.append(l)
-            loss += l
-
-        return loss, gen_losses
-
-    @staticmethod
-    def feature_loss(fmap_r, fmap_g):
-        loss = 0
-        for dr, dg in zip(fmap_r, fmap_g):
-            for rl, gl in zip(dr, dg):
-                rl = rl.float().detach()
-                gl = gl.float()
-                loss += torch.mean(torch.abs(rl - gl))
-
-        return loss * 2
-
-    @staticmethod
-    def kl_loss(m_q, logs_q, m_p, logs_p, z_mask):
-        """
-        m_q, logs_q: [b, h, t_t]
-        m_p, logs_p: [b, h, t_t]
-        """
-        m_q = m_q.float()
-        logs_q = logs_q.float()
-        m_p = m_p.float()
-        logs_p = logs_p.float()
-        z_mask = z_mask.float()
-
-        kl = 0.5 * (
-            (m_q - m_p) ** 2 / torch.exp(logs_p)
-            + torch.exp(logs_q) / torch.exp(logs_p)
-            - 1
-            - logs_q
-            + logs_p
-        )
-
-        kl = torch.sum(kl * z_mask)
-        l = kl / torch.sum(z_mask)
-
-        return l
-
     def training_step(self, batch, batch_idx):
         optim_g, optim_d = self.optimizers()
 
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
         features, feature_lengths = batch["features"], batch["feature_lengths"]
+        audios = audios[:, None, :]
 
         audios = audios.float()
         features = features.float()
 
         with torch.no_grad():
-            gt_mels, gt_specs = self.mel_transform(audios, return_linear=True)
-            gt_mels = gt_mels.transpose(1, 2)
-            key_padding_mask = sequence_mask(feature_lengths)
-            mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
-            audio_masks = sequence_mask(audio_lengths)[:, None]
-
-            assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
-            gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
-            gt_mels = gt_mels[:, :gt_mel_length]
-            gt_specs = gt_specs[:, :, :gt_mel_length]
-            mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
-
-            assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
-            gt_feature_length = min(features.shape[1], key_padding_mask.shape[1])
-            features = features[:, :gt_feature_length]
-            key_padding_mask = key_padding_mask[:, :gt_feature_length]
-
-        audios = audios[:, None, :]
-
-        speaker = self.speaker_encoder(gt_mels, mels_key_padding_mask)[:, :, None]
-        prior = self.semantic_encoder(
-            x=features,
-            key_padding_mask=key_padding_mask,
-            g=speaker,
-        )
-
-        posterior_key_padding_mask = (~mels_key_padding_mask).float()[:, None]
-        posterior = self.posterior_encoder(
-            gt_specs, posterior_key_padding_mask, g=speaker
-        )
-        # z_p = self.flow(posterior.mean, posterior_key_padding_mask, g=speaker)
-        fake_audios = self.generator(posterior.z, g=speaker)
-
-        min_audio_length = min(audios.shape[-1], fake_audios.shape[-1])
-        audios = audios[:, :, :min_audio_length]
-        fake_audios = fake_audios[:, :, :min_audio_length]
-        audio_masks = audio_masks[:, :, :min_audio_length]
-
-        audio = torch.masked_fill(audios, audio_masks, 0.0)
-        fake_audios = torch.masked_fill(fake_audios, audio_masks, 0.0)
-        assert fake_audios.shape == audio.shape
+            gt_mels = self.mel_transform(audios)
+            assert (
+                gt_mels.shape[2] == features.shape[1]
+            ), f"Shapes do not match: {gt_mels.shape}, {features.shape}"
+
+        (
+            y_hat,
+            ids_slice,
+            x_mask,
+            y_mask,
+            (z_q_audio, z_p),
+            (m_p_text, logs_p_text),
+            (m_q, logs_q),
+        ) = self.generator(features, feature_lengths, gt_mels, feature_lengths)
+
+        y_hat_mel = self.mel_transform(y_hat.squeeze(1))
+        y_mel = slice_segments(gt_mels, ids_slice, self.segment_size // self.hop_length)
+        y = slice_segments(audios, ids_slice * self.hop_length, self.segment_size)
 
         # Discriminator
-        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audio, fake_audios.detach())
+        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
 
         with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_disc_all, _, _ = self.discriminator_loss(y_d_hat_r, y_d_hat_g)
+            loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
 
         self.log(
             "train/discriminator/loss",
@@ -235,32 +134,26 @@ class VQGAN(L.LightningModule):
         )
         optim_d.step()
 
-        y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(audios, fake_audios)
-        fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
-
-        # Min mel length
-        min_mel_length = min(gt_mels.shape[1], fake_mels.shape[1])
-        gt_mels = gt_mels[:, :min_mel_length]
-        fake_mels = fake_mels[:, :min_mel_length]
-        mels_key_padding_mask = mels_key_padding_mask[:, :min_mel_length]
-
-        # Fill mel mask
-        fake_mels = torch.masked_fill(fake_mels, mels_key_padding_mask[:, :, None], 0.0)
-        gt_mels = torch.masked_fill(gt_mels, mels_key_padding_mask[:, :, None], 0.0)
+        y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
 
         with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_mel = F.l1_loss(gt_mels, fake_mels)
-            loss_adv, _ = self.generator_loss(y_d_hat_g)
-            loss_fm = self.feature_loss(fmap_r, fmap_g)
-            loss_kl = self.kl_loss(
-                posterior.mean,
-                posterior.logs,
-                prior.mean,
-                prior.logs,
-                posterior_key_padding_mask,
+            loss_mel = F.l1_loss(y_mel, y_hat_mel)
+            loss_adv, _ = generator_loss(y_d_hat_g)
+            loss_fm = feature_loss(fmap_r, fmap_g)
+            # x_mask,
+            # y_mask,
+            # (z_q_audio, z_p),
+            # (m_p_text, logs_p_text),
+            # (m_q, logs_q),
+            loss_kl = kl_loss_normal(
+                m_q,
+                logs_q,
+                m_p_text,
+                logs_p_text,
+                x_mask,
             )
 
-            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + prior.loss + loss_kl
+            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + loss_kl * 0.05
 
         self.log(
             "train/generator/loss",
@@ -307,15 +200,15 @@ class VQGAN(L.LightningModule):
             logger=True,
             sync_dist=True,
         )
-        self.log(
-            "train/generator/loss_vq",
-            prior.loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
+        # self.log(
+        #     "train/generator/loss_vq",
+        #     prior.loss,
+        #     on_step=True,
+        #     on_epoch=False,
+        #     prog_bar=False,
+        #     logger=True,
+        #     sync_dist=True,
+        # )
 
         optim_g.zero_grad()
         self.manual_backward(loss_gen_all)
@@ -335,60 +228,23 @@ class VQGAN(L.LightningModule):
 
         audios = audios.float()
         features = features.float()
-
-        with torch.no_grad():
-            gt_mels, gt_specs = self.mel_transform(audios, return_linear=True)
-            gt_mels = gt_mels.transpose(1, 2)
-            key_padding_mask = sequence_mask(feature_lengths)
-            mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
-
-            assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
-            gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
-            gt_mels = gt_mels[:, :gt_mel_length]
-            gt_specs = gt_specs[:, :, :gt_mel_length]
-            mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
-
-            assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
-            gt_feature_length = min(features.shape[1], key_padding_mask.shape[1])
-            features = features[:, :gt_feature_length]
-            key_padding_mask = key_padding_mask[:, :gt_feature_length]
-
-        # Generator
-        # speaker: (B, C, 1)
-        speaker = self.speaker_encoder(gt_mels, mels_key_padding_mask)[:, :, None]
-        posterior_key_padding_mask = (~mels_key_padding_mask).float()[:, None]
-
-        z_gen = self.semantic_encoder(
-            x=features,
-            key_padding_mask=key_padding_mask,
-            g=speaker,
-        ).z
-
-        # z_gen = self.flow(z_gen, posterior_key_padding_mask, g=speaker, reverse=True)
-
-        z_posterior = self.posterior_encoder(
-            gt_specs, posterior_key_padding_mask, g=speaker
-        ).mean
-
         audios = audios[:, None, :]
-        fake_audios = self.generator(z_gen, g=speaker)
-        posterior_audios = self.generator(z_posterior)
-        min_audio_length = min(
-            audios.shape[-1], fake_audios.shape[-1], posterior_audios.shape[-1]
-        )
 
-        audios = audios[:, :, :min_audio_length]
-        fake_audios = fake_audios[:, :, :min_audio_length]
-        posterior_audios = posterior_audios[:, :, :min_audio_length]
-        assert fake_audios.shape == audios.shape == posterior_audios.shape
+        gt_mels = self.mel_transform(audios)
+        assert (
+            gt_mels.shape[2] == features.shape[1]
+        ), f"Shapes do not match: {gt_mels.shape}, {features.shape}"
+
+        fake_audios = self.generator.infer(features, feature_lengths, gt_mels)
+        posterior_audios = self.generator.reconstruct(gt_mels, feature_lengths)
 
-        fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
-        posterior_mels = self.mel_transform(posterior_audios.squeeze(1)).transpose(1, 2)
+        fake_mels = self.mel_transform(fake_audios.squeeze(1))
+        posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
 
-        min_mel_length = min(gt_mels.shape[1], fake_mels.shape[1])
-        gt_mels = gt_mels[:, :min_mel_length]
-        fake_mels = fake_mels[:, :min_mel_length]
-        posterior_mels = posterior_mels[:, :min_mel_length]
+        min_mel_length = min(gt_mels.shape[-1], fake_mels.shape[-1])
+        gt_mels = gt_mels[:, :, :min_mel_length]
+        fake_mels = fake_mels[:, :, :min_mel_length]
+        posterior_mels = posterior_mels[:, :, :min_mel_length]
 
         mel_loss = F.l1_loss(gt_mels, fake_mels)
         self.log(
@@ -411,9 +267,9 @@ class VQGAN(L.LightningModule):
             audio_len,
         ) in enumerate(
             zip(
-                gt_mels.transpose(1, 2),
-                fake_mels.transpose(1, 2),
-                posterior_mels.transpose(1, 2),
+                gt_mels,
+                fake_mels,
+                posterior_mels,
                 audios,
                 fake_audios,
                 posterior_audios,

+ 92 - 0
fish_speech/models/vqgan/losses.py

@@ -0,0 +1,92 @@
+from typing import List
+
+import torch
+
+
+def feature_loss(fmap_r: List[torch.Tensor], fmap_g: List[torch.Tensor]):
+    loss = 0
+    for dr, dg in zip(fmap_r, fmap_g):
+        for rl, gl in zip(dr, dg):
+            rl = rl.float().detach()
+            gl = gl.float()
+            loss += torch.mean(torch.abs(rl - gl))
+
+    return loss * 2
+
+
+def discriminator_loss(
+    disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
+):
+    loss = 0
+    r_losses = []
+    g_losses = []
+    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+        dr = dr.float()
+        dg = dg.float()
+        r_loss = torch.mean((1 - dr) ** 2)
+        g_loss = torch.mean(dg**2)
+        loss += r_loss + g_loss
+        r_losses.append(r_loss.item())
+        g_losses.append(g_loss.item())
+
+    return loss, r_losses, g_losses
+
+
+def generator_loss(disc_outputs: List[torch.Tensor]):
+    loss = 0
+    gen_losses = []
+    for dg in disc_outputs:
+        dg = dg.float()
+        l = torch.mean((1 - dg) ** 2)
+        gen_losses.append(l)
+        loss += l
+
+    return loss, gen_losses
+
+
+def kl_loss(
+    z_p: torch.Tensor,
+    logs_q: torch.Tensor,
+    m_p: torch.Tensor,
+    logs_p: torch.Tensor,
+    z_mask: torch.Tensor,
+):
+    """
+    z_p, logs_q: [b, h, t_t]
+    m_p, logs_p: [b, h, t_t]
+    """
+    z_p = z_p.float()
+    logs_q = logs_q.float()
+    m_p = m_p.float()
+    logs_p = logs_p.float()
+    z_mask = z_mask.float()
+
+    kl = logs_p - logs_q - 0.5
+    kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+    kl = torch.sum(kl * z_mask)
+    l = kl / torch.sum(z_mask)
+    return l
+
+
+def kl_loss_normal(
+    m_q: torch.Tensor,
+    logs_q: torch.Tensor,
+    m_p: torch.Tensor,
+    logs_p: torch.Tensor,
+    z_mask: torch.Tensor,
+):
+    """
+    z_p, logs_q: [b, h, t_t]
+    m_p, logs_p: [b, h, t_t]
+    """
+    m_q = m_q.float()
+    logs_q = logs_q.float()
+    m_p = m_p.float()
+    logs_p = logs_p.float()
+    z_mask = z_mask.float()
+
+    kl = logs_p - logs_q - 0.5
+    kl += 0.5 * (torch.exp(2.0 * logs_q) + (m_q - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+    kl = torch.sum(kl * z_mask)
+    l = kl / torch.sum(z_mask)
+    return l

+ 0 - 1090
fish_speech/models/vqgan/modules.py

@@ -1,1090 +0,0 @@
-import math
-from dataclasses import dataclass
-
-import torch
-from encodec.quantization.core_vq import VectorQuantization
-from torch import nn
-from torch.nn import Conv1d, Conv2d, ConvTranspose1d
-from torch.nn import functional as F
-from torch.nn.utils.parametrizations import spectral_norm, weight_norm
-from torch.nn.utils.parametrize import remove_parametrizations
-
-from fish_speech.models.vqgan.utils import (
-    convert_pad_shape,
-    fused_add_tanh_sigmoid_multiply,
-    get_padding,
-    init_weights,
-)
-
-LRELU_SLOPE = 0.1
-
-
-class ResidualCouplingBlock(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size: int = 5,
-        dilation_rate: int = 1,
-        n_layers: int = 4,
-        n_flows: int = 4,
-        gin_channels: int = 512,
-    ):
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.n_flows = n_flows
-        self.gin_channels = gin_channels
-
-        self.flows = nn.ModuleList()
-        for i in range(n_flows):
-            self.flows.append(
-                ResidualCouplingLayer(
-                    channels,
-                    hidden_channels,
-                    kernel_size,
-                    dilation_rate,
-                    n_layers,
-                    gin_channels=gin_channels,
-                    mean_only=True,
-                )
-            )
-            self.flows.append(Flip())
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        if not reverse:
-            for flow in self.flows:
-                x, _ = flow(x, x_mask, g=g, reverse=reverse)
-        else:
-            for flow in reversed(self.flows):
-                x = flow(x, x_mask, g=g, reverse=reverse)
-        return x
-
-
-class Flip(nn.Module):
-    def forward(self, x, *args, reverse=False, **kwargs):
-        x = torch.flip(x, [1])
-        if not reverse:
-            logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
-            return x, logdet
-        else:
-            return x
-
-
-class ResidualCouplingLayer(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        p_dropout=0,
-        gin_channels=0,
-        mean_only=False,
-    ):
-        assert channels % 2 == 0, "channels should be divisible by 2"
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.half_channels = channels // 2
-        self.mean_only = mean_only
-
-        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
-        self.enc = WaveNet(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            p_dropout=p_dropout,
-            gin_channels=gin_channels,
-        )
-        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
-        self.post.weight.data.zero_()
-        self.post.bias.data.zero_()
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
-        h = self.pre(x0) * x_mask
-        h = self.enc(h, x_mask, g=g)
-        stats = self.post(h) * x_mask
-        if not self.mean_only:
-            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
-        else:
-            m = stats
-            logs = torch.zeros_like(m)
-
-        if not reverse:
-            x1 = m + x1 * torch.exp(logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            logdet = torch.sum(logs, [1, 2])
-            return x, logdet
-        else:
-            x1 = (x1 - m) * torch.exp(-logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            return x
-
-
-@dataclass
-class SemanticEncoderOutput:
-    loss: torch.Tensor
-    mean: torch.Tensor
-    logs: torch.Tensor
-    z: torch.Tensor
-
-
-class SemanticEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int = 1024,
-        hidden_channels: int = 384,
-        out_channels: int = 192,
-        num_heads: int = 2,
-        num_layers: int = 8,
-        input_downsample: bool = True,
-        code_book_size: int = 2048,
-        freeze_vq: bool = False,
-        gin_channels: int = 512,
-    ):
-        super().__init__()
-
-        # Feature Encoder
-        down_sample = 2 if input_downsample else 1
-
-        self.in_proj = nn.Conv1d(
-            in_channels, in_channels, kernel_size=down_sample, stride=down_sample
-        )
-        self.vq = VectorQuantization(
-            dim=in_channels,
-            codebook_size=code_book_size,
-            threshold_ema_dead_code=2,
-            kmeans_init=False,
-            kmeans_iters=50,
-        )
-
-        # Init weights of in_proj to mimic the effect of avg pooling
-        nn.init.normal_(
-            self.in_proj.weight, mean=1 / (down_sample * in_channels), std=0.01
-        )
-        self.in_proj.bias.data.zero_()
-
-        self.feature_in = nn.Linear(in_channels, hidden_channels)
-        self.g_in = nn.Conv1d(gin_channels, hidden_channels, 1)
-        self.blocks = nn.ModuleList(
-            [
-                TransformerBlock(
-                    hidden_channels,
-                    num_heads,
-                    window_size=4,
-                    window_heads_share=True,
-                    proximal_init=True,
-                    proximal_bias=False,
-                    use_relative_attn=True,
-                )
-                for _ in range(num_layers)
-            ]
-        )
-
-        self.out_proj = nn.Linear(hidden_channels, out_channels * 2)
-
-        self.input_downsample = input_downsample
-
-        if freeze_vq:
-            for p in self.vq.parameters():
-                p.requires_grad = False
-
-            for p in self.vq_in.parameters():
-                p.requires_grad = False
-
-    def forward(self, x, key_padding_mask=None, g=None) -> SemanticEncoderOutput:
-        # x: (batch, seq_len, channels)
-
-        assert key_padding_mask.size(1) == x.size(
-            1
-        ), f"key_padding_mask shape {key_padding_mask.size()} does not match features shape {x.size()}"
-
-        # Encode Features
-        features = self.in_proj(x.transpose(1, 2))
-        features, _, loss = self.vq(features)
-        features = features.transpose(1, 2)
-
-        if self.input_downsample:
-            features = F.interpolate(
-                features.transpose(1, 2), scale_factor=2, mode="nearest"
-            ).transpose(1, 2)
-
-        # Shape may change due to downsampling, let's cut it to the same size
-        if features.shape[1] != key_padding_mask.shape[1]:
-            assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
-            min_len = min(features.shape[1], key_padding_mask.shape[1])
-            features = features[:, :min_len]
-            key_padding_mask = key_padding_mask[:, :min_len]
-
-        features = self.feature_in(features)
-        g = self.g_in(g).transpose(1, 2)
-        features = features + g
-
-        for block in self.blocks:
-            features = block(features, key_padding_mask=key_padding_mask)
-
-        stats = self.out_proj(features).transpose(1, 2)
-        stats = torch.masked_fill(stats, key_padding_mask.unsqueeze(1), 0)
-        mean, logs = torch.chunk(stats, 2, dim=1)
-
-        return SemanticEncoderOutput(
-            loss=loss,
-            mean=mean,
-            logs=logs,
-            z=mean + torch.randn_like(mean) * torch.exp(logs) * 0.5,
-        )
-
-
-class WaveNet(nn.Module):
-    def __init__(
-        self,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        gin_channels=0,
-        p_dropout=0,
-    ):
-        super(WaveNet, self).__init__()
-        assert kernel_size % 2 == 1
-        self.hidden_channels = hidden_channels
-        self.kernel_size = (kernel_size,)
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.gin_channels = gin_channels
-        self.p_dropout = p_dropout
-
-        self.in_layers = nn.ModuleList()
-        self.res_skip_layers = nn.ModuleList()
-        self.drop = nn.Dropout(p_dropout)
-
-        if gin_channels != 0:
-            self.cond_layer = weight_norm(
-                nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
-            )
-
-        for i in range(n_layers):
-            dilation = dilation_rate**i
-            padding = int((kernel_size * dilation - dilation) / 2)
-            in_layer = weight_norm(
-                nn.Conv1d(
-                    hidden_channels,
-                    2 * hidden_channels,
-                    kernel_size,
-                    dilation=dilation,
-                    padding=padding,
-                )
-            )
-            self.in_layers.append(in_layer)
-
-            # last one is not necessary
-            if i < n_layers - 1:
-                res_skip_channels = 2 * hidden_channels
-            else:
-                res_skip_channels = hidden_channels
-
-            res_skip_layer = weight_norm(
-                nn.Conv1d(hidden_channels, res_skip_channels, 1), name="weight"
-            )
-            self.res_skip_layers.append(res_skip_layer)
-
-    def forward(self, x, x_mask, g=None):
-        output = torch.zeros_like(x)
-        n_channels_tensor = torch.IntTensor([self.hidden_channels])
-
-        if g is not None:
-            g = self.cond_layer(g)
-
-        for i in range(self.n_layers):
-            x_in = self.in_layers[i](x)
-            if g is not None:
-                cond_offset = i * 2 * self.hidden_channels
-                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
-            else:
-                g_l = torch.zeros_like(x_in)
-
-            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
-            acts = self.drop(acts)
-
-            res_skip_acts = self.res_skip_layers[i](acts)
-            if i < self.n_layers - 1:
-                res_acts = res_skip_acts[:, : self.hidden_channels, :]
-                x = (x + res_acts) * x_mask
-                output = output + res_skip_acts[:, self.hidden_channels :, :]
-            else:
-                output = output + res_skip_acts
-
-        return output * x_mask
-
-    def remove_parametrizations(self):
-        if self.gin_channels != 0:
-            nn.utils.remove_parametrizations(self.cond_layer)
-        for l in self.in_layers:
-            nn.utils.remove_parametrizations(l)
-        for l in self.res_skip_layers:
-            nn.utils.remove_parametrizations(l)
-
-
-@dataclass
-class PosteriorEncoderOutput:
-    z: torch.Tensor
-    mean: torch.Tensor
-    logs: torch.Tensor
-
-
-class PosteriorEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int = 1024,
-        out_channels: int = 192,
-        hidden_channels: int = 192,
-        kernel_size: int = 5,
-        dilation_rate: int = 1,
-        n_layers: int = 16,
-        gin_channels: int = 512,
-    ):
-        super().__init__()
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.gin_channels = gin_channels
-
-        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
-        self.enc = WaveNet(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            gin_channels=gin_channels,
-        )
-        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
-
-    def forward(self, x, x_mask, g=None):
-        g = g.detach()
-        x = self.pre(x) * x_mask
-        x = self.enc(x, x_mask, g=g)
-        stats = self.proj(x) * x_mask
-        m, logs = torch.split(stats, self.out_channels, dim=1)
-        z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
-
-        return PosteriorEncoderOutput(
-            z=z,
-            mean=m,
-            logs=logs,
-        )
-
-
-class SpeakerEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int = 128,
-        hidden_channels: int = 192,
-        out_channels: int = 512,
-        num_heads: int = 2,
-        num_layers: int = 4,
-    ) -> None:
-        super().__init__()
-
-        self.query = nn.Parameter(torch.randn(1, 1, hidden_channels))
-        self.in_proj = nn.Linear(in_channels, hidden_channels)
-        self.blocks = nn.ModuleList(
-            [
-                TransformerBlock(
-                    hidden_channels,
-                    num_heads,
-                    use_relative_attn=False,
-                )
-                for _ in range(num_layers)
-            ]
-        )
-        self.out_proj = nn.Linear(hidden_channels, out_channels)
-
-    def forward(self, mels, mels_key_padding_mask=None):
-        x = self.in_proj(mels)
-        x = torch.cat([self.query.expand(x.shape[0], -1, -1), x], dim=1)
-
-        mels_key_padding_mask = torch.cat(
-            [
-                torch.ones(x.shape[0], 1, dtype=torch.bool, device=x.device),
-                mels_key_padding_mask,
-            ],
-            dim=1,
-        )
-        for block in self.blocks:
-            x = block(x, key_padding_mask=mels_key_padding_mask)
-
-        x = self.out_proj(x[:, 0])
-
-        return x
-
-
-class TransformerBlock(nn.Module):
-    def __init__(
-        self,
-        channels,
-        n_heads,
-        mlp_ratio=4 * 2 / 3,
-        p_dropout=0.0,
-        window_size=4,
-        window_heads_share=True,
-        proximal_init=True,
-        proximal_bias=False,
-        use_relative_attn=True,
-    ):
-        super().__init__()
-
-        self.attn_norm = RMSNorm(channels)
-
-        if use_relative_attn:
-            self.attn = RelativeAttention(
-                channels,
-                n_heads,
-                p_dropout,
-                window_size,
-                window_heads_share,
-                proximal_init,
-                proximal_bias,
-            )
-        else:
-            self.attn = nn.MultiheadAttention(
-                embed_dim=channels,
-                num_heads=n_heads,
-                dropout=p_dropout,
-                batch_first=True,
-            )
-
-        self.mlp_norm = RMSNorm(channels)
-        self.mlp = SwiGLU(channels, int(channels * mlp_ratio), channels, drop=p_dropout)
-
-    def forward(self, x, key_padding_mask=None):
-        norm_x = self.attn_norm(x)
-
-        if isinstance(self.attn, RelativeAttention):
-            attn = self.attn(norm_x, key_padding_mask=key_padding_mask)
-        else:
-            attn, _ = self.attn(
-                norm_x, norm_x, norm_x, key_padding_mask=key_padding_mask
-            )
-
-        x = x + attn
-        x = x + self.mlp(self.mlp_norm(x))
-
-        return x
-
-
-class SwiGLU(nn.Module):
-    """
-    Swish-Gated Linear Unit (SwiGLU) activation function
-    """
-
-    def __init__(
-        self,
-        in_features,
-        hidden_features=None,
-        out_features=None,
-        bias=True,
-        drop=0.0,
-    ):
-        super().__init__()
-        out_features = out_features or in_features
-        hidden_features = hidden_features or in_features
-        assert hidden_features % 2 == 0
-
-        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
-        self.act = nn.SiLU()
-        self.drop1 = nn.Dropout(drop)
-        self.norm = RMSNorm(hidden_features // 2)
-        self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias)
-        self.drop2 = nn.Dropout(drop)
-
-    def init_weights(self):
-        # override init of fc1 w/ gate portion set to weight near zero, bias=1
-        fc1_mid = self.fc1.bias.shape[0] // 2
-        nn.init.ones_(self.fc1.bias[fc1_mid:])
-        nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
-
-    def forward(self, x):
-        x = self.fc1(x)
-        x1, x2 = x.chunk(2, dim=-1)
-
-        x = x1 * self.act(x2)
-        x = self.drop1(x)
-        x = self.norm(x)
-        x = self.fc2(x)
-        x = self.drop2(x)
-
-        return x
-
-
-class RMSNorm(nn.Module):
-    def __init__(self, hidden_size, eps=1e-6):
-        """
-        LlamaRMSNorm is equivalent to T5LayerNorm
-        """
-        super().__init__()
-
-        self.weight = nn.Parameter(torch.ones(hidden_size))
-        self.variance_epsilon = eps
-
-    def forward(self, hidden_states):
-        input_dtype = hidden_states.dtype
-        hidden_states = hidden_states.to(torch.float32)
-        variance = hidden_states.pow(2).mean(-1, keepdim=True)
-        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
-
-        return self.weight * hidden_states.to(input_dtype)
-
-
-class RelativeAttention(nn.Module):
-    def __init__(
-        self,
-        channels,
-        n_heads,
-        p_dropout=0.0,
-        window_size=4,
-        window_heads_share=True,
-        proximal_init=True,
-        proximal_bias=False,
-    ):
-        super().__init__()
-        assert channels % n_heads == 0
-
-        self.channels = channels
-        self.n_heads = n_heads
-        self.p_dropout = p_dropout
-        self.window_size = window_size
-        self.heads_share = window_heads_share
-        self.proximal_init = proximal_init
-        self.proximal_bias = proximal_bias
-
-        self.k_channels = channels // n_heads
-        self.qkv = nn.Linear(channels, channels * 3)
-        self.drop = nn.Dropout(p_dropout)
-
-        if window_size is not None:
-            n_heads_rel = 1 if window_heads_share else n_heads
-            rel_stddev = self.k_channels**-0.5
-            self.emb_rel_k = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-            self.emb_rel_v = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-
-        nn.init.xavier_uniform_(self.qkv.weight)
-
-        if proximal_init:
-            with torch.no_grad():
-                # Sync qk weights
-                self.qkv.weight.data[: self.channels] = self.qkv.weight.data[
-                    self.channels : self.channels * 2
-                ]
-                self.qkv.bias.data[: self.channels] = self.qkv.bias.data[
-                    self.channels : self.channels * 2
-                ]
-
-    def forward(self, x, key_padding_mask=None):
-        # x: (batch, seq_len, channels)
-        batch_size, seq_len, _ = x.size()
-        qkv = (
-            self.qkv(x)
-            .reshape(batch_size, seq_len, 3, self.n_heads, self.k_channels)
-            .permute(2, 0, 3, 1, 4)
-        )
-        query, key, value = torch.unbind(qkv, dim=0)
-
-        scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
-
-        if self.window_size is not None:
-            key_relative_embeddings = self._get_relative_embeddings(
-                self.emb_rel_k, seq_len
-            )
-            rel_logits = self._matmul_with_relative_keys(
-                query / math.sqrt(self.k_channels), key_relative_embeddings
-            )
-            scores_local = self._relative_position_to_absolute_position(rel_logits)
-            scores = scores + scores_local
-
-        if self.proximal_bias:
-            scores = scores + self._attention_bias_proximal(seq_len).to(
-                device=scores.device, dtype=scores.dtype
-            )
-
-        # key_padding_mask: (batch, seq_len)
-        if key_padding_mask is not None:
-            assert key_padding_mask.size() == (
-                batch_size,
-                seq_len,
-            ), f"key_padding_mask shape {key_padding_mask.size()} does not match x shape {x.size()}"
-            assert (
-                key_padding_mask.dtype == torch.bool
-            ), f"key_padding_mask dtype {key_padding_mask.dtype} is not bool"
-
-            key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
-                -1, self.n_heads, -1, -1
-            )
-            scores = scores.masked_fill(key_padding_mask, float("-inf"))
-
-        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
-        p_attn = self.drop(p_attn)
-        output = torch.matmul(p_attn, value)
-
-        if self.window_size is not None:
-            relative_weights = self._absolute_position_to_relative_position(p_attn)
-            value_relative_embeddings = self._get_relative_embeddings(
-                self.emb_rel_v, seq_len
-            )
-            output = output + self._matmul_with_relative_values(
-                relative_weights, value_relative_embeddings
-            )
-
-        return output.reshape(batch_size, seq_len, self.n_heads * self.k_channels)
-
-    def _matmul_with_relative_values(self, x, y):
-        """
-        x: [b, h, l, m]
-        y: [h or 1, m, d]
-        ret: [b, h, l, d]
-        """
-        ret = torch.matmul(x, y.unsqueeze(0))
-        return ret
-
-    def _matmul_with_relative_keys(self, x, y):
-        """
-        x: [b, h, l, d]
-        y: [h or 1, m, d]
-        ret: [b, h, l, m]
-        """
-        ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
-        return ret
-
-    def _get_relative_embeddings(self, relative_embeddings, length):
-        max_relative_position = 2 * self.window_size + 1
-        # Pad first before slice to avoid using cond ops.
-        pad_length = max(length - (self.window_size + 1), 0)
-        slice_start_position = max((self.window_size + 1) - length, 0)
-        slice_end_position = slice_start_position + 2 * length - 1
-        if pad_length > 0:
-            padded_relative_embeddings = F.pad(
-                relative_embeddings,
-                convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
-            )
-        else:
-            padded_relative_embeddings = relative_embeddings
-        used_relative_embeddings = padded_relative_embeddings[
-            :, slice_start_position:slice_end_position
-        ]
-        return used_relative_embeddings
-
-    def _relative_position_to_absolute_position(self, x):
-        """
-        x: [b, h, l, 2*l-1]
-        ret: [b, h, l, l]
-        """
-        batch, heads, length, _ = x.size()
-        # Concat columns of pad to shift from relative to absolute indexing.
-        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
-
-        # Concat extra elements so to add up to shape (len+1, 2*len-1).
-        x_flat = x.view([batch, heads, length * 2 * length])
-        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
-
-        # Reshape and slice out the padded elements.
-        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
-            :, :, :length, length - 1 :
-        ]
-        return x_final
-
-    def _absolute_position_to_relative_position(self, x):
-        """
-        x: [b, h, l, l]
-        ret: [b, h, l, 2*l-1]
-        """
-        batch, heads, length, _ = x.size()
-        # pad along column
-        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
-        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
-        # add 0's in the beginning that will skew the elements after reshape
-        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
-        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
-        return x_final
-
-    def _attention_bias_proximal(self, length):
-        """Bias for self-attention to encourage attention to close positions.
-        Args:
-          length: an integer scalar.
-        Returns:
-          a Tensor with shape [1, 1, length, length]
-        """
-        r = torch.arange(length, dtype=torch.float32)
-        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
-        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
-
-
-class ResBlock1(nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
-        super(ResBlock1, self).__init__()
-        self.convs1 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[2],
-                        padding=get_padding(kernel_size, dilation[2]),
-                    )
-                ),
-            ]
-        )
-        self.convs1.apply(init_weights)
-
-        self.convs2 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-            ]
-        )
-        self.convs2.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c1, c2 in zip(self.convs1, self.convs2):
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c1(xt)
-            xt = F.leaky_relu(xt, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c2(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_parametrizations(self):
-        for l in self.convs1:
-            remove_parametrizations(l)
-        for l in self.convs2:
-            remove_parametrizations(l)
-
-
-class ResBlock2(nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
-        super(ResBlock2, self).__init__()
-        self.convs = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-            ]
-        )
-        self.convs.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c in self.convs:
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_parametrizations(self):
-        for l in self.convs:
-            remove_parametrizations(l)
-
-
-class Generator(nn.Module):
-    def __init__(
-        self,
-        initial_channel,
-        resblock,
-        resblock_kernel_sizes,
-        resblock_dilation_sizes,
-        upsample_rates,
-        upsample_initial_channel,
-        upsample_kernel_sizes,
-        gin_channels,
-    ):
-        super(Generator, self).__init__()
-        self.num_kernels = len(resblock_kernel_sizes)
-        self.num_upsamples = len(upsample_rates)
-        self.conv_pre = Conv1d(
-            initial_channel, upsample_initial_channel, 7, 1, padding=3
-        )
-        resblock = ResBlock1 if resblock == "1" else ResBlock2
-
-        self.ups = nn.ModuleList()
-        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
-            self.ups.append(
-                weight_norm(
-                    ConvTranspose1d(
-                        upsample_initial_channel // (2**i),
-                        upsample_initial_channel // (2 ** (i + 1)),
-                        k,
-                        u,
-                        padding=(k - u) // 2,
-                    )
-                )
-            )
-
-        self.resblocks = nn.ModuleList()
-        for i in range(len(self.ups)):
-            ch = upsample_initial_channel // (2 ** (i + 1))
-            for j, (k, d) in enumerate(
-                zip(resblock_kernel_sizes, resblock_dilation_sizes)
-            ):
-                self.resblocks.append(resblock(ch, k, d))
-
-        self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
-        self.ups.apply(init_weights)
-
-        self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
-
-    def forward(self, x, g=None):
-        x = self.conv_pre(x)
-        if g is not None:
-            g = self.cond(g)
-
-        for i in range(self.num_upsamples):
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            x = self.ups[i](x)
-            xs = None
-            for j in range(self.num_kernels):
-                if xs is None:
-                    xs = self.resblocks[i * self.num_kernels + j](x)
-                else:
-                    xs += self.resblocks[i * self.num_kernels + j](x)
-            x = xs / self.num_kernels
-        x = F.leaky_relu(x)
-        x = self.conv_post(x)
-        x = torch.tanh(x)
-
-        return x
-
-    def remove_parametrizations(self):
-        print("Removing weight norm...")
-        for l in self.ups:
-            remove_parametrizations(l)
-        for l in self.resblocks:
-            l.remove_parametrizations()
-
-
-class DiscriminatorP(nn.Module):
-    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
-        super(DiscriminatorP, self).__init__()
-        self.period = period
-        self.use_spectral_norm = use_spectral_norm
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(
-                    Conv2d(
-                        1,
-                        32,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        32,
-                        128,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        128,
-                        512,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        512,
-                        1024,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        1024,
-                        1024,
-                        (kernel_size, 1),
-                        1,
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-            ]
-        )
-        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
-
-    def forward(self, x):
-        fmap = []
-
-        # 1d to 2d
-        b, c, t = x.shape
-        if t % self.period != 0:  # pad first
-            n_pad = self.period - (t % self.period)
-            x = F.pad(x, (0, n_pad), "reflect")
-            t = t + n_pad
-        x = x.view(b, c, t // self.period, self.period)
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class DiscriminatorS(nn.Module):
-    def __init__(self, use_spectral_norm=False):
-        super(DiscriminatorS, self).__init__()
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(Conv1d(1, 16, 15, 1, padding=7)),
-                norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
-                norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
-                norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
-                norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
-                norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
-            ]
-        )
-        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
-
-    def forward(self, x):
-        fmap = []
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class EnsembleDiscriminator(nn.Module):
-    def __init__(self, use_spectral_norm=False):
-        super(EnsembleDiscriminator, self).__init__()
-        periods = [2, 3, 5, 7, 11]
-
-        discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
-        discs = discs + [
-            DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
-        ]
-        self.discriminators = nn.ModuleList(discs)
-
-    def forward(self, y, y_hat):
-        y_d_rs = []
-        y_d_gs = []
-        fmap_rs = []
-        fmap_gs = []
-        for i, d in enumerate(self.discriminators):
-            y_d_r, fmap_r = d(y)
-            y_d_g, fmap_g = d(y_hat)
-            y_d_rs.append(y_d_r)
-            y_d_gs.append(y_d_g)
-            fmap_rs.append(fmap_r)
-            fmap_gs.append(fmap_g)
-
-        return y_d_rs, y_d_gs, fmap_rs, fmap_gs

+ 37 - 0
fish_speech/models/vqgan/modules/condition.py

@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+
+
+class MultiCondLayer(nn.Module):
+    def __init__(
+        self,
+        gin_channels: int,
+        out_channels: int,
+        n_cond: int,
+    ):
+        """MultiCondLayer of VITS model.
+
+        Args:
+            gin_channels (int): Number of conditioning tensor channels.
+            out_channels (int): Number of output tensor channels.
+            n_cond (int): Number of conditions.
+        """
+        super().__init__()
+        self.n_cond = n_cond
+
+        self.cond_layers = nn.ModuleList()
+        for _ in range(n_cond):
+            self.cond_layers.append(nn.Linear(gin_channels, out_channels))
+
+    def forward(self, cond: torch.Tensor, x_mask: torch.Tensor):
+        """
+        Shapes:
+            - cond: :math:`[B, C, N]`
+            - x_mask: :math`[B, 1, T]`
+        """
+
+        cond_out = torch.zeros_like(cond)
+        for i in range(self.n_cond):
+            cond_in = self.cond_layers[i](cond.mT).mT
+            cond_out = cond_out + cond_in
+        return cond_out * x_mask

+ 227 - 0
fish_speech/models/vqgan/modules/decoder.py

@@ -0,0 +1,227 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
+
+from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
+from fish_speech.models.vqgan.utils import get_padding, init_weights
+
+
+class Generator(nn.Module):
+    def __init__(
+        self,
+        initial_channel,
+        resblock,
+        resblock_kernel_sizes,
+        resblock_dilation_sizes,
+        upsample_rates,
+        upsample_initial_channel,
+        upsample_kernel_sizes,
+        gin_channels=0,
+    ):
+        super(Generator, self).__init__()
+        self.num_kernels = len(resblock_kernel_sizes)
+        self.num_upsamples = len(upsample_rates)
+        self.conv_pre = nn.Conv1d(
+            initial_channel, upsample_initial_channel, 7, 1, padding=3
+        )
+        resblock = ResBlock1 if resblock == "1" else ResBlock2
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            self.ups.append(
+                weight_norm(
+                    nn.ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            for j, (k, d) in enumerate(
+                zip(resblock_kernel_sizes, resblock_dilation_sizes)
+            ):
+                self.resblocks.append(resblock(ch, k, d))
+
+        self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+        self.ups.apply(init_weights)
+
+        if gin_channels != 0:
+            self.cond = nn.Linear(gin_channels, upsample_initial_channel)
+
+    def forward(self, x, g=None):
+        x = self.conv_pre(x)
+        if g is not None:
+            x = x + self.cond(g.mT).mT
+
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.ups[i](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        print("Removing weight norm...")
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+
+
+class ResBlock1(nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock1, self).__init__()
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x, x_mask=None):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c2(xt)
+            x = xt + x
+        if x_mask is not None:
+            x = x * x_mask
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class ResBlock2(nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+        super(ResBlock2, self).__init__()
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+            ]
+        )
+        self.convs.apply(init_weights)
+
+    def forward(self, x, x_mask=None):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c(xt)
+            x = xt + x
+        if x_mask is not None:
+            x = x * x_mask
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)

+ 142 - 0
fish_speech/models/vqgan/modules/discriminator.py

@@ -0,0 +1,142 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.parametrizations import spectral_norm, weight_norm
+
+from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
+from fish_speech.models.vqgan.utils import get_padding
+
+
+class DiscriminatorP(nn.Module):
+    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+        super(DiscriminatorP, self).__init__()
+        self.period = period
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(
+                    nn.Conv2d(
+                        1,
+                        32,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    nn.Conv2d(
+                        32,
+                        128,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    nn.Conv2d(
+                        128,
+                        512,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    nn.Conv2d(
+                        512,
+                        1024,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    nn.Conv2d(
+                        1024,
+                        1024,
+                        (kernel_size, 1),
+                        1,
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+            ]
+        )
+        self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+    def forward(self, x):
+        fmap = []
+
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), "reflect")
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class DiscriminatorS(nn.Module):
+    def __init__(self, use_spectral_norm=False):
+        super(DiscriminatorS, self).__init__()
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)),
+                norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
+                norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
+                norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
+                norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+                norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
+            ]
+        )
+        self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
+
+    def forward(self, x):
+        fmap = []
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class EnsembleDiscriminator(nn.Module):
+    def __init__(self, use_spectral_norm=False):
+        super(EnsembleDiscriminator, self).__init__()
+        periods = [2, 3, 5, 7, 11]  # [1, 2, 3, 5, 7, 11]
+
+        discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
+        discs = discs + [
+            DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
+        ]
+        self.discriminators = nn.ModuleList(discs)
+
+    def forward(self, y, y_hat):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+        for i, d in enumerate(self.discriminators):
+            y_d_r, fmap_r = d(y)
+            y_d_g, fmap_g = d(y_hat)
+            y_d_rs.append(y_d_r)
+            y_d_gs.append(y_d_g)
+            fmap_rs.append(fmap_r)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs

+ 207 - 0
fish_speech/models/vqgan/modules/encoders.py

@@ -0,0 +1,207 @@
+import math
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+
+from fish_speech.models.vqgan.modules.modules import WN
+from fish_speech.models.vqgan.modules.transformer import RelativePositionTransformer
+from fish_speech.models.vqgan.utils import sequence_mask
+
+
+# * Ready and Tested
+class TextEncoder(nn.Module):
+    def __init__(
+        self,
+        n_vocab: int,
+        out_channels: int,
+        hidden_channels: int,
+        hidden_channels_ffn: int,
+        n_heads: int,
+        n_layers: int,
+        kernel_size: int,
+        dropout: float,
+        gin_channels=0,
+        lang_channels=0,
+        speaker_cond_layer=0,
+    ):
+        """Text Encoder for VITS model.
+
+        Args:
+            n_vocab (int): Number of characters for the embedding layer.
+            out_channels (int): Number of channels for the output.
+            hidden_channels (int): Number of channels for the hidden layers.
+            hidden_channels_ffn (int): Number of channels for the convolutional layers.
+            n_heads (int): Number of attention heads for the Transformer layers.
+            n_layers (int): Number of Transformer layers.
+            kernel_size (int): Kernel size for the FFN layers in Transformer network.
+            dropout (float): Dropout rate for the Transformer layers.
+            gin_channels (int, optional): Number of channels for speaker embedding. Defaults to 0.
+            lang_channels (int, optional): Number of channels for language embedding. Defaults to 0.
+        """
+        super().__init__()
+        self.out_channels = out_channels
+        self.hidden_channels = hidden_channels
+
+        # self.emb = nn.Linear(n_vocab, hidden_channels)
+        self.emb = nn.Linear(n_vocab, hidden_channels, 1)
+        # nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
+
+        self.encoder = RelativePositionTransformer(
+            in_channels=hidden_channels,
+            out_channels=hidden_channels,
+            hidden_channels=hidden_channels,
+            hidden_channels_ffn=hidden_channels_ffn,
+            n_heads=n_heads,
+            n_layers=n_layers,
+            kernel_size=kernel_size,
+            dropout=dropout,
+            window_size=4,
+            gin_channels=gin_channels,
+            lang_channels=lang_channels,
+            speaker_cond_layer=speaker_cond_layer,
+        )
+        self.proj = nn.Linear(hidden_channels, out_channels * 2)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_lengths: torch.Tensor,
+        g: torch.Tensor = None,
+        lang: torch.Tensor = None,
+    ):
+        """
+        Shapes:
+            - x: :math:`[B, T]`
+            - x_length: :math:`[B]`
+        """
+        # x = self.emb(x).mT * math.sqrt(self.hidden_channels)  # [b, h, t]
+        x = self.emb(x).mT  # * math.sqrt(self.hidden_channels)  # [b, h, t]
+        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+
+        x = self.encoder(x, x_mask, g=g, lang=lang)
+        stats = self.proj(x.mT).mT * x_mask
+
+        m, logs = torch.split(stats, self.out_channels, dim=1)
+        z = m + torch.randn_like(m) * torch.exp(logs) * x_mask
+        return z, m, logs, x, x_mask
+
+
+# * Ready and Tested
+class PosteriorEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        hidden_channels: int,
+        kernel_size: int,
+        dilation_rate: int,
+        n_layers: int,
+        gin_channels=0,
+    ):
+        """Posterior Encoder of VITS model.
+
+        ::
+            x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
+
+        Args:
+            in_channels (int): Number of input tensor channels.
+            out_channels (int): Number of output tensor channels.
+            hidden_channels (int): Number of hidden channels.
+            kernel_size (int): Kernel size of the WaveNet convolution layers.
+            dilation_rate (int): Dilation rate of the WaveNet layers.
+            num_layers (int): Number of the WaveNet layers.
+            cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
+        """
+        super().__init__()
+        self.out_channels = out_channels
+
+        self.pre = nn.Linear(in_channels, hidden_channels)
+        self.enc = WN(
+            hidden_channels,
+            kernel_size,
+            dilation_rate,
+            n_layers,
+            gin_channels=gin_channels,
+        )
+        self.proj = nn.Linear(hidden_channels, out_channels * 2)
+
+    def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g=None):
+        """
+        Shapes:
+            - x: :math:`[B, C, T]`
+            - x_lengths: :math:`[B, 1]`
+            - g: :math:`[B, C, 1]`
+        """
+        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+        x = self.pre(x.mT).mT * x_mask
+        x = self.enc(x, x_mask, g=g)
+        stats = self.proj(x.mT).mT * x_mask
+        m, logs = torch.split(stats, self.out_channels, dim=1)
+        z = m + torch.randn_like(m) * torch.exp(logs) * x_mask
+        return z, m, logs, x_mask
+
+
+# TODO: Ready for testing
+class SpeakerEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels: int = 128,
+        hidden_channels: int = 192,
+        out_channels: int = 512,
+        num_heads: int = 2,
+        num_layers: int = 4,
+        p_dropout: float = 0.0,
+    ) -> None:
+        super().__init__()
+
+        self.query = nn.Parameter(torch.randn(1, 1, hidden_channels))
+        self.in_proj = nn.Sequential(
+            nn.Conv1d(in_channels, hidden_channels, 1),
+            nn.SiLU(),
+            nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
+            nn.SiLU(),
+            nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
+            nn.SiLU(),
+            nn.Dropout(p_dropout),
+        )
+
+        self.blocks = nn.ModuleList(
+            [
+                nn.MultiheadAttention(
+                    embed_dim=hidden_channels,
+                    num_heads=num_heads,
+                    dropout=p_dropout,
+                    batch_first=True,
+                )
+                for _ in range(num_layers)
+            ]
+        )
+        self.out_proj = nn.Linear(hidden_channels, out_channels)
+
+    def forward(self, mels, mel_lengths: torch.Tensor):
+        """
+        Shapes:
+            - x: :math:`[B, C, T]`
+            - x_lengths: :math:`[B, 1]`
+        """
+
+        x_mask = ~(sequence_mask(mel_lengths, mels.size(2)).bool())
+
+        x = self.in_proj(mels).transpose(1, 2)
+        x = torch.cat([self.query.expand(x.shape[0], -1, -1), x], dim=1)
+
+        x_mask = torch.cat(
+            [
+                torch.zeros(x.shape[0], 1, dtype=torch.bool, device=x.device),
+                x_mask,
+            ],
+            dim=1,
+        )
+
+        for block in self.blocks:
+            x = block(x, x, x, key_padding_mask=x_mask)[0]
+
+        x = self.out_proj(x[:, 0])
+
+        return x

+ 115 - 0
fish_speech/models/vqgan/modules/models.py

@@ -0,0 +1,115 @@
+import torch
+from torch import nn
+
+from fish_speech.models.vqgan.modules.decoder import Generator
+from fish_speech.models.vqgan.modules.encoders import (
+    PosteriorEncoder,
+    SpeakerEncoder,
+    TextEncoder,
+)
+from fish_speech.models.vqgan.utils import rand_slice_segments
+
+
+class SynthesizerTrn(nn.Module):
+    """
+    Synthesizer for Training
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        spec_channels,
+        segment_size,
+        inter_channels,
+        hidden_channels,
+        filter_channels,
+        n_heads,
+        n_layers,
+        n_layers_q,
+        n_layers_spk,
+        kernel_size,
+        p_dropout,
+        speaker_cond_layer,
+        resblock,
+        resblock_kernel_sizes,
+        resblock_dilation_sizes,
+        upsample_rates,
+        upsample_initial_channel,
+        upsample_kernel_sizes,
+        gin_channels,
+    ):
+        super().__init__()
+
+        self.segment_size = segment_size
+
+        self.enc_p = TextEncoder(
+            in_channels,
+            inter_channels,
+            hidden_channels,
+            filter_channels,
+            n_heads,
+            n_layers,
+            kernel_size,
+            p_dropout,
+            gin_channels=gin_channels,
+            speaker_cond_layer=speaker_cond_layer,
+        )
+        self.enc_spk = SpeakerEncoder(
+            in_channels=spec_channels,
+            hidden_channels=inter_channels,
+            out_channels=gin_channels,
+            num_heads=n_heads,
+            num_layers=n_layers_spk,
+            p_dropout=p_dropout,
+        )
+        self.enc_q = PosteriorEncoder(
+            spec_channels,
+            inter_channels,
+            hidden_channels,
+            5,
+            1,
+            n_layers_q,
+            gin_channels=gin_channels,
+        )
+        self.dec = Generator(
+            inter_channels,
+            resblock,
+            resblock_kernel_sizes,
+            resblock_dilation_sizes,
+            upsample_rates,
+            upsample_initial_channel,
+            upsample_kernel_sizes,
+            gin_channels=gin_channels,
+        )
+
+    def forward(self, x, x_lengths, y):
+        g = self.enc_spk(y, x_lengths)
+        z_p, m_p, logs_p, _, x_mask = self.enc_p(x, x_lengths, g=g)
+        z_q, m_q, logs_q, y_mask = self.enc_q(y, x_lengths, g=g)
+        z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
+        o = self.dec(z_slice, g=g)
+
+        return (
+            o,
+            ids_slice,
+            x_mask,
+            y_mask,
+            (z_q, z_p),
+            (m_p, logs_p),
+            (m_q, logs_q),
+        )
+
+    def infer(self, x, x_lengths, y, max_len=None):
+        g = self.enc_spk(y, x_lengths)
+        z_p, m_p, logs_p, h_text, x_mask = self.enc_p(x, x_lengths, g=g)
+        # z_p_audio, m_p_audio, logs_p_audio = self.flow(z_p_text, m_p_text, logs_p_text, x_mask, g=g, reverse=True)
+
+        o = self.dec((z_p * x_mask)[:, :, :max_len], g=g)
+        return o
+
+    def reconstruct(self, x, x_lengths, max_len=None):
+        g = self.enc_spk(x, x_lengths)
+        z_q, m_q, logs_q, x_mask = self.enc_q(x, x_lengths, g=g)
+        o = self.dec((z_q * x_mask)[:, :, :max_len], g=g)
+
+        return o

+ 105 - 0
fish_speech/models/vqgan/modules/modules.py

@@ -0,0 +1,105 @@
+import torch
+import torch.nn as nn
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+
+from fish_speech.models.vqgan.utils import fused_add_tanh_sigmoid_multiply
+
+LRELU_SLOPE = 0.1
+
+
+# ! PosteriorEncoder
+# ! ResidualCouplingLayer
+class WN(nn.Module):
+    def __init__(
+        self,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        gin_channels=0,
+        p_dropout=0,
+    ):
+        super(WN, self).__init__()
+        assert kernel_size % 2 == 1
+        self.hidden_channels = hidden_channels
+        self.kernel_size = (kernel_size,)
+        self.n_layers = n_layers
+        self.gin_channels = gin_channels
+
+        self.in_layers = nn.ModuleList()
+        self.res_skip_layers = nn.ModuleList()
+        self.drop = nn.Dropout(p_dropout)
+
+        if gin_channels != 0:
+            cond_layer = nn.Linear(gin_channels, 2 * hidden_channels * n_layers)
+            self.cond_layer = weight_norm(cond_layer, name="weight")
+
+        for i in range(n_layers):
+            dilation = dilation_rate**i
+            padding = int((kernel_size * dilation - dilation) / 2)
+            in_layer = nn.Conv1d(
+                hidden_channels,
+                2 * hidden_channels,
+                kernel_size,
+                dilation=dilation,
+                padding=padding,
+            )
+            in_layer = weight_norm(in_layer, name="weight")
+            self.in_layers.append(in_layer)
+
+            # last one is not necessary
+            res_skip_channels = (
+                2 * hidden_channels if i < n_layers - 1 else hidden_channels
+            )
+            res_skip_layer = nn.Linear(hidden_channels, res_skip_channels)
+            res_skip_layer = weight_norm(res_skip_layer, name="weight")
+            self.res_skip_layers.append(res_skip_layer)
+
+    def forward(self, x, x_mask, g=None, **kwargs):
+        output = torch.zeros_like(x)
+        n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+        if g is not None:
+            g = self.cond_layer(g.mT).mT
+
+        for i in range(self.n_layers):
+            x_in = self.in_layers[i](x)
+            if g is not None:
+                cond_offset = i * 2 * self.hidden_channels
+                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
+            else:
+                g_l = torch.zeros_like(x_in)
+
+            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+            acts = self.drop(acts)
+
+            res_skip_acts = self.res_skip_layers[i](acts.mT).mT
+            if i < self.n_layers - 1:
+                res_acts = res_skip_acts[:, : self.hidden_channels, :]
+                x = (x + res_acts) * x_mask
+                output = output + res_skip_acts[:, self.hidden_channels :, :]
+            else:
+                output = output + res_skip_acts
+        return output * x_mask
+
+    def remove_weight_norm(self):
+        if self.gin_channels != 0:
+            remove_parametrizations(self.cond_layer)
+        for l in self.in_layers:
+            remove_parametrizations(l)
+        for l in self.res_skip_layers:
+            remove_parametrizations(l)
+
+
+# ! StochasticDurationPredictor
+# ! ResidualCouplingBlock
+# TODO convert to class method
+class Flip(nn.Module):
+    def forward(self, x, *args, reverse=False, **kwargs):
+        x = torch.flip(x, [1])
+        if not reverse:
+            logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+            return x, logdet
+        else:
+            return x

+ 34 - 0
fish_speech/models/vqgan/modules/normalization.py

@@ -0,0 +1,34 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class LayerNorm(nn.Module):
+    def __init__(self, channels, eps=1e-5):
+        super().__init__()
+        self.channels = channels
+        self.eps = eps
+
+        self.gamma = nn.Parameter(torch.ones(channels))
+        self.beta = nn.Parameter(torch.zeros(channels))
+
+    def forward(self, x: torch.Tensor):
+        x = F.layer_norm(x.mT, (self.channels,), self.gamma, self.beta, self.eps)
+        return x.mT
+
+
+class CondLayerNorm(nn.Module):
+    def __init__(self, channels, eps=1e-5, cond_channels=0):
+        super().__init__()
+        self.channels = channels
+        self.eps = eps
+
+        self.linear_gamma = nn.Linear(cond_channels, channels)
+        self.linear_beta = nn.Linear(cond_channels, channels)
+
+    def forward(self, x: torch.Tensor, cond: torch.Tensor):
+        gamma = self.linear_gamma(cond)
+        beta = self.linear_beta(cond)
+
+        x = F.layer_norm(x.mT, (self.channels,), gamma, beta, self.eps)
+        return x.mT

+ 324 - 0
fish_speech/models/vqgan/modules/transformer.py

@@ -0,0 +1,324 @@
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from fish_speech.models.vqgan.modules.normalization import LayerNorm
+from fish_speech.models.vqgan.utils import convert_pad_shape
+
+
+# TODO add conditioning on language
+# TODO check whether we need to stop gradient for speaker embedding
+class RelativePositionTransformer(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        hidden_channels: int,
+        out_channels: int,
+        hidden_channels_ffn: int,
+        n_heads: int,
+        n_layers: int,
+        kernel_size=1,
+        dropout=0.0,
+        window_size=4,
+        gin_channels=0,
+        lang_channels=0,
+        speaker_cond_layer=0,
+    ):
+        super().__init__()
+        self.n_layers = n_layers
+        self.speaker_cond_layer = speaker_cond_layer
+
+        self.drop = nn.Dropout(dropout)
+        self.attn_layers = nn.ModuleList()
+        self.norm_layers_1 = nn.ModuleList()
+        self.ffn_layers = nn.ModuleList()
+        self.norm_layers_2 = nn.ModuleList()
+        for i in range(self.n_layers):
+            self.attn_layers.append(
+                MultiHeadAttention(
+                    hidden_channels if i != 0 else in_channels,
+                    hidden_channels,
+                    n_heads,
+                    p_dropout=dropout,
+                    window_size=window_size,
+                )
+            )
+            self.norm_layers_1.append(LayerNorm(hidden_channels))
+            self.ffn_layers.append(
+                FFN(
+                    hidden_channels,
+                    hidden_channels,
+                    hidden_channels_ffn,
+                    kernel_size,
+                    p_dropout=dropout,
+                )
+            )
+            self.norm_layers_2.append(LayerNorm(hidden_channels))
+        if gin_channels != 0:
+            self.cond = nn.Linear(gin_channels, hidden_channels)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_mask: torch.Tensor,
+        g: torch.Tensor = None,
+        lang: torch.Tensor = None,
+    ):
+        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+        x = x * x_mask
+        for i in range(self.n_layers):
+            # TODO consider using other conditioning
+            # TODO https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/modules/attentions.py#L12
+            if i == self.speaker_cond_layer - 1 and g is not None:
+                # ! g = torch.detach(g)
+                x = x + self.cond(g.mT).mT
+                x = x * x_mask
+            y = self.attn_layers[i](x, x, attn_mask)
+            y = self.drop(y)
+            x = self.norm_layers_1[i](x + y)
+
+            y = self.ffn_layers[i](x, x_mask)
+            y = self.drop(y)
+            x = self.norm_layers_2[i](x + y)
+        x = x * x_mask
+        return x
+
+
+class MultiHeadAttention(nn.Module):
+    def __init__(
+        self,
+        channels,
+        out_channels,
+        n_heads,
+        p_dropout=0.0,
+        window_size=None,
+        heads_share=True,
+        block_length=None,
+        proximal_bias=False,
+        proximal_init=False,
+    ):
+        super().__init__()
+        assert channels % n_heads == 0
+
+        self.channels = channels
+        self.out_channels = out_channels
+        self.n_heads = n_heads
+        self.p_dropout = p_dropout
+        self.window_size = window_size
+        self.heads_share = heads_share
+        self.block_length = block_length
+        self.proximal_bias = proximal_bias
+        self.proximal_init = proximal_init
+        self.attn = None
+
+        self.k_channels = channels // n_heads
+        self.conv_q = nn.Linear(channels, channels)
+        self.conv_k = nn.Linear(channels, channels)
+        self.conv_v = nn.Linear(channels, channels)
+        self.conv_o = nn.Linear(channels, out_channels)
+        self.drop = nn.Dropout(p_dropout)
+
+        if window_size is not None:
+            n_heads_rel = 1 if heads_share else n_heads
+            rel_stddev = self.k_channels**-0.5
+            self.emb_rel_k = nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+            self.emb_rel_v = nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+
+        nn.init.xavier_uniform_(self.conv_q.weight)
+        nn.init.xavier_uniform_(self.conv_k.weight)
+        nn.init.xavier_uniform_(self.conv_v.weight)
+        if proximal_init:
+            with torch.no_grad():
+                self.conv_k.weight.copy_(self.conv_q.weight)
+                self.conv_k.bias.copy_(self.conv_q.bias)
+
+    def forward(self, x, c, attn_mask=None):
+        q = self.conv_q(x.mT).mT
+        k = self.conv_k(c.mT).mT
+        v = self.conv_v(c.mT).mT
+
+        x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+        x = self.conv_o(x.mT).mT
+        return x
+
+    def attention(self, query, key, value, mask=None):
+        # reshape [b, d, t] -> [b, n_h, t, d_k]
+        b, d, t_s, t_t = (*key.size(), query.size(2))
+        query = query.view(b, self.n_heads, self.k_channels, t_t).mT
+        key = key.view(b, self.n_heads, self.k_channels, t_s).mT
+        value = value.view(b, self.n_heads, self.k_channels, t_s).mT
+
+        scores = torch.matmul(query / math.sqrt(self.k_channels), key.mT)
+        if self.window_size is not None:
+            assert (
+                t_s == t_t
+            ), "Relative attention is only available for self-attention."
+            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+            rel_logits = self._matmul_with_relative_keys(
+                query / math.sqrt(self.k_channels), key_relative_embeddings
+            )
+            scores_local = self._relative_position_to_absolute_position(rel_logits)
+            scores = scores + scores_local
+        if self.proximal_bias:
+            assert t_s == t_t, "Proximal bias is only available for self-attention."
+            scores = scores + self._attention_bias_proximal(t_s).to(
+                device=scores.device, dtype=scores.dtype
+            )
+        if mask is not None:
+            scores = scores.masked_fill(mask == 0, -1e4)
+            if self.block_length is not None:
+                assert (
+                    t_s == t_t
+                ), "Local attention is only available for self-attention."
+                block_mask = (
+                    torch.ones_like(scores)
+                    .triu(-self.block_length)
+                    .tril(self.block_length)
+                )
+                scores = scores.masked_fill(block_mask == 0, -1e4)
+        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
+        p_attn = self.drop(p_attn)
+        output = torch.matmul(p_attn, value)
+        if self.window_size is not None:
+            relative_weights = self._absolute_position_to_relative_position(p_attn)
+            value_relative_embeddings = self._get_relative_embeddings(
+                self.emb_rel_v, t_s
+            )
+            output = output + self._matmul_with_relative_values(
+                relative_weights, value_relative_embeddings
+            )
+        output = output.mT.contiguous().view(
+            b, d, t_t
+        )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
+        return output, p_attn
+
+    def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor):
+        """
+        x: [b, h, l, m]
+        y: [h or 1, m, d]
+        ret: [b, h, l, d]
+        """
+        return torch.matmul(x, y.unsqueeze(0))
+
+    def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor):
+        """
+        x: [b, h, l, d]
+        y: [h or 1, m, d]
+        ret: [b, h, l, m]
+        """
+        return torch.matmul(x, y.unsqueeze(0).mT)
+
+    def _get_relative_embeddings(self, relative_embeddings, length):
+        max_relative_position = 2 * self.window_size + 1
+        # Pad first before slice to avoid using cond ops.
+        pad_length = max(length - (self.window_size + 1), 0)
+        slice_start_position = max((self.window_size + 1) - length, 0)
+        slice_end_position = slice_start_position + 2 * length - 1
+        if pad_length > 0:
+            padded_relative_embeddings = F.pad(
+                relative_embeddings,
+                convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+            )
+        else:
+            padded_relative_embeddings = relative_embeddings
+        used_relative_embeddings = padded_relative_embeddings[
+            :, slice_start_position:slice_end_position
+        ]
+        return used_relative_embeddings
+
+    def _relative_position_to_absolute_position(self, x):
+        """
+        x: [b, h, l, 2*l-1]
+        ret: [b, h, l, l]
+        """
+        batch, heads, length, _ = x.size()
+        # Concat columns of pad to shift from relative to absolute indexing.
+        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+        # Concat extra elements so to add up to shape (len+1, 2*len-1).
+        x_flat = x.view([batch, heads, length * 2 * length])
+        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
+
+        # Reshape and slice out the padded elements.
+        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+            :, :, :length, length - 1 :
+        ]
+        return x_final
+
+    def _absolute_position_to_relative_position(self, x):
+        """
+        x: [b, h, l, l]
+        ret: [b, h, l, 2*l-1]
+        """
+        batch, heads, length, _ = x.size()
+        # pad along column
+        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
+        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+        # add 0's in the beginning that will skew the elements after reshape
+        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+        return x_final
+
+    def _attention_bias_proximal(self, length):
+        """Bias for self-attention to encourage attention to close positions.
+        Args:
+          length: an integer scalar.
+        Returns:
+          a Tensor with shape [1, 1, length, length]
+        """
+        r = torch.arange(length, dtype=torch.float32)
+        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        filter_channels,
+        kernel_size,
+        p_dropout=0.0,
+        causal=False,
+    ):
+        super().__init__()
+        self.kernel_size = kernel_size
+        self.padding = self._causal_padding if causal else self._same_padding
+
+        self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
+        self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
+        self.drop = nn.Dropout(p_dropout)
+
+    def forward(self, x, x_mask):
+        x = self.conv_1(self.padding(x * x_mask))
+        x = torch.relu(x)
+        x = self.drop(x)
+        x = self.conv_2(self.padding(x * x_mask))
+        return x * x_mask
+
+    def _causal_padding(self, x):
+        if self.kernel_size == 1:
+            return x
+        pad_l = self.kernel_size - 1
+        pad_r = 0
+        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+        x = F.pad(x, convert_pad_shape(padding))
+        return x
+
+    def _same_padding(self, x):
+        if self.kernel_size == 1:
+            return x
+        pad_l = (self.kernel_size - 1) // 2
+        pad_r = self.kernel_size // 2
+        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+        x = F.pad(x, convert_pad_shape(padding))
+        return x

+ 21 - 1
fish_speech/models/vqgan/utils.py

@@ -15,7 +15,7 @@ def sequence_mask(length, max_length=None):
     if max_length is None:
         max_length = length.max()
     x = torch.arange(max_length, dtype=length.dtype, device=length.device)
-    return x.unsqueeze(0) >= length.unsqueeze(1)
+    return x.unsqueeze(0) < length.unsqueeze(1)
 
 
 def init_weights(m, mean=0.0, std=0.01):
@@ -52,6 +52,26 @@ def plot_mel(data, titles=None):
     return fig
 
 
+def slice_segments(x, ids_str, segment_size=4):
+    ret = torch.zeros_like(x[:, :, :segment_size])
+    for i in range(x.size(0)):
+        idx_str = ids_str[i]
+        idx_end = idx_str + segment_size
+        ret[i] = x[i, :, idx_str:idx_end]
+
+    return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+    b, d, t = x.size()
+    if x_lengths is None:
+        x_lengths = t
+    ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
+    ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+    ret = slice_segments(x, ids_str, segment_size)
+    return ret, ids_str
+
+
 @torch.jit.script
 def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
     n_channels_int = n_channels[0]

+ 6 - 3
tools/vqgan/migrate_from_vits.py

@@ -35,11 +35,14 @@ def main(cfg: DictConfig):
 
     # Posterior Encoder
     encoder_state = {
-        k[6:]: v for k, v in generator_weights.items() if k.startswith("enc_q.")
+        k[6:]: v
+        for k, v in generator_weights.items()
+        if k.startswith("enc_q.")
+        if not k.startswith("enc_q.proj.")
     }
     logger.info(f"Found {len(encoder_state)} posterior encoder weights, restoring...")
-    model.posterior_encoder.load_state_dict(encoder_state, strict=True)
-    logger.info("Posterior encoder weights restored.")
+    x = model.posterior_encoder.load_state_dict(encoder_state, strict=False)
+    logger.info(f"Posterior encoder weights restored. {x}")
 
     # Flow
     # flow_state = {