Lengyue 2 лет назад
Родитель
Сommit
de7a4d187b

+ 0 - 112
fish_speech/configs/hubert_vq.yaml

@@ -1,112 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: hubert_vq
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: 4
-  strategy: ddp_find_unused_parameters_true
-  precision: 32
-  max_steps: 1_000_000
-  val_check_interval: 5000
-
-sample_rate: 32000
-hop_length: 640
-num_mels: 128
-n_fft: 2048
-win_length: 2048
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/filelist.split.train
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  slice_frames: 512
-
-val_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/filelist.split.valid
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-data:
-  _target_: fish_speech.datasets.vqgan.VQGANDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 32
-  val_batch_size: 4
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vqgan.VQGAN
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  segment_size: 20480
-
-  generator:
-    _target_: fish_speech.models.vqgan.modules.models.SynthesizerTrn
-    in_channels: 2048
-    spec_channels: ${num_mels}
-    segment_size: "${eval: '${model.segment_size} // ${hop_length}'}"
-    inter_channels: 192
-    hidden_channels: 192
-    filter_channels: 768
-    n_heads: 2
-    n_layers: 6
-    n_layers_q: 6
-    n_layers_flow: 6
-    n_layers_spk: 4
-    n_flows: 4
-    kernel_size: 3
-    p_dropout: 0.1
-    speaker_cond_layer: 2
-    resblock: "1"
-    resblock_kernel_sizes: [3, 7, 11]
-    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    upsample_rates: [10, 8, 2, 2, 2]
-    upsample_initial_channel: 512
-    upsample_kernel_sizes: [16, 16, 8, 2, 2]
-    gin_channels: 512 # basically the speaker embedding size
-    kmeans_ckpt: results/hubert-vq-pretrain/kmeans.pt
-    codebook_size: 2048
-
-  discriminator:
-    _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
-
-  mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 2e-4
-    betas: [0.8, 0.99]
-    eps: 1e-5
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.LambdaLR
-    _partial_: true
-    lr_lambda:
-      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-      _partial_: true
-      num_warmup_steps: 0
-      num_training_steps: ${trainer.max_steps}
-      final_lr_ratio: 0.05
-
-callbacks:
-  grad_norm_monitor:
-    sub_module: generator
-
-# Resume from rcell's checkpoint
-ckpt_path: results/hubert-vq-pretrain/rcell/ckpt_23000_pl.pth
-resume_weights_only: true

+ 25 - 8
fish_speech/configs/vq_diffusion.yaml

@@ -2,7 +2,7 @@ defaults:
   - base
   - _self_
 
-project: vq_diffusion
+project: vq_naive
 
 # Lightning Trainer
 trainer:
@@ -86,12 +86,28 @@ model:
     num_layers: 4
     p_dropout: 0.1
   
-  denoiser:
-    _target_: fish_speech.models.vq_diffusion.wavenet.WaveNet
-    d_encoder: 128
-    mel_channels: 100
-    residual_channels: 512
-    residual_layers: 20
+  decoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
+    in_channels: 128
+    out_channels: 100
+    hidden_channels: 192
+    hidden_channels_ffn: 768
+    n_heads: 2
+    n_layers: 6
+    kernel_size: 1
+    use_vae: false
+    dropout: 0
+    gin_channels: 128
+    speaker_cond_layer: 0
+
+  postnet:
+    _target_: fish_speech.models.vq_diffusion.convnext_1d.ConvNext1DModel
+    in_channels: 100
+    out_channels: 100
+    intermediate_dim: 256
+    mlp_dim: 1024
+    num_layers: 6
+    dilation_cycle_length: 2
 
   vocoder:
     _target_: fish_speech.models.vq_diffusion.bigvgan.BigVGAN
@@ -137,4 +153,5 @@ callbacks:
       - vq_encoder
       - text_encoder
       - speaker_encoder
-      - denoiser
+      - decoder
+      - postnet

+ 146 - 0
fish_speech/configs/vqgan.yaml

