Просмотр исходного кода

borrow VITS components from other repo

Lengyue 2 лет назад
Родитель
Сommit
e0d8c72b24

+ 22 - 49
fish_speech/configs/hubert_vq.yaml

@@ -13,7 +13,7 @@ trainer:
     static_graph: true
     static_graph: true
   precision: 32
   precision: 32
   max_steps: 1_000_000
   max_steps: 1_000_000
-  val_check_interval: 1000
+  val_check_interval: 5000
 
 
 sample_rate: 32000
 sample_rate: 32000
 hop_length: 640
 hop_length: 640
@@ -27,14 +27,13 @@ train_dataset:
   filelist: data/vq_train_filelist.txt
   filelist: data/vq_train_filelist.txt
   sample_rate: ${sample_rate}
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   hop_length: ${hop_length}
-  slice_frames: 32
+  slice_frames: 512
 
 
 val_dataset:
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
   _target_: fish_speech.datasets.vqgan.VQGANDataset
   filelist: data/vq_val_filelist.txt
   filelist: data/vq_val_filelist.txt
   sample_rate: ${sample_rate}
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   hop_length: ${hop_length}
-  slice_frames: 256
 
 
 data:
 data:
   _target_: fish_speech.datasets.vqgan.VQGANDataModule
   _target_: fish_speech.datasets.vqgan.VQGANDataModule
@@ -51,57 +50,31 @@ model:
   hop_length: ${hop_length}
   hop_length: ${hop_length}
   segment_size: 20480
   segment_size: 20480
 
 
-  semantic_encoder:
-    _target_: fish_speech.models.vqgan.modules.SemanticEncoder
+  generator:
+    _target_: fish_speech.models.vqgan.modules.models.SynthesizerTrn
     in_channels: 1024
     in_channels: 1024
-    hidden_channels: 384
-    out_channels: 192
-    num_heads: 2
-    num_layers: 8
-    input_downsample: true
-    code_book_size: 2048
-    freeze_vq: false
-    gin_channels: ${model.speaker_encoder.out_channels}
-
-  posterior_encoder:
-    _target_: fish_speech.models.vqgan.modules.PosteriorEncoder
-    in_channels: "${eval: '${n_fft} // 2 + 1'}"
-    hidden_channels: 192
-    out_channels: 192
-    gin_channels: ${model.speaker_encoder.out_channels}
-  
-  speaker_encoder:
-    _target_: fish_speech.models.vqgan.modules.SpeakerEncoder
-    in_channels: ${num_mels}
+    spec_channels: ${num_mels}
+    segment_size: "${eval: '${model.segment_size} // ${hop_length}'}"
+    inter_channels: 192
     hidden_channels: 192
     hidden_channels: 192
-    out_channels: 512
-
-  # flow:
-  #   _target_: fish_speech.models.vqgan.modules.ResidualCouplingBlock
-  #   channels: 192
-  #   hidden_channels: 192
-  #   kernel_size: 5
-  #   dilation_rate: 1
-  #   n_layers: 4
-  #   n_flows: 4
-  #   gin_channels: ${model.speaker_encoder.out_channels}
-
-  generator:
-    _target_: fish_speech.models.vqgan.modules.Generator
-    initial_channel: 192
-    gin_channels: ${model.speaker_encoder.out_channels}
+    filter_channels: 768
+    n_heads: 4
+    n_layers: 8
+    n_layers_q: 16
+    n_layers_spk: 6
+    kernel_size: 3
+    p_dropout: 0.1
+    speaker_cond_layer: 0
     resblock: "1"
     resblock: "1"
     resblock_kernel_sizes: [3, 7, 11]
     resblock_kernel_sizes: [3, 7, 11]
-    resblock_dilation_sizes: 
-      - [1, 3, 5]
-      - [1, 3, 5]
-      - [1, 3, 5]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
     upsample_rates: [10, 8, 2, 2, 2]
     upsample_rates: [10, 8, 2, 2, 2]
     upsample_initial_channel: 512
     upsample_initial_channel: 512
     upsample_kernel_sizes: [16, 16, 8, 2, 2]
     upsample_kernel_sizes: [16, 16, 8, 2, 2]
+    gin_channels: 512 # basically the speaker embedding size
 
 
   discriminator:
   discriminator:
-    _target_: fish_speech.models.vqgan.modules.EnsembleDiscriminator
+    _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
 
 
   mel_transform:
   mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
@@ -114,7 +87,7 @@ model:
   optimizer:
   optimizer:
     _target_: torch.optim.AdamW
     _target_: torch.optim.AdamW
     _partial_: true
     _partial_: true
-    lr: 1e-4
+    lr: 2e-4
     betas: [0.8, 0.99]
     betas: [0.8, 0.99]
     eps: 1e-5
     eps: 1e-5
 
 
@@ -124,7 +97,7 @@ model:
     lr_lambda:
     lr_lambda:
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
       _partial_: true
       _partial_: true
-      num_warmup_steps: 0
+      num_warmup_steps: 1000
       num_training_steps: ${trainer.max_steps}
       num_training_steps: ${trainer.max_steps}
       final_lr_ratio: 0.05
       final_lr_ratio: 0.05
 
 
@@ -133,5 +106,5 @@ callbacks:
     sub_module: generator
     sub_module: generator
 
 
 # Resume from rcell's checkpoint
 # Resume from rcell's checkpoint
-ckpt_path: results/hubert-vq-pretrain/rcell/ckpt_23000_pl.pth
-resume_weights_only: true
+# ckpt_path: results/hubert-vq-pretrain/rcell/ckpt_23000_pl.pth
+# resume_weights_only: true

+ 5 - 1
fish_speech/datasets/vqgan.py

@@ -26,7 +26,11 @@ class VQGANDataset(Dataset):
         filelist = Path(filelist)
         filelist = Path(filelist)
         root = filelist.parent
         root = filelist.parent
 
 
-        self.files = [root / line.strip() for line in filelist.read_text().splitlines()]
+        self.files = [
+            root / line.strip()
+            for line in filelist.read_text().splitlines()
+            if ("Genshin" in line or "StarRail" in line)
+        ]
         self.sample_rate = sample_rate
         self.sample_rate = sample_rate
         self.hop_length = hop_length
         self.hop_length = hop_length
         self.slice_frames = slice_frames
         self.slice_frames = slice_frames

+ 73 - 217
fish_speech/models/vqgan/lit_module.py

@@ -8,15 +8,17 @@ import wandb
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from matplotlib import pyplot as plt
 from torch import nn
 from torch import nn
+from vector_quantize_pytorch import ResidualLFQ
 
 
-from fish_speech.models.vqgan.modules import (
-    EnsembleDiscriminator,
-    Generator,
-    PosteriorEncoder,
-    SemanticEncoder,
-    SpeakerEncoder,
+from fish_speech.models.vqgan.losses import (
+    discriminator_loss,
+    feature_loss,
+    generator_loss,
+    kl_loss_normal,
 )
 )
-from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
+from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
+from fish_speech.models.vqgan.modules.models import SynthesizerTrn
+from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
 
 
 
 
 class VQGAN(L.LightningModule):
 class VQGAN(L.LightningModule):
@@ -24,11 +26,7 @@ class VQGAN(L.LightningModule):
         self,
         self,
         optimizer: Callable,
         optimizer: Callable,
         lr_scheduler: Callable,
         lr_scheduler: Callable,
-        semantic_encoder: SemanticEncoder,
-        posterior_encoder: PosteriorEncoder,
-        speaker_encoder: SpeakerEncoder,
-        # flow: nn.Module,
-        generator: Generator,
+        generator: SynthesizerTrn,
         discriminator: EnsembleDiscriminator,
         discriminator: EnsembleDiscriminator,
         mel_transform: nn.Module,
         mel_transform: nn.Module,
         segment_size: int = 20480,
         segment_size: int = 20480,
@@ -42,11 +40,6 @@ class VQGAN(L.LightningModule):
         self.lr_scheduler_builder = lr_scheduler
         self.lr_scheduler_builder = lr_scheduler
 
 
         # Generator and discriminators
         # Generator and discriminators
-        # Compile generator so that snake can save memory
-        self.semantic_encoder = semantic_encoder
-        self.posterior_encoder = posterior_encoder
-        self.speaker_encoder = speaker_encoder
-        # self.flow = flow
         self.generator = generator
         self.generator = generator
         self.discriminator = discriminator
         self.discriminator = discriminator
         self.mel_transform = mel_transform
         self.mel_transform = mel_transform
@@ -61,15 +54,7 @@ class VQGAN(L.LightningModule):
 
 
     def configure_optimizers(self):
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
         # Need two optimizers and two schedulers
-        optimizer_generator = self.optimizer_builder(
-            itertools.chain(
-                self.semantic_encoder.parameters(),
-                self.posterior_encoder.parameters(),
-                self.speaker_encoder.parameters(),
-                self.generator.parameters(),
-                # self.flow.parameters(),
-            )
-        )
+        optimizer_generator = self.optimizer_builder(self.generator.parameters())
         optimizer_discriminator = self.optimizer_builder(
         optimizer_discriminator = self.optimizer_builder(
             self.discriminator.parameters()
             self.discriminator.parameters()
         )
         )
@@ -96,127 +81,41 @@ class VQGAN(L.LightningModule):
             },
             },
         )
         )
 
 
