Lengyue преди 2 години
родител
ревизия
79b5e1907b

+ 145 - 0
fish_speech/configs/vq_naive.yaml

@@ -0,0 +1,145 @@
+defaults:
+  - base
+  - _self_
+
+project: vq_naive
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: 1
+  strategy: ddp_find_unused_parameters_true
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
+  precision: bf16-mixed
+  max_steps: 100_000
+  val_check_interval: 5000
+
+sample_rate: 22050
+hop_length: 256
+num_mels: 80
+n_fft: 1024
+win_length: 1024
+segment_size: 512
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.split.train
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: ${segment_size}
+
+val_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.split.valid
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+data:
+  _target_: fish_speech.datasets.vqgan.VQGANDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 8
+  batch_size: 128
+  val_batch_size: 16
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vqgan.VQNaive
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+  downsample:
+    _target_: fish_speech.models.vq_diffusion.lit_module.ConvDownSample
+    dims: ["${num_mels}", 512, 256]
+    kernel_sizes: [3, 3]
+    strides: [2, 2]
+
+  mel_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
+    in_channels: 256
+    out_channels: 256
+    hidden_channels: 192
+    hidden_channels_ffn: 768
+    n_heads: 2
+    n_layers: 6
+    kernel_size: 1
+    dropout: 0.1
+    use_vae: false
+
+  vq_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
+    in_channels: 256
+    vq_channels: 256
+    codebook_size: 4096
+    downsample: 1
+
+  speaker_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
+    in_channels: ${num_mels}
+    hidden_channels: 192
+    out_channels: 256
+    num_heads: 2
+    num_layers: 4
+    p_dropout: 0.1
+
+  decoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
+    in_channels: 256
+    out_channels: ${num_mels}
+    hidden_channels: 192
+    hidden_channels_ffn: 768
+    n_heads: 2
+    n_layers: 8
+    kernel_size: 1
+    use_vae: false
+    dropout: 0.1
+    gin_channels: 256
+    speaker_cond_layer: 4
+
+  mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+    f_min: 0
+    f_max: 8000
+
+  vocoder:
+    _target_: fish_speech.models.vqgan.modules.decoder.Generator
+    initial_channel: ${num_mels}
+    resblock: "1"
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    upsample_rates: [8, 8, 2, 2]
+    upsample_initial_channel: 512
+    upsample_kernel_sizes: [16, 16, 4, 4]
+    ckpt_path: "checkpoints/hifigan-v1-universal-22050/g_02500000"
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 1e-4
+    betas: [0.9, 0.999]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 1000
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.05
+
+callbacks:
+  grad_norm_monitor:
+    sub_module: 
+      - mel_encoder
+      - vq_encoder
+      - speaker_encoder
+      - decoder

+ 146 - 0
fish_speech/configs/vq_naive_lfq.yaml

@@ -0,0 +1,146 @@
+defaults:
+  - base
+  - _self_
+
+project: vq_naive_lfq
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: [1]
+  strategy: ddp_find_unused_parameters_true
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
+  precision: bf16-mixed
+  max_steps: 100_000
+  val_check_interval: 5000
+
+sample_rate: 22050
+hop_length: 256
+num_mels: 80
+n_fft: 1024
+win_length: 1024
+segment_size: 512
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.split.train
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  slice_frames: ${segment_size}
+
+val_dataset:
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/filelist.split.valid
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+data:
+  _target_: fish_speech.datasets.vqgan.VQGANDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 8
+  batch_size: 128
+  val_batch_size: 16
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vqgan.VQNaive
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+
+  downsample:
+    _target_: fish_speech.models.vq_diffusion.lit_module.ConvDownSample
+    dims: ["${num_mels}", 512, 256]
+    kernel_sizes: [3, 3]
+    strides: [2, 2]
+
+  mel_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
+    in_channels: 256
+    out_channels: 256
+    hidden_channels: 192
+    hidden_channels_ffn: 768
+    n_heads: 2
+    n_layers: 6
+    kernel_size: 1
+    dropout: 0.1
+    use_vae: false
+
+  vq_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
+    in_channels: 256
+    vq_channels: 14
+    codebook_size: 16384
+    downsample: 1
+    use_lfq: true
+
+  speaker_encoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
+    in_channels: ${num_mels}
+    hidden_channels: 192
+    out_channels: 256
+    num_heads: 2
+    num_layers: 4
+    p_dropout: 0.1
+
+  decoder:
+    _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
+    in_channels: 256
+    out_channels: ${num_mels}
+    hidden_channels: 192
+    hidden_channels_ffn: 768
+    n_heads: 2
+    n_layers: 8
+    kernel_size: 1
+    use_vae: false
+    dropout: 0.1
+    gin_channels: 256
+    speaker_cond_layer: 4
+
+  mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+    f_min: 0
+    f_max: 8000
+
+  vocoder:
+    _target_: fish_speech.models.vqgan.modules.decoder.Generator
+    initial_channel: ${num_mels}
+    resblock: "1"
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    upsample_rates: [8, 8, 2, 2]
+    upsample_initial_channel: 512
+    upsample_kernel_sizes: [16, 16, 4, 4]
+    ckpt_path: "checkpoints/hifigan-v1-universal-22050/g_02500000"
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 1e-4
+    betas: [0.9, 0.999]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 1000
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.05
+
+callbacks:
+  grad_norm_monitor:
+    sub_module: 
+      - mel_encoder
+      - vq_encoder
+      - speaker_encoder
+      - decoder