@@ -0,0 +1,146 @@
+defaults:
+  - base
+  - _self_
+
+project: vqgan
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: 8
+  strategy: ddp_find_unused_parameters_true
+  precision: bf16-mixed
+  max_steps: 1_000_000
+  val_check_interval: 5000
+
+sample_rate: 22050
+hop_length: 256
+num_mels: 80
+n_fft: 1024
+win_length: 1024
+segment_size: 512
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.split.train
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: ${segment_size}
+
+val_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.split.valid
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+data:
+  _target_: fish_speech.datasets.vqgan.VQGANDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 16
+  val_batch_size: 4
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vqgan.VQGAN
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  segment_size: 8192
+  freeze_hifigan: true
+
+  downsample:
+    _target_: fish_speech.models.vq_diffusion.lit_module.ConvDownSample
+    dims: ["${num_mels}", 512, "${num_mels}"]
+    kernel_sizes: [3, 3]
+    strides: [2, 2]
+
+  text_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
+    in_channels: ${num_mels}
+    out_channels: ${num_mels}
+    hidden_channels: 192
+    hidden_channels_ffn: 768
+    n_heads: 2
+    n_layers: 6
+    kernel_size: 1
+    dropout: 0.1
+    use_vae: false
+
+  vq_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
+    in_channels: ${num_mels}
+    vq_channels: ${num_mels}
+    codebook_size: 4096
+    downsample: 1
+
+  speaker_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
+    in_channels: ${num_mels}
+    hidden_channels: 192
+    out_channels: ${num_mels}
+    num_heads: 2
+    num_layers: 4
+    p_dropout: 0.1
+
+  decoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
+    in_channels: ${num_mels}
+    out_channels: ${num_mels}
+    hidden_channels: 192
+    hidden_channels_ffn: 768
+    n_heads: 2
+    n_layers: 6
+    kernel_size: 1
+    use_vae: false
+    dropout: 0
+    gin_channels: ${num_mels}
+    speaker_cond_layer: 0
+
+  generator:
+    _target_: fish_speech.models.vqgan.modules.decoder.Generator
+    initial_channel: ${num_mels}
+    resblock: "1"
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    upsample_rates: [8, 8, 2, 2]
+    upsample_initial_channel: 512
+    upsample_kernel_sizes: [16, 16, 4, 4]
+    ckpt_path: "checkpoints/hifigan-v1-universal-22050/g_02500000"
+
+  discriminator:
+    _target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
+    ckpt_path: checkpoints/hifigan-v1-universal-22050/do_02500000
+
+  mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+    f_min: 0
+    f_max: 8000
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 2e-4
+    betas: [0.8, 0.99]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.ExponentialLR
+    _partial_: true
+    gamma: 0.999999  # Estimated base on LibriTTS dataset
+
+callbacks:
+  grad_norm_monitor:
+    sub_module: 
+      - generator
+      - discriminator
+      - text_encoder
+      - vq_encoder
+      - speaker_encoder
+      - decoder

+ 8 - 25
fish_speech/datasets/vqgan.py

@@ -42,22 +42,16 @@ class VQGANDataset(Dataset):
         file = self.files[idx]
 
         audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
-        features = np.load(file.with_suffix(".npy"))  # (T, 1024)
 
         # Slice audio and features