-    @staticmethod
-    def discriminator_loss(disc_real_outputs, disc_generated_outputs):
-        loss = 0
-        r_losses = []
-        g_losses = []
-        for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
-            dr = dr.float()
-            dg = dg.float()
-            r_loss = torch.mean((1 - dr) ** 2)
-            g_loss = torch.mean(dg**2)
-            loss += r_loss + g_loss
-            r_losses.append(r_loss.item())
-            g_losses.append(g_loss.item())
-
-        return loss, r_losses, g_losses
-
-    @staticmethod
-    def generator_loss(disc_outputs):
-        loss = 0
-        gen_losses = []
-        for dg in disc_outputs:
-            dg = dg.float()
-            l = torch.mean((1 - dg) ** 2)
-            gen_losses.append(l)
-            loss += l
-
-        return loss, gen_losses
-
-    @staticmethod
-    def feature_loss(fmap_r, fmap_g):
-        loss = 0
-        for dr, dg in zip(fmap_r, fmap_g):
-            for rl, gl in zip(dr, dg):
-                rl = rl.float().detach()
-                gl = gl.float()
-                loss += torch.mean(torch.abs(rl - gl))
-
-        return loss * 2
-
-    @staticmethod
-    def kl_loss(m_q, logs_q, m_p, logs_p, z_mask):
-        """
-        m_q, logs_q: [b, h, t_t]
-        m_p, logs_p: [b, h, t_t]
-        """
-        m_q = m_q.float()
-        logs_q = logs_q.float()
-        m_p = m_p.float()
-        logs_p = logs_p.float()
-        z_mask = z_mask.float()
-
-        kl = 0.5 * (
-            (m_q - m_p) ** 2 / torch.exp(logs_p)
-            + torch.exp(logs_q) / torch.exp(logs_p)
-            - 1
-            - logs_q
-            + logs_p
-        )
-
-        kl = torch.sum(kl * z_mask)
-        l = kl / torch.sum(z_mask)
-
-        return l
-
     def training_step(self, batch, batch_idx):
     def training_step(self, batch, batch_idx):
         optim_g, optim_d = self.optimizers()
         optim_g, optim_d = self.optimizers()
 
 
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
         features, feature_lengths = batch["features"], batch["feature_lengths"]
         features, feature_lengths = batch["features"], batch["feature_lengths"]
+        audios = audios[:, None, :]
 
 
         audios = audios.float()
         audios = audios.float()
         features = features.float()
         features = features.float()
 
 
         with torch.no_grad():
         with torch.no_grad():
-            gt_mels, gt_specs = self.mel_transform(audios, return_linear=True)
-            gt_mels = gt_mels.transpose(1, 2)
-            key_padding_mask = sequence_mask(feature_lengths)
-            mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
-            audio_masks = sequence_mask(audio_lengths)[:, None]
-
-            assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
-            gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
-            gt_mels = gt_mels[:, :gt_mel_length]
-            gt_specs = gt_specs[:, :, :gt_mel_length]
-            mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
-
-            assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
-            gt_feature_length = min(features.shape[1], key_padding_mask.shape[1])
-            features = features[:, :gt_feature_length]
-            key_padding_mask = key_padding_mask[:, :gt_feature_length]
-
-        audios = audios[:, None, :]
-
-        speaker = self.speaker_encoder(gt_mels, mels_key_padding_mask)[:, :, None]
-        prior = self.semantic_encoder(
-            x=features,
-            key_padding_mask=key_padding_mask,
-            g=speaker,
-        )
-
-        posterior_key_padding_mask = (~mels_key_padding_mask).float()[:, None]
-        posterior = self.posterior_encoder(
-            gt_specs, posterior_key_padding_mask, g=speaker
-        )
-        # z_p = self.flow(posterior.mean, posterior_key_padding_mask, g=speaker)
-        fake_audios = self.generator(posterior.z, g=speaker)
-
-        min_audio_length = min(audios.shape[-1], fake_audios.shape[-1])
-        audios = audios[:, :, :min_audio_length]
-        fake_audios = fake_audios[:, :, :min_audio_length]
-        audio_masks = audio_masks[:, :, :min_audio_length]
-
-        audio = torch.masked_fill(audios, audio_masks, 0.0)
-        fake_audios = torch.masked_fill(fake_audios, audio_masks, 0.0)
-        assert fake_audios.shape == audio.shape
+            gt_mels = self.mel_transform(audios)
+            assert (
+                gt_mels.shape[2] == features.shape[1]
+            ), f"Shapes do not match: {gt_mels.shape}, {features.shape}"
+
+        (
+            y_hat,
+            ids_slice,
+            x_mask,
+            y_mask,
+            (z_q_audio, z_p),
+            (m_p_text, logs_p_text),
+            (m_q, logs_q),
+        ) = self.generator(features, feature_lengths, gt_mels, feature_lengths)
+
+        y_hat_mel = self.mel_transform(y_hat.squeeze(1))
+        y_mel = slice_segments(gt_mels, ids_slice, self.segment_size // self.hop_length)
+        y = slice_segments(audios, ids_slice * self.hop_length, self.segment_size)
 
 
         # Discriminator
         # Discriminator
-        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audio, fake_audios.detach())
+        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
 
 
         with torch.autocast(device_type=audios.device.type, enabled=False):
         with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_disc_all, _, _ = self.discriminator_loss(y_d_hat_r, y_d_hat_g)
+            loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
 
 
         self.log(
         self.log(
             "train/discriminator/loss",
             "train/discriminator/loss",
@@ -235,32 +134,26 @@ class VQGAN(L.LightningModule):
         )
         )
         optim_d.step()
         optim_d.step()
 
 
-        y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(audios, fake_audios)
-        fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
-
-        # Min mel length
-        min_mel_length = min(gt_mels.shape[1], fake_mels.shape[1])
-        gt_mels = gt_mels[:, :min_mel_length]
-        fake_mels = fake_mels[:, :min_mel_length]
-        mels_key_padding_mask = mels_key_padding_mask[:, :min_mel_length]
-
-        # Fill mel mask
-        fake_mels = torch.masked_fill(fake_mels, mels_key_padding_mask[:, :, None], 0.0)
-        gt_mels = torch.masked_fill(gt_mels, mels_key_padding_mask[:, :, None], 0.0)
+        y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
 
 
         with torch.autocast(device_type=audios.device.type, enabled=False):
         with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_mel = F.l1_loss(gt_mels, fake_mels)
-            loss_adv, _ = self.generator_loss(y_d_hat_g)
-            loss_fm = self.feature_loss(fmap_r, fmap_g)
-            loss_kl = self.kl_loss(
-                posterior.mean,
-                posterior.logs,
-                prior.mean,
-                prior.logs,
-                posterior_key_padding_mask,
+            loss_mel = F.l1_loss(y_mel, y_hat_mel)
+            loss_adv, _ = generator_loss(y_d_hat_g)
+            loss_fm = feature_loss(fmap_r, fmap_g)
+            # x_mask,
+            # y_mask,
+            # (z_q_audio, z_p),
+            # (m_p_text, logs_p_text),
+            # (m_q, logs_q),
+            loss_kl = kl_loss_normal(
+                m_q,
+                logs_q,
+                m_p_text,
+                logs_p_text,
+                x_mask,
             )
             )
 
 
-            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + prior.loss + loss_kl
+            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + loss_kl * 0.05
 
 
         self.log(
         self.log(
             "train/generator/loss",
             "train/generator/loss",
@@ -307,15 +200,15 @@ class VQGAN(L.LightningModule):
             logger=True,
             logger=True,
             sync_dist=True,
             sync_dist=True,
         )
         )