+ 19 - 11
fish_speech/configs/vqgan.yaml

@@ -9,7 +9,7 @@ trainer:
   accelerator: gpu
   devices: 8
   strategy: ddp_find_unused_parameters_true
-  precision: bf16-mixed
+  precision: 32
   max_steps: 1_000_000
   val_check_interval: 5000
 
@@ -48,18 +48,18 @@ model:
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
   segment_size: 8192
-  freeze_hifigan: true
+  freeze_hifigan: false
 
   downsample:
     _target_: fish_speech.models.vq_diffusion.lit_module.ConvDownSample
-    dims: ["${num_mels}", 512, "${num_mels}"]
+    dims: [128, 512, 128]
     kernel_sizes: [3, 3]
     strides: [2, 2]
 
   text_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
-    in_channels: ${num_mels}
-    out_channels: ${num_mels}
+    in_channels: 128
+    out_channels: 256
     hidden_channels: 192
     hidden_channels_ffn: 768
     n_heads: 2
@@ -70,23 +70,23 @@ model:
 
   vq_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
-    in_channels: ${num_mels}
-    vq_channels: ${num_mels}
+    in_channels: 256
+    vq_channels: 256
     codebook_size: 4096
     downsample: 1
 
   speaker_encoder:
     _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
-    in_channels: ${num_mels}
+    in_channels: 128
     hidden_channels: 192
-    out_channels: ${num_mels}
+    out_channels: 256
     num_heads: 2
     num_layers: 4
     p_dropout: 0.1
 
   decoder:
     _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
-    in_channels: ${num_mels}
+    in_channels: 256
     out_channels: ${num_mels}
     hidden_channels: 192
     hidden_channels_ffn: 768
@@ -95,7 +95,7 @@ model:
     kernel_size: 1
     use_vae: false
     dropout: 0
-    gin_channels: ${num_mels}
+    gin_channels: 256
     speaker_cond_layer: 0
 
   generator:
@@ -123,6 +123,14 @@ model:
     f_min: 0
     f_max: 8000
 
+  feature_mel_transform:
+    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    sample_rate: 32000
+    n_fft: 2048
+    hop_length: 320
+    win_length: 2048
+    n_mels: 128
+
   optimizer:
     _target_: torch.optim.AdamW
     _partial_: true

+ 2 - 2
fish_speech/models/vqgan/__init__.py

@@ -1,3 +1,3 @@
-from .lit_module import VQGAN
+from .lit_module import VQGAN, VQNaive
 
-__all__ = ["VQGAN"]
+__all__ = ["VQGAN", "VQNaive"]

+ 295 - 34
fish_speech/models/vqgan/lit_module.py

@@ -45,10 +45,12 @@ class VQGAN(L.LightningModule):
         generator: Generator,
         discriminator: EnsembleDiscriminator,
         mel_transform: nn.Module,
+        feature_mel_transform: nn.Module,
         segment_size: int = 20480,
         hop_length: int = 640,
         sample_rate: int = 32000,
         freeze_hifigan: bool = False,
+        freeze_vq: bool = False,
     ):
         super().__init__()
 
@@ -65,6 +67,7 @@ class VQGAN(L.LightningModule):
         self.generator = generator
         self.discriminator = discriminator
         self.mel_transform = mel_transform
+        self.feature_mel_transform = feature_mel_transform
 
         # Crop length for saving memory
         self.segment_size = segment_size
@@ -83,6 +86,17 @@ class VQGAN(L.LightningModule):
             for p in self.generator.parameters():
                 p.requires_grad = False
 