-        if self.slice_frames is not None and features.shape[0] > self.slice_frames:
-            start = np.random.randint(0, features.shape[0] - self.slice_frames)
-            features = features[start : start + self.slice_frames]
-
-            start_in_seconds, end_in_seconds = (
-                start * 320 / 16000,
-                (start + self.slice_frames) * 320 / 16000,
+        if (
+            self.slice_frames is not None
+            and audio.shape[0] > self.slice_frames * self.hop_length
+        ):
+            start = np.random.randint(
+                0, audio.shape[0] - self.slice_frames * self.hop_length
             )
-            audio = audio[
-                int(start_in_seconds * self.sample_rate) : int(
-                    end_in_seconds * self.sample_rate
-                )
-            ]
+            audio = audio[start : start + self.slice_frames * self.hop_length]
 
         if len(audio) == 0:
             return None
@@ -68,7 +62,6 @@ class VQGANDataset(Dataset):
 
         return {
             "audio": torch.from_numpy(audio),
-            "features": torch.from_numpy(features),
         }
 
     def __getitem__(self, idx):
@@ -85,28 +78,18 @@ class VQGANCollator:
         batch = [x for x in batch if x is not None]
 
         audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
-        feature_lengths = torch.tensor([len(x["features"]) for x in batch])
-
         audio_maxlen = audio_lengths.max()
-        feature_maxlen = feature_lengths.max()
 
         # Rounds up to nearest multiple of 2 (audio_lengths)
-        audios, features = [], []
+        audios = []
         for x in batch:
             audios.append(
                 torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
             )
-            features.append(
-                torch.nn.functional.pad(
-                    x["features"], (0, feature_maxlen - len(x["features"]))
-                )
-            )
 
         return {
             "audios": torch.stack(audios),
-            "features": torch.stack(features),
             "audio_lengths": audio_lengths,
-            "feature_lengths": feature_lengths,
         }
 
 

+ 33 - 44
fish_speech/models/vq_diffusion/lit_module.py

@@ -83,7 +83,8 @@ class VQDiffusion(L.LightningModule):
         vq_encoder: VQEncoder,
         speaker_encoder: SpeakerEncoder,
         text_encoder: TextEncoder,
-        denoiser: ConvNext1DModel,
+        decoder: ConvNext1DModel,
+        postnet: ConvNext1DModel,
         vocoder: nn.Module,
         hop_length: int = 640,
         sample_rate: int = 32000,
@@ -109,7 +110,8 @@ class VQDiffusion(L.LightningModule):
         self.vq_encoder = vq_encoder
         self.speaker_encoder = speaker_encoder
         self.text_encoder = text_encoder
-        self.denoiser = denoiser
+        self.decoder = decoder
+        self.postnet = postnet
         self.downsample = downsample
 
         self.vocoder = vocoder
@@ -187,37 +189,38 @@ class VQDiffusion(L.LightningModule):
             text_features, size=gt_mels.shape[2], mode="nearest"
         )
 
-        text_features = text_features + speaker_features
-
         # Sample noise that we'll add to the images
-        normalized_gt_mels = self.normalize_mels(gt_mels)
-        noise = torch.randn_like(normalized_gt_mels)
-
-        # Sample a random timestep for each image
-        timesteps = torch.randint(
-            0,
-            self.noise_scheduler.config.num_train_timesteps,
-            (normalized_gt_mels.shape[0],),
-            device=normalized_gt_mels.device,
-        ).long()
-
-        # Add noise to the clean images according to the noise magnitude at each timestep
-        # (this is the forward diffusion process)
-        noisy_images = self.noise_scheduler.add_noise(
-            normalized_gt_mels, noise, timesteps
-        )
+        normalized_gt_mels = gt_mels / 2.303
 
         # Predict
-        model_output = self.denoiser(noisy_images, timesteps, mel_masks, text_features)
+        mels = self.decoder(text_features, mel_masks, g=speaker_features)
+        t = torch.tensor([0] * mels.shape[0], device=mels.device, dtype=torch.long)
+        postnet_mels = self.postnet(mels, t, mel_masks)
 
         # MSE loss without the mask
-        noise_loss = (torch.abs(model_output * mel_masks - noise * mel_masks)).sum() / (
-            mel_masks.sum() * gt_mels.shape[1]
+        mel_loss = F.l1_loss(
+            mels * mel_masks,
+            normalized_gt_mels * mel_masks,
+        )
+
+        postnet_loss = F.l1_loss(
+            postnet_mels * mel_masks,
+            normalized_gt_mels * mel_masks,
+        )
+
+        self.log(
+            "train/mel_loss",
+            mel_loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
         )
 
         self.log(
-            "train/noise_loss",
-            noise_loss,
+            "train/postnet_loss",
+            postnet_loss,
             on_step=True,
             on_epoch=False,
             prog_bar=True,
@@ -235,7 +238,7 @@ class VQDiffusion(L.LightningModule):
             sync_dist=True,
         )
 
-        return noise_loss + vq_loss
+        return vq_loss + mel_loss + postnet_loss
 
     def validation_step(self, batch: Any, batch_idx: int):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
@@ -278,26 +281,12 @@ class VQDiffusion(L.LightningModule):
             text_features, size=gt_mels.shape[2], mode="nearest"
         )
 
-        text_features = text_features + speaker_features
-
         # Begin sampling
-        sampled_mels = torch.randn_like(gt_mels)
-        self.noise_scheduler.set_timesteps(50)
-
-        for t in tqdm(self.noise_scheduler.timesteps):
-            timesteps = torch.tensor([t], device=sampled_mels.device, dtype=torch.long)
-
-            # 1. predict noise model_output
-            model_output = self.denoiser(
-                sampled_mels, timesteps, mel_masks, text_features
-            )
-
-            # 2. compute previous image: x_t -> x_t-1
-            sampled_mels = self.noise_scheduler.step(
-                model_output, t, sampled_mels
-            ).prev_sample
+        mels = self.decoder(text_features, mel_masks, g=speaker_features)
+        t = torch.tensor([0] * mels.shape[0], device=mels.device, dtype=torch.long)
+        postnet_mels = self.postnet(mels, t, mel_masks)
 
-        sampled_mels = self.denormalize_mels(sampled_mels)
+        sampled_mels = postnet_mels * 2.303
         sampled_mels = sampled_mels * mel_masks
 
         with torch.autocast(device_type=sampled_mels.device.type, enabled=False):

+ 148 - 89
fish_speech/models/vqgan/lit_module.py

@@ -16,9 +16,20 @@ from fish_speech.models.vqgan.losses import (
     generator_loss,
     kl_loss,
 )
+from fish_speech.models.vqgan.modules.decoder import Generator
 from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
-from fish_speech.models.vqgan.modules.models import SynthesizerTrn
-from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
+from fish_speech.models.vqgan.modules.encoders import (
+    ConvDownSampler,
+    SpeakerEncoder,
+    TextEncoder,
+    VQEncoder,
+)
+from fish_speech.models.vqgan.utils import (
+    plot_mel,
+    rand_slice_segments,
+    sequence_mask,
+    slice_segments,
+)
 
 
 class VQGAN(L.LightningModule):
@@ -26,12 +37,18 @@ class VQGAN(L.LightningModule):
         self,
         optimizer: Callable,
         lr_scheduler: Callable,
-        generator: SynthesizerTrn,
+        downsample: ConvDownSampler,
+        vq_encoder: VQEncoder,
+        speaker_encoder: SpeakerEncoder,
+        text_encoder: TextEncoder,
+        decoder: TextEncoder,
+        generator: Generator,
         discriminator: EnsembleDiscriminator,
         mel_transform: nn.Module,
         segment_size: int = 20480,
         hop_length: int = 640,
         sample_rate: int = 32000,
+        freeze_hifigan: bool = False,
     ):
         super().__init__()
 
@@ -40,6 +57,11 @@ class VQGAN(L.LightningModule):
         self.lr_scheduler_builder = lr_scheduler
 
         # Generator and discriminators
+        self.downsample = downsample
+        self.vq_encoder = vq_encoder
+        self.speaker_encoder = speaker_encoder
+        self.text_encoder = text_encoder
+        self.decoder = decoder
         self.generator = generator
         self.discriminator = discriminator
         self.mel_transform = mel_transform
@@ -48,13 +70,31 @@ class VQGAN(L.LightningModule):
         self.segment_size = segment_size
         self.hop_length = hop_length
         self.sampling_rate = sample_rate
+        self.freeze_hifigan = freeze_hifigan
 
         # Disable automatic optimization
         self.automatic_optimization = False
 
+        # Stage 1: Train the VQ only
+        if self.freeze_hifigan:
+            for p in self.discriminator.parameters():
+                p.requires_grad = False
+
+            for p in self.generator.parameters():
+                p.requires_grad = False
+
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
-        optimizer_generator = self.optimizer_builder(self.generator.parameters())
+        optimizer_generator = self.optimizer_builder(
+            itertools.chain(
+                self.downsample.parameters(),
+                self.vq_encoder.parameters(),
+                self.speaker_encoder.parameters(),
+                self.text_encoder.parameters(),
+                self.decoder.parameters(),
+                self.generator.parameters(),
+            )
+        )
         optimizer_discriminator = self.optimizer_builder(
             self.discriminator.parameters()
         )
@@ -85,30 +125,49 @@ class VQGAN(L.LightningModule):
         optim_g, optim_d = self.optimizers()
 
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-        features, feature_lengths = batch["features"], batch["feature_lengths"]
-        audios = audios[:, None, :]
 
         audios = audios.float()
-        # features = features.long()
+        audios = audios[:, None, :]
 
         with torch.no_grad():
             gt_mels = self.mel_transform(audios)
-            gt_mels = gt_mels[:, :, : features.shape[1]]
-
-        (
-            y_hat,
-            ids_slice,
-            x_mask,
-            y_mask,
-            (z_q, z_p),
-            (m_p, logs_p),
-            (m_q, logs_q),
-            # vq_loss,
-        ) = self.generator(features, feature_lengths, gt_mels)
-
-        y_hat_mel = self.mel_transform(y_hat.squeeze(1))
-        y_mel = slice_segments(gt_mels, ids_slice, self.segment_size // self.hop_length)
-        y = slice_segments(audios, ids_slice * self.hop_length, self.segment_size)
+
+        if self.downsample is not None:
+            features = self.downsample(gt_mels)
+
+        mel_lengths = audio_lengths // self.hop_length
+        feature_lengths = (
+            audio_lengths
+            / self.hop_length
+            / (self.downsample.total_strides if self.downsample is not None else 1)
+        ).long()
+
+        feature_masks = torch.unsqueeze(
+            sequence_mask(feature_lengths, features.shape[2]), 1
+        ).to(gt_mels.dtype)
+        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
+            gt_mels.dtype
+        )
+
+        speaker_features = self.speaker_encoder(features, feature_masks)
+
+        # vq_features is 50 hz, need to convert to true mel size
+        text_features = self.text_encoder(features, feature_masks)
+        text_features, loss_vq = self.vq_encoder(text_features, feature_masks)
+        text_features = F.interpolate(
+            text_features, size=gt_mels.shape[2], mode="nearest"
+        )
+
+        # Sample mels
+        decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
+        fake_audios = self.generator(decoded_mels)
+
+        y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
+
+        y, ids_slice = rand_slice_segments(audios, audio_lengths, self.segment_size)
+        y_hat = slice_segments(fake_audios, ids_slice, self.segment_size)
+
+        assert y.shape == y_hat.shape, f"{y.shape} != {y_hat.shape}"
 
         # Discriminator
         y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
@@ -126,41 +185,29 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
 
-        optim_d.zero_grad()
-        self.manual_backward(loss_disc_all)
-        self.clip_gradients(
-            optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
-        )
-        optim_d.step()
+        # Since we don't want to update the discriminator, we skip the backward pass
+        if self.freeze_hifigan is False:
+            optim_d.zero_grad()
+            self.manual_backward(loss_disc_all)
+            self.clip_gradients(
+                optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+            )
+            optim_d.step()
 
         y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
 
         with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_mel = F.l1_loss(y_mel, y_hat_mel)
+            loss_decoded_mel = F.l1_loss(gt_mels, decoded_mels)
+            loss_mel = F.l1_loss(gt_mels, y_hat_mels)
             loss_adv, _ = generator_loss(y_d_hat_g)
             loss_fm = feature_loss(fmap_r, fmap_g)
-            loss_kl = kl_loss(
-                z_p=z_p,
-                logs_q=logs_q,
-                m_p=m_p,
-                logs_p=logs_p,
-                z_mask=x_mask,
-            )
 
-            # Cyclical kl loss
-            # then 500 steps linear: 0.1
-            # then 500 steps 0.1
-            # then go back to 0
+            mel_loss_weight = 25 if self.freeze_hifigan is True else 45
 
-            if self.global_step < 100000:
-                beta = 1e-6
-            else:
-                beta = self.global_step % 1000
-                beta = min(beta, 500) / 500 * 0.1 + 1e-6
+            loss_gen_all = loss_mel * mel_loss_weight + loss_fm + loss_adv + loss_vq
 
-            loss_gen_all = (
-                loss_mel * 45 + loss_fm + loss_adv + loss_kl * beta
-            )  # + vq_loss
+            if self.freeze_hifigan is True:
+                loss_gen_all += loss_decoded_mel * mel_loss_weight
 
         self.log(
             "train/generator/loss",
@@ -171,6 +218,15 @@ class VQGAN(L.LightningModule):
             logger=True,
             sync_dist=True,
         )
+        self.log(
+            "train/generator/loss_decoded_mel",
+            loss_decoded_mel,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
         self.log(
             "train/generator/loss_mel",
             loss_mel,
@@ -199,23 +255,14 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
         self.log(
-            "train/generator/loss_kl",
-            loss_kl,
+            "train/generator/loss_vq",
+            loss_vq,
             on_step=True,
             on_epoch=False,
             prog_bar=False,
             logger=True,
             sync_dist=True,
         )
-        # self.log(
-        #     "train/generator/loss_vq",
-        #     vq_loss,
-        #     on_step=True,
-        #     on_epoch=False,
-        #     prog_bar=False,
-        #     logger=True,
-        #     sync_dist=True,
-        # )
 
         optim_g.zero_grad()
         self.manual_backward(loss_gen_all)
@@ -231,25 +278,50 @@ class VQGAN(L.LightningModule):
 
     def validation_step(self, batch: Any, batch_idx: int):
         audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-        features, feature_lengths = batch["features"], batch["feature_lengths"]
 
         audios = audios.float()
-        # features = features.float()
         audios = audios[:, None, :]
 
         gt_mels = self.mel_transform(audios)
-        gt_mels = gt_mels[:, :, : features.shape[1]]
 
-        fake_audios = self.generator.infer(features, feature_lengths, gt_mels)
-        posterior_audios = self.generator.reconstruct(gt_mels, feature_lengths)
+        if self.downsample is not None:
+            features = self.downsample(gt_mels)
+
+        mel_lengths = audio_lengths // self.hop_length
+        feature_lengths = (
+            audio_lengths
+            / self.hop_length
+            / (self.downsample.total_strides if self.downsample is not None else 1)
+        ).long()
+
+        feature_masks = torch.unsqueeze(
+            sequence_mask(feature_lengths, features.shape[2]), 1
+        ).to(gt_mels.dtype)
+        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
+            gt_mels.dtype
+        )
+
+        speaker_features = self.speaker_encoder(features, feature_masks)
+
+        # vq_features is 50 hz, need to convert to true mel size
+        text_features = self.text_encoder(features, feature_masks)
+        text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
+        text_features = F.interpolate(
+            text_features, size=gt_mels.shape[2], mode="nearest"
+        )
+
+        # Sample mels
+        decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
+        fake_audios = self.generator(decoded_mels)
 
         fake_mels = self.mel_transform(fake_audios.squeeze(1))
-        posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
 
-        min_mel_length = min(gt_mels.shape[-1], fake_mels.shape[-1])
+        min_mel_length = min(
+            decoded_mels.shape[-1], gt_mels.shape[-1], fake_mels.shape[-1]
+        )
+        decoded_mels = decoded_mels[:, :, :min_mel_length]
         gt_mels = gt_mels[:, :, :min_mel_length]
         fake_mels = fake_mels[:, :, :min_mel_length]
-        posterior_mels = posterior_mels[:, :, :min_mel_length]
 
         mel_loss = F.l1_loss(gt_mels, fake_mels)
         self.log(
@@ -265,19 +337,17 @@ class VQGAN(L.LightningModule):
         for idx, (
             mel,
             gen_mel,
-            post_mel,
+            decode_mel,
             audio,
             gen_audio,
-            post_audio,
             audio_len,
         ) in enumerate(
             zip(
                 gt_mels,
                 fake_mels,
-                posterior_mels,
-                audios,
-                fake_audios,
-                posterior_audios,
+                decoded_mels,
+                audios.detach().float(),
+                fake_audios.detach().float(),
                 audio_lengths,
             )
         ):
@@ -286,13 +356,13 @@ class VQGAN(L.LightningModule):
             image_mels = plot_mel(
                 [
                     gen_mel[:, :mel_len],
-                    post_mel[:, :mel_len],
+                    decode_mel[:, :mel_len],
                     mel[:, :mel_len],
                 ],
                 [
-                    "Generated Spectrogram",
-                    "Posterior Spectrogram",
-                    "Ground-Truth Spectrogram",
+                    "Generated",
+                    "Decoded",
+                    "Ground-Truth",
                 ],
             )
 
@@ -311,11 +381,6 @@ class VQGAN(L.LightningModule):
                                 sample_rate=self.sampling_rate,
                                 caption="prediction",
                             ),
-                            wandb.Audio(
-                                post_audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="posterior",
-                            ),
                         ],
                     },
                 )
@@ -338,11 +403,5 @@ class VQGAN(L.LightningModule):
                     self.global_step,
                     sample_rate=self.sampling_rate,
                 )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/posterior",
-                    post_audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
 
             plt.close(image_mels)

+ 46 - 3
fish_speech/models/vqgan/modules/decoder.py

@@ -19,12 +19,13 @@ class Generator(nn.Module):
         upsample_initial_channel,
         upsample_kernel_sizes,
         gin_channels=0,
+        ckpt_path=None,
     ):
         super(Generator, self).__init__()
         self.num_kernels = len(resblock_kernel_sizes)
         self.num_upsamples = len(upsample_rates)
-        self.conv_pre = nn.Conv1d(
-            initial_channel, upsample_initial_channel, 7, 1, padding=3
+        self.conv_pre = weight_norm(
+            nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
         )
         resblock = ResBlock1 if resblock == "1" else ResBlock2
 
@@ -50,12 +51,15 @@ class Generator(nn.Module):
             ):
                 self.resblocks.append(resblock(ch, k, d))
 
-        self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+        self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
         self.ups.apply(init_weights)
 
         if gin_channels != 0:
             self.cond = nn.Linear(gin_channels, upsample_initial_channel)
 
+        if ckpt_path is not None:
+            self.load_state_dict(torch.load(ckpt_path)["generator"], strict=True)
+
     def forward(self, x, g=None):
         x = self.conv_pre(x)
         if g is not None:
@@ -225,3 +229,42 @@ class ResBlock2(nn.Module):
     def remove_weight_norm(self):
         for l in self.convs:
             remove_weight_norm(l)
+
+
+if __name__ == "__main__":
+    import librosa
+    import soundfile as sf
+
+    from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
+
+    gen = Generator(
+        80,
+        "1",
+        [3, 7, 11],
+        [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+        [8, 8, 2, 2],
+        512,
+        [16, 16, 4, 4],
+        ckpt_path="checkpoints/hifigan-v1-universal-22050/g_02500000",
+    )
+
+    spec = LogMelSpectrogram(
+        sample_rate=22050,
+        n_fft=1024,
+        win_length=1024,
+        hop_length=256,
+        n_mels=80,
+        f_min=0.0,
+        f_max=8000.0,
+    )
+
+    audio = librosa.load("data/StarRail/Chinese/符玄/archive_fuxuan_9.wav", sr=22050)[0]
+    audio = torch.from_numpy(audio).unsqueeze(0)
+
+    spec = spec(audio)
+    print(spec.shape)
+
+    audio = gen(spec)
+    print(audio.shape)
+
+    sf.write("test.wav", audio.detach().squeeze().numpy(), 22050)

+ 36 - 11
fish_speech/models/vqgan/modules/discriminator.py

@@ -1,7 +1,7 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from torch.nn.utils.parametrizations import spectral_norm, weight_norm
+from torch.nn.utils import spectral_norm, weight_norm
 
 from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
 from fish_speech.models.vqgan.utils import get_padding
@@ -91,11 +91,12 @@ class DiscriminatorS(nn.Module):
         norm_f = weight_norm if use_spectral_norm == False else spectral_norm
         self.convs = nn.ModuleList(
             [
-                norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)),
-                norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)),
-                norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)),
-                norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
-                norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+                norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
+                norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
+                norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
+                norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
+                norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
+                norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
                 norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
             ]
         )
@@ -116,16 +117,34 @@ class DiscriminatorS(nn.Module):
 
 
 class EnsembleDiscriminator(nn.Module):
-    def __init__(self, use_spectral_norm=False):
+    def __init__(self, ckpt_path=None):
         super(EnsembleDiscriminator, self).__init__()
         periods = [2, 3, 5, 7, 11]  # [1, 2, 3, 5, 7, 11]
 
-        discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
-        discs = discs + [
-            DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
-        ]
+        discs = [DiscriminatorS(use_spectral_norm=True)]
+        discs = discs + [DiscriminatorP(i, use_spectral_norm=False) for i in periods]
         self.discriminators = nn.ModuleList(discs)
 
+        if ckpt_path is not None:
+            self.restore_from_ckpt(ckpt_path)
+
+    def restore_from_ckpt(self, ckpt_path):
+        ckpt = torch.load(ckpt_path, map_location="cpu")
+        mpd, msd = ckpt["mpd"], ckpt["msd"]
+
+        all_keys = {}
+        for k, v in mpd.items():
+            keys = k.split(".")
+            keys[1] = str(int(keys[1]) + 1)
+            all_keys[".".join(keys)] = v
+
+        for k, v in msd.items():
+            if not k.startswith("discriminators.0"):
+                continue
+            all_keys[k] = v
+
+        self.load_state_dict(all_keys, strict=True)
+
     def forward(self, y, y_hat):
         y_d_rs = []
         y_d_gs = []
@@ -140,3 +159,9 @@ class EnsembleDiscriminator(nn.Module):
             fmap_gs.append(fmap_g)
 
         return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+if __name__ == "__main__":
+    m = EnsembleDiscriminator(
+        ckpt_path="checkpoints/hifigan-v1-universal-22050/do_02500000"
+    )

+ 53 - 1
fish_speech/models/vqgan/modules/encoders.py

@@ -13,6 +13,58 @@ from fish_speech.models.vqgan.modules.transformer import (
 from fish_speech.models.vqgan.utils import sequence_mask
 
 
+# * Ready and Tested
+class ConvDownSampler(nn.Module):
+    def __init__(
+        self,
+        dims: list,
+        kernel_sizes: list,
+        strides: list,
+    ):
+        super().__init__()
+
+        self.dims = dims
+        self.kernel_sizes = kernel_sizes
+        self.strides = strides
+        self.total_strides = np.prod(self.strides)
+
+        self.convs = nn.ModuleList(
+            [
+                nn.ModuleList(
+                    [
+                        nn.Conv1d(
+                            in_channels=self.dims[i],
+                            out_channels=self.dims[i + 1],
+                            kernel_size=self.kernel_sizes[i],
+                            stride=self.strides[i],
+                            padding=(self.kernel_sizes[i] - 1) // 2,
+                        ),
+                        nn.LayerNorm(self.dims[i + 1], elementwise_affine=True),
+                        nn.GELU(),
+                    ]
+                )
+                for i in range(len(self.dims) - 1)
+            ]
+        )
+
+        self.apply(self.init_weights)
+
+    def init_weights(self, m):
+        if isinstance(m, nn.Conv1d):
+            nn.init.normal_(m.weight, std=0.02)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.ones_(m.weight)
+            nn.init.zeros_(m.bias)
+
+    def forward(self, x):
+        for conv, norm, act in self.convs:
+            x = conv(x)
+            x = norm(x.mT).mT
+            x = act(x)
+
+        return x
+
+
 # * Ready and Tested
 class TextEncoder(nn.Module):
     def __init__(
@@ -229,7 +281,7 @@ class VQEncoder(nn.Module):
         in_channels: int = 1024,
         vq_channels: int = 1024,
         codebook_size: int = 2048,
-        downsample: int = 2,
+        downsample: int = 1,
         kmeans_ckpt: Optional[str] = None,
     ):
         super().__init__()

+ 5 - 7
fish_speech/models/vqgan/utils.py

@@ -53,13 +53,11 @@ def plot_mel(data, titles=None):
 
 
 def slice_segments(x, ids_str, segment_size=4):
-    ret = torch.zeros_like(x[:, :, :segment_size])
-    for i in range(x.size(0)):
-        idx_str = ids_str[i]
-        idx_end = idx_str + segment_size
-        ret[i] = x[i, :, idx_str:idx_end]
-
-    return ret
+    # Slice segments
+    gather_indices = ids_str[:, None, None] + torch.arange(
+        segment_size, device=x.device
+    )
+    return torch.gather(x, 2, gather_indices)
 
 
 def rand_slice_segments(x, x_lengths=None, segment_size=4):