Kaynağa Gözat

Implement mel direct generation

Lengyue 2 yıl önce
ebeveyn
işleme
a5699b169d

+ 31 - 52
fish_speech/configs/vqgan_pretrain_v2.yaml

@@ -3,16 +3,14 @@ defaults:
   - _self_
   - _self_
 
 
 project: vqgan_pretrain_v2
 project: vqgan_pretrain_v2
-ckpt_path: checkpoints/hifigan-base-comb-mix-lb-020/step_001200000_weights_only.ckpt
-resume_weights_only: true
 
 
 # Lightning Trainer
 # Lightning Trainer
 trainer:
 trainer:
   accelerator: gpu
   accelerator: gpu
   devices: auto
   devices: auto
   strategy: ddp_find_unused_parameters_true
   strategy: ddp_find_unused_parameters_true
-  precision: 32
-  max_steps: 1_000_000
+  precision: bf16-mixed
+  max_steps: 10_000_000
   val_check_interval: 5000
   val_check_interval: 5000
 
 
 sample_rate: 44100
 sample_rate: 44100
@@ -63,38 +61,31 @@ model:
   _target_: fish_speech.models.vqgan.VQGAN
   _target_: fish_speech.models.vqgan.VQGAN
   sample_rate: ${sample_rate}
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   hop_length: ${hop_length}
-  segment_size: 32768
-  mode: pretrain
-  freeze_discriminator: true
-
-  downsample:
-    _target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
-    dims: ["${num_mels}", 512, 256]
-    kernel_sizes: [3, 3]
-    strides: [2, 2]
-
-  mel_encoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WN
-    hidden_channels: 256
+
+  encoder:
+    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
+    hidden_channels: 384
     kernel_size: 3
     kernel_size: 3
     dilation_rate: 2
     dilation_rate: 2
-    n_layers: 12
+    n_layers: 10
+    in_channels: ${num_mels}
 
 
-  vq_encoder:
+  vq:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
-    in_channels: 256
-    vq_channels: 256
+    in_channels: 384
+    vq_channels: 384
     codebook_size: 256
     codebook_size: 256
-    codebook_groups: 4
-    downsample: 1
+    codebook_groups: 2
+    codebook_layers: 2
+    downsample: 4
 
 
   decoder:
   decoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WN
-    hidden_channels: 256
-    out_channels: ${num_mels}
+    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
+    hidden_channels: 384
     kernel_size: 3
     kernel_size: 3
     dilation_rate: 2
     dilation_rate: 2
-    n_layers: 6
+    n_layers: 10
+    out_channels: ${num_mels}
 
 
   generator:
   generator:
     _target_: fish_speech.models.vqgan.modules.decoder_v2.HiFiGANGenerator
     _target_: fish_speech.models.vqgan.modules.decoder_v2.HiFiGANGenerator
@@ -108,27 +99,16 @@ model:
     use_template: true
     use_template: true
     pre_conv_kernel_size: 7
     pre_conv_kernel_size: 7
     post_conv_kernel_size: 7
     post_conv_kernel_size: 7
+    ckpt_path: checkpoints/hifigan-base-comb-mix-lb-020/step_001200000_weights_only.ckpt
+
+  discriminator:
+    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
+    hidden_channels: 256
+    kernel_size: 3
+    dilation_rate: 2
+    n_layers: 6
+    in_channels: ${num_mels}
 
 
-  discriminators:
-    _target_: torch.nn.ModuleDict
-    modules:
-      mpd:
-        _target_: fish_speech.models.vqgan.modules.discriminators.mpd.MultiPeriodDiscriminator
-        periods: [2, 3, 5, 7, 11, 17, 23, 37]
-
-      mrd:
-        _target_: fish_speech.models.vqgan.modules.discriminators.mrd.MultiResolutionDiscriminator
-        resolutions:
-          - ["${n_fft}", "${hop_length}", "${win_length}"]
-          - [1024, 120, 600]
-          - [2048, 240, 1200]
-          - [4096, 480, 2400]
-          - [512, 50, 240]
-
-  multi_resolution_stft_loss:
-    _target_: fish_speech.models.vqgan.losses.MultiResolutionSTFTLoss
-    resolutions: ${model.discriminators.modules.mrd.resolutions}
-  
   mel_transform:
   mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     sample_rate: ${sample_rate}
     sample_rate: ${sample_rate}
@@ -140,7 +120,7 @@ model:
   optimizer:
   optimizer:
     _target_: torch.optim.AdamW
     _target_: torch.optim.AdamW
     _partial_: true
     _partial_: true
-    lr: 1e-4
+    lr: 2e-4
     betas: [0.8, 0.99]
     betas: [0.8, 0.99]
     eps: 1e-5
     eps: 1e-5
 
 
@@ -152,8 +132,7 @@ model:
 callbacks:
 callbacks:
   grad_norm_monitor:
   grad_norm_monitor:
     sub_module: 
     sub_module: 
-      - generator
-      - discriminators
-      - mel_encoder
-      - vq_encoder
+      - encoder
+      - vq
       - decoder
       - decoder
+      - discriminator

+ 91 - 283
fish_speech/models/vqgan/lit_module.py

@@ -9,6 +9,7 @@ import wandb
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from matplotlib import pyplot as plt
 from torch import nn
 from torch import nn