-        self.log(
-            "train/generator/loss_vq",
-            prior.loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
+        # self.log(
+        #     "train/generator/loss_vq",
+        #     prior.loss,
+        #     on_step=True,
+        #     on_epoch=False,
+        #     prog_bar=False,
+        #     logger=True,
+        #     sync_dist=True,
+        # )
 
 
         optim_g.zero_grad()
         optim_g.zero_grad()
         self.manual_backward(loss_gen_all)
         self.manual_backward(loss_gen_all)
@@ -335,60 +228,23 @@ class VQGAN(L.LightningModule):
 
 
         audios = audios.float()
         audios = audios.float()
         features = features.float()
         features = features.float()
-
-        with torch.no_grad():
-            gt_mels, gt_specs = self.mel_transform(audios, return_linear=True)
-            gt_mels = gt_mels.transpose(1, 2)
-            key_padding_mask = sequence_mask(feature_lengths)
-            mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
-
-            assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
-            gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
-            gt_mels = gt_mels[:, :gt_mel_length]
-            gt_specs = gt_specs[:, :, :gt_mel_length]
-            mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
-
-            assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
-            gt_feature_length = min(features.shape[1], key_padding_mask.shape[1])
-            features = features[:, :gt_feature_length]
-            key_padding_mask = key_padding_mask[:, :gt_feature_length]
-
-        # Generator
-        # speaker: (B, C, 1)
-        speaker = self.speaker_encoder(gt_mels, mels_key_padding_mask)[:, :, None]
-        posterior_key_padding_mask = (~mels_key_padding_mask).float()[:, None]
-
-        z_gen = self.semantic_encoder(
-            x=features,
-            key_padding_mask=key_padding_mask,
-            g=speaker,
-        ).z
-
-        # z_gen = self.flow(z_gen, posterior_key_padding_mask, g=speaker, reverse=True)
-
-        z_posterior = self.posterior_encoder(
-            gt_specs, posterior_key_padding_mask, g=speaker
-        ).mean
-
         audios = audios[:, None, :]
         audios = audios[:, None, :]
-        fake_audios = self.generator(z_gen, g=speaker)
-        posterior_audios = self.generator(z_posterior)
-        min_audio_length = min(
-            audios.shape[-1], fake_audios.shape[-1], posterior_audios.shape[-1]
-        )
 
 
-        audios = audios[:, :, :min_audio_length]
-        fake_audios = fake_audios[:, :, :min_audio_length]
-        posterior_audios = posterior_audios[:, :, :min_audio_length]
-        assert fake_audios.shape == audios.shape == posterior_audios.shape
+        gt_mels = self.mel_transform(audios)
+        assert (
+            gt_mels.shape[2] == features.shape[1]
+        ), f"Shapes do not match: {gt_mels.shape}, {features.shape}"
+
+        fake_audios = self.generator.infer(features, feature_lengths, gt_mels)
+        posterior_audios = self.generator.reconstruct(gt_mels, feature_lengths)
 
 
-        fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
-        posterior_mels = self.mel_transform(posterior_audios.squeeze(1)).transpose(1, 2)
+        fake_mels = self.mel_transform(fake_audios.squeeze(1))
+        posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
 
 
-        min_mel_length = min(gt_mels.shape[1], fake_mels.shape[1])
-        gt_mels = gt_mels[:, :min_mel_length]
-        fake_mels = fake_mels[:, :min_mel_length]
-        posterior_mels = posterior_mels[:, :min_mel_length]
+        min_mel_length = min(gt_mels.shape[-1], fake_mels.shape[-1])
+        gt_mels = gt_mels[:, :, :min_mel_length]
+        fake_mels = fake_mels[:, :, :min_mel_length]
+        posterior_mels = posterior_mels[:, :, :min_mel_length]
 
 
         mel_loss = F.l1_loss(gt_mels, fake_mels)
         mel_loss = F.l1_loss(gt_mels, fake_mels)
         self.log(
         self.log(
@@ -411,9 +267,9 @@ class VQGAN(L.LightningModule):
             audio_len,
             audio_len,
         ) in enumerate(
         ) in enumerate(
             zip(
             zip(
-                gt_mels.transpose(1, 2),
-                fake_mels.transpose(1, 2),
-                posterior_mels.transpose(1, 2),
+                gt_mels,
+                fake_mels,
+                posterior_mels,
                 audios,
                 audios,
                 fake_audios,
                 fake_audios,
                 posterior_audios,
                 posterior_audios,

+ 92 - 0
fish_speech/models/vqgan/losses.py

@@ -0,0 +1,92 @@
+from typing import List
+
+import torch
+
+
+def feature_loss(fmap_r: List[torch.Tensor], fmap_g: List[torch.Tensor]):
+    loss = 0
+    for dr, dg in zip(fmap_r, fmap_g):
+        for rl, gl in zip(dr, dg):
+            rl = rl.float().detach()
+            gl = gl.float()
+            loss += torch.mean(torch.abs(rl - gl))
+
+    return loss * 2
+
+
+def discriminator_loss(
+    disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
+):
+    loss = 0
+    r_losses = []
+    g_losses = []
+    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+        dr = dr.float()
+        dg = dg.float()
+        r_loss = torch.mean((1 - dr) ** 2)
+        g_loss = torch.mean(dg**2)
+        loss += r_loss + g_loss
+        r_losses.append(r_loss.item())
+        g_losses.append(g_loss.item())
+
+    return loss, r_losses, g_losses
+
+
+def generator_loss(disc_outputs: List[torch.Tensor]):
+    loss = 0
+    gen_losses = []
+    for dg in disc_outputs:
+        dg = dg.float()
+        l = torch.mean((1 - dg) ** 2)
+        gen_losses.append(l)
+        loss += l
+
+    return loss, gen_losses
+
+
+def kl_loss(
+    z_p: torch.Tensor,
+    logs_q: torch.Tensor,
+    m_p: torch.Tensor,
+    logs_p: torch.Tensor,
+    z_mask: torch.Tensor,
+):
+    """
+    z_p, logs_q: [b, h, t_t]
+    m_p, logs_p: [b, h, t_t]
+    """
+    z_p = z_p.float()
+    logs_q = logs_q.float()
+    m_p = m_p.float()
+    logs_p = logs_p.float()
+    z_mask = z_mask.float()
+
+    kl = logs_p - logs_q - 0.5
+    kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+    kl = torch.sum(kl * z_mask)
+    l = kl / torch.sum(z_mask)
+    return l
+
+
+def kl_loss_normal(
+    m_q: torch.Tensor,
+    logs_q: torch.Tensor,
+    m_p: torch.Tensor,
+    logs_p: torch.Tensor,
+    z_mask: torch.Tensor,
+):
+    """
+    z_p, logs_q: [b, h, t_t]
+    m_p, logs_p: [b, h, t_t]
+    """
+    m_q = m_q.float()
+    logs_q = logs_q.float()
+    m_p = m_p.float()
+    logs_p = logs_p.float()
+    z_mask = z_mask.float()
+
+    kl = logs_p - logs_q - 0.5
+    kl += 0.5 * (torch.exp(2.0 * logs_q) + (m_q - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+    kl = torch.sum(kl * z_mask)
+    l = kl / torch.sum(z_mask)
+    return l

+ 0 - 1090
fish_speech/models/vqgan/modules.py

@@ -1,1090 +0,0 @@
-import math
-from dataclasses import dataclass
-
-import torch
-from encodec.quantization.core_vq import VectorQuantization
-from torch import nn
-from torch.nn import Conv1d, Conv2d, ConvTranspose1d
-from torch.nn import functional as F
-from torch.nn.utils.parametrizations import spectral_norm, weight_norm
-from torch.nn.utils.parametrize import remove_parametrizations
-
-from fish_speech.models.vqgan.utils import (
-    convert_pad_shape,
-    fused_add_tanh_sigmoid_multiply,
-    get_padding,
-    init_weights,
-)
-
-LRELU_SLOPE = 0.1
-
-
-class ResidualCouplingBlock(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size: int = 5,
-        dilation_rate: int = 1,
-        n_layers: int = 4,
-        n_flows: int = 4,
-        gin_channels: int = 512,
-    ):
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.n_flows = n_flows
-        self.gin_channels = gin_channels
-
-        self.flows = nn.ModuleList()
-        for i in range(n_flows):
-            self.flows.append(
-                ResidualCouplingLayer(
-                    channels,
-                    hidden_channels,
-                    kernel_size,
-                    dilation_rate,
-                    n_layers,
-                    gin_channels=gin_channels,
-                    mean_only=True,
-                )
-            )
-            self.flows.append(Flip())
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        if not reverse:
-            for flow in self.flows:
-                x, _ = flow(x, x_mask, g=g, reverse=reverse)
-        else:
-            for flow in reversed(self.flows):
-                x = flow(x, x_mask, g=g, reverse=reverse)
-        return x
-
-
-class Flip(nn.Module):
-    def forward(self, x, *args, reverse=False, **kwargs):
-        x = torch.flip(x, [1])
-        if not reverse:
-            logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
-            return x, logdet
-        else:
-            return x
-
-
-class ResidualCouplingLayer(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        p_dropout=0,
-        gin_channels=0,
-        mean_only=False,
-    ):
-        assert channels % 2 == 0, "channels should be divisible by 2"
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.half_channels = channels // 2
-        self.mean_only = mean_only
-
-        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
-        self.enc = WaveNet(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            p_dropout=p_dropout,
-            gin_channels=gin_channels,
-        )
-        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
-        self.post.weight.data.zero_()
-        self.post.bias.data.zero_()
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
-        h = self.pre(x0) * x_mask
-        h = self.enc(h, x_mask, g=g)
-        stats = self.post(h) * x_mask
-        if not self.mean_only:
-            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
-        else:
-            m = stats
-            logs = torch.zeros_like(m)
-
-        if not reverse:
-            x1 = m + x1 * torch.exp(logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            logdet = torch.sum(logs, [1, 2])
-            return x, logdet
-        else:
-            x1 = (x1 - m) * torch.exp(-logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            return x
-
-
-@dataclass
-class SemanticEncoderOutput:
-    loss: torch.Tensor
-    mean: torch.Tensor
-    logs: torch.Tensor
-    z: torch.Tensor
-
-
-class SemanticEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int = 1024,
-        hidden_channels: int = 384,
-        out_channels: int = 192,
-        num_heads: int = 2,
-        num_layers: int = 8,
-        input_downsample: bool = True,
-        code_book_size: int = 2048,
-        freeze_vq: bool = False,
-        gin_channels: int = 512,
-    ):
-        super().__init__()
-
-        # Feature Encoder
-        down_sample = 2 if input_downsample else 1
-
-        self.in_proj = nn.Conv1d(
-            in_channels, in_channels, kernel_size=down_sample, stride=down_sample
-        )
-        self.vq = VectorQuantization(
-            dim=in_channels,
-            codebook_size=code_book_size,
-            threshold_ema_dead_code=2,
-            kmeans_init=False,
-            kmeans_iters=50,
-        )
-
-        # Init weights of in_proj to mimic the effect of avg pooling
-        nn.init.normal_(
-            self.in_proj.weight, mean=1 / (down_sample * in_channels), std=0.01
-        )
-        self.in_proj.bias.data.zero_()
-
-        self.feature_in = nn.Linear(in_channels, hidden_channels)
-        self.g_in = nn.Conv1d(gin_channels, hidden_channels, 1)
-        self.blocks = nn.ModuleList(
-            [
-                TransformerBlock(
-                    hidden_channels,
-                    num_heads,
-                    window_size=4,
-                    window_heads_share=True,
-                    proximal_init=True,
-                    proximal_bias=False,
-                    use_relative_attn=True,
-                )
-                for _ in range(num_layers)
-            ]
-        )
-
-        self.out_proj = nn.Linear(hidden_channels, out_channels * 2)
-
-        self.input_downsample = input_downsample
-
-        if freeze_vq:
-            for p in self.vq.parameters():
-                p.requires_grad = False
-
-            for p in self.vq_in.parameters():
-                p.requires_grad = False
-
-    def forward(self, x, key_padding_mask=None, g=None) -> SemanticEncoderOutput:
-        # x: (batch, seq_len, channels)
-
-        assert key_padding_mask.size(1) == x.size(
-            1
-        ), f"key_padding_mask shape {key_padding_mask.size()} does not match features shape {x.size()}"
-
-        # Encode Features
-        features = self.in_proj(x.transpose(1, 2))
-        features, _, loss = self.vq(features)
-        features = features.transpose(1, 2)
-
-        if self.input_downsample:
-            features = F.interpolate(
-                features.transpose(1, 2), scale_factor=2, mode="nearest"
-            ).transpose(1, 2)
-
-        # Shape may change due to downsampling, let's cut it to the same size
-        if features.shape[1] != key_padding_mask.shape[1]:
-            assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
-            min_len = min(features.shape[1], key_padding_mask.shape[1])
-            features = features[:, :min_len]
-            key_padding_mask = key_padding_mask[:, :min_len]
-
-        features = self.feature_in(features)
-        g = self.g_in(g).transpose(1, 2)
-        features = features + g
-
-        for block in self.blocks:
-            features = block(features, key_padding_mask=key_padding_mask)
-
-        stats = self.out_proj(features).transpose(1, 2)
-        stats = torch.masked_fill(stats, key_padding_mask.unsqueeze(1), 0)
-        mean, logs = torch.chunk(stats, 2, dim=1)
-
-        return SemanticEncoderOutput(
-            loss=loss,
-            mean=mean,
-            logs=logs,
-            z=mean + torch.randn_like(mean) * torch.exp(logs) * 0.5,
-        )
-
-
-class WaveNet(nn.Module):
-    def __init__(
-        self,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        gin_channels=0,
-        p_dropout=0,
-    ):
-        super(WaveNet, self).__init__()
-        assert kernel_size % 2 == 1
-        self.hidden_channels = hidden_channels
-        self.kernel_size = (kernel_size,)
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.gin_channels = gin_channels
-        self.p_dropout = p_dropout
-
-        self.in_layers = nn.ModuleList()
-        self.res_skip_layers = nn.ModuleList()
-        self.drop = nn.Dropout(p_dropout)
-
-        if gin_channels != 0:
-            self.cond_layer = weight_norm(
-                nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
-            )
-
-        for i in range(n_layers):
-            dilation = dilation_rate**i
-            padding = int((kernel_size * dilation - dilation) / 2)
-            in_layer = weight_norm(
-                nn.Conv1d(
-                    hidden_channels,
-                    2 * hidden_channels,
-                    kernel_size,
-                    dilation=dilation,
-                    padding=padding,
-                )
-            )
-            self.in_layers.append(in_layer)
-
-            # last one is not necessary
-            if i < n_layers - 1:
-                res_skip_channels = 2 * hidden_channels
-            else:
-                res_skip_channels = hidden_channels
-
-            res_skip_layer = weight_norm(
-                nn.Conv1d(hidden_channels, res_skip_channels, 1), name="weight"
-            )
-            self.res_skip_layers.append(res_skip_layer)
-
-    def forward(self, x, x_mask, g=None):
-        output = torch.zeros_like(x)
-        n_channels_tensor = torch.IntTensor([self.hidden_channels])
-
-        if g is not None:
-            g = self.cond_layer(g)
-
-        for i in range(self.n_layers):
-            x_in = self.in_layers[i](x)
-            if g is not None:
-                cond_offset = i * 2 * self.hidden_channels
-                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
-            else:
-                g_l = torch.zeros_like(x_in)
-
-            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
-            acts = self.drop(acts)
-
-            res_skip_acts = self.res_skip_layers[i](acts)
-            if i < self.n_layers - 1:
-                res_acts = res_skip_acts[:, : self.hidden_channels, :]
-                x = (x + res_acts) * x_mask
-                output = output + res_skip_acts[:, self.hidden_channels :, :]
-            else:
-                output = output + res_skip_acts
-
-        return output * x_mask
-
-    def remove_parametrizations(self):
-        if self.gin_channels != 0:
-            nn.utils.remove_parametrizations(self.cond_layer)
-        for l in self.in_layers:
-            nn.utils.remove_parametrizations(l)
-        for l in self.res_skip_layers:
-            nn.utils.remove_parametrizations(l)
-
-
-@dataclass
-class PosteriorEncoderOutput:
-    z: torch.Tensor
-    mean: torch.Tensor
-    logs: torch.Tensor
-
-
-class PosteriorEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int = 1024,
-        out_channels: int = 192,
-        hidden_channels: int = 192,
-        kernel_size: int = 5,
-        dilation_rate: int = 1,
-        n_layers: int = 16,
-        gin_channels: int = 512,
-    ):
-        super().__init__()
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.gin_channels = gin_channels
-
-        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
-        self.enc = WaveNet(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            gin_channels=gin_channels,
-        )
-        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
-
-    def forward(self, x, x_mask, g=None):
-        g = g.detach()
-        x = self.pre(x) * x_mask
-        x = self.enc(x, x_mask, g=g)
-        stats = self.proj(x) * x_mask
-        m, logs = torch.split(stats, self.out_channels, dim=1)
-        z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
-
-        return PosteriorEncoderOutput(
-            z=z,
-            mean=m,
-            logs=logs,
-        )
-
-
-class SpeakerEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int = 128,
-        hidden_channels: int = 192,
-        out_channels: int = 512,
-        num_heads: int = 2,
-        num_layers: int = 4,
-    ) -> None:
-        super().__init__()
-
-        self.query = nn.Parameter(torch.randn(1, 1, hidden_channels))
-        self.in_proj = nn.Linear(in_channels, hidden_channels)
-        self.blocks = nn.ModuleList(
-            [
-                TransformerBlock(
-                    hidden_channels,
-                    num_heads,
-                    use_relative_attn=False,
-                )
-                for _ in range(num_layers)
-            ]
-        )
-        self.out_proj = nn.Linear(hidden_channels, out_channels)
-
-    def forward(self, mels, mels_key_padding_mask=None):
-        x = self.in_proj(mels)
-        x = torch.cat([self.query.expand(x.shape[0], -1, -1), x], dim=1)
-
-        mels_key_padding_mask = torch.cat(
-            [
-                torch.ones(x.shape[0], 1, dtype=torch.bool, device=x.device),
-                mels_key_padding_mask,
-            ],
-            dim=1,
-        )
-        for block in self.blocks:
-            x = block(x, key_padding_mask=mels_key_padding_mask)
-
-        x = self.out_proj(x[:, 0])
-
-        return x
-
-
-class TransformerBlock(nn.Module):
-    def __init__(
-        self,
-        channels,
-        n_heads,
-        mlp_ratio=4 * 2 / 3,
-        p_dropout=0.0,
-        window_size=4,
-        window_heads_share=True,
-        proximal_init=True,
-        proximal_bias=False,
-        use_relative_attn=True,
-    ):
-        super().__init__()
-
-        self.attn_norm = RMSNorm(channels)
-
-        if use_relative_attn:
-            self.attn = RelativeAttention(
-                channels,
-                n_heads,
-                p_dropout,
-                window_size,
-                window_heads_share,
-                proximal_init,
-                proximal_bias,
-            )
-        else:
-            self.attn = nn.MultiheadAttention(
-                embed_dim=channels,
-                num_heads=n_heads,
-                dropout=p_dropout,
-                batch_first=True,
-            )
-
-        self.mlp_norm = RMSNorm(channels)
-        self.mlp = SwiGLU(channels, int(channels * mlp_ratio), channels, drop=p_dropout)
-
-    def forward(self, x, key_padding_mask=None):
-        norm_x = self.attn_norm(x)
-
-        if isinstance(self.attn, RelativeAttention):
-            attn = self.attn(norm_x, key_padding_mask=key_padding_mask)
-        else:
-            attn, _ = self.attn(
-                norm_x, norm_x, norm_x, key_padding_mask=key_padding_mask
-            )
-
-        x = x + attn
-        x = x + self.mlp(self.mlp_norm(x))
-
-        return x
-
-
-class SwiGLU(nn.Module):
-    """
-    Swish-Gated Linear Unit (SwiGLU) activation function
-    """
-
-    def __init__(
-        self,
-        in_features,
-        hidden_features=None,
-        out_features=None,
-        bias=True,
-        drop=0.0,
-    ):
-        super().__init__()
-        out_features = out_features or in_features
-        hidden_features = hidden_features or in_features
-        assert hidden_features % 2 == 0
-
-        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
-        self.act = nn.SiLU()
-        self.drop1 = nn.Dropout(drop)
-        self.norm = RMSNorm(hidden_features // 2)
-        self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias)
-        self.drop2 = nn.Dropout(drop)
-
-    def init_weights(self):
-        # override init of fc1 w/ gate portion set to weight near zero, bias=1
-        fc1_mid = self.fc1.bias.shape[0] // 2
-        nn.init.ones_(self.fc1.bias[fc1_mid:])
-        nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
-
-    def forward(self, x):
-        x = self.fc1(x)
-        x1, x2 = x.chunk(2, dim=-1)
-
-        x = x1 * self.act(x2)
-        x = self.drop1(x)
-        x = self.norm(x)
-        x = self.fc2(x)
-        x = self.drop2(x)
-
-        return x
-
-
-class RMSNorm(nn.Module):
-    def __init__(self, hidden_size, eps=1e-6):
-        """
-        LlamaRMSNorm is equivalent to T5LayerNorm
-        """
-        super().__init__()
-
-        self.weight = nn.Parameter(torch.ones(hidden_size))
-        self.variance_epsilon = eps
-
-    def forward(self, hidden_states):
-        input_dtype = hidden_states.dtype
-        hidden_states = hidden_states.to(torch.float32)
-        variance = hidden_states.pow(2).mean(-1, keepdim=True)
-        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
-
-        return self.weight * hidden_states.to(input_dtype)
-
-
-class RelativeAttention(nn.Module):
-    def __init__(
-        self,
-        channels,
-        n_heads,
-        p_dropout=0.0,
-        window_size=4,
-        window_heads_share=True,
-        proximal_init=True,
-        proximal_bias=False,
-    ):
-        super().__init__()
-        assert channels % n_heads == 0
-
-        self.channels = channels
-        self.n_heads = n_heads
-        self.p_dropout = p_dropout
-        self.window_size = window_size
-        self.heads_share = window_heads_share
-        self.proximal_init = proximal_init
-        self.proximal_bias = proximal_bias
-
-        self.k_channels = channels // n_heads
-        self.qkv = nn.Linear(channels, channels * 3)
-        self.drop = nn.Dropout(p_dropout)
-
-        if window_size is not None:
-            n_heads_rel = 1 if window_heads_share else n_heads
-            rel_stddev = self.k_channels**-0.5
-            self.emb_rel_k = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-            self.emb_rel_v = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-
-        nn.init.xavier_uniform_(self.qkv.weight)
-
-        if proximal_init:
-            with torch.no_grad():
-                # Sync qk weights
-                self.qkv.weight.data[: self.channels] = self.qkv.weight.data[
-                    self.channels : self.channels * 2
-                ]
-                self.qkv.bias.data[: self.channels] = self.qkv.bias.data[
-                    self.channels : self.channels * 2
-                ]
-
-    def forward(self, x, key_padding_mask=None):
-        # x: (batch, seq_len, channels)
-        batch_size, seq_len, _ = x.size()
-        qkv = (
-            self.qkv(x)
-            .reshape(batch_size, seq_len, 3, self.n_heads, self.k_channels)
-            .permute(2, 0, 3, 1, 4)
-        )
-        query, key, value = torch.unbind(qkv, dim=0)
-
-        scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
-
-        if self.window_size is not None:
-            key_relative_embeddings = self._get_relative_embeddings(
-                self.emb_rel_k, seq_len
-            )
-            rel_logits = self._matmul_with_relative_keys(
-                query / math.sqrt(self.k_channels), key_relative_embeddings
-            )
-            scores_local = self._relative_position_to_absolute_position(rel_logits)
-            scores = scores + scores_local
-
-        if self.proximal_bias:
-            scores = scores + self._attention_bias_proximal(seq_len).to(
-                device=scores.device, dtype=scores.dtype
-            )
-
-        # key_padding_mask: (batch, seq_len)
-        if key_padding_mask is not None:
-            assert key_padding_mask.size() == (
-                batch_size,
-                seq_len,
-            ), f"key_padding_mask shape {key_padding_mask.size()} does not match x shape {x.size()}"
-            assert (
-                key_padding_mask.dtype == torch.bool
-            ), f"key_padding_mask dtype {key_padding_mask.dtype} is not bool"
-
-            key_padding_mask = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(
-                -1, self.n_heads, -1, -1
-            )
-            scores = scores.masked_fill(key_padding_mask, float("-inf"))
-
-        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
-        p_attn = self.drop(p_attn)
-        output = torch.matmul(p_attn, value)
-
-        if self.window_size is not None:
-            relative_weights = self._absolute_position_to_relative_position(p_attn)
-            value_relative_embeddings = self._get_relative_embeddings(
-                self.emb_rel_v, seq_len
-            )
-            output = output + self._matmul_with_relative_values(
-                relative_weights, value_relative_embeddings
-            )
-
-        return output.reshape(batch_size, seq_len, self.n_heads * self.k_channels)
-
-    def _matmul_with_relative_values(self, x, y):
-        """
-        x: [b, h, l, m]
-        y: [h or 1, m, d]
-        ret: [b, h, l, d]
-        """
-        ret = torch.matmul(x, y.unsqueeze(0))
-        return ret
-
-    def _matmul_with_relative_keys(self, x, y):
-        """
-        x: [b, h, l, d]
-        y: [h or 1, m, d]
-        ret: [b, h, l, m]
-        """
-        ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
-        return ret
-
-    def _get_relative_embeddings(self, relative_embeddings, length):
-        max_relative_position = 2 * self.window_size + 1
-        # Pad first before slice to avoid using cond ops.
-        pad_length = max(length - (self.window_size + 1), 0)
-        slice_start_position = max((self.window_size + 1) - length, 0)
-        slice_end_position = slice_start_position + 2 * length - 1
-        if pad_length > 0:
-            padded_relative_embeddings = F.pad(
-                relative_embeddings,
-                convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
-            )
-        else:
-            padded_relative_embeddings = relative_embeddings
-        used_relative_embeddings = padded_relative_embeddings[
-            :, slice_start_position:slice_end_position
-        ]
-        return used_relative_embeddings
-
-    def _relative_position_to_absolute_position(self, x):
-        """
-        x: [b, h, l, 2*l-1]
-        ret: [b, h, l, l]
-        """
-        batch, heads, length, _ = x.size()
-        # Concat columns of pad to shift from relative to absolute indexing.
-        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
-
-        # Concat extra elements so to add up to shape (len+1, 2*len-1).
-        x_flat = x.view([batch, heads, length * 2 * length])
-        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
-
-        # Reshape and slice out the padded elements.
-        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
-            :, :, :length, length - 1 :
-        ]
-        return x_final
-
-    def _absolute_position_to_relative_position(self, x):
-        """
-        x: [b, h, l, l]
-        ret: [b, h, l, 2*l-1]
-        """
-        batch, heads, length, _ = x.size()
-        # pad along column
-        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
-        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
-        # add 0's in the beginning that will skew the elements after reshape
-        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
-        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
-        return x_final
-
-    def _attention_bias_proximal(self, length):
-        """Bias for self-attention to encourage attention to close positions.
-        Args:
-          length: an integer scalar.
-        Returns:
-          a Tensor with shape [1, 1, length, length]
-        """
-        r = torch.arange(length, dtype=torch.float32)
-        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
-        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
-
-
-class ResBlock1(nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
-        super(ResBlock1, self).__init__()
-        self.convs1 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[2],
-                        padding=get_padding(kernel_size, dilation[2]),
-                    )
-                ),
-            ]
-        )
-        self.convs1.apply(init_weights)
-
-        self.convs2 = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-            ]
-        )
-        self.convs2.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c1, c2 in zip(self.convs1, self.convs2):
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c1(xt)
-            xt = F.leaky_relu(xt, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c2(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_parametrizations(self):
-        for l in self.convs1:
-            remove_parametrizations(l)
-        for l in self.convs2:
-            remove_parametrizations(l)
-
-
-class ResBlock2(nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
-        super(ResBlock2, self).__init__()
-        self.convs = nn.ModuleList(
-            [
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-            ]
-        )
-        self.convs.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c in self.convs:
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_parametrizations(self):
-        for l in self.convs:
-            remove_parametrizations(l)
-
-
-class Generator(nn.Module):
-    def __init__(
-        self,
-        initial_channel,
-        resblock,
-        resblock_kernel_sizes,
-        resblock_dilation_sizes,
-        upsample_rates,
-        upsample_initial_channel,
-        upsample_kernel_sizes,
-        gin_channels,
-    ):
-        super(Generator, self).__init__()
-        self.num_kernels = len(resblock_kernel_sizes)
-        self.num_upsamples = len(upsample_rates)
-        self.conv_pre = Conv1d(
-            initial_channel, upsample_initial_channel, 7, 1, padding=3
-        )
-        resblock = ResBlock1 if resblock == "1" else ResBlock2
-
-        self.ups = nn.ModuleList()
-        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
-            self.ups.append(
-                weight_norm(
-                    ConvTranspose1d(
-                        upsample_initial_channel // (2**i),
-                        upsample_initial_channel // (2 ** (i + 1)),
-                        k,
-                        u,
-                        padding=(k - u) // 2,
-                    )
-                )
-            )
-
-        self.resblocks = nn.ModuleList()
-        for i in range(len(self.ups)):
-            ch = upsample_initial_channel // (2 ** (i + 1))
-            for j, (k, d) in enumerate(
-                zip(resblock_kernel_sizes, resblock_dilation_sizes)
-            ):
-                self.resblocks.append(resblock(ch, k, d))
-
-        self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
-        self.ups.apply(init_weights)
-
-        self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
-
-    def forward(self, x, g=None):
-        x = self.conv_pre(x)
-        if g is not None:
-            g = self.cond(g)
-
-        for i in range(self.num_upsamples):
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            x = self.ups[i](x)
-            xs = None
-            for j in range(self.num_kernels):
-                if xs is None:
-                    xs = self.resblocks[i * self.num_kernels + j](x)
-                else:
-                    xs += self.resblocks[i * self.num_kernels + j](x)
-            x = xs / self.num_kernels
-        x = F.leaky_relu(x)
-        x = self.conv_post(x)
-        x = torch.tanh(x)
-
-        return x
-
-    def remove_parametrizations(self):
-        print("Removing weight norm...")
-        for l in self.ups:
-            remove_parametrizations(l)
-        for l in self.resblocks:
-            l.remove_parametrizations()
-
-
-class DiscriminatorP(nn.Module):
-    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
-        super(DiscriminatorP, self).__init__()
-        self.period = period
-        self.use_spectral_norm = use_spectral_norm
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(
-                    Conv2d(
-                        1,
-                        32,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        32,
-                        128,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        128,
-                        512,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        512,
-                        1024,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-                norm_f(
-                    Conv2d(
-                        1024,
-                        1024,
-                        (kernel_size, 1),
-                        1,
-                        padding=(get_padding(kernel_size, 1), 0),
-                    )
-                ),
-            ]
-        )
-        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
-
-    def forward(self, x):
-        fmap = []
-
-        # 1d to 2d
-        b, c, t = x.shape
-        if t % self.period != 0:  # pad first
-            n_pad = self.period - (t % self.period)
-            x = F.pad(x, (0, n_pad), "reflect")
-            t = t + n_pad
-        x = x.view(b, c, t // self.period, self.period)
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class DiscriminatorS(nn.Module):
-    def __init__(self, use_spectral_norm=False):
-        super(DiscriminatorS, self).__init__()
-        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
-        self.convs = nn.ModuleList(
-            [
-                norm_f(Conv1d(1, 16, 15, 1, padding=7)),
-                norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
-                norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
-                norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
-                norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
-                norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
-            ]
-        )
-        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
-
-    def forward(self, x):
-        fmap = []
-
-        for l in self.convs:
-            x = l(x)
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            fmap.append(x)
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class EnsembleDiscriminator(nn.Module):
-    def __init__(self, use_spectral_norm=False):
-        super(EnsembleDiscriminator, self).__init__()
-        periods = [2, 3, 5, 7, 11]
-
-        discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
-        discs = discs + [
-            DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
-        ]
-        self.discriminators = nn.ModuleList(discs)
-
-    def forward(self, y, y_hat):
-        y_d_rs = []
-        y_d_gs = []
-        fmap_rs = []
-        fmap_gs = []
-        for i, d in enumerate(self.discriminators):
-            y_d_r, fmap_r = d(y)
-            y_d_g, fmap_g = d(y_hat)
-            y_d_rs.append(y_d_r)
-            y_d_gs.append(y_d_g)
-            fmap_rs.append(fmap_r)
-            fmap_gs.append(fmap_g)
-
-        return y_d_rs, y_d_gs, fmap_rs, fmap_gs

+ 37 - 0
fish_speech/models/vqgan/modules/condition.py

@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+
+
+class MultiCondLayer(nn.Module):
+    def __init__(
+        self,
+        gin_channels: int,
+        out_channels: int,
+        n_cond: int,
+    ):
+        """MultiCondLayer of VITS model.
+
+        Args:
+            gin_channels (int): Number of conditioning tensor channels.
+            out_channels (int): Number of output tensor channels.
+            n_cond (int): Number of conditions.
+        """
+        super().__init__()
+        self.n_cond = n_cond
+
+        self.cond_layers = nn.ModuleList()
+        for _ in range(n_cond):
+            self.cond_layers.append(nn.Linear(gin_channels, out_channels))
+
+    def forward(self, cond: torch.Tensor, x_mask: torch.Tensor):
+        """
+        Shapes:
+            - cond: :math:`[B, C, N]`
+            - x_mask: :math`[B, 1, T]`
+        """
+
+        cond_out = torch.zeros_like(cond)
+        for i in range(self.n_cond):
+            cond_in = self.cond_layers[i](cond.mT).mT
+            cond_out = cond_out + cond_in
+        return cond_out * x_mask

+ 227 - 0
fish_speech/models/vqgan/modules/decoder.py

@@ -0,0 +1,227 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
+
+from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
+from fish_speech.models.vqgan.utils import get_padding, init_weights
+
+
+class Generator(nn.Module):
+    def __init__(
+        self,
+        initial_channel,
+        resblock,
+        resblock_kernel_sizes,
+        resblock_dilation_sizes,
+        upsample_rates,
+        upsample_initial_channel,
+        upsample_kernel_sizes,
+        gin_channels=0,
+    ):
+        super(Generator, self).__init__()
+        self.num_kernels = len(resblock_kernel_sizes)
+        self.num_upsamples = len(upsample_rates)
+        self.conv_pre = nn.Conv1d(
+            initial_channel, upsample_initial_channel, 7, 1, padding=3
+        )
+        resblock = ResBlock1 if resblock == "1" else ResBlock2
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            self.ups.append(
+                weight_norm(
+                    nn.ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            for j, (k, d) in enumerate(
+                zip(resblock_kernel_sizes, resblock_dilation_sizes)
+            ):
+                self.resblocks.append(resblock(ch, k, d))
+
+        self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+        self.ups.apply(init_weights)
+
+        if gin_channels != 0:
+            self.cond = nn.Linear(gin_channels, upsample_initial_channel)
+
+    def forward(self, x, g=None):
+        x = self.conv_pre(x)
+        if g is not None:
+            x = x + self.cond(g.mT).mT
+
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.ups[i](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        print("Removing weight norm...")
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+
+
+class ResBlock1(nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock1, self).__init__()
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x, x_mask=None):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c2(xt)
+            x = xt + x
+        if x_mask is not None:
+            x = x * x_mask
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class ResBlock2(nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+        super(ResBlock2, self).__init__()
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    nn.Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+            ]
+        )
+        self.convs.apply(init_weights)
+
+    def forward(self, x, x_mask=None):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c(xt)
+            x = xt + x
+        if x_mask is not None:
+            x = x * x_mask
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)

+ 142 - 0
fish_speech/models/vqgan/modules/discriminator.py

@@ -0,0 +1,142 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.parametrizations import spectral_norm, weight_norm
+
+from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
+from fish_speech.models.vqgan.utils import get_padding
+
+
+class DiscriminatorP(nn.Module):
+    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+        super(DiscriminatorP, self).__init__()
+        self.period = period
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(
+                    nn.Conv2d(
+                        1,
+                        32,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    nn.Conv2d(
+                        32,
+                        128,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    nn.Conv2d(
+                        128,
+                        512,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    nn.Conv2d(
+                        512,
+                        1024,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    nn.Conv2d(
+                        1024,
+                        1024,
+                        (kernel_size, 1),
+                        1,
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+            ]
+        )
+        self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+    def forward(self, x):
+        fmap = []
+
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), "reflect")
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class DiscriminatorS(nn.Module):
+    def __init__(self, use_spectral_norm=False):
+        super(DiscriminatorS, self).__init__()
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)),
+                norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
+                norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
+                norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
+                norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+                norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
+            ]
+        )
+        self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
+
+    def forward(self, x):
+        fmap = []
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class EnsembleDiscriminator(nn.Module):
+    def __init__(self, use_spectral_norm=False):
+        super(EnsembleDiscriminator, self).__init__()
+        periods = [2, 3, 5, 7, 11]  # [1, 2, 3, 5, 7, 11]
+
+        discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
+        discs = discs + [
+            DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
+        ]
+        self.discriminators = nn.ModuleList(discs)
+
+    def forward(self, y, y_hat):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+        for i, d in enumerate(self.discriminators):
+            y_d_r, fmap_r = d(y)
+            y_d_g, fmap_g = d(y_hat)
+            y_d_rs.append(y_d_r)
+            y_d_gs.append(y_d_g)
+            fmap_rs.append(fmap_r)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs

+ 207 - 0
fish_speech/models/vqgan/modules/encoders.py

@@ -0,0 +1,207 @@
+import math
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+
+from fish_speech.models.vqgan.modules.modules import WN
+from fish_speech.models.vqgan.modules.transformer import RelativePositionTransformer
+from fish_speech.models.vqgan.utils import sequence_mask
+
+
+# * Ready and Tested
+class TextEncoder(nn.Module):
+    def __init__(
+        self,
+        n_vocab: int,
+        out_channels: int,
+        hidden_channels: int,
+        hidden_channels_ffn: int,
+        n_heads: int,
+        n_layers: int,
+        kernel_size: int,
+        dropout: float,
+        gin_channels=0,
+        lang_channels=0,
+        speaker_cond_layer=0,
+    ):
+        """Text Encoder for VITS model.
+
+        Args:
+            n_vocab (int): Number of characters for the embedding layer.
+            out_channels (int): Number of channels for the output.
+            hidden_channels (int): Number of channels for the hidden layers.
+            hidden_channels_ffn (int): Number of channels for the convolutional layers.
+            n_heads (int): Number of attention heads for the Transformer layers.
+            n_layers (int): Number of Transformer layers.
+            kernel_size (int): Kernel size for the FFN layers in Transformer network.
+            dropout (float): Dropout rate for the Transformer layers.
+            gin_channels (int, optional): Number of channels for speaker embedding. Defaults to 0.
+            lang_channels (int, optional): Number of channels for language embedding. Defaults to 0.
+        """
+        super().__init__()
+        self.out_channels = out_channels
+        self.hidden_channels = hidden_channels
+
+        # self.emb = nn.Linear(n_vocab, hidden_channels)
+        self.emb = nn.Linear(n_vocab, hidden_channels, 1)
+        # nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
+
+        self.encoder = RelativePositionTransformer(
+            in_channels=hidden_channels,
+            out_channels=hidden_channels,
+            hidden_channels=hidden_channels,
+            hidden_channels_ffn=hidden_channels_ffn,
+            n_heads=n_heads,
+            n_layers=n_layers,
+            kernel_size=kernel_size,
+            dropout=dropout,
+            window_size=4,
+            gin_channels=gin_channels,
+            lang_channels=lang_channels,
+            speaker_cond_layer=speaker_cond_layer,
+        )
+        self.proj = nn.Linear(hidden_channels, out_channels * 2)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_lengths: torch.Tensor,
+        g: torch.Tensor = None,
+        lang: torch.Tensor = None,
+    ):
+        """
+        Shapes:
+            - x: :math:`[B, T]`
+            - x_length: :math:`[B]`
+        """
+        # x = self.emb(x).mT * math.sqrt(self.hidden_channels)  # [b, h, t]
+        x = self.emb(x).mT  # * math.sqrt(self.hidden_channels)  # [b, h, t]
+        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+
+        x = self.encoder(x, x_mask, g=g, lang=lang)
+        stats = self.proj(x.mT).mT * x_mask
+
+        m, logs = torch.split(stats, self.out_channels, dim=1)
+        z = m + torch.randn_like(m) * torch.exp(logs) * x_mask
+        return z, m, logs, x, x_mask
+
+
+# * Ready and Tested
+class PosteriorEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        hidden_channels: int,
+        kernel_size: int,
+        dilation_rate: int,
+        n_layers: int,
+        gin_channels=0,
+    ):
+        """Posterior Encoder of VITS model.
+
+        ::
+            x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
+
+        Args:
+            in_channels (int): Number of input tensor channels.
+            out_channels (int): Number of output tensor channels.
+            hidden_channels (int): Number of hidden channels.
+            kernel_size (int): Kernel size of the WaveNet convolution layers.
+            dilation_rate (int): Dilation rate of the WaveNet layers.
+            num_layers (int): Number of the WaveNet layers.
+            cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
+        """
+        super().__init__()
+        self.out_channels = out_channels
+
+        self.pre = nn.Linear(in_channels, hidden_channels)
+        self.enc = WN(
+            hidden_channels,
+            kernel_size,
+            dilation_rate,
+            n_layers,
+            gin_channels=gin_channels,
+        )
+        self.proj = nn.Linear(hidden_channels, out_channels * 2)
+
+    def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g=None):
+        """
+        Shapes:
+            - x: :math:`[B, C, T]`
+            - x_lengths: :math:`[B, 1]`
+            - g: :math:`[B, C, 1]`
+        """
+        x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+        x = self.pre(x.mT).mT * x_mask
+        x = self.enc(x, x_mask, g=g)
+        stats = self.proj(x.mT).mT * x_mask
+        m, logs = torch.split(stats, self.out_channels, dim=1)
+        z = m + torch.randn_like(m) * torch.exp(logs) * x_mask
+        return z, m, logs, x_mask
+
+
+# TODO: Ready for testing
+class SpeakerEncoder(nn.Module):
+    def __init__(
+        self,
+        in_channels: int = 128,
+        hidden_channels: int = 192,
+        out_channels: int = 512,
+        num_heads: int = 2,
+        num_layers: int = 4,
+        p_dropout: float = 0.0,
+    ) -> None:
+        super().__init__()
+
+        self.query = nn.Parameter(torch.randn(1, 1, hidden_channels))
+        self.in_proj = nn.Sequential(
+            nn.Conv1d(in_channels, hidden_channels, 1),
+            nn.SiLU(),
+            nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
+            nn.SiLU(),
+            nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
+            nn.SiLU(),
+            nn.Dropout(p_dropout),
+        )
+
+        self.blocks = nn.ModuleList(
+            [
+                nn.MultiheadAttention(
+                    embed_dim=hidden_channels,
+                    num_heads=num_heads,
+                    dropout=p_dropout,
+                    batch_first=True,
+                )
+                for _ in range(num_layers)
+            ]
+        )
+        self.out_proj = nn.Linear(hidden_channels, out_channels)
+
+    def forward(self, mels, mel_lengths: torch.Tensor):
+        """
+        Shapes:
+            - x: :math:`[B, C, T]`
+            - x_lengths: :math:`[B, 1]`
+        """
+
+        x_mask = ~(sequence_mask(mel_lengths, mels.size(2)).bool())
+
+        x = self.in_proj(mels).transpose(1, 2)
+        x = torch.cat([self.query.expand(x.shape[0], -1, -1), x], dim=1)
+
+        x_mask = torch.cat(
+            [
+                torch.zeros(x.shape[0], 1, dtype=torch.bool, device=x.device),
+                x_mask,
+            ],
+            dim=1,
+        )
+
+        for block in self.blocks:
+            x = block(x, x, x, key_padding_mask=x_mask)[0]
+
+        x = self.out_proj(x[:, 0])
+
+        return x

+ 115 - 0
fish_speech/models/vqgan/modules/models.py

@@ -0,0 +1,115 @@
+import torch
+from torch import nn
+
+from fish_speech.models.vqgan.modules.decoder import Generator
+from fish_speech.models.vqgan.modules.encoders import (
+    PosteriorEncoder,
+    SpeakerEncoder,
+    TextEncoder,
+)
+from fish_speech.models.vqgan.utils import rand_slice_segments
+
+
+class SynthesizerTrn(nn.Module):
+    """
+    Synthesizer for Training
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        spec_channels,
+        segment_size,
+        inter_channels,
+        hidden_channels,
+        filter_channels,
+        n_heads,
+        n_layers,
+        n_layers_q,
+        n_layers_spk,
+        kernel_size,
+        p_dropout,
+        speaker_cond_layer,
+        resblock,
+        resblock_kernel_sizes,
+        resblock_dilation_sizes,
+        upsample_rates,
+        upsample_initial_channel,
+        upsample_kernel_sizes,
+        gin_channels,
+    ):
+        super().__init__()
+
+        self.segment_size = segment_size
+
+        self.enc_p = TextEncoder(
+            in_channels,
+            inter_channels,
+            hidden_channels,
+            filter_channels,
+            n_heads,
+            n_layers,
+            kernel_size,
+            p_dropout,
+            gin_channels=gin_channels,
+            speaker_cond_layer=speaker_cond_layer,
+        )
+        self.enc_spk = SpeakerEncoder(
+            in_channels=spec_channels,
+            hidden_channels=inter_channels,
+            out_channels=gin_channels,
+            num_heads=n_heads,
+            num_layers=n_layers_spk,
+            p_dropout=p_dropout,
+        )
+        self.enc_q = PosteriorEncoder(
+            spec_channels,
+            inter_channels,
+            hidden_channels,
+            5,
+            1,
+            n_layers_q,
+            gin_channels=gin_channels,
+        )
+        self.dec = Generator(
+            inter_channels,
+            resblock,
+            resblock_kernel_sizes,
+            resblock_dilation_sizes,
+            upsample_rates,
+            upsample_initial_channel,
+            upsample_kernel_sizes,
+            gin_channels=gin_channels,
+        )
+
+    def forward(self, x, x_lengths, y):
+        g = self.enc_spk(y, x_lengths)
+        z_p, m_p, logs_p, _, x_mask = self.enc_p(x, x_lengths, g=g)
+        z_q, m_q, logs_q, y_mask = self.enc_q(y, x_lengths, g=g)
+        z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
+        o = self.dec(z_slice, g=g)
+
+        return (
+            o,
+            ids_slice,
+            x_mask,
+            y_mask,
+            (z_q, z_p),
+            (m_p, logs_p),
+            (m_q, logs_q),
+        )
+
+    def infer(self, x, x_lengths, y, max_len=None):
+        g = self.enc_spk(y, x_lengths)
+        z_p, m_p, logs_p, h_text, x_mask = self.enc_p(x, x_lengths, g=g)
+        # z_p_audio, m_p_audio, logs_p_audio = self.flow(z_p_text, m_p_text, logs_p_text, x_mask, g=g, reverse=True)
+
+        o = self.dec((z_p * x_mask)[:, :, :max_len], g=g)
+        return o
+
+    def reconstruct(self, x, x_lengths, max_len=None):
+        g = self.enc_spk(x, x_lengths)
+        z_q, m_q, logs_q, x_mask = self.enc_q(x, x_lengths, g=g)
+        o = self.dec((z_q * x_mask)[:, :, :max_len], g=g)
+
+        return o

+ 105 - 0
fish_speech/models/vqgan/modules/modules.py

@@ -0,0 +1,105 @@
+import torch
+import torch.nn as nn
+from torch.nn.utils.parametrizations import weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
+
+from fish_speech.models.vqgan.utils import fused_add_tanh_sigmoid_multiply
+
+LRELU_SLOPE = 0.1
+
+
+# ! PosteriorEncoder
+# ! ResidualCouplingLayer
+class WN(nn.Module):
+    def __init__(
+        self,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        gin_channels=0,
+        p_dropout=0,
+    ):
+        super(WN, self).__init__()
+        assert kernel_size % 2 == 1
+        self.hidden_channels = hidden_channels
+        self.kernel_size = (kernel_size,)
+        self.n_layers = n_layers
+        self.gin_channels = gin_channels
+
+        self.in_layers = nn.ModuleList()
+        self.res_skip_layers = nn.ModuleList()
+        self.drop = nn.Dropout(p_dropout)
+
+        if gin_channels != 0:
+            cond_layer = nn.Linear(gin_channels, 2 * hidden_channels * n_layers)
+            self.cond_layer = weight_norm(cond_layer, name="weight")
+
+        for i in range(n_layers):
+            dilation = dilation_rate**i
+            padding = int((kernel_size * dilation - dilation) / 2)
+            in_layer = nn.Conv1d(
+                hidden_channels,
+                2 * hidden_channels,
+                kernel_size,
+                dilation=dilation,
+                padding=padding,
+            )
+            in_layer = weight_norm(in_layer, name="weight")
+            self.in_layers.append(in_layer)
+
+            # last one is not necessary
+            res_skip_channels = (
+                2 * hidden_channels if i < n_layers - 1 else hidden_channels
+            )
+            res_skip_layer = nn.Linear(hidden_channels, res_skip_channels)
+            res_skip_layer = weight_norm(res_skip_layer, name="weight")
+            self.res_skip_layers.append(res_skip_layer)
+
+    def forward(self, x, x_mask, g=None, **kwargs):
+        output = torch.zeros_like(x)
+        n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+        if g is not None:
+            g = self.cond_layer(g.mT).mT
+
+        for i in range(self.n_layers):
+            x_in = self.in_layers[i](x)
+            if g is not None:
+                cond_offset = i * 2 * self.hidden_channels
+                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
+            else:
+                g_l = torch.zeros_like(x_in)
+
+            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+            acts = self.drop(acts)
+
+            res_skip_acts = self.res_skip_layers[i](acts.mT).mT
+            if i < self.n_layers - 1:
+                res_acts = res_skip_acts[:, : self.hidden_channels, :]
+                x = (x + res_acts) * x_mask
+                output = output + res_skip_acts[:, self.hidden_channels :, :]
+            else:
+                output = output + res_skip_acts
+        return output * x_mask
+
+    def remove_weight_norm(self):
+        if self.gin_channels != 0:
+            remove_parametrizations(self.cond_layer)
+        for l in self.in_layers:
+            remove_parametrizations(l)
+        for l in self.res_skip_layers:
+            remove_parametrizations(l)
+
+
+# ! StochasticDurationPredictor
+# ! ResidualCouplingBlock
+# TODO convert to class method
+class Flip(nn.Module):
+    def forward(self, x, *args, reverse=False, **kwargs):
+        x = torch.flip(x, [1])
+        if not reverse:
+            logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+            return x, logdet
+        else:
+            return x

+ 34 - 0
fish_speech/models/vqgan/modules/normalization.py

@@ -0,0 +1,34 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class LayerNorm(nn.Module):
+    def __init__(self, channels, eps=1e-5):
+        super().__init__()
+        self.channels = channels
+        self.eps = eps
+
+        self.gamma = nn.Parameter(torch.ones(channels))
+        self.beta = nn.Parameter(torch.zeros(channels))
+
+    def forward(self, x: torch.Tensor):
+        x = F.layer_norm(x.mT, (self.channels,), self.gamma, self.beta, self.eps)
+        return x.mT
+
+
+class CondLayerNorm(nn.Module):
+    def __init__(self, channels, eps=1e-5, cond_channels=0):
+        super().__init__()
+        self.channels = channels
+        self.eps = eps
+
+        self.linear_gamma = nn.Linear(cond_channels, channels)
+        self.linear_beta = nn.Linear(cond_channels, channels)
+
+    def forward(self, x: torch.Tensor, cond: torch.Tensor):
+        gamma = self.linear_gamma(cond)
+        beta = self.linear_beta(cond)
+
+        x = F.layer_norm(x.mT, (self.channels,), gamma, beta, self.eps)
+        return x.mT

+ 324 - 0
fish_speech/models/vqgan/modules/transformer.py

@@ -0,0 +1,324 @@
+import math
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from fish_speech.models.vqgan.modules.normalization import LayerNorm
+from fish_speech.models.vqgan.utils import convert_pad_shape
+
+
+# TODO add conditioning on language
+# TODO check whether we need to stop gradient for speaker embedding
+class RelativePositionTransformer(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        hidden_channels: int,
+        out_channels: int,
+        hidden_channels_ffn: int,
+        n_heads: int,
+        n_layers: int,
+        kernel_size=1,
+        dropout=0.0,
+        window_size=4,
+        gin_channels=0,
+        lang_channels=0,
+        speaker_cond_layer=0,
+    ):
+        super().__init__()
+        self.n_layers = n_layers
+        self.speaker_cond_layer = speaker_cond_layer
+
+        self.drop = nn.Dropout(dropout)
+        self.attn_layers = nn.ModuleList()
+        self.norm_layers_1 = nn.ModuleList()
+        self.ffn_layers = nn.ModuleList()
+        self.norm_layers_2 = nn.ModuleList()
+        for i in range(self.n_layers):
+            self.attn_layers.append(
+                MultiHeadAttention(
+                    hidden_channels if i != 0 else in_channels,
+                    hidden_channels,
+                    n_heads,
+                    p_dropout=dropout,
+                    window_size=window_size,
+                )
+            )
+            self.norm_layers_1.append(LayerNorm(hidden_channels))
+            self.ffn_layers.append(
+                FFN(
+                    hidden_channels,
+                    hidden_channels,
+                    hidden_channels_ffn,
+                    kernel_size,
+                    p_dropout=dropout,
+                )
+            )
+            self.norm_layers_2.append(LayerNorm(hidden_channels))
+        if gin_channels != 0:
+            self.cond = nn.Linear(gin_channels, hidden_channels)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_mask: torch.Tensor,
+        g: torch.Tensor = None,
+        lang: torch.Tensor = None,
+    ):
+        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+        x = x * x_mask
+        for i in range(self.n_layers):
+            # TODO consider using other conditioning
+            # TODO https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/modules/attentions.py#L12
+            if i == self.speaker_cond_layer - 1 and g is not None:
+                # ! g = torch.detach(g)
+                x = x + self.cond(g.mT).mT
+                x = x * x_mask
+            y = self.attn_layers[i](x, x, attn_mask)
+            y = self.drop(y)
+            x = self.norm_layers_1[i](x + y)
+
+            y = self.ffn_layers[i](x, x_mask)
+            y = self.drop(y)
+            x = self.norm_layers_2[i](x + y)
+        x = x * x_mask
+        return x
+
+
+class MultiHeadAttention(nn.Module):
+    def __init__(
+        self,
+        channels,
+        out_channels,
+        n_heads,
+        p_dropout=0.0,
+        window_size=None,
+        heads_share=True,
+        block_length=None,
+        proximal_bias=False,
+        proximal_init=False,
+    ):
+        super().__init__()
+        assert channels % n_heads == 0
+
+        self.channels = channels
+        self.out_channels = out_channels
+        self.n_heads = n_heads
+        self.p_dropout = p_dropout
+        self.window_size = window_size
+        self.heads_share = heads_share
+        self.block_length = block_length
+        self.proximal_bias = proximal_bias
+        self.proximal_init = proximal_init
+        self.attn = None
+
+        self.k_channels = channels // n_heads
+        self.conv_q = nn.Linear(channels, channels)
+        self.conv_k = nn.Linear(channels, channels)
+        self.conv_v = nn.Linear(channels, channels)
+        self.conv_o = nn.Linear(channels, out_channels)
+        self.drop = nn.Dropout(p_dropout)
+
+        if window_size is not None:
+            n_heads_rel = 1 if heads_share else n_heads
+            rel_stddev = self.k_channels**-0.5
+            self.emb_rel_k = nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+            self.emb_rel_v = nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+
+        nn.init.xavier_uniform_(self.conv_q.weight)
+        nn.init.xavier_uniform_(self.conv_k.weight)
+        nn.init.xavier_uniform_(self.conv_v.weight)
+        if proximal_init:
+            with torch.no_grad():
+                self.conv_k.weight.copy_(self.conv_q.weight)
+                self.conv_k.bias.copy_(self.conv_q.bias)
+
+    def forward(self, x, c, attn_mask=None):
+        q = self.conv_q(x.mT).mT
+        k = self.conv_k(c.mT).mT
+        v = self.conv_v(c.mT).mT
+
+        x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+        x = self.conv_o(x.mT).mT
+        return x
+
+    def attention(self, query, key, value, mask=None):
+        # reshape [b, d, t] -> [b, n_h, t, d_k]
+        b, d, t_s, t_t = (*key.size(), query.size(2))
+        query = query.view(b, self.n_heads, self.k_channels, t_t).mT
+        key = key.view(b, self.n_heads, self.k_channels, t_s).mT
+        value = value.view(b, self.n_heads, self.k_channels, t_s).mT
+
+        scores = torch.matmul(query / math.sqrt(self.k_channels), key.mT)
+        if self.window_size is not None:
+            assert (
+                t_s == t_t
+            ), "Relative attention is only available for self-attention."
+            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+            rel_logits = self._matmul_with_relative_keys(
+                query / math.sqrt(self.k_channels), key_relative_embeddings
+            )
+            scores_local = self._relative_position_to_absolute_position(rel_logits)
+            scores = scores + scores_local
+        if self.proximal_bias:
+            assert t_s == t_t, "Proximal bias is only available for self-attention."
+            scores = scores + self._attention_bias_proximal(t_s).to(
+                device=scores.device, dtype=scores.dtype
+            )
+        if mask is not None:
+            scores = scores.masked_fill(mask == 0, -1e4)
+            if self.block_length is not None:
+                assert (
+                    t_s == t_t
+                ), "Local attention is only available for self-attention."
+                block_mask = (
+                    torch.ones_like(scores)
+                    .triu(-self.block_length)
+                    .tril(self.block_length)
+                )
+                scores = scores.masked_fill(block_mask == 0, -1e4)
+        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
+        p_attn = self.drop(p_attn)
+        output = torch.matmul(p_attn, value)
+        if self.window_size is not None:
+            relative_weights = self._absolute_position_to_relative_position(p_attn)
+            value_relative_embeddings = self._get_relative_embeddings(
+                self.emb_rel_v, t_s
+            )
+            output = output + self._matmul_with_relative_values(
+                relative_weights, value_relative_embeddings
+            )
+        output = output.mT.contiguous().view(
+            b, d, t_t
+        )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
+        return output, p_attn
+
+    def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor):
+        """
+        x: [b, h, l, m]
+        y: [h or 1, m, d]
+        ret: [b, h, l, d]
+        """
+        return torch.matmul(x, y.unsqueeze(0))
+
+    def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor):
+        """
+        x: [b, h, l, d]
+        y: [h or 1, m, d]
+        ret: [b, h, l, m]
+        """
+        return torch.matmul(x, y.unsqueeze(0).mT)
+
+    def _get_relative_embeddings(self, relative_embeddings, length):
+        max_relative_position = 2 * self.window_size + 1
+        # Pad first before slice to avoid using cond ops.
+        pad_length = max(length - (self.window_size + 1), 0)
+        slice_start_position = max((self.window_size + 1) - length, 0)
+        slice_end_position = slice_start_position + 2 * length - 1
+        if pad_length > 0:
+            padded_relative_embeddings = F.pad(
+                relative_embeddings,
+                convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+            )
+        else:
+            padded_relative_embeddings = relative_embeddings
+        used_relative_embeddings = padded_relative_embeddings[
+            :, slice_start_position:slice_end_position
+        ]
+        return used_relative_embeddings
+
+    def _relative_position_to_absolute_position(self, x):
+        """
+        x: [b, h, l, 2*l-1]
+        ret: [b, h, l, l]
+        """
+        batch, heads, length, _ = x.size()
+        # Concat columns of pad to shift from relative to absolute indexing.
+        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+        # Concat extra elements so to add up to shape (len+1, 2*len-1).
+        x_flat = x.view([batch, heads, length * 2 * length])
+        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
+
+        # Reshape and slice out the padded elements.
+        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+            :, :, :length, length - 1 :
+        ]
+        return x_final
+
+    def _absolute_position_to_relative_position(self, x):
+        """
+        x: [b, h, l, l]
+        ret: [b, h, l, 2*l-1]
+        """
+        batch, heads, length, _ = x.size()
+        # pad along column
+        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
+        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+        # add 0's in the beginning that will skew the elements after reshape
+        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+        return x_final
+
+    def _attention_bias_proximal(self, length):
+        """Bias for self-attention to encourage attention to close positions.
+        Args:
+          length: an integer scalar.
+        Returns:
+          a Tensor with shape [1, 1, length, length]
+        """
+        r = torch.arange(length, dtype=torch.float32)
+        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        filter_channels,
+        kernel_size,
+        p_dropout=0.0,
+        causal=False,
+    ):
+        super().__init__()
+        self.kernel_size = kernel_size
+        self.padding = self._causal_padding if causal else self._same_padding
+
+        self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
+        self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
+        self.drop = nn.Dropout(p_dropout)
+
+    def forward(self, x, x_mask):
+        x = self.conv_1(self.padding(x * x_mask))
+        x = torch.relu(x)
+        x = self.drop(x)
+        x = self.conv_2(self.padding(x * x_mask))
+        return x * x_mask
+
+    def _causal_padding(self, x):
+        if self.kernel_size == 1:
+            return x
+        pad_l = self.kernel_size - 1
+        pad_r = 0
+        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+        x = F.pad(x, convert_pad_shape(padding))
+        return x
+
+    def _same_padding(self, x):
+        if self.kernel_size == 1:
+            return x
+        pad_l = (self.kernel_size - 1) // 2
+        pad_r = self.kernel_size // 2
+        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
+        x = F.pad(x, convert_pad_shape(padding))
+        return x

+ 21 - 1
fish_speech/models/vqgan/utils.py

@@ -15,7 +15,7 @@ def sequence_mask(length, max_length=None):
     if max_length is None:
     if max_length is None:
         max_length = length.max()
         max_length = length.max()
     x = torch.arange(max_length, dtype=length.dtype, device=length.device)
     x = torch.arange(max_length, dtype=length.dtype, device=length.device)
-    return x.unsqueeze(0) >= length.unsqueeze(1)
+    return x.unsqueeze(0) < length.unsqueeze(1)
 
 
 
 
 def init_weights(m, mean=0.0, std=0.01):
 def init_weights(m, mean=0.0, std=0.01):
@@ -52,6 +52,26 @@ def plot_mel(data, titles=None):
     return fig
     return fig
 
 
 
 
+def slice_segments(x, ids_str, segment_size=4):
+    ret = torch.zeros_like(x[:, :, :segment_size])
+    for i in range(x.size(0)):
+        idx_str = ids_str[i]
+        idx_end = idx_str + segment_size
+        ret[i] = x[i, :, idx_str:idx_end]
+
+    return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+    b, d, t = x.size()
+    if x_lengths is None:
+        x_lengths = t
+    ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
+    ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+    ret = slice_segments(x, ids_str, segment_size)
+    return ret, ids_str
+
+
 @torch.jit.script
 @torch.jit.script
 def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
 def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
     n_channels_int = n_channels[0]
     n_channels_int = n_channels[0]

+ 6 - 3
tools/vqgan/migrate_from_vits.py

@@ -35,11 +35,14 @@ def main(cfg: DictConfig):
 
 
     # Posterior Encoder
     # Posterior Encoder
     encoder_state = {
     encoder_state = {
-        k[6:]: v for k, v in generator_weights.items() if k.startswith("enc_q.")
+        k[6:]: v
+        for k, v in generator_weights.items()
+        if k.startswith("enc_q.")
+        if not k.startswith("enc_q.proj.")
     }
     }
     logger.info(f"Found {len(encoder_state)} posterior encoder weights, restoring...")
     logger.info(f"Found {len(encoder_state)} posterior encoder weights, restoring...")
-    model.posterior_encoder.load_state_dict(encoder_state, strict=True)
-    logger.info("Posterior encoder weights restored.")
+    x = model.posterior_encoder.load_state_dict(encoder_state, strict=False)
+    logger.info(f"Posterior encoder weights restored. {x}")
 
 
     # Flow
     # Flow
     # flow_state = {
     # flow_state = {