+        # Stage 2: Train the HifiGAN + Decoder + Generator
+        if freeze_vq:
+            for p in self.vq_encoder.parameters():
+                p.requires_grad = False
+
+            for p in self.text_encoder.parameters():
+                p.requires_grad = False
+
+            for p in self.downsample.parameters():
+                p.requires_grad = False
+
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
         optimizer_generator = self.optimizer_builder(
@@ -130,15 +144,20 @@ class VQGAN(L.LightningModule):
         audios = audios[:, None, :]
 
         with torch.no_grad():
-            gt_mels = self.mel_transform(audios)
+            gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
+            features = self.feature_mel_transform(
+                audios, sample_rate=self.sampling_rate
+            )
 
         if self.downsample is not None:
-            features = self.downsample(gt_mels)
+            features = self.downsample(features)
 
         mel_lengths = audio_lengths // self.hop_length
         feature_lengths = (
             audio_lengths
-            / self.hop_length
+            / self.sampling_rate
+            * self.feature_mel_transform.sample_rate
+            / self.feature_mel_transform.hop_length
             / (self.downsample.total_strides if self.downsample is not None else 1)
         ).long()
 
@@ -169,45 +188,43 @@ class VQGAN(L.LightningModule):
 
         assert y.shape == y_hat.shape, f"{y.shape} != {y_hat.shape}"
 
-        # Discriminator
-        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
-
-        with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
-
-        self.log(
-            "train/discriminator/loss",
-            loss_disc_all,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
-
         # Since we don't want to update the discriminator, we skip the backward pass
         if self.freeze_hifigan is False:
+            # Discriminator
+            y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
+
+            with torch.autocast(device_type=audios.device.type, enabled=False):
+                loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
+
+            self.log(
+                "train/discriminator/loss",
+                loss_disc_all,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=True,
+                logger=True,
+                sync_dist=True,
+            )
+
             optim_d.zero_grad()
             self.manual_backward(loss_disc_all)
             self.clip_gradients(
-                optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+                optim_d, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
             )
             optim_d.step()
 
         y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
 
         with torch.autocast(device_type=audios.device.type, enabled=False):
-            loss_decoded_mel = F.l1_loss(gt_mels, decoded_mels)
-            loss_mel = F.l1_loss(gt_mels, y_hat_mels)
+            loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
+            loss_mel = F.l1_loss(gt_mels * mel_masks, y_hat_mels * mel_masks)
             loss_adv, _ = generator_loss(y_d_hat_g)
             loss_fm = feature_loss(fmap_r, fmap_g)
 
-            mel_loss_weight = 25 if self.freeze_hifigan is True else 45
-
-            loss_gen_all = loss_mel * mel_loss_weight + loss_fm + loss_adv + loss_vq
-
             if self.freeze_hifigan is True:
-                loss_gen_all += loss_decoded_mel * mel_loss_weight
+                loss_gen_all = loss_decoded_mel + loss_vq
+            else:
+                loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
 
         self.log(
             "train/generator/loss",
@@ -267,7 +284,7 @@ class VQGAN(L.LightningModule):
         optim_g.zero_grad()
         self.manual_backward(loss_gen_all)
         self.clip_gradients(
-            optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+            optim_g, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
         )
         optim_g.step()
 
@@ -282,15 +299,18 @@ class VQGAN(L.LightningModule):
         audios = audios.float()
         audios = audios[:, None, :]
 
-        gt_mels = self.mel_transform(audios)
+        gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
+        features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
 
         if self.downsample is not None:
-            features = self.downsample(gt_mels)
+            features = self.downsample(features)
 
         mel_lengths = audio_lengths // self.hop_length
         feature_lengths = (
             audio_lengths
-            / self.hop_length
+            / self.sampling_rate
+            * self.feature_mel_transform.sample_rate
+            / self.feature_mel_transform.hop_length
             / (self.downsample.total_strides if self.downsample is not None else 1)
         ).long()
 
@@ -301,11 +321,11 @@ class VQGAN(L.LightningModule):
             gt_mels.dtype
         )
 
-        speaker_features = self.speaker_encoder(features, feature_masks)
+        speaker_features = self.speaker_encoder(gt_mels, mel_masks)
 
         # vq_features is 50 hz, need to convert to true mel size
         text_features = self.text_encoder(features, feature_masks)
-        text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
+        text_features, _ = self.vq_encoder(text_features, feature_masks)
         text_features = F.interpolate(
             text_features, size=gt_mels.shape[2], mode="nearest"
         )
@@ -323,7 +343,7 @@ class VQGAN(L.LightningModule):
         gt_mels = gt_mels[:, :, :min_mel_length]
         fake_mels = fake_mels[:, :, :min_mel_length]
 
-        mel_loss = F.l1_loss(gt_mels, fake_mels)
+        mel_loss = F.l1_loss(gt_mels * mel_masks, fake_mels * mel_masks)
         self.log(
             "val/mel_loss",
             mel_loss,
@@ -405,3 +425,244 @@ class VQGAN(L.LightningModule):
                 )
 
             plt.close(image_mels)
+
+
+class VQNaive(L.LightningModule):
+    def __init__(
+        self,
+        optimizer: Callable,
+        lr_scheduler: Callable,
+        downsample: ConvDownSampler,
+        vq_encoder: VQEncoder,
+        speaker_encoder: SpeakerEncoder,
+        mel_encoder: TextEncoder,
+        decoder: TextEncoder,
+        mel_transform: nn.Module,
+        hop_length: int = 640,
+        sample_rate: int = 32000,
+        vocoder: Generator = None,
+    ):
+        super().__init__()
+
+        # Model parameters
+        self.optimizer_builder = optimizer
+        self.lr_scheduler_builder = lr_scheduler
+
+        # Generator and discriminators
+        self.downsample = downsample
+        self.vq_encoder = vq_encoder
+        self.speaker_encoder = speaker_encoder
+        self.mel_encoder = mel_encoder
+        self.decoder = decoder
+        self.mel_transform = mel_transform
+
+        # Crop length for saving memory
+        self.hop_length = hop_length
+        self.sampling_rate = sample_rate
+
+        # Vocoder
+        self.vocoder = vocoder
+
+        for p in self.vocoder.parameters():
+            p.requires_grad = False
+
+    def configure_optimizers(self):
+        optimizer = self.optimizer_builder(self.parameters())
+        lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+        return {
+            "optimizer": optimizer,
+            "lr_scheduler": {
+                "scheduler": lr_scheduler,
+                "interval": "step",
+            },
+        }
+
+    def training_step(self, batch, batch_idx):
+        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+
+        audios = audios.float()
+        audios = audios[:, None, :]
+
+        with torch.no_grad():
+            features = gt_mels = self.mel_transform(
+                audios, sample_rate=self.sampling_rate
+            )
+
+        if self.downsample is not None:
+            features = self.downsample(features)
+
+        mel_lengths = audio_lengths // self.hop_length
+        feature_lengths = (
+            audio_lengths
+            / self.hop_length
+            / (self.downsample.total_strides if self.downsample is not None else 1)
+        ).long()
+
+        feature_masks = torch.unsqueeze(
+            sequence_mask(feature_lengths, features.shape[2]), 1
+        ).to(gt_mels.dtype)
+        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
+            gt_mels.dtype
+        )
+
+        speaker_features = self.speaker_encoder(gt_mels, mel_masks)
+
+        # vq_features is 50 hz, need to convert to true mel size
+        text_features = self.mel_encoder(features, feature_masks)
+        text_features, loss_vq = self.vq_encoder(text_features, feature_masks)
+        text_features = F.interpolate(
+            text_features, size=gt_mels.shape[2], mode="nearest"
+        )
+
+        # Sample mels
+        decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
+        loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
+        loss = loss_mel + loss_vq
+
+        self.log(
+            "train/generator/loss",
+            loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
+        self.log(
+            "train/loss_mel",
+            loss_mel,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+        self.log(
+            "train/generator/loss_vq",
+            loss_vq,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+
+        return loss
+
+    def validation_step(self, batch: Any, batch_idx: int):
+        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+
+        audios = audios.float()
+        audios = audios[:, None, :]
+
+        features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
+
+        if self.downsample is not None:
+            features = self.downsample(features)
+
+        mel_lengths = audio_lengths // self.hop_length
+        feature_lengths = (
+            audio_lengths
+            / self.hop_length
+            / (self.downsample.total_strides if self.downsample is not None else 1)
+        ).long()
+
+        feature_masks = torch.unsqueeze(
+            sequence_mask(feature_lengths, features.shape[2]), 1
+        ).to(gt_mels.dtype)
+        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
+            gt_mels.dtype
+        )
+
+        speaker_features = self.speaker_encoder(gt_mels, mel_masks)
+
+        # vq_features is 50 hz, need to convert to true mel size
+        text_features = self.mel_encoder(features, feature_masks)
+        text_features, _ = self.vq_encoder(text_features, feature_masks)
+        text_features = F.interpolate(
+            text_features, size=gt_mels.shape[2], mode="nearest"
+        )
+
+        # Sample mels
+        decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
+        fake_audios = self.vocoder(decoded_mels)
+
+        mel_loss = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
+        self.log(
+            "val/mel_loss",
+            mel_loss,
+            on_step=False,
+            on_epoch=True,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
+
+        for idx, (
+            mel,
+            decoded_mel,
+            audio,
+            gen_audio,
+            audio_len,
+        ) in enumerate(
+            zip(
+                gt_mels,
+                decoded_mels,
+                audios.detach().float(),
+                fake_audios.detach().float(),
+                audio_lengths,
+            )
+        ):
+            mel_len = audio_len // self.hop_length
+
+            image_mels = plot_mel(
+                [
+                    decoded_mel[:, :mel_len],
+                    mel[:, :mel_len],
+                ],
+                [
+                    "Generated",
+                    "Ground-Truth",
+                ],
+            )
+
+            if isinstance(self.logger, WandbLogger):
+                self.logger.experiment.log(
+                    {
+                        "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
+                        "wavs": [
+                            wandb.Audio(
+                                audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="gt",
+                            ),
+                            wandb.Audio(
+                                gen_audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="prediction",
+                            ),
+                        ],
+                    },
+                )
+
+            if isinstance(self.logger, TensorBoardLogger):
+                self.logger.experiment.add_figure(
+                    f"sample-{idx}/mels",
+                    image_mels,
+                    global_step=self.global_step,
+                )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/gt",
+                    audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/prediction",
+                    gen_audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
+
+            plt.close(image_mels)

+ 34 - 8
fish_speech/models/vqgan/modules/encoders.py

@@ -1,3 +1,4 @@
+from math import log2
 from typing import Optional
 
 import torch
@@ -283,16 +284,34 @@ class VQEncoder(nn.Module):
         codebook_size: int = 2048,
         downsample: int = 1,
         kmeans_ckpt: Optional[str] = None,
+        use_lfq: bool = False,
     ):
         super().__init__()
 
-        self.vq = VectorQuantize(
-            dim=vq_channels,
-            codebook_size=codebook_size,
-            threshold_ema_dead_code=2,
-            kmeans_init=False,
-            channel_last=False,
-        )
+        if use_lfq:
+            assert 2**vq_channels == codebook_size, (
+                "LFQ requires 2 ** vq_channels == codebook_size. "
+                f"Got vq_channels={vq_channels} and codebook_size={codebook_size}"
+            )
+
+            self.ln = nn.LayerNorm(vq_channels, eps=1e-5)
+            self.vq = LFQ(
+                dim=vq_channels,
+                codebook_size=codebook_size,
+                entropy_loss_weight=0.1,
+                commitment_loss_weight=1,
+                diversity_gamma=2.5,
+            )
+        else:
+            self.vq = VectorQuantize(
+                dim=vq_channels,
+                codebook_size=codebook_size,
+                threshold_ema_dead_code=2,
+                kmeans_init=False,
+                channel_last=False,
+            )
+
+        self.use_lfq = use_lfq
         self.downsample = downsample
         self.conv_in = nn.Conv1d(
             in_channels, vq_channels, kernel_size=downsample, stride=downsample
@@ -338,7 +357,14 @@ class VQEncoder(nn.Module):
             x_mask = F.pad(x_mask, (0, self.downsample - x_len % self.downsample))
 
         x = self.conv_in(x)
-        q, _, loss = self.vq(x)
+
+        if self.use_lfq:
+            x = self.ln(x.mT)
+            q, _, loss = self.vq(x)
+            q = q.mT
+        else:
+            q, _, loss = self.vq(x)
+
         x = self.conv_out(q) * x_mask
         x = x[:, :, :x_len]
 

+ 7 - 5
fish_speech/models/vqgan/utils.py

@@ -53,11 +53,13 @@ def plot_mel(data, titles=None):
 
 
 def slice_segments(x, ids_str, segment_size=4):
-    # Slice segments
-    gather_indices = ids_str[:, None, None] + torch.arange(
-        segment_size, device=x.device
-    )
-    return torch.gather(x, 2, gather_indices)
+    ret = torch.zeros_like(x[:, :, :segment_size])
+    for i in range(x.size(0)):
+        idx_str = ids_str[i]
+        idx_end = idx_str + segment_size
+        ret[i] = x[i, :, idx_str:idx_end]
+
+    return ret
 
 
 def rand_slice_segments(x, x_lengths=None, segment_size=4):