+from torch.utils.checkpoint import checkpoint as gradient_checkpoint
 
 
 from fish_speech.models.vqgan.losses import (
 from fish_speech.models.vqgan.losses import (
     MultiResolutionSTFTLoss,
     MultiResolutionSTFTLoss,
@@ -16,19 +17,9 @@ from fish_speech.models.vqgan.losses import (
     feature_loss,
     feature_loss,
     generator_loss,
     generator_loss,
 )
 )
-from fish_speech.models.vqgan.modules.balancer import Balancer
-from fish_speech.models.vqgan.modules.decoder import Generator
-from fish_speech.models.vqgan.modules.encoders import (
-    ConvDownSampler,
-    TextEncoder,
-    VQEncoder,
-)
-from fish_speech.models.vqgan.utils import (
-    plot_mel,
-    rand_slice_segments,
-    sequence_mask,
-    slice_segments,
-)
+from fish_speech.models.vqgan.modules.convnext import ConvNeXt
+from fish_speech.models.vqgan.modules.encoders import VQEncoder
+from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
 
 
 
 
 @dataclass
 @dataclass
@@ -41,9 +32,7 @@ class VQEncodeResult:
 
 
 @dataclass
 @dataclass
 class VQDecodeResult:
 class VQDecodeResult:
-    audios: torch.Tensor
     mels: torch.Tensor
     mels: torch.Tensor
-    mel_lengths: torch.Tensor
 
 
 
 
 class VQGAN(L.LightningModule):
 class VQGAN(L.LightningModule):
@@ -51,19 +40,15 @@ class VQGAN(L.LightningModule):
         self,
         self,
         optimizer: Callable,
         optimizer: Callable,
         lr_scheduler: Callable,
         lr_scheduler: Callable,
-        downsample: ConvDownSampler,
-        vq_encoder: VQEncoder,
-        mel_encoder: TextEncoder,
-        decoder: TextEncoder,
-        generator: Generator,
-        discriminators: nn.ModuleDict,
+        encoder: ConvNeXt,
+        vq: VQEncoder,
+        decoder: ConvNeXt,
+        generator: nn.Module,
+        discriminator: ConvNeXt,
         mel_transform: nn.Module,
         mel_transform: nn.Module,
-        segment_size: int = 20480,
         hop_length: int = 640,
         hop_length: int = 640,
         sample_rate: int = 32000,
         sample_rate: int = 32000,
-        mode: Literal["pretrain", "finetune"] = "finetune",
         freeze_discriminator: bool = False,
         freeze_discriminator: bool = False,
-        multi_resolution_stft_loss: Optional[MultiResolutionSTFTLoss] = None,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
@@ -74,68 +59,41 @@ class VQGAN(L.LightningModule):
         self.optimizer_builder = optimizer
         self.optimizer_builder = optimizer
         self.lr_scheduler_builder = lr_scheduler
         self.lr_scheduler_builder = lr_scheduler
 
 
-        # Generator and discriminators
-        self.downsample = downsample
-        self.vq_encoder = vq_encoder
-        self.mel_encoder = mel_encoder
+        # Generator and discriminator
+        self.encoder = encoder
+        self.vq = vq
         self.decoder = decoder
         self.decoder = decoder
         self.generator = generator
         self.generator = generator
-        self.discriminators = discriminators
+        self.discriminator = discriminator
         self.mel_transform = mel_transform
         self.mel_transform = mel_transform
         self.freeze_discriminator = freeze_discriminator
         self.freeze_discriminator = freeze_discriminator
 
 
         # Crop length for saving memory
         # Crop length for saving memory
-        self.segment_size = segment_size
         self.hop_length = hop_length
         self.hop_length = hop_length
         self.sampling_rate = sample_rate
         self.sampling_rate = sample_rate
-        self.mode = mode
 
 
         # Disable automatic optimization
         # Disable automatic optimization
         self.automatic_optimization = False
         self.automatic_optimization = False
 
 
-        # Finetune: Train the VQ only
-        if self.mode == "finetune":
-            for p in self.vq_encoder.parameters():
-                p.requires_grad = False
-
-            for p in self.mel_encoder.parameters():
-                p.requires_grad = False
-
-            for p in self.downsample.parameters():
-                p.requires_grad = False
-
         if self.freeze_discriminator:
         if self.freeze_discriminator:
-            for p in self.discriminators.parameters():
+            for p in self.discriminator.parameters():
                 p.requires_grad = False
                 p.requires_grad = False
 
 
-        # Losses
-        self.multi_resolution_stft_loss = multi_resolution_stft_loss
-        loss_dict = {
-            "mel": 1,
-            "adv": 1,
-            "fm": 1,
-        }
-
-        if self.multi_resolution_stft_loss is not None:
-            loss_dict["stft"] = 1
-
-        self.balancer = Balancer(loss_dict)
+        # Freeze generator
+        for p in self.generator.parameters():
+            p.requires_grad = False
 
 
     def configure_optimizers(self):
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
         # Need two optimizers and two schedulers
-        components = [
-            self.downsample.parameters(),
-            self.vq_encoder.parameters(),
-            self.mel_encoder.parameters(),
-        ]
-
-        if self.decoder is not None:
-            components.append(self.decoder.parameters())
-
-        components.append(self.generator.parameters())
-        optimizer_generator = self.optimizer_builder(itertools.chain(*components))
+        optimizer_generator = self.optimizer_builder(
+            itertools.chain(
+                self.encoder.parameters(),
+                self.vq.parameters(),
+                self.decoder.parameters(),
+            )
+        )
         optimizer_discriminator = self.optimizer_builder(
         optimizer_discriminator = self.optimizer_builder(
-            self.discriminators.parameters()
+            self.discriminator.parameters()
         )
         )
 
 
         lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
         lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
@@ -171,13 +129,6 @@ class VQGAN(L.LightningModule):
         with torch.no_grad():
         with torch.no_grad():
             gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
             gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
 
 
-        if self.mode == "finetune":
-            # Disable gradient computation for VQ
-            torch.set_grad_enabled(False)
-            self.vq_encoder.eval()
-            self.mel_encoder.eval()
-            self.downsample.eval()
-
         mel_lengths = audio_lengths // self.hop_length
         mel_lengths = audio_lengths // self.hop_length
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
         mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
             gt_mels.dtype
             gt_mels.dtype
@@ -189,186 +140,80 @@ class VQGAN(L.LightningModule):
         if loss_vq.ndim > 1:
         if loss_vq.ndim > 1:
             loss_vq = loss_vq.mean()
             loss_vq = loss_vq.mean()
 
 
-        if self.mode == "finetune":
-            # Enable gradient computation
-            torch.set_grad_enabled(True)
-
-        decoded = self.decode(
-            indices=vq_result.indices if self.mode == "finetune" else None,
-            features=vq_result.features if self.mode == "pretrain" else None,
+        decoded_mels = self.decode(
+            indices=None,
+            features=vq_result.features,
             audio_lengths=audio_lengths,
             audio_lengths=audio_lengths,
-            mel_only=True,
-        )
-        decoded_mels = decoded.mels
-        input_mels = gt_mels if self.mode == "pretrain" else decoded_mels
+        ).mels
 
 
-        if self.segment_size is not None:
-            audios, ids_slice = rand_slice_segments(
-                audios, audio_lengths, self.segment_size
-            )
-            input_mels = slice_segments(
-                input_mels,
-                ids_slice // self.hop_length,
-                self.segment_size // self.hop_length,
-            )
-            sliced_gt_mels = slice_segments(
-                gt_mels,
-                ids_slice // self.hop_length,
-                self.segment_size // self.hop_length,
-            )
-            gen_mel_masks = slice_segments(
-                mel_masks,
-                ids_slice // self.hop_length,
-                self.segment_size // self.hop_length,
-            )
-        else:
-            sliced_gt_mels = gt_mels
-            gen_mel_masks = mel_masks
+        with torch.no_grad():
+            with torch.autocast(device_type=audios.device.type, enabled=False):
+                fake_audios = self.generator(decoded_mels.float())
 
 
-        fake_audios = self.generator(input_mels)
-        fake_audio_mels = self.mel_transform(fake_audios.squeeze(1))
         assert (
         assert (
             audios.shape == fake_audios.shape
             audios.shape == fake_audios.shape
         ), f"{audios.shape} != {fake_audios.shape}"
         ), f"{audios.shape} != {fake_audios.shape}"
 
 
-        # Multi-Resolution STFT Loss
-        if self.multi_resolution_stft_loss is not None:
-            with torch.autocast(device_type=audios.device.type, enabled=False):
-                sc_loss, mag_loss = self.multi_resolution_stft_loss(
-                    fake_audios.squeeze(1).float(), audios.squeeze(1).float()
-                )
-                loss_stft = sc_loss + mag_loss
-
         # Discriminator
         # Discriminator
         if self.freeze_discriminator is False:
         if self.freeze_discriminator is False:
-            loss_disc_all = []
-
-            for key, disc in self.discriminators.items():
-                scores, _ = disc(audios)
-                score_fakes, _ = disc(fake_audios.detach())
-
-                with torch.autocast(device_type=audios.device.type, enabled=False):
-                    loss_disc, _, _ = discriminator_loss(scores, score_fakes)
-
-                self.log(
-                    f"train/discriminator/{key}",
-                    loss_disc,
-                    on_step=True,
-                    on_epoch=False,
-                    prog_bar=False,
-                    logger=True,
-                    sync_dist=True,
-                )
+            scores = self.discriminator(gt_mels)
+            score_fakes = self.discriminator(decoded_mels.detach())
 
 
-                loss_disc_all.append(loss_disc)
-
-            loss_disc_all = torch.stack(loss_disc_all).mean()
+            with torch.autocast(device_type=audios.device.type, enabled=False):
+                loss_disc, _, _ = discriminator_loss([scores], [score_fakes])
 
 
             self.log(
             self.log(
-                "train/discriminator/loss",
-                loss_disc_all,
+                f"train/discriminator/loss",
+                loss_disc,
                 on_step=True,
                 on_step=True,
                 on_epoch=False,
                 on_epoch=False,
-                prog_bar=True,
+                prog_bar=False,
                 logger=True,
                 logger=True,
                 sync_dist=True,
                 sync_dist=True,
             )
             )
 
 
             optim_d.zero_grad()
             optim_d.zero_grad()
-            self.manual_backward(loss_disc_all)
+            self.manual_backward(loss_disc)
             self.clip_gradients(
             self.clip_gradients(
                 optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
                 optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
             )
             )
             optim_d.step()
             optim_d.step()
 
 
         # Adv Loss
         # Adv Loss
-        loss_adv_all = []
-        loss_fm_all = []
-
-        for key, disc in self.discriminators.items():
-            score_fakes, feat_fake = disc(fake_audios)
-
-            # Adversarial Loss
-            with torch.autocast(device_type=audios.device.type, enabled=False):
-                loss_fake, _ = generator_loss(score_fakes)
-
-            self.log(
-                f"train/generator/adv_{key}",
-                loss_fake,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
-
-            loss_adv_all.append(loss_fake)
-
-            # Feature Matching Loss
-            _, feat_real = disc(audios)
-
-            with torch.autocast(device_type=audios.device.type, enabled=False):
-                loss_fm = feature_loss(feat_real, feat_fake)
-
-            self.log(
-                f"train/generator/adv_fm_{key}",
-                loss_fm,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
-
-            loss_fm_all.append(loss_fm)
-
-        loss_adv_all = torch.stack(loss_adv_all).mean()
-        loss_fm_all = torch.stack(loss_fm_all).mean()
+        score_fakes = self.discriminator(decoded_mels)
 
 
+        # Adversarial Loss
         with torch.autocast(device_type=audios.device.type, enabled=False):
         with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
-            loss_mel = F.l1_loss(
-                sliced_gt_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
-            )
-
-            loss_dict = {
-                "mel": loss_mel,
-                "adv": loss_adv_all,
-                "fm": loss_fm_all,
-            }
+            loss_adv, _ = generator_loss([score_fakes])
 
 
-            if self.multi_resolution_stft_loss is not None:
-                loss_dict["stft"] = loss_stft
+        self.log(
+            f"train/generator/adv",
+            loss_adv,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
 
 
-            generator_out_grad = self.balancer.compute(
-                loss_dict,
-                fake_audios,
-            )
+        # Feature Matching Loss
+        score_gts = self.discriminator(gt_mels)
 
 
-            if self.mode == "pretrain":
-                loss_vq_all = loss_decoded_mel + loss_vq
+        with torch.autocast(device_type=audios.device.type, enabled=False):
+            loss_fm = feature_loss([score_gts], [score_fakes])
 
 
-        # Loss vq and loss decoded mel are only used in pretrain stage
-        if self.mode == "pretrain":
-            self.log(
-                "train/generator/loss_vq",
-                loss_vq,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
+        self.log(
+            f"train/generator/adv_fm",
+            loss_fm,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
 
 
-            self.log(
-                "train/generator/loss_decoded_mel",
-                loss_decoded_mel,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
+        with torch.autocast(device_type=audios.device.type, enabled=False):
+            loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
 
 
         self.log(
         self.log(
             "train/generator/loss_mel",
             "train/generator/loss_mel",
@@ -380,29 +225,20 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
             sync_dist=True,
         )
         )
 
 
-        if self.multi_resolution_stft_loss is not None:
-            self.log(
-                "train/generator/loss_stft",
-                loss_stft,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
-
         self.log(
         self.log(
-            "train/generator/loss_fm_all",
-            loss_fm_all,
+            "train/generator/loss_vq",
+            loss_vq,
             on_step=True,
             on_step=True,
             on_epoch=False,
             on_epoch=False,
             prog_bar=False,
             prog_bar=False,
             logger=True,
             logger=True,
             sync_dist=True,
             sync_dist=True,
         )
         )
+
+        loss = loss_mel * 20 + loss_vq + loss_adv + loss_fm
         self.log(
         self.log(
-            "train/generator/loss_adv_all",
-            loss_adv_all,
+            "train/generator/loss",
+            loss,
             on_step=True,
             on_step=True,
             on_epoch=False,
             on_epoch=False,
             prog_bar=False,
             prog_bar=False,
@@ -412,11 +248,7 @@ class VQGAN(L.LightningModule):
 
 
         optim_g.zero_grad()
         optim_g.zero_grad()
 
 
-        # Only backpropagate loss_vq_all in pretrain stage
-        if self.mode == "pretrain":
-            self.manual_backward(loss_vq_all, retain_graph=True)
-
-        self.manual_backward(fake_audios, gradient=generator_out_grad)
+        self.manual_backward(loss)
         self.clip_gradients(
         self.clip_gradients(
             optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
             optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
         )
         )
@@ -440,19 +272,11 @@ class VQGAN(L.LightningModule):
         )
         )
 
 
         vq_result = self.encode(audios, audio_lengths)
         vq_result = self.encode(audios, audio_lengths)
-        decoded = self.decode(
+        decoded_mels = self.decode(
             indices=vq_result.indices,
             indices=vq_result.indices,
             audio_lengths=audio_lengths,
             audio_lengths=audio_lengths,
-            mel_only=self.mode == "pretrain",
-        )
-
-        decoded_mels = decoded.mels
-
-        # Use gt mel as input for pretrain
-        if self.mode == "pretrain":
-            fake_audios = self.generator(gt_mels)
-        else:
-            fake_audios = decoded.audios
+        ).mels
+        fake_audios = self.generator(decoded_mels)
 
 
         fake_mels = self.mel_transform(fake_audios.squeeze(1))
         fake_mels = self.mel_transform(fake_audios.squeeze(1))
 
 
@@ -557,21 +381,25 @@ class VQGAN(L.LightningModule):
         with torch.no_grad():
         with torch.no_grad():
             features = self.mel_transform(audios, sample_rate=self.sampling_rate)
             features = self.mel_transform(audios, sample_rate=self.sampling_rate)
 
 
-        if self.downsample is not None:
-            features = self.downsample(features)
-
         feature_lengths = (
         feature_lengths = (
             audio_lengths
             audio_lengths
             / self.hop_length
             / self.hop_length
-            / (self.downsample.total_strides if self.downsample is not None else 1)
+            # / self.vq.downsample
         ).long()
         ).long()
 
 
+        # print(features.shape, feature_lengths.shape, torch.max(feature_lengths))
+
         feature_masks = torch.unsqueeze(
         feature_masks = torch.unsqueeze(
             sequence_mask(feature_lengths, features.shape[2]), 1
             sequence_mask(feature_lengths, features.shape[2]), 1
         ).to(features.dtype)
         ).to(features.dtype)
 
 
-        text_features = self.mel_encoder(features, feature_masks)
-        vq_features, indices, loss = self.vq_encoder(text_features, feature_masks)
+        features = (
+            gradient_checkpoint(
+                self.encoder, features, feature_masks, use_reentrant=False
+            )
+            * feature_masks
+        )
+        vq_features, indices, loss = self.vq(features, feature_masks)
 
 
         return VQEncodeResult(
         return VQEncodeResult(
             features=vq_features,
             features=vq_features,
@@ -581,18 +409,13 @@ class VQGAN(L.LightningModule):
         )
         )
 
 
     def calculate_audio_lengths(self, feature_lengths):
     def calculate_audio_lengths(self, feature_lengths):
-        return (
-            feature_lengths
-            * self.hop_length
-            * (self.downsample.total_strides if self.downsample is not None else 1)
-        )
+        return feature_lengths * self.hop_length * self.vq.downsample
 
 
     def decode(
     def decode(
         self,
         self,
         indices=None,
         indices=None,
         features=None,
         features=None,
         audio_lengths=None,
         audio_lengths=None,
-        mel_only=False,
         feature_lengths=None,
         feature_lengths=None,
     ):
     ):
         assert (
         assert (
@@ -611,26 +434,11 @@ class VQGAN(L.LightningModule):
         ).float()
         ).float()
 
 
         if indices is not None:
         if indices is not None:
-            features = self.vq_encoder.decode(indices)
-
-        features = F.interpolate(features, size=mel_masks.shape[2], mode="nearest")
+            features = self.vq.decode(indices)
 
 
         # Sample mels
         # Sample mels
-        if self.decoder is not None:
-            decoded_mels = self.decoder(features, mel_masks)
-        else:
-            decoded_mels = features
-
-        if mel_only:
-            return VQDecodeResult(
-                audios=None,
-                mels=decoded_mels,
-                mel_lengths=mel_lengths,
-            )
+        decoded = gradient_checkpoint(self.decoder, features, use_reentrant=False)
 
 
-        fake_audios = self.generator(decoded_mels)
         return VQDecodeResult(
         return VQDecodeResult(
-            audios=fake_audios,
-            mels=decoded_mels,
-            mel_lengths=mel_lengths,
+            mels=decoded,
         )
         )

+ 3 - 4
fish_speech/models/vqgan/losses.py

@@ -6,10 +6,9 @@ from torch import nn
 def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]):
 def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]):
     loss = 0
     loss = 0
     for dr, dg in zip(fmap_r, fmap_g):
     for dr, dg in zip(fmap_r, fmap_g):
-        for rl, gl in zip(dr, dg):
-            rl = rl.float().detach()
-            gl = gl.float()
-            loss += torch.mean(torch.abs(rl - gl))
+        dr = dr.float().detach()
+        dg = dg.float()
+        loss += torch.mean(torch.abs(dr - dg))
 
 
     return loss * 2
     return loss * 2
 
 

+ 0 - 37
fish_speech/models/vqgan/modules/condition.py

@@ -1,37 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-class MultiCondLayer(nn.Module):
-    def __init__(
-        self,
-        gin_channels: int,
-        out_channels: int,
-        n_cond: int,
-    ):
-        """MultiCondLayer of VITS model.
-
-        Args:
-            gin_channels (int): Number of conditioning tensor channels.
-            out_channels (int): Number of output tensor channels.
-            n_cond (int): Number of conditions.
-        """
-        super().__init__()
-        self.n_cond = n_cond
-
-        self.cond_layers = nn.ModuleList()
-        for _ in range(n_cond):
-            self.cond_layers.append(nn.Linear(gin_channels, out_channels))
-
-    def forward(self, cond: torch.Tensor, x_mask: torch.Tensor):
-        """
-        Shapes:
-            - cond: :math:`[B, C, N]`
-            - x_mask: :math`[B, 1, T]`
-        """
-
-        cond_out = torch.zeros_like(cond)
-        for i in range(self.n_cond):
-            cond_in = self.cond_layers[i](cond.mT).mT
-            cond_out = cond_out + cond_in
-        return cond_out * x_mask

+ 315 - 0
fish_speech/models/vqgan/modules/convnext.py

@@ -0,0 +1,315 @@
+from functools import partial
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from vocos.spectral_ops import ISTFT
+
+
+def drop_path(
+    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+    'survival rate' as the argument.
+
+    """  # noqa: E501
+
+    if drop_prob == 0.0 or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (
+        x.ndim - 1
+    )  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+    if keep_prob > 0.0 and scale_by_keep:
+        random_tensor.div_(keep_prob)
+    return x * random_tensor
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""  # noqa: E501
+
+    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+        self.scale_by_keep = scale_by_keep
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+    def extra_repr(self):
+        return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class LayerNorm(nn.Module):
+    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+    with shape (batch_size, channels, height, width).
+    """  # noqa: E501
+
+    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(normalized_shape))
+        self.bias = nn.Parameter(torch.zeros(normalized_shape))
+        self.eps = eps
+        self.data_format = data_format
+        if self.data_format not in ["channels_last", "channels_first"]:
+            raise NotImplementedError
+        self.normalized_shape = (normalized_shape,)
+
+    def forward(self, x):
+        if self.data_format == "channels_last":
+            return F.layer_norm(
+                x, self.normalized_shape, self.weight, self.bias, self.eps
+            )
+        elif self.data_format == "channels_first":
+            u = x.mean(1, keepdim=True)
+            s = (x - u).pow(2).mean(1, keepdim=True)
+            x = (x - u) / torch.sqrt(s + self.eps)
+            x = self.weight[:, None] * x + self.bias[:, None]
+            return x
+
+
+class ConvNeXtBlock(nn.Module):
+    r"""ConvNeXt Block. There are two equivalent implementations:
+    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+    We use (2) as we find it slightly faster in PyTorch
+
+    Args:
+        dim (int): Number of input channels.
+        drop_path (float): Stochastic depth rate. Default: 0.0
+        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+        kernel_size (int): Kernel size for depthwise conv. Default: 7.
+        dilation (int): Dilation for depthwise conv. Default: 1.
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        dim: int,
+        drop_path: float = 0.0,
+        layer_scale_init_value: float = 1e-6,
+        mlp_ratio: float = 4.0,
+        kernel_size: int = 7,
+        dilation: int = 1,
+    ):
+        super().__init__()
+
+        self.dwconv = nn.Conv1d(
+            dim,
+            dim,
+            kernel_size=kernel_size,
+            padding=int(dilation * (kernel_size - 1) / 2),
+            groups=dim,
+        )  # depthwise conv
+        self.norm = LayerNorm(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(
+            dim, int(mlp_ratio * dim)
+        )  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
+        self.gamma = (
+            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+            if layer_scale_init_value > 0
+            else None
+        )
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+    def forward(self, x, apply_residual: bool = True):
+        input = x
+
+        x = self.dwconv(x)
+        x = x.permute(0, 2, 1)  # (N, C, L) -> (N, L, C)
+        x = self.norm(x)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+
+        if self.gamma is not None:
+            x = self.gamma * x
+
+        x = x.permute(0, 2, 1)  # (N, L, C) -> (N, C, L)
+        x = self.drop_path(x)
+
+        if apply_residual:
+            x = input + x
+
+        return x
+
+
+class ParallelConvNeXtBlock(nn.Module):
+    def __init__(self, kernel_sizes: list[int], *args, **kwargs):
+        super().__init__()
+        self.blocks = nn.ModuleList(
+            [
+                ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
+                for kernel_size in kernel_sizes
+            ]
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return torch.stack(
+            [block(x, apply_residual=False) for block in self.blocks] + [x],
+            dim=1,
+        ).sum(dim=1)
+
+
+class ConvNeXt(nn.Module):
+    def __init__(
+        self,
+        input_channels: int = 3,
+        depths: list[int] = [3, 3, 9, 3],
+        dims: list[int] = [96, 192, 384, 768],
+        drop_path_rate: float = 0.0,
+        layer_scale_init_value: float = 1e-6,
+        kernel_sizes: tuple[int] = (7,),
+    ):
+        super().__init__()
+        assert len(depths) == len(dims)
+
+        self.channel_layers = nn.ModuleList()
+        stem = nn.Sequential(
+            nn.Conv1d(
+                input_channels,
+                dims[0],
+                kernel_size=7,
+                padding=3,
+                # padding_mode="replicate",
+                padding_mode="zeros",
+            ),
+            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+        )
+        self.channel_layers.append(stem)
+
+        for i in range(len(depths) - 1):
+            mid_layer = nn.Sequential(
+                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+                nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
+            )
+            self.channel_layers.append(mid_layer)
+
+        block_fn = (
+            partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
+            if len(kernel_sizes) == 1
+            else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
+        )
+
+        self.stages = nn.ModuleList()
+        drop_path_rates = [
+            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+        ]
+
+        cur = 0
+        for i in range(len(depths)):
+            stage = nn.Sequential(
+                *[
+                    block_fn(
+                        dim=dims[i],
+                        drop_path=drop_path_rates[cur + j],
+                        layer_scale_init_value=layer_scale_init_value,
+                    )
+                    for j in range(depths[i])
+                ]
+            )
+            self.stages.append(stage)
+            cur += depths[i]
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Conv1d, nn.Linear)):
+            nn.init.trunc_normal_(m.weight, std=0.02)
+            nn.init.constant_(m.bias, 0)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        return_features: bool = False,
+    ) -> torch.Tensor:
+        features = []
+
+        for channel_layer, stage in zip(self.channel_layers, self.stages):
+            x = channel_layer(x)
+            x = stage(x)
+
+            if return_features:
+                features.append(x)
+
+        if return_features:
+            return features
+
+        return x
+
+
+class ISTFTHead(nn.Module):
+    """
+    ISTFT Head module for predicting STFT complex coefficients.
+
+    Args:
+        dim (int): Hidden dimension of the model.
+        n_fft (int): Size of Fourier transform.
+        hop_length (int): The distance between neighboring sliding window frames, which should align with
+                          the resolution of the input features.
+        win_length (int): The size of window frame and STFT filter.
+        padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+    """  # noqa: E501
+
+    def __init__(
+        self,
+        dim: int,
+        n_fft: int,
+        hop_length: int,
+        win_length: int,
+        padding: str = "same",
+    ):
+        super().__init__()
+
+        self.n_fft = n_fft
+        self.hop_length = hop_length
+        self.win_length = win_length
+
+        self.istft = ISTFT(
+            n_fft=n_fft,
+            hop_length=hop_length,
+            win_length=win_length,
+            padding=padding,
+        )
+
+        out_dim = n_fft * 2
+        self.out = nn.Conv1d(dim, out_dim, 1)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Forward pass of the ISTFTHead module.
+
+        Args:
+            x (Tensor): Input tensor of shape (B, H, L), where B is the batch size,
+                        L is the sequence length, and H denotes the model dimension.
+
+        Returns:
+            Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
+        """  # noqa: E501
+
+        x = self.out(x)
+
+        mag, p = x.chunk(2, dim=1)
+        mag = torch.exp(mag)
+        mag = torch.clip(
+            mag, max=1e2
+        )  # safeguard to prevent excessively large magnitudes
+
+        # wrapping happens here. These two lines produce real and imaginary value
+        x = torch.cos(p)
+        y = torch.sin(p)
+
+        S = mag * (x + 1j * y)
+
+        x = self.istft(S)
+        return x.unsqueeze(1)

+ 0 - 270
fish_speech/models/vqgan/modules/decoder.py

@@ -1,270 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn.utils.parametrizations import weight_norm
-from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
-
-from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
-from fish_speech.models.vqgan.utils import get_padding, init_weights
-
-
-class Generator(nn.Module):
-    def __init__(
-        self,
-        initial_channel,
-        resblock,
-        resblock_kernel_sizes,
-        resblock_dilation_sizes,
-        upsample_rates,
-        upsample_initial_channel,
-        upsample_kernel_sizes,
-        gin_channels=0,
-        ckpt_path=None,
-    ):
-        super(Generator, self).__init__()
-        self.num_kernels = len(resblock_kernel_sizes)
-        self.num_upsamples = len(upsample_rates)
-        self.conv_pre = weight_norm(
-            nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
-        )
-        resblock = ResBlock1 if resblock == "1" else ResBlock2
-
-        self.ups = nn.ModuleList()
-        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
-            self.ups.append(
-                weight_norm(
-                    nn.ConvTranspose1d(
-                        upsample_initial_channel // (2**i),
-                        upsample_initial_channel // (2 ** (i + 1)),
-                        k,
-                        u,
-                        padding=(k - u) // 2,
-                    )
-                )
-            )
-
-        self.resblocks = nn.ModuleList()
-        for i in range(len(self.ups)):
-            ch = upsample_initial_channel // (2 ** (i + 1))
-            for j, (k, d) in enumerate(
-                zip(resblock_kernel_sizes, resblock_dilation_sizes)
-            ):
-                self.resblocks.append(resblock(ch, k, d))
-
-        self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
-        self.ups.apply(init_weights)
-
-        if gin_channels != 0:
-            self.cond = nn.Linear(gin_channels, upsample_initial_channel)
-
-        if ckpt_path is not None:
-            self.load_state_dict(torch.load(ckpt_path)["generator"], strict=True)
-
-    def forward(self, x, g=None):
-        x = self.conv_pre(x)
-        if g is not None:
-            x = x + self.cond(g.mT).mT
-
-        for i in range(self.num_upsamples):
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            x = self.ups[i](x)
-            xs = None
-            for j in range(self.num_kernels):
-                if xs is None:
-                    xs = self.resblocks[i * self.num_kernels + j](x)
-                else:
-                    xs += self.resblocks[i * self.num_kernels + j](x)
-            x = xs / self.num_kernels
-        x = F.leaky_relu(x)
-        x = self.conv_post(x)
-        x = torch.tanh(x)
-
-        return x
-
-    def remove_weight_norm(self):
-        print("Removing weight norm...")
-        for l in self.ups:
-            remove_weight_norm(l)
-        for l in self.resblocks:
-            l.remove_weight_norm()
-
-
-class ResBlock1(nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
-        super(ResBlock1, self).__init__()
-        self.convs1 = nn.ModuleList(
-            [
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[2],
-                        padding=get_padding(kernel_size, dilation[2]),
-                    )
-                ),
-            ]
-        )
-        self.convs1.apply(init_weights)
-
-        self.convs2 = nn.ModuleList(
-            [
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=1,
-                        padding=get_padding(kernel_size, 1),
-                    )
-                ),
-            ]
-        )
-        self.convs2.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c1, c2 in zip(self.convs1, self.convs2):
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c1(xt)
-            xt = F.leaky_relu(xt, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c2(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_weight_norm(self):
-        for l in self.convs1:
-            remove_weight_norm(l)
-        for l in self.convs2:
-            remove_weight_norm(l)
-
-
-class ResBlock2(nn.Module):
-    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
-        super(ResBlock2, self).__init__()
-        self.convs = nn.ModuleList(
-            [
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[0],
-                        padding=get_padding(kernel_size, dilation[0]),
-                    )
-                ),
-                weight_norm(
-                    nn.Conv1d(
-                        channels,
-                        channels,
-                        kernel_size,
-                        1,
-                        dilation=dilation[1],
-                        padding=get_padding(kernel_size, dilation[1]),
-                    )
-                ),
-            ]
-        )
-        self.convs.apply(init_weights)
-
-    def forward(self, x, x_mask=None):
-        for c in self.convs:
-            xt = F.leaky_relu(x, LRELU_SLOPE)
-            if x_mask is not None:
-                xt = xt * x_mask
-            xt = c(xt)
-            x = xt + x
-        if x_mask is not None:
-            x = x * x_mask
-        return x
-
-    def remove_weight_norm(self):
-        for l in self.convs:
-            remove_weight_norm(l)
-
-
-if __name__ == "__main__":
-    import librosa
-    import soundfile as sf
-
-    from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
-
-    gen = Generator(
-        80,
-        "1",
-        [3, 7, 11],
-        [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
-        [8, 8, 2, 2],
-        512,
-        [16, 16, 4, 4],
-        ckpt_path="checkpoints/hifigan-v1-universal-22050/g_02500000",
-    )
-
-    spec = LogMelSpectrogram(
-        sample_rate=22050,
-        n_fft=1024,
-        win_length=1024,
-        hop_length=256,
-        n_mels=80,
-        f_min=0.0,
-        f_max=8000.0,
-    )
-
-    audio = librosa.load("data/StarRail/Chinese/符玄/archive_fuxuan_9.wav", sr=22050)[0]
-    audio = torch.from_numpy(audio).unsqueeze(0)
-
-    spec = spec(audio)
-    print(spec.shape)
-
-    audio = gen(spec)
-    print(audio.shape)
-
-    sf.write("test.wav", audio.detach().squeeze().numpy(), 22050)

+ 45 - 0
fish_speech/models/vqgan/modules/decoder_v2.py

@@ -150,6 +150,7 @@ class HiFiGANGenerator(nn.Module):
         post_conv_kernel_size: int = 7,
         post_conv_kernel_size: int = 7,
         post_activation: Callable = partial(nn.SiLU, inplace=True),
         post_activation: Callable = partial(nn.SiLU, inplace=True),
         checkpointing: bool = False,
         checkpointing: bool = False,
+        ckpt_path: str = None,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
@@ -229,6 +230,17 @@ class HiFiGANGenerator(nn.Module):
         # Gradient checkpointing
         # Gradient checkpointing
         self.checkpointing = checkpointing
         self.checkpointing = checkpointing
 
 
+        if ckpt_path is not None:
+            states = torch.load(ckpt_path, map_location="cpu")
+            if "state_dict" in states:
+                states = states["state_dict"]
+            states = {
+                k.replace("generator.", ""): v
+                for k, v in states.items()
+                if k.startswith("generator")
+            }
+            self.load_state_dict(states, strict=True)
+
     def forward(self, x, template=None):
     def forward(self, x, template=None):
         if self.use_template and template is None:
         if self.use_template and template is None:
             length = x.shape[-1] * self.hop_length
             length = x.shape[-1] * self.hop_length
@@ -268,3 +280,36 @@ class HiFiGANGenerator(nn.Module):
             block.remove_parametrizations()
             block.remove_parametrizations()
         remove_parametrizations(self.conv_pre)
         remove_parametrizations(self.conv_pre)
         remove_parametrizations(self.conv_post)
         remove_parametrizations(self.conv_post)
+
+
+if __name__ == "__main__":
+    import torchaudio
+
+    from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
+
+    spec = LogMelSpectrogram(n_mels=160)
+    audio, sr = torchaudio.load("test.wav")
+    audio = audio[None, :]
+    spec = spec(audio, sample_rate=sr)
+
+    model = HiFiGANGenerator(
+        hop_length=512,
+        upsample_rates=(8, 8, 2, 2, 2),
+        upsample_kernel_sizes=(16, 16, 4, 4, 4),
+        resblock_kernel_sizes=(3, 7, 11),
+        resblock_dilation_sizes=((1, 3, 5), (1, 3, 5), (1, 3, 5)),
+        num_mels=160,
+        upsample_initial_channel=512,
+        use_template=True,
+        pre_conv_kernel_size=7,
+        post_conv_kernel_size=7,
+        post_activation=partial(nn.SiLU, inplace=True),
+        ckpt_path="checkpoints/hifigan-base-comb-mix-lb-020/step_001200000_weights_only.ckpt",
+    )
+
+    print(model)
+
+    out = model(spec)
+    print(out.shape)
+
+    torchaudio.save("out.wav", out[0], 44100)

+ 0 - 80
fish_speech/models/vqgan/modules/discriminators/mpd.py

@@ -1,80 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn.utils.parametrizations import weight_norm
-
-
-class DiscriminatorP(nn.Module):
-    def __init__(
-        self,
-        *,
-        period: int,
-        kernel_size: int = 5,
-        stride: int = 3,
-        channels: tuple[int] = (1, 64, 128, 256, 512, 1024),
-    ) -> None:
-        super(DiscriminatorP, self).__init__()
-
-        self.period = period
-        self.convs = nn.ModuleList(
-            [
-                weight_norm(
-                    nn.Conv2d(
-                        in_channels,
-                        out_channels,
-                        (kernel_size, 1),
-                        (stride, 1),
-                        padding=(kernel_size // 2, 0),
-                    )
-                )
-                for in_channels, out_channels in zip(channels[:-1], channels[1:])
-            ]
-        )
-
-        self.conv_post = weight_norm(
-            nn.Conv2d(channels[-1], 1, (3, 1), 1, padding=(1, 0))
-        )
-
-    def forward(self, x):
-        fmap = []
-
-        # 1d to 2d
-        b, c, t = x.shape
-        if t % self.period != 0:  # pad first
-            n_pad = self.period - (t % self.period)
-            x = F.pad(x, (0, n_pad), "constant")
-            t = t + n_pad
-        x = x.view(b, c, t // self.period, self.period)
-
-        for conv in self.convs:
-            x = conv(x)
-            x = F.silu(x, inplace=True)
-            fmap.append(x)
-
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-
-class MultiPeriodDiscriminator(nn.Module):
-    def __init__(self, periods: tuple[int] = (2, 3, 5, 7, 11)) -> None:
-        super().__init__()
-
-        self.discriminators = nn.ModuleList(
-            [DiscriminatorP(period=period) for period in periods]
-        )
-
-    def forward(
-        self, x: torch.Tensor
-    ) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]:
-        scores, feature_map = [], []
-
-        for disc in self.discriminators:
-            res, fmap = disc(x)
-
-            scores.append(res)
-            feature_map.append(fmap)
-
-        return scores, feature_map

+ 0 - 100
fish_speech/models/vqgan/modules/discriminators/mrd.py

@@ -1,100 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn.utils.parametrizations import weight_norm
-
-
-class DiscriminatorR(torch.nn.Module):
-    def __init__(
-        self,
-        *,
-        n_fft: int = 1024,
-        hop_length: int = 120,
-        win_length: int = 600,
-    ):
-        super(DiscriminatorR, self).__init__()
-
-        self.n_fft = n_fft
-        self.hop_length = hop_length
-        self.win_length = win_length
-
-        self.convs = nn.ModuleList(
-            [
-                weight_norm(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
-                weight_norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
-                weight_norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
-                weight_norm(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
-                weight_norm(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
-            ]
-        )
-
-        self.conv_post = weight_norm(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
-
-    def forward(self, x):
-        fmap = []
-
-        x = self.spectrogram(x)
-        x = x.unsqueeze(1)
-
-        for conv in self.convs:
-            x = conv(x)
-            x = F.silu(x, inplace=True)
-            fmap.append(x)
-
-        x = self.conv_post(x)
-        fmap.append(x)
-        x = torch.flatten(x, 1, -1)
-
-        return x, fmap
-
-    def spectrogram(self, x):
-        x = F.pad(
-            x,
-            (
-                (self.n_fft - self.hop_length) // 2,
-                (self.n_fft - self.hop_length + 1) // 2,
-            ),
-            mode="reflect",
-        )
-        x = x.squeeze(1)
-        x = torch.stft(
-            x,
-            n_fft=self.n_fft,
-            hop_length=self.hop_length,
-            win_length=self.win_length,
-            center=False,
-            return_complex=True,
-        )
-        x = torch.view_as_real(x)  # [B, F, TT, 2]
-        mag = torch.norm(x, p=2, dim=-1)  # [B, F, TT]
-
-        return mag
-
-
-class MultiResolutionDiscriminator(torch.nn.Module):
-    def __init__(self, resolutions: list[tuple[int]]):
-        super().__init__()
-
-        self.discriminators = nn.ModuleList(
-            [
-                DiscriminatorR(
-                    n_fft=n_fft,
-                    hop_length=hop_length,
-                    win_length=win_length,
-                )
-                for n_fft, hop_length, win_length in resolutions
-            ]
-        )
-
-    def forward(
-        self, x: torch.Tensor
-    ) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]:
-        scores, feature_map = [], []
-
-        for disc in self.discriminators:
-            res, fmap = disc(x)
-
-            scores.append(res)
-            feature_map.append(fmap)
-
-        return scores, feature_map

+ 0 - 188
fish_speech/models/vqgan/modules/discriminators/mssbcqtd.py

@@ -1,188 +0,0 @@
-# Copyright (c) 2023 Amphion.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-
-# Monkey patching to fix a bug in nnAudio
-import numpy as np
-import torch
-import torchaudio.transforms as T
-from einops import rearrange
-from nnAudio import features
-from torch import nn
-
-from .msstftd import NormConv2d, get_2d_padding
-
-np.float = float
-
-LRELU_SLOPE = 0.1
-
-
-class DiscriminatorCQT(nn.Module):
-    def __init__(
-        self,
-        hop_length,
-        n_octaves,
-        bins_per_octave,
-        filters=32,
-        max_filters=1024,
-        filters_scale=1,
-        dilations=[1, 2, 4],
-        in_channels=1,
-        out_channels=1,
-        sample_rate=16000,
-    ):
-        super().__init__()
-
-        self.filters = filters
-        self.max_filters = max_filters
-        self.filters_scale = filters_scale
-        self.kernel_size = (3, 9)
-        self.dilations = dilations
-        self.stride = (1, 2)
-
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.fs = sample_rate
-        self.hop_length = hop_length
-        self.n_octaves = n_octaves
-        self.bins_per_octave = bins_per_octave
-
-        self.cqt_transform = features.cqt.CQT2010v2(
-            sr=self.fs * 2,
-            hop_length=self.hop_length,
-            n_bins=self.bins_per_octave * self.n_octaves,
-            bins_per_octave=self.bins_per_octave,
-            output_format="Complex",
-            pad_mode="constant",
-        )
-
-        self.conv_pres = nn.ModuleList()
-        for i in range(self.n_octaves):
-            self.conv_pres.append(
-                NormConv2d(
-                    self.in_channels * 2,
-                    self.in_channels * 2,
-                    kernel_size=self.kernel_size,
-                    padding=get_2d_padding(self.kernel_size),
-                )
-            )
-
-        self.convs = nn.ModuleList()
-
-        self.convs.append(
-            NormConv2d(
-                self.in_channels * 2,
-                self.filters,
-                kernel_size=self.kernel_size,
-                padding=get_2d_padding(self.kernel_size),
-            )
-        )
-
-        in_chs = min(self.filters_scale * self.filters, self.max_filters)
-        for i, dilation in enumerate(self.dilations):
-            out_chs = min(
-                (self.filters_scale ** (i + 1)) * self.filters, self.max_filters
-            )
-            self.convs.append(
-                NormConv2d(
-                    in_chs,
-                    out_chs,
-                    kernel_size=self.kernel_size,
-                    stride=self.stride,
-                    dilation=(dilation, 1),
-                    padding=get_2d_padding(self.kernel_size, (dilation, 1)),
-                    norm="weight_norm",
-                )
-            )
-            in_chs = out_chs
-        out_chs = min(
-            (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
-            self.max_filters,
-        )
-        self.convs.append(
-            NormConv2d(
-                in_chs,
-                out_chs,
-                kernel_size=(self.kernel_size[0], self.kernel_size[0]),
-                padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
-                norm="weight_norm",
-            )
-        )
-
-        self.conv_post = NormConv2d(
-            out_chs,
-            self.out_channels,
-            kernel_size=(self.kernel_size[0], self.kernel_size[0]),
-            padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
-            norm="weight_norm",
-        )
-
-        self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE)
-        self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2)
-
-    def forward(self, x):
-        fmap = []
-
-        x = self.resample(x)
-
-        z = self.cqt_transform(x)
-
-        z_amplitude = z[:, :, :, 0].unsqueeze(1)
-        z_phase = z[:, :, :, 1].unsqueeze(1)
-
-        z = torch.cat([z_amplitude, z_phase], dim=1)
-        z = rearrange(z, "b c w t -> b c t w")
-
-        latent_z = []
-        for i in range(self.n_octaves):
-            latent_z.append(
-                self.conv_pres[i](
-                    z[
-                        :,
-                        :,
-                        :,
-                        i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
-                    ]
-                )
-            )
-        latent_z = torch.cat(latent_z, dim=-1)
-
-        for i, layer in enumerate(self.convs):
-            latent_z = layer(latent_z)
-
-            latent_z = self.activation(latent_z)
-            fmap.append(latent_z)
-
-        latent_z = self.conv_post(latent_z)
-
-        return latent_z, fmap
-
-
-class MultiScaleSubbandCQTDiscriminator(nn.Module):
-    def __init__(self, hop_lengths, n_octaves, bins_per_octaves, **kwargs):
-        super().__init__()
-
-        self.discriminators = nn.ModuleList(
-            [
-                DiscriminatorCQT(
-                    hop_length=hop_length,
-                    n_octaves=n_octaves,
-                    bins_per_octave=bins_per_octave,
-                    **kwargs,
-                )
-                for hop_length, n_octaves, bins_per_octave in zip(
-                    hop_lengths, n_octaves, bins_per_octaves
-                )
-            ]
-        )
-
-    def forward(self, x: torch.Tensor):
-        logits = []
-        fmaps = []
-        for disc in self.discriminators:
-            logit, fmap = disc(x)
-            logits.append(logit)
-            fmaps.append(fmap)
-
-        return logits, fmaps

+ 0 - 303
fish_speech/models/vqgan/modules/discriminators/msstftd.py

@@ -1,303 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""MS-STFT discriminator, provided here for reference."""
-
-import typing as tp
-
-import einops
-import torch
-import torchaudio
-from einops import rearrange
-from torch import nn
-from torch.nn.utils import spectral_norm, weight_norm
-
-FeatureMapType = tp.List[torch.Tensor]
-LogitsType = torch.Tensor
-DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
-
-
-class ConvLayerNorm(nn.LayerNorm):
-    """
-    Convolution-friendly LayerNorm that moves channels to last dimensions
-    before running the normalization and moves them back to original position right after.
-    """  # noqa: E501
-
-    def __init__(
-        self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs
-    ):
-        super().__init__(normalized_shape, **kwargs)
-
-    def forward(self, x):
-        x = einops.rearrange(x, "b ... t -> b t ...")
-        x = super().forward(x)
-        x = einops.rearrange(x, "b t ... -> b ... t")
-        return
-
-
-CONV_NORMALIZATIONS = frozenset(
-    [
-        "none",
-        "weight_norm",
-        "spectral_norm",
-        "time_layer_norm",
-        "layer_norm",
-        "time_group_norm",
-    ]
-)
-
-
-def apply_parametrization_norm(module: nn.Module, norm: str = "none") -> nn.Module:
-    assert norm in CONV_NORMALIZATIONS
-    if norm == "weight_norm":
-        return weight_norm(module)
-    elif norm == "spectral_norm":
-        return spectral_norm(module)
-    else:
-        # We already check was in CONV_NORMALIZATION, so any other choice
-        # doesn't need reparametrization.
-        return module
-
-
-def get_norm_module(
-    module: nn.Module, causal: bool = False, norm: str = "none", **norm_kwargs
-) -> nn.Module:
-    """Return the proper normalization module. If causal is True, this will ensure the returned
-    module is causal, or return an error if the normalization doesn't support causal evaluation.
-    """  # noqa: E501
-    assert norm in CONV_NORMALIZATIONS
-    if norm == "layer_norm":
-        assert isinstance(module, nn.modules.conv._ConvNd)
-        return ConvLayerNorm(module.out_channels, **norm_kwargs)
-    elif norm == "time_group_norm":
-        if causal:
-            raise ValueError("GroupNorm doesn't support causal evaluation.")
-        assert isinstance(module, nn.modules.conv._ConvNd)
-        return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
-    else:
-        return nn.Identity()
-
-
-class NormConv2d(nn.Module):
-    """Wrapper around Conv2d and normalization applied to this conv
-    to provide a uniform interface across normalization approaches.
-    """
-
-    def __init__(
-        self,
-        *args,
-        norm: str = "none",
-        norm_kwargs: tp.Dict[str, tp.Any] = {},
-        **kwargs,
-    ):
-        super().__init__()
-        self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
-        self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
-        self.norm_type = norm
-
-    def forward(self, x):
-        x = self.conv(x)
-        x = self.norm(x)
-        return x
-
-
-def get_2d_padding(
-    kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)
-):
-    return (
-        ((kernel_size[0] - 1) * dilation[0]) // 2,
-        ((kernel_size[1] - 1) * dilation[1]) // 2,
-    )
-
-
-class DiscriminatorSTFT(nn.Module):
-    """STFT sub-discriminator.
-    Args:
-        filters (int): Number of filters in convolutions
-        in_channels (int): Number of input channels. Default: 1
-        out_channels (int): Number of output channels. Default: 1
-        n_fft (int): Size of FFT for each scale. Default: 1024
-        hop_length (int): Length of hop between STFT windows for each scale. Default: 256
-        kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
-        stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
-        dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
-        win_length (int): Window size for each scale. Default: 1024
-        normalized (bool): Whether to normalize by magnitude after stft. Default: True
-        norm (str): Normalization method. Default: `'weight_norm'`
-        activation (str): Activation function. Default: `'LeakyReLU'`
-        activation_params (dict): Parameters to provide to the activation function.
-        growth (int): Growth factor for the filters. Default: 1
-    """  # noqa: E501
-
-    def __init__(
-        self,
-        filters: int,
-        in_channels: int = 1,
-        out_channels: int = 1,
-        n_fft: int = 1024,
-        hop_length: int = 256,
-        win_length: int = 1024,
-        max_filters: int = 1024,
-        filters_scale: int = 1,
-        kernel_size: tp.Tuple[int, int] = (3, 9),
-        dilations: tp.List = [1, 2, 4],
-        stride: tp.Tuple[int, int] = (1, 2),
-        normalized: bool = True,
-        norm: str = "weight_norm",
-        activation: str = "LeakyReLU",
-        activation_params: dict = {"negative_slope": 0.2},
-    ):
-        super().__init__()
-        assert len(kernel_size) == 2
-        assert len(stride) == 2
-        self.filters = filters
-        self.in_channels = in_channels
-        self.out_channels = out_channels
-        self.n_fft = n_fft
-        self.hop_length = hop_length
-        self.win_length = win_length
-        self.normalized = normalized
-        self.activation = getattr(torch.nn, activation)(**activation_params)
-        self.spec_transform = torchaudio.transforms.Spectrogram(
-            n_fft=self.n_fft,
-            hop_length=self.hop_length,
-            win_length=self.win_length,
-            window_fn=torch.hann_window,
-            normalized=self.normalized,
-            center=False,
-            pad_mode=None,
-            power=None,
-        )
-        spec_channels = 2 * self.in_channels
-        self.convs = nn.ModuleList()
-        self.convs.append(
-            NormConv2d(
-                spec_channels,
-                self.filters,
-                kernel_size=kernel_size,
-                padding=get_2d_padding(kernel_size),
-            )
-        )
-        in_chs = min(filters_scale * self.filters, max_filters)
-        for i, dilation in enumerate(dilations):
-            out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
-            self.convs.append(
-                NormConv2d(
-                    in_chs,
-                    out_chs,
-                    kernel_size=kernel_size,
-                    stride=stride,
-                    dilation=(dilation, 1),
-                    padding=get_2d_padding(kernel_size, (dilation, 1)),
-                    norm=norm,
-                )
-            )
-            in_chs = out_chs
-        out_chs = min(
-            (filters_scale ** (len(dilations) + 1)) * self.filters, max_filters
-        )
-        self.convs.append(
-            NormConv2d(
-                in_chs,
-                out_chs,
-                kernel_size=(kernel_size[0], kernel_size[0]),
-                padding=get_2d_padding((kernel_size[0], kernel_size[0])),
-                norm=norm,
-            )
-        )
-        self.conv_post = NormConv2d(
-            out_chs,
-            self.out_channels,
-            kernel_size=(kernel_size[0], kernel_size[0]),
-            padding=get_2d_padding((kernel_size[0], kernel_size[0])),
-            norm=norm,
-        )
-
-    def forward(self, x: torch.Tensor):
-        fmap = []
-        z = self.spec_transform(x)  # [B, 2, Freq, Frames, 2]
-        z = torch.cat([z.real, z.imag], dim=1)
-        z = rearrange(z, "b c w t -> b c t w")
-        for i, layer in enumerate(self.convs):
-            z = layer(z)
-            z = self.activation(z)
-            fmap.append(z)
-        z = self.conv_post(z)
-        return z, fmap
-
-
-class MultiScaleSTFTDiscriminator(nn.Module):
-    """Multi-Scale STFT (MS-STFT) discriminator.
-    Args:
-        filters (int): Number of filters in convolutions
-        in_channels (int): Number of input channels. Default: 1
-        out_channels (int): Number of output channels. Default: 1
-        n_ffts (Sequence[int]): Size of FFT for each scale
-        hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
-        win_lengths (Sequence[int]): Window size for each scale
-        **kwargs: additional args for STFTDiscriminator
-    """
-
-    def __init__(
-        self,
-        filters: int,
-        in_channels: int = 1,
-        out_channels: int = 1,
-        n_ffts: tp.List[int] = [1024, 2048, 512],
-        hop_lengths: tp.List[int] = [256, 512, 128],
-        win_lengths: tp.List[int] = [1024, 2048, 512],
-        **kwargs,
-    ):
-        super().__init__()
-        assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
-        self.discriminators = nn.ModuleList(
-            [
-                DiscriminatorSTFT(
-                    filters,
-                    in_channels=in_channels,
-                    out_channels=out_channels,
-                    n_fft=n_ffts[i],
-                    win_length=win_lengths[i],
-                    hop_length=hop_lengths[i],
-                    **kwargs,
-                )
-                for i in range(len(n_ffts))
-            ]
-        )
-        self.num_discriminators = len(self.discriminators)
-
-    def forward(self, x: torch.Tensor) -> DiscriminatorOutput:
-        logits = []
-        fmaps = []
-        for disc in self.discriminators:
-            logit, fmap = disc(x)
-            logits.append(logit)
-            fmaps.append(fmap)
-
-        return logits, fmaps
-
-
-def test():
-    disc = MultiScaleSTFTDiscriminator(filters=32)
-    y = torch.randn(1, 1, 24000)
-    y_hat = torch.randn(1, 1, 24000)
-
-    y_disc_r, fmap_r = disc(y)
-    y_disc_gen, fmap_gen = disc(y_hat)
-    assert (
-        len(y_disc_r)
-        == len(y_disc_gen)
-        == len(fmap_r)
-        == len(fmap_gen)
-        == disc.num_discriminators
-    )
-    assert all([len(fm) == 5 for fm in fmap_r + fmap_gen])
-    assert all([list(f.shape)[:2] == [1, 32] for fm in fmap_r + fmap_gen for f in fm])
-    assert all([len(logits.shape) == 4 for logits in y_disc_r + y_disc_gen])
-
-
-if __name__ == "__main__":
-    test()

+ 2 - 261
fish_speech/models/vqgan/modules/encoders.py

@@ -8,265 +8,6 @@ import torch.nn.functional as F
 from einops import rearrange
 from einops import rearrange
 from vector_quantize_pytorch import LFQ, GroupedResidualVQ, VectorQuantize
 from vector_quantize_pytorch import LFQ, GroupedResidualVQ, VectorQuantize
 
 
-from fish_speech.models.vqgan.modules.modules import WN
-from fish_speech.models.vqgan.modules.transformer import (
-    MultiHeadAttention,
-    RelativePositionTransformer,
-)
-from fish_speech.models.vqgan.utils import sequence_mask
-
-
-# * Ready and Tested
-class ConvDownSampler(nn.Module):
-    def __init__(
-        self,
-        dims: list,
-        kernel_sizes: list,
-        strides: list,
-    ):
-        super().__init__()
-
-        self.dims = dims
-        self.kernel_sizes = kernel_sizes
-        self.strides = strides
-        self.total_strides = np.prod(self.strides)
-
-        self.convs = nn.ModuleList(
-            [
-                nn.ModuleList(
-                    [
-                        nn.Conv1d(
-                            in_channels=self.dims[i],
-                            out_channels=self.dims[i + 1],
-                            kernel_size=self.kernel_sizes[i],
-                            stride=self.strides[i],
-                            padding=(self.kernel_sizes[i] - 1) // 2,
-                        ),
-                        nn.LayerNorm(self.dims[i + 1], elementwise_affine=True),
-                        nn.GELU(),
-                    ]
-                )
-                for i in range(len(self.dims) - 1)
-            ]
-        )
-
-        self.apply(self.init_weights)
-
-    def init_weights(self, m):
-        if isinstance(m, nn.Conv1d):
-            nn.init.normal_(m.weight, std=0.02)
-        elif isinstance(m, nn.LayerNorm):
-            nn.init.ones_(m.weight)
-            nn.init.zeros_(m.bias)
-
-    def forward(self, x):
-        for conv, norm, act in self.convs:
-            x = conv(x)
-            x = norm(x.mT).mT
-            x = act(x)
-
-        return x
-
-
-# * Ready and Tested
-class TextEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int,
-        out_channels: int,
-        hidden_channels: int,
-        hidden_channels_ffn: int,
-        n_heads: int,
-        n_layers: int,
-        kernel_size: int,
-        dropout: float,
-        gin_channels=0,
-        speaker_cond_layer=0,
-        use_vae=True,
-        use_embedding=False,
-    ):
-        """Text Encoder for VITS model.
-
-        Args:
-            in_channels (int): Number of characters for the embedding layer.
-            out_channels (int): Number of channels for the output.
-            hidden_channels (int): Number of channels for the hidden layers.
-            hidden_channels_ffn (int): Number of channels for the convolutional layers.
-            n_heads (int): Number of attention heads for the Transformer layers.
-            n_layers (int): Number of Transformer layers.
-            kernel_size (int): Kernel size for the FFN layers in Transformer network.
-            dropout (float): Dropout rate for the Transformer layers.
-            gin_channels (int, optional): Number of channels for speaker embedding. Defaults to 0.
-        """
-        super().__init__()
-        self.out_channels = out_channels
-        self.hidden_channels = hidden_channels
-        self.use_embedding = use_embedding
-
-        if use_embedding:
-            self.proj_in = nn.Embedding(in_channels, hidden_channels)
-        else:
-            self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
-
-        self.encoder = RelativePositionTransformer(
-            in_channels=hidden_channels,
-            out_channels=hidden_channels,
-            hidden_channels=hidden_channels,
-            hidden_channels_ffn=hidden_channels_ffn,
-            n_heads=n_heads,
-            n_layers=n_layers,
-            kernel_size=kernel_size,
-            dropout=dropout,
-            window_size=4,
-            gin_channels=gin_channels,
-            speaker_cond_layer=speaker_cond_layer,
-        )
-        self.proj_out = nn.Conv1d(
-            hidden_channels, out_channels * 2 if use_vae else out_channels, 1
-        )
-        self.use_vae = use_vae
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        x_mask: torch.Tensor,
-        g: torch.Tensor = None,
-        noise_scale: float = 1,
-    ):
-        """
-        Shapes:
-            - x: :math:`[B, T]`
-            - x_length: :math:`[B]`
-        """
-
-        if self.use_embedding:
-            x = self.proj_in(x.long()).mT * x_mask
-        else:
-            x = self.proj_in(x) * x_mask
-
-        x = self.encoder(x, x_mask, g=g)
-        x = self.proj_out(x) * x_mask
-
-        if self.use_vae is False:
-            return x
-
-        m, logs = torch.split(x, self.out_channels, dim=1)
-        z = m + torch.randn_like(m) * torch.exp(logs) * x_mask * noise_scale
-        return z, m, logs, x, x_mask
-
-
-# * Ready and Tested
-class PosteriorEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int,
-        out_channels: int,
-        hidden_channels: int,
-        kernel_size: int,
-        dilation_rate: int,
-        n_layers: int,
-        gin_channels=0,
-    ):
-        """Posterior Encoder of VITS model.
-
-        ::
-            x -> conv1x1() -> WaveNet() (non-causal) -> conv1x1() -> split() -> [m, s] -> sample(m, s) -> z
-
-        Args:
-            in_channels (int): Number of input tensor channels.
-            out_channels (int): Number of output tensor channels.
-            hidden_channels (int): Number of hidden channels.
-            kernel_size (int): Kernel size of the WaveNet convolution layers.
-            dilation_rate (int): Dilation rate of the WaveNet layers.
-            num_layers (int): Number of the WaveNet layers.
-            cond_channels (int, optional): Number of conditioning tensor channels. Defaults to 0.
-        """
-        super().__init__()
-        self.out_channels = out_channels
-
-        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
-        self.enc = WN(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            gin_channels=gin_channels,
-        )
-        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        x_mask: torch.Tensor,
-        g: torch.Tensor,
-        noise_scale: float = 1,
-    ):
-        """
-        Shapes:
-            - x: :math:`[B, C, T]`
-            - x_lengths: :math:`[B, 1]`
-            - g: :math:`[B, C, 1]`
-        """
-        x = self.pre(x) * x_mask
-        x = self.enc(x, x_mask, g=g)
-        stats = self.proj(x) * x_mask
-        m, logs = torch.split(stats, self.out_channels, dim=1)
-        z = m + torch.randn_like(m) * torch.exp(logs) * x_mask * noise_scale
-        return z, m, logs, x_mask
-
-
-# TODO: Ready for testing
-class SpeakerEncoder(nn.Module):
-    def __init__(
-        self,
-        in_channels: int = 128,
-        hidden_channels: int = 192,
-        out_channels: int = 512,
-        num_layers: int = 4,
-    ) -> None:
-        super().__init__()
-
-        self.in_proj = nn.Sequential(
-            nn.Conv1d(in_channels, hidden_channels, 1),
-            nn.Mish(),
-            nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
-            nn.Mish(),
-            nn.Conv1d(hidden_channels, hidden_channels, 5, padding=2),
-            nn.Mish(),
-        )
-        self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1)
-        self.apply(self._init_weights)
-
-        self.encoder = WN(
-            hidden_channels,
-            kernel_size=3,
-            dilation_rate=1,
-            n_layers=num_layers,
-        )
-
-    def _init_weights(self, m):
-        if isinstance(m, (nn.Conv1d, nn.Linear)):
-            nn.init.normal_(m.weight, mean=0, std=0.02)
-            nn.init.zeros_(m.bias)
-
-    def forward(self, mels, mel_masks: torch.Tensor):
-        """
-        Shapes:
-            - x: :math:`[B, C, T]`
-            - x_lengths: :math:`[B, 1]`
-        """
-
-        x = self.in_proj(mels) * mel_masks
-        x = self.encoder(x, mel_masks)
-
-        # Avg Pooling
-        x = x * mel_masks
-        x = self.out_proj(x)
-        x = torch.sum(x, dim=-1) / torch.sum(mel_masks, dim=-1)
-        x = x[..., None]
-
-        return x
-
 
 
 class VQEncoder(nn.Module):
 class VQEncoder(nn.Module):
     def __init__(
     def __init__(
@@ -286,7 +27,7 @@ class VQEncoder(nn.Module):
                 dim=vq_channels,
                 dim=vq_channels,
                 codebook_size=codebook_size,
                 codebook_size=codebook_size,
                 threshold_ema_dead_code=threshold_ema_dead_code,
                 threshold_ema_dead_code=threshold_ema_dead_code,
-                kmeans_init=False,
+                kmeans_init=True,
                 groups=codebook_groups,
                 groups=codebook_groups,
                 num_quantizers=codebook_layers,
                 num_quantizers=codebook_layers,
             )
             )
@@ -295,7 +36,7 @@ class VQEncoder(nn.Module):
                 dim=vq_channels,
                 dim=vq_channels,
                 codebook_size=codebook_size,
                 codebook_size=codebook_size,
                 threshold_ema_dead_code=threshold_ema_dead_code,
                 threshold_ema_dead_code=threshold_ema_dead_code,
-                kmeans_init=False,
+                kmeans_init=True,
             )
             )
 
 
         self.codebook_groups = codebook_groups
         self.codebook_groups = codebook_groups

+ 0 - 297
fish_speech/models/vqgan/modules/flow.py

@@ -1,297 +0,0 @@
-import torch
-from torch import nn
-
-from fish_speech.models.vqgan.modules.modules import WN, Flip
-from fish_speech.models.vqgan.modules.normalization import LayerNorm
-from fish_speech.models.vqgan.modules.transformer import FFN, MultiHeadAttention
-
-
-class ResidualCouplingBlock(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        n_flows=4,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.n_flows = n_flows
-        self.gin_channels = gin_channels
-
-        self.flows = nn.ModuleList()
-
-        for i in range(n_flows):
-            self.flows.append(
-                ResidualCouplingLayer(
-                    channels,
-                    hidden_channels,
-                    kernel_size,
-                    dilation_rate,
-                    n_layers,
-                    gin_channels=gin_channels,
-                    mean_only=True,
-                )
-            )
-            self.flows.append(Flip())
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        if not reverse:
-            for flow in self.flows:
-                x, _ = flow(x, x_mask, g=g, reverse=reverse)
-        else:
-            for flow in reversed(self.flows):
-                x = flow(x, x_mask, g=g, reverse=reverse)
-        return x
-
-
-class ResidualCouplingLayer(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        dilation_rate,
-        n_layers,
-        p_dropout=0,
-        gin_channels=0,
-        mean_only=False,
-    ):
-        assert channels % 2 == 0, "channels should be divisible by 2"
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.dilation_rate = dilation_rate
-        self.n_layers = n_layers
-        self.half_channels = channels // 2
-        self.mean_only = mean_only
-
-        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
-        self.enc = WN(
-            hidden_channels,
-            kernel_size,
-            dilation_rate,
-            n_layers,
-            p_dropout=p_dropout,
-            gin_channels=gin_channels,
-        )
-        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
-        self.post.weight.data.zero_()
-        self.post.bias.data.zero_()
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
-        h = self.pre(x0) * x_mask
-        h = self.enc(h, x_mask, g=g)
-        stats = self.post(h) * x_mask
-        if not self.mean_only:
-            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
-        else:
-            m = stats
-            logs = torch.zeros_like(m)
-
-        if not reverse:
-            x1 = m + x1 * torch.exp(logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            logdet = torch.sum(logs, [1, 2])
-            return x, logdet
-        else:
-            x1 = (x1 - m) * torch.exp(-logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            return x
-
-
-class TransformerCouplingBlock(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        kernel_size,
-        p_dropout,
-        n_flows=4,
-        gin_channels=0,
-    ):
-        super().__init__()
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.n_layers = n_layers
-        self.n_flows = n_flows
-        self.gin_channels = gin_channels
-
-        self.flows = nn.ModuleList()
-
-        for i in range(n_flows):
-            self.flows.append(
-                TransformerCouplingLayer(
-                    channels,
-                    hidden_channels,
-                    kernel_size,
-                    n_layers,
-                    n_heads,
-                    p_dropout,
-                    filter_channels,
-                    mean_only=True,
-                    gin_channels=self.gin_channels,
-                )
-            )
-            self.flows.append(Flip())
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        if not reverse:
-            for flow in self.flows:
-                x, _ = flow(x, x_mask, g=g, reverse=reverse)
-        else:
-            for flow in reversed(self.flows):
-                x = flow(x, x_mask, g=g, reverse=reverse)
-        return x
-
-
-class TransformerCouplingLayer(nn.Module):
-    def __init__(
-        self,
-        channels,
-        hidden_channels,
-        kernel_size,
-        n_layers,
-        n_heads,
-        p_dropout=0,
-        filter_channels=0,
-        mean_only=False,
-        gin_channels=0,
-    ):
-        super().__init__()
-
-        assert channels % 2 == 0, "channels should be divisible by 2"
-
-        self.channels = channels
-        self.hidden_channels = hidden_channels
-        self.kernel_size = kernel_size
-        self.n_layers = n_layers
-        self.half_channels = channels // 2
-        self.mean_only = mean_only
-
-        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
-        self.enc = Encoder(
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers,
-            kernel_size,
-            p_dropout,
-            gin_channels=gin_channels,
-        )
-        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
-        self.post.weight.data.zero_()
-        self.post.bias.data.zero_()
-
-    def forward(self, x, x_mask, g=None, reverse=False):
-        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
-        h = self.pre(x0) * x_mask
-        h = self.enc(h, x_mask, g=g)
-        stats = self.post(h) * x_mask
-        if not self.mean_only:
-            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
-        else:
-            m = stats
-            logs = torch.zeros_like(m)
-
-        if not reverse:
-            x1 = m + x1 * torch.exp(logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            logdet = torch.sum(logs, [1, 2])
-            return x, logdet
-        else:
-            x1 = (x1 - m) * torch.exp(-logs) * x_mask
-            x = torch.cat([x0, x1], 1)
-            return x
-
-
-class Encoder(nn.Module):
-    def __init__(
-        self,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        kernel_size=1,
-        p_dropout=0.0,
-        window_size=4,
-        gin_channels=512,
-        cond_layer_idx=2,
-    ):
-        super().__init__()
-        self.hidden_channels = hidden_channels
-        self.filter_channels = filter_channels
-        self.n_heads = n_heads
-        self.n_layers = n_layers
-        self.kernel_size = kernel_size
-        self.p_dropout = p_dropout
-        self.window_size = window_size
-
-        self.spk_emb_linear = nn.Linear(gin_channels, self.hidden_channels)
-        self.cond_layer_idx = cond_layer_idx
-
-        assert (
-            self.cond_layer_idx < self.n_layers
-        ), "cond_layer_idx should be less than n_layers"
-
-        self.drop = nn.Dropout(p_dropout)
-        self.attn_layers = nn.ModuleList()
-        self.norm_layers_1 = nn.ModuleList()
-        self.ffn_layers = nn.ModuleList()
-        self.norm_layers_2 = nn.ModuleList()
-        for i in range(self.n_layers):
-            self.attn_layers.append(
-                MultiHeadAttention(
-                    hidden_channels,
-                    hidden_channels,
-                    n_heads,
-                    p_dropout=p_dropout,
-                    window_size=window_size,
-                )
-            )
-            self.norm_layers_1.append(LayerNorm(hidden_channels))
-            self.ffn_layers.append(
-                FFN(
-                    hidden_channels,
-                    hidden_channels,
-                    filter_channels,
-                    kernel_size,
-                    p_dropout=p_dropout,
-                )
-            )
-            self.norm_layers_2.append(LayerNorm(hidden_channels))
-
-    def forward(self, x, x_mask, g=None):
-        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
-        x = x * x_mask
-
-        for i in range(self.n_layers):
-            if i == self.cond_layer_idx and g is not None:
-                g = self.spk_emb_linear(g.transpose(1, 2))
-                g = g.transpose(1, 2)
-                x = x + g
-                x = x * x_mask
-            y = self.attn_layers[i](x, x, attn_mask)
-            y = self.drop(y)
-            x = self.norm_layers_1[i](x + y)
-
-            y = self.ffn_layers[i](x, x_mask)
-            y = self.drop(y)
-            x = self.norm_layers_2[i](x + y)
-
-        x = x * x_mask
-
-        return x

+ 0 - 174
fish_speech/models/vqgan/modules/models.py

@@ -1,174 +0,0 @@
-import torch
-from torch import nn
-
-from fish_speech.models.vqgan.modules.decoder import Generator
-from fish_speech.models.vqgan.modules.encoders import (
-    PosteriorEncoder,
-    SpeakerEncoder,
-    TextEncoder,
-    VQEncoder,
-)
-from fish_speech.models.vqgan.modules.flow import ResidualCouplingBlock
-from fish_speech.models.vqgan.utils import rand_slice_segments, sequence_mask
-
-
-class SynthesizerTrn(nn.Module):
-    """
-    Synthesizer for Training
-    """
-
-    def __init__(
-        self,
-        *,
-        in_channels,
-        spec_channels,
-        segment_size,
-        inter_channels,
-        hidden_channels,
-        filter_channels,
-        n_heads,
-        n_layers,
-        n_flows,
-        n_layers_q,
-        n_layers_spk,
-        n_layers_flow,
-        kernel_size,
-        p_dropout,
-        speaker_cond_layer,
-        resblock,
-        resblock_kernel_sizes,
-        resblock_dilation_sizes,
-        upsample_rates,
-        upsample_initial_channel,
-        upsample_kernel_sizes,
-        gin_channels,
-        codebook_size,
-        kmeans_ckpt=None,
-    ):
-        super().__init__()
-
-        self.segment_size = segment_size
-
-        # self.vq = VQEncoder(
-        #     in_channels=in_channels,
-        #     vq_channels=in_channels,
-        #     codebook_size=codebook_size,
-        #     kmeans_ckpt=kmeans_ckpt,
-        # )
-        self.enc_p = TextEncoder(
-            in_channels,
-            inter_channels,
-            hidden_channels,
-            filter_channels,
-            n_heads,
-            n_layers,
-            kernel_size,
-            p_dropout,
-            gin_channels=gin_channels,
-            speaker_cond_layer=speaker_cond_layer,
-        )
-        self.enc_spk = SpeakerEncoder(
-            in_channels=spec_channels,
-            hidden_channels=inter_channels,
-            out_channels=gin_channels,
-            num_heads=n_heads,
-            num_layers=n_layers_spk,
-            p_dropout=p_dropout,
-        )
-        self.flow = ResidualCouplingBlock(
-            channels=inter_channels,
-            hidden_channels=hidden_channels,
-            kernel_size=5,
-            dilation_rate=1,
-            n_layers=n_layers_flow,
-            n_flows=n_flows,
-            gin_channels=gin_channels,
-        )
-        self.enc_q = PosteriorEncoder(
-            spec_channels,
-            inter_channels,
-            hidden_channels,
-            5,
-            1,
-            n_layers_q,
-            gin_channels=gin_channels,
-        )
-        self.dec = Generator(
-            inter_channels,
-            resblock,
-            resblock_kernel_sizes,
-            resblock_dilation_sizes,
-            upsample_rates,
-            upsample_initial_channel,
-            upsample_kernel_sizes,
-            gin_channels=gin_channels,
-        )
-
-    def forward(self, x, x_lengths, specs):
-        # x = x.mT
-
-        min_length = min(x.shape[1], specs.shape[2])
-        if min_length % 2 != 0:
-            min_length -= 1
-
-        x = x[:, :min_length]
-        specs = specs[:, :, :min_length]
-        x_lengths = torch.clamp(x_lengths, max=min_length)
-
-        spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
-            specs.dtype
-        )
-        x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
-
-        g = self.enc_spk(specs, spec_masks)
-
-        # with torch.no_grad():
-        #     x, _ = self.vq(x, x_masks)
-        #     vq_loss = 0
-
-        _, m_p, logs_p, _, _ = self.enc_p(x, x_masks, g=g)
-        z_q, m_q, logs_q, _ = self.enc_q(specs, spec_masks, g=g)
-        z_p = self.flow(z_q, spec_masks, g=g, reverse=False)
-
-        z_slice, ids_slice = rand_slice_segments(z_q, x_lengths, self.segment_size)
-        o = self.dec(z_slice, g=g)
-
-        return (
-            o,
-            ids_slice,
-            x_masks,
-            spec_masks,
-            (z_q, z_p),
-            (m_p, logs_p),
-            (m_q, logs_q),
-            # vq_loss,
-        )
-
-    def infer(self, x, x_lengths, specs, max_len=None, noise_scale=0.35):
-        # x = x.mT
-        spec_masks = torch.unsqueeze(sequence_mask(x_lengths, specs.shape[2]), 1).to(
-            specs.dtype
-        )
-        # print(x_lengths, x.shape)
-        x_masks = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
-        g = self.enc_spk(specs, spec_masks)
-        # x, vq_loss = self.vq(x, x_masks)
-        z_p, m_p, logs_p, h_text, _ = self.enc_p(
-            x, x_masks, g=g, noise_scale=noise_scale
-        )
-        z_p = self.flow(z_p, x_masks, g=g, reverse=True)
-
-        o = self.dec((z_p * x_masks)[:, :, :max_len], g=g)
-        return o
-
-    def reconstruct(self, specs, spec_lengths, max_len=None, noise_scale=0.35):
-        spec_masks = torch.unsqueeze(sequence_mask(spec_lengths, specs.shape[2]), 1).to(
-            specs.dtype
-        )
-        g = self.enc_spk(specs, spec_masks)
-        z_q, m_q, logs_q, _ = self.enc_q(
-            specs, spec_masks, g=g, noise_scale=noise_scale
-        )
-        o = self.dec((z_q * spec_masks)[:, :, :max_len], g=g)
-
-        return o

+ 17 - 33
fish_speech/models/vqgan/modules/modules.py

@@ -10,31 +10,30 @@ LRELU_SLOPE = 0.1
 
 
 # ! PosteriorEncoder
 # ! PosteriorEncoder
 # ! ResidualCouplingLayer
 # ! ResidualCouplingLayer
-class WN(nn.Module):
+class WaveNet(nn.Module):
     def __init__(
     def __init__(
         self,
         self,
         hidden_channels,
         hidden_channels,
         kernel_size,
         kernel_size,
         dilation_rate,
         dilation_rate,
         n_layers,
         n_layers,
-        gin_channels=0,
         p_dropout=0,
         p_dropout=0,
         out_channels=None,
         out_channels=None,
+        in_channels=None,
     ):
     ):
-        super(WN, self).__init__()
+        super(WaveNet, self).__init__()
         assert kernel_size % 2 == 1
         assert kernel_size % 2 == 1
         self.hidden_channels = hidden_channels
         self.hidden_channels = hidden_channels
         self.kernel_size = (kernel_size,)
         self.kernel_size = (kernel_size,)
         self.n_layers = n_layers
         self.n_layers = n_layers
-        self.gin_channels = gin_channels
 
 
         self.in_layers = nn.ModuleList()
         self.in_layers = nn.ModuleList()
         self.res_skip_layers = nn.ModuleList()
         self.res_skip_layers = nn.ModuleList()
         self.drop = nn.Dropout(p_dropout)
         self.drop = nn.Dropout(p_dropout)
 
 
-        if gin_channels != 0:
-            cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
-            self.cond_layer = weight_norm(cond_layer, name="weight")
+        self.in_channels = in_channels
+        if in_channels is not None:
+            self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
 
 
         for i in range(n_layers):
         for i in range(n_layers):
             dilation = dilation_rate**i
             dilation = dilation_rate**i
@@ -61,33 +60,31 @@ class WN(nn.Module):
         if out_channels is not None:
         if out_channels is not None:
             self.out_layer = nn.Conv1d(hidden_channels, out_channels, 1)
             self.out_layer = nn.Conv1d(hidden_channels, out_channels, 1)
 
 
-    def forward(self, x, x_mask, g=None, **kwargs):
-        output = torch.zeros_like(x)
+    def forward(self, x, x_mask=None):
         n_channels_tensor = torch.IntTensor([self.hidden_channels])
         n_channels_tensor = torch.IntTensor([self.hidden_channels])
 
 
-        if g is not None:
-            g = self.cond_layer(g)
+        if self.in_channels is not None:
+            x = self.proj_in(x)
+
+        output = torch.zeros_like(x)
 
 
         for i in range(self.n_layers):
         for i in range(self.n_layers):
             x_in = self.in_layers[i](x)
             x_in = self.in_layers[i](x)
-            if g is not None:
-                cond_offset = i * 2 * self.hidden_channels
-                g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
-            else:
-                g_l = torch.zeros_like(x_in)
-
-            acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+            acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
             acts = self.drop(acts)
             acts = self.drop(acts)
 
 
             res_skip_acts = self.res_skip_layers[i](acts)
             res_skip_acts = self.res_skip_layers[i](acts)
             if i < self.n_layers - 1:
             if i < self.n_layers - 1:
                 res_acts = res_skip_acts[:, : self.hidden_channels, :]
                 res_acts = res_skip_acts[:, : self.hidden_channels, :]
-                x = (x + res_acts) * x_mask
+                x = x + res_acts
+                if x_mask is not None:
+                    x = x * x_mask
                 output = output + res_skip_acts[:, self.hidden_channels :, :]
                 output = output + res_skip_acts[:, self.hidden_channels :, :]
             else:
             else:
                 output = output + res_skip_acts
                 output = output + res_skip_acts
 
 
-        x = output * x_mask
+        if x_mask is not None:
+            x = output * x_mask
 
 
         if self.out_channels is not None:
         if self.out_channels is not None:
             x = self.out_layer(x)
             x = self.out_layer(x)
@@ -101,16 +98,3 @@ class WN(nn.Module):
             remove_parametrizations(l)
             remove_parametrizations(l)
         for l in self.res_skip_layers:
         for l in self.res_skip_layers:
             remove_parametrizations(l)
             remove_parametrizations(l)
-
-
-# ! StochasticDurationPredictor
-# ! ResidualCouplingBlock
-# TODO convert to class method
-class Flip(nn.Module):
-    def forward(self, x, *args, reverse=False, **kwargs):
-        x = torch.flip(x, [1])
-        if not reverse:
-            logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
-            return x, logdet
-        else:
-            return x

+ 0 - 34
fish_speech/models/vqgan/modules/normalization.py

@@ -1,34 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class LayerNorm(nn.Module):
-    def __init__(self, channels, eps=1e-5):
-        super().__init__()
-        self.channels = channels
-        self.eps = eps
-
-        self.gamma = nn.Parameter(torch.ones(channels))
-        self.beta = nn.Parameter(torch.zeros(channels))
-
-    def forward(self, x: torch.Tensor):
-        x = F.layer_norm(x.mT, (self.channels,), self.gamma, self.beta, self.eps)
-        return x.mT
-
-
-class CondLayerNorm(nn.Module):
-    def __init__(self, channels, eps=1e-5, cond_channels=0):
-        super().__init__()
-        self.channels = channels
-        self.eps = eps
-
-        self.linear_gamma = nn.Linear(cond_channels, channels)
-        self.linear_beta = nn.Linear(cond_channels, channels)
-
-    def forward(self, x: torch.Tensor, cond: torch.Tensor):
-        gamma = self.linear_gamma(cond)
-        beta = self.linear_beta(cond)
-
-        x = F.layer_norm(x.mT, (self.channels,), gamma, beta, self.eps)
-        return x.mT

+ 0 - 328
fish_speech/models/vqgan/modules/transformer.py

@@ -1,328 +0,0 @@
-import math
-
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-from fish_speech.models.vqgan.modules.normalization import LayerNorm
-from fish_speech.models.vqgan.utils import convert_pad_shape
-
-
-# TODO add conditioning on language
-# TODO check whether we need to stop gradient for speaker embedding
-class RelativePositionTransformer(nn.Module):
-    def __init__(
-        self,
-        in_channels: int,
-        hidden_channels: int,
-        out_channels: int,
-        hidden_channels_ffn: int,
-        n_heads: int,
-        n_layers: int,
-        kernel_size=1,
-        dropout=0.0,
-        window_size=4,
-        gin_channels=0,
-        speaker_cond_layer=0,
-    ):
-        super().__init__()
-        assert (
-            out_channels == hidden_channels
-        ), "out_channels must be equal to hidden_channels"
-
-        self.n_layers = n_layers
-        self.speaker_cond_layer = speaker_cond_layer
-
-        self.drop = nn.Dropout(dropout)
-        self.attn_layers = nn.ModuleList()
-        self.norm_layers_1 = nn.ModuleList()
-        self.ffn_layers = nn.ModuleList()
-        self.norm_layers_2 = nn.ModuleList()
-        for i in range(self.n_layers):
-            self.attn_layers.append(
-                MultiHeadAttention(
-                    hidden_channels if i != 0 else in_channels,
-                    hidden_channels,
-                    n_heads,
-                    p_dropout=dropout,
-                    window_size=window_size,
-                )
-            )
-            self.norm_layers_1.append(LayerNorm(hidden_channels))
-            self.ffn_layers.append(
-                FFN(
-                    hidden_channels,
-                    hidden_channels,
-                    hidden_channels_ffn,
-                    kernel_size,
-                    p_dropout=dropout,
-                )
-            )
-            self.norm_layers_2.append(LayerNorm(hidden_channels))
-
-        if gin_channels != 0:
-            self.cond = nn.Conv1d(gin_channels, hidden_channels, 1)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        x_mask: torch.Tensor,
-        g: torch.Tensor = None,
-    ):
-        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
-        x = x * x_mask
-        for i in range(self.n_layers):
-            # TODO consider using other conditioning
-            # TODO https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/modules/attentions.py#L12
-            if i == self.speaker_cond_layer and g is not None:
-                # ! g = torch.detach(g)
-                x = x + self.cond(g)
-                x = x * x_mask
-
-            y = self.attn_layers[i](x, x, attn_mask)
-            y = self.drop(y)
-            x = self.norm_layers_1[i](x + y)
-
-            y = self.ffn_layers[i](x, x_mask)
-            y = self.drop(y)
-            x = self.norm_layers_2[i](x + y)
-        x = x * x_mask
-        return x
-
-
-class MultiHeadAttention(nn.Module):
-    def __init__(
-        self,
-        channels,
-        out_channels,
-        n_heads,
-        p_dropout=0.0,
-        window_size=None,
-        heads_share=True,
-        block_length=None,
-        proximal_bias=False,
-        proximal_init=False,
-    ):
-        super().__init__()
-        assert channels % n_heads == 0
-
-        self.channels = channels
-        self.out_channels = out_channels
-        self.n_heads = n_heads
-        self.p_dropout = p_dropout
-        self.window_size = window_size
-        self.heads_share = heads_share
-        self.block_length = block_length
-        self.proximal_bias = proximal_bias
-        self.proximal_init = proximal_init
-        self.attn = None
-
-        self.k_channels = channels // n_heads
-        self.conv_q = nn.Linear(channels, channels)
-        self.conv_k = nn.Linear(channels, channels)
-        self.conv_v = nn.Linear(channels, channels)
-        self.conv_o = nn.Linear(channels, out_channels)
-        self.drop = nn.Dropout(p_dropout)
-
-        if window_size is not None:
-            n_heads_rel = 1 if heads_share else n_heads
-            rel_stddev = self.k_channels**-0.5
-            self.emb_rel_k = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-            self.emb_rel_v = nn.Parameter(
-                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
-                * rel_stddev
-            )
-
-        nn.init.xavier_uniform_(self.conv_q.weight)
-        nn.init.xavier_uniform_(self.conv_k.weight)
-        nn.init.xavier_uniform_(self.conv_v.weight)
-        if proximal_init:
-            with torch.no_grad():
-                self.conv_k.weight.copy_(self.conv_q.weight)
-                self.conv_k.bias.copy_(self.conv_q.bias)
-
-    def forward(self, x, c, attn_mask=None):
-        q = self.conv_q(x.mT).mT
-        k = self.conv_k(c.mT).mT
-        v = self.conv_v(c.mT).mT
-
-        x, self.attn = self.attention(q, k, v, mask=attn_mask)
-
-        x = self.conv_o(x.mT).mT
-        return x
-
-    def attention(self, query, key, value, mask=None):
-        # reshape [b, d, t] -> [b, n_h, t, d_k]
-        b, d, t_s, t_t = (*key.size(), query.size(2))
-        query = query.view(b, self.n_heads, self.k_channels, t_t).mT
-        key = key.view(b, self.n_heads, self.k_channels, t_s).mT
-        value = value.view(b, self.n_heads, self.k_channels, t_s).mT
-
-        scores = torch.matmul(query / math.sqrt(self.k_channels), key.mT)
-        if self.window_size is not None:
-            assert (
-                t_s == t_t
-            ), "Relative attention is only available for self-attention."
-            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
-            rel_logits = self._matmul_with_relative_keys(
-                query / math.sqrt(self.k_channels), key_relative_embeddings
-            )
-            scores_local = self._relative_position_to_absolute_position(rel_logits)
-            scores = scores + scores_local
-        if self.proximal_bias:
-            assert t_s == t_t, "Proximal bias is only available for self-attention."
-            scores = scores + self._attention_bias_proximal(t_s).to(
-                device=scores.device, dtype=scores.dtype
-            )
-        if mask is not None:
-            scores = scores.masked_fill(mask == 0, -1e4)
-            if self.block_length is not None:
-                assert (
-                    t_s == t_t
-                ), "Local attention is only available for self-attention."
-                block_mask = (
-                    torch.ones_like(scores)
-                    .triu(-self.block_length)
-                    .tril(self.block_length)
-                )
-                scores = scores.masked_fill(block_mask == 0, -1e4)
-        p_attn = F.softmax(scores, dim=-1)  # [b, n_h, t_t, t_s]
-        p_attn = self.drop(p_attn)
-        output = torch.matmul(p_attn, value)
-        if self.window_size is not None:
-            relative_weights = self._absolute_position_to_relative_position(p_attn)
-            value_relative_embeddings = self._get_relative_embeddings(
-                self.emb_rel_v, t_s
-            )
-            output = output + self._matmul_with_relative_values(
-                relative_weights, value_relative_embeddings
-            )
-        output = output.mT.contiguous().view(
-            b, d, t_t
-        )  # [b, n_h, t_t, d_k] -> [b, d, t_t]
-        return output, p_attn
-
-    def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor):
-        """
-        x: [b, h, l, m]
-        y: [h or 1, m, d]
-        ret: [b, h, l, d]
-        """
-        return torch.matmul(x, y.unsqueeze(0))
-
-    def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor):
-        """
-        x: [b, h, l, d]
-        y: [h or 1, m, d]
-        ret: [b, h, l, m]
-        """
-        return torch.matmul(x, y.unsqueeze(0).mT)
-
-    def _get_relative_embeddings(self, relative_embeddings, length):
-        max_relative_position = 2 * self.window_size + 1
-        # Pad first before slice to avoid using cond ops.
-        pad_length = max(length - (self.window_size + 1), 0)
-        slice_start_position = max((self.window_size + 1) - length, 0)
-        slice_end_position = slice_start_position + 2 * length - 1
-        if pad_length > 0:
-            padded_relative_embeddings = F.pad(
-                relative_embeddings,
-                convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
-            )
-        else:
-            padded_relative_embeddings = relative_embeddings
-        used_relative_embeddings = padded_relative_embeddings[
-            :, slice_start_position:slice_end_position
-        ]
-        return used_relative_embeddings
-
-    def _relative_position_to_absolute_position(self, x):
-        """
-        x: [b, h, l, 2*l-1]
-        ret: [b, h, l, l]
-        """
-        batch, heads, length, _ = x.size()
-        # Concat columns of pad to shift from relative to absolute indexing.
-        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
-
-        # Concat extra elements so to add up to shape (len+1, 2*len-1).
-        x_flat = x.view([batch, heads, length * 2 * length])
-        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
-
-        # Reshape and slice out the padded elements.
-        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
-            :, :, :length, length - 1 :
-        ]
-        return x_final
-
-    def _absolute_position_to_relative_position(self, x):
-        """
-        x: [b, h, l, l]
-        ret: [b, h, l, 2*l-1]
-        """
-        batch, heads, length, _ = x.size()
-        # pad along column
-        x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
-        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
-        # add 0's in the beginning that will skew the elements after reshape
-        x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
-        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
-        return x_final
-
-    def _attention_bias_proximal(self, length):
-        """Bias for self-attention to encourage attention to close positions.
-        Args:
-          length: an integer scalar.
-        Returns:
-          a Tensor with shape [1, 1, length, length]
-        """
-        r = torch.arange(length, dtype=torch.float32)
-        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
-        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
-
-
-class FFN(nn.Module):
-    def __init__(
-        self,
-        in_channels,
-        out_channels,
-        filter_channels,
-        kernel_size,
-        p_dropout=0.0,
-        causal=False,
-    ):
-        super().__init__()
-        self.kernel_size = kernel_size
-        self.padding = self._causal_padding if causal else self._same_padding
-
-        self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
-        self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
-        self.drop = nn.Dropout(p_dropout)
-
-    def forward(self, x, x_mask):
-        x = self.conv_1(self.padding(x * x_mask))
-        x = torch.relu(x)
-        x = self.drop(x)
-        x = self.conv_2(self.padding(x * x_mask))
-        return x * x_mask
-
-    def _causal_padding(self, x):
-        if self.kernel_size == 1:
-            return x
-        pad_l = self.kernel_size - 1
-        pad_r = 0
-        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
-        x = F.pad(x, convert_pad_shape(padding))
-        return x
-
-    def _same_padding(self, x):
-        if self.kernel_size == 1:
-            return x
-        pad_l = (self.kernel_size - 1) // 2
-        pad_r = self.kernel_size // 2
-        padding = [[0, 0], [0, 0], [pad_l, pad_r]]
-        x = F.pad(x, convert_pad_shape(padding))
-        return x

+ 2 - 3
fish_speech/models/vqgan/utils.py

@@ -40,7 +40,7 @@ def plot_mel(data, titles=None):
         mel = data[i]
         mel = data[i]
 
 
         if isinstance(mel, torch.Tensor):
         if isinstance(mel, torch.Tensor):
-            mel = mel.detach().cpu().numpy()
+            mel = mel.float().detach().cpu().numpy()
 
 
         axes[i][0].imshow(mel, origin="lower")
         axes[i][0].imshow(mel, origin="lower")
         axes[i][0].set_aspect(2.5, adjustable="box")
         axes[i][0].set_aspect(2.5, adjustable="box")
@@ -73,9 +73,8 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
 
 
 
 
 @torch.jit.script
 @torch.jit.script
-def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
     n_channels_int = n_channels[0]
     n_channels_int = n_channels[0]
-    in_act = input_a + input_b
     t_act = torch.tanh(in_act[:, :n_channels_int, :])
     t_act = torch.tanh(in_act[:, :n_channels_int, :])
     s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
     s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
     acts = t_act * s_act
     acts = t_act * s_act