Lengyue 2 лет назад
Родитель
Сommit
703bcbd64e

+ 20 - 13
fish_speech/configs/hubert_vq.yaml

@@ -6,42 +6,49 @@ project: hubert_vq
 
 
 # Lightning Trainer
 # Lightning Trainer
 trainer:
 trainer:
-  accumulate_grad_batches: 2
-  gradient_clip_val: 1000.0  # For safety
-  gradient_clip_algorithm: 'norm'
-  precision: 32
+  accelerator: gpu
+  strategy:
+    _target_: lightning.pytorch.strategies.DDPStrategy
+    static_graph: true
+  precision: 16-mixed
   max_steps: 1_000_000
   max_steps: 1_000_000
 
 
 sample_rate: 32000
 sample_rate: 32000
+hop_length: 640
+num_mels: 128
 
 
 # Dataset Configuration
 # Dataset Configuration
 train_dataset:
 train_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
   _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/test.filelist
+  filelist: data/vq_train_filelist.txt
   sample_rate: ${sample_rate}
   sample_rate: ${sample_rate}
 
 
 val_dataset:
 val_dataset:
-  _target_: fish_speech.datasets.text.TextDataset
-  repo: fishaudio/cn-hubert-25hz-vq
-  prefix: 'data/test'
+  _target_: fish_speech.datasets.vqgan.VQGANDataset
+  filelist: data/vq_val_filelist.txt
+  sample_rate: ${sample_rate}
 
 
 data:
 data:
-  _target_: fish_speech.datasets.text.TextDataModule
+  _target_: fish_speech.datasets.vqgan.VQGANDataModule
   train_dataset: ${train_dataset}
   train_dataset: ${train_dataset}
   val_dataset: ${val_dataset}
   val_dataset: ${val_dataset}
   num_workers: 4
   num_workers: 4
   batch_size: 8
   batch_size: 8
-  tokenizer: ${tokenizer}
+  val_batch_size: 4
+  hop_length: ${hop_length}
 
 
 # Model Configuration
 # Model Configuration
 model:
 model:
   _target_: fish_speech.models.vqgan.VQGAN
   _target_: fish_speech.models.vqgan.VQGAN
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  segment_size: 20480
 
 
   encoder:
   encoder:
     _target_: fish_speech.models.vqgan.modules.VQEncoder
     _target_: fish_speech.models.vqgan.modules.VQEncoder
     in_channels: 1024
     in_channels: 1024
     channels: 192
     channels: 192
-    num_mels: 128
+    num_mels: ${num_mels}
     num_heads: 2
     num_heads: 2
     num_feature_layers: 2
     num_feature_layers: 2
     num_speaker_layers: 4
     num_speaker_layers: 4
@@ -70,9 +77,9 @@ model:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     sample_rate: ${sample_rate}
     sample_rate: ${sample_rate}
     n_fft: 2048
     n_fft: 2048
-    hop_length: 640
+    hop_length: ${hop_length}
     win_length: 2048
     win_length: 2048
-    n_mels: 128
+    n_mels: ${num_mels}
 
 
   optimizer:
   optimizer:
     _target_: torch.optim.AdamW
     _target_: torch.optim.AdamW

+ 4 - 3
fish_speech/datasets/vqgan.py

@@ -1,11 +1,12 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 from pathlib import Path
 from pathlib import Path
+from typing import Optional
 
 
 import librosa
 import librosa
 import numpy as np
 import numpy as np
 import torch
 import torch
 from lightning import LightningDataModule
 from lightning import LightningDataModule
-from torch.utils.data import Dataset
+from torch.utils.data import DataLoader, Dataset
 
 
 
 
 class VQGANDataset(Dataset):
 class VQGANDataset(Dataset):
@@ -78,12 +79,14 @@ class VQGANDataModule(LightningDataModule):
         batch_size: int = 32,
         batch_size: int = 32,
         hop_length: int = 640,
         hop_length: int = 640,
         num_workers: int = 4,
         num_workers: int = 4,
+        val_batch_size: Optional[int] = None,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
         self.train_dataset = train_dataset
         self.train_dataset = train_dataset
         self.val_dataset = val_dataset
         self.val_dataset = val_dataset
         self.batch_size = batch_size
         self.batch_size = batch_size
+        self.val_batch_size = val_batch_size or batch_size
         self.hop_length = hop_length
         self.hop_length = hop_length
         self.num_workers = num_workers
         self.num_workers = num_workers
 
 
@@ -106,8 +109,6 @@ class VQGANDataModule(LightningDataModule):
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    from torch.utils.data import DataLoader
-
     dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
     dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
     dataloader = DataLoader(
     dataloader = DataLoader(
         dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
         dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()

+ 258 - 177
fish_speech/models/vqgan/lit_module.py

@@ -1,10 +1,16 @@
+import itertools
 from typing import Any, Callable
 from typing import Any, Callable
 
 
 import lightning as L
 import lightning as L
 import torch
 import torch
 import torch.nn.functional as F
 import torch.nn.functional as F
+import wandb
+from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
+from matplotlib import pyplot as plt
 from torch import nn
 from torch import nn
-from torch.utils.checkpoint import checkpoint as gradient_checkpointing
+
+from fish_speech.models.vqgan.modules import EnsembleDiscriminator, Generator, VQEncoder
+from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
 
 
 
 
 class VQGAN(L.LightningModule):
 class VQGAN(L.LightningModule):
@@ -12,11 +18,13 @@ class VQGAN(L.LightningModule):
         self,
         self,
         optimizer: Callable,
         optimizer: Callable,
         lr_scheduler: Callable,
         lr_scheduler: Callable,
-        encoder: nn.Module,
-        generator: nn.Module,
-        discriminator: nn.Module,
+        encoder: VQEncoder,
+        generator: Generator,
+        discriminator: EnsembleDiscriminator,
         mel_transform: nn.Module,
         mel_transform: nn.Module,
         segment_size: int = 20480,
         segment_size: int = 20480,
+        hop_length: int = 640,
+        sample_rate: int = 32000,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
@@ -33,15 +41,19 @@ class VQGAN(L.LightningModule):
 
 
         # Crop length for saving memory
         # Crop length for saving memory
         self.segment_size = segment_size
         self.segment_size = segment_size
+        self.hop_length = hop_length
+        self.sampling_rate = sample_rate
 
 
         # Disable automatic optimization
         # Disable automatic optimization
         self.automatic_optimization = False
         self.automatic_optimization = False
 
 
     def configure_optimizers(self):
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
         # Need two optimizers and two schedulers
-        optimizer_generator = self.optimizer_builder(self.generator.parameters())
+        optimizer_generator = self.optimizer_builder(
+            itertools.chain(self.encoder.parameters(), self.generator.parameters())
+        )
         optimizer_discriminator = self.optimizer_builder(
         optimizer_discriminator = self.optimizer_builder(
-            self.discriminators.parameters()
+            self.discriminator.parameters()
         )
         )
 
 
         lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
         lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
@@ -66,109 +78,112 @@ class VQGAN(L.LightningModule):
             },
             },
         )
         )
 
 
-    def training_generator(self, audio, audio_mask):
-        # fake_audio, base_loss = self.forward(audio, audio_mask)
+    @staticmethod
+    def discriminator_loss(disc_real_outputs, disc_generated_outputs):
+        loss = 0
+        r_losses = []
+        g_losses = []
+        for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+            dr = dr.float()
+            dg = dg.float()
+            r_loss = torch.mean((1 - dr) ** 2)
+            g_loss = torch.mean(dg**2)
+            loss += r_loss + g_loss
+            r_losses.append(r_loss.item())
+            g_losses.append(g_loss.item())
+
+        return loss, r_losses, g_losses
+
+    @staticmethod
+    def generator_loss(disc_outputs):
+        loss = 0
+        gen_losses = []
+        for dg in disc_outputs:
+            dg = dg.float()
+            l = torch.mean((1 - dg) ** 2)
+            gen_losses.append(l)
+            loss += l
+
+        return loss, gen_losses
+
+    @staticmethod
+    def feature_loss(fmap_r, fmap_g):
+        loss = 0
+        for dr, dg in zip(fmap_r, fmap_g):
+            for rl, gl in zip(dr, dg):
+                rl = rl.float().detach()
+                gl = gl.float()
+                loss += torch.mean(torch.abs(rl - gl))
+
+        return loss * 2
 
 
-        assert fake_audio.shape == audio.shape
+    def training_step(self, batch, batch_idx):
+        optim_g, optim_d = self.optimizers()
 
 
-        # Apply mask
-        audio = audio * audio_mask
-        fake_audio = fake_audio * audio_mask
+        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+        features, feature_lengths = batch["features"], batch["feature_lengths"]
 
 
-        # Multi-Resolution STFT Loss
-        sc_loss, mag_loss = self.multi_resolution_stft_loss(
-            fake_audio.squeeze(1), audio.squeeze(1)
-        )
-        loss_stft = sc_loss + mag_loss
+        with torch.no_grad():
+            gt_mels = self.mel_transform(audios).transpose(1, 2)
+            key_padding_mask = sequence_mask(feature_lengths)
+            mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
 
 
-        self.log(
-            "train/generator/stft",
-            loss_stft,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
+            assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
+            gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
+            gt_mels = gt_mels[:, :gt_mel_length]
+            mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
 
 
-        # L1 Mel-Spectrogram Loss
-        # This is not used in backpropagation currently
-        audio_mel = self.mel_transforms.loss(audio.squeeze(1))
-        fake_audio_mel = self.mel_transforms.loss(fake_audio.squeeze(1))
-        loss_mel = F.l1_loss(audio_mel, fake_audio_mel)
+            assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
+            gt_feature_length = min(features.shape[1], key_padding_mask.shape[1])
+            features = features[:, :gt_feature_length]
+            key_padding_mask = key_padding_mask[:, :gt_feature_length]
 
 
-        self.log(
-            "train/generator/mel",
-            loss_mel,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
+        # Generator
+        encoded = self.encoder(
+            x=features,
+            mels=gt_mels,
+            key_padding_mask=key_padding_mask,
+            mels_key_padding_mask=mels_key_padding_mask,
         )
         )
 
 
-        # Now, we need to reduce the length of the audio to save memory
-        if self.crop_length is not None and audio.shape[2] > self.crop_length:
-            slice_idx = torch.randint(0, audio.shape[-1] - self.crop_length, (1,))
+        features = encoded.features
+        audios = audios[:, None, :]
 
 
-            audio = audio[..., slice_idx : slice_idx + self.crop_length]
-            fake_audio = fake_audio[..., slice_idx : slice_idx + self.crop_length]
-            audio_mask = audio_mask[..., slice_idx : slice_idx + self.crop_length]
-
-            assert audio.shape == fake_audio.shape == audio_mask.shape
-
-        # Adv Loss
-        loss_adv_all = 0
-
-        for key, disc in self.discriminators.items():
-            score_fakes, feat_fake = disc(fake_audio)
-
-            # Adversarial Loss
-            score_fakes = torch.cat(score_fakes, dim=1)
-            loss_fake = torch.mean((1 - score_fakes) ** 2)
-
-            self.log(
-                f"train/generator/adv_{key}",
-                loss_fake,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
+        # Get slice of audio
+        if audios.shape[-1] > self.segment_size:
+            start = torch.randint(
+                0, audios.shape[-1] - self.segment_size, (1,), device=audios.device
+            ).item()
+            start = start // self.hop_length * self.hop_length
 
 
-            loss_adv_all += loss_fake
+            audios = audios[:, :, start : start + self.segment_size]
+            audio_masks = sequence_mask(audio_lengths)[
+                :, None, start : start + self.segment_size
+            ]
 
 
-            if self.feature_matching is False:
-                continue
+            mel_start = start // self.hop_length
+            mel_size = self.segment_size // self.hop_length
+            gt_mels = gt_mels[:, mel_start : mel_start + mel_size]
+            mels_key_padding_mask = mels_key_padding_mask[
+                :, mel_start : mel_start + mel_size
+            ]
 
 
-            # Feature Matching Loss
-            _, feat_real = disc(audio)
-            loss_fm = 0
-            for dr, dg in zip(feat_real, feat_fake):
-                for rl, gl in zip(dr, dg):
-                    loss_fm += F.l1_loss(rl, gl)
+            features = features[:, :, mel_start : mel_start + mel_size]
 
 
-            loss_fm /= len(feat_real)
+        fake_audios = self.generator(features)
+        audio = torch.masked_fill(audios, audio_masks, 0.0)
+        fake_audios = torch.masked_fill(fake_audios, audio_masks, 0.0)
+        assert fake_audios.shape == audio.shape
 
 
-            self.log(
-                f"train/generator/adv_fm_{key}",
-                loss_fm,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
-
-            loss_adv_all += loss_fm
+        # Discriminator
+        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audio, fake_audios.detach())
 
 
-        loss_adv_all /= len(self.discriminators)
-        loss_gen_all = base_loss + loss_stft * 2.5 + loss_mel * 45 + loss_adv_all
+        with torch.autocast(device_type=audios.device.type, enabled=False):
+            loss_disc_all, _, _ = self.discriminator_loss(y_d_hat_r, y_d_hat_g)
 
 
         self.log(
         self.log(
-            "train/generator/all",
-            loss_gen_all,
+            "train/discriminator/loss",
+            loss_disc_all,
             on_step=True,
             on_step=True,
             on_epoch=False,
             on_epoch=False,
             prog_bar=True,
             prog_bar=True,
@@ -176,99 +191,79 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
             sync_dist=True,
         )
         )
 
 
-        return loss_gen_all, audio, fake_audio
+        optim_d.zero_grad()
+        self.manual_backward(loss_disc_all)
+        self.clip_gradients(
+            optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+        )
+        optim_d.step()
 
 
-    def training_discriminator(self, audio, fake_audio):
-        loss_disc_all = 0
+        y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(audios, fake_audios)
+        fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
 
 
-        for key, disc in self.discriminators.items():
-            if self.training and self.checkpointing:
-                scores, _ = gradient_checkpointing(disc, audio, use_reentrant=False)
-                score_fakes, _ = gradient_checkpointing(
-                    disc, fake_audio.detach(), use_reentrant=False
-                )
-            else:
-                scores, _ = disc(audio)
-                score_fakes, _ = disc(fake_audio.detach())
-
-            scores = torch.cat(scores, dim=1)
-            score_fakes = torch.cat(score_fakes, dim=1)
-            loss_disc = torch.mean((scores - 1) ** 2) + torch.mean((score_fakes) ** 2)
-
-            self.log(
-                f"train/discriminator/{key}",
-                loss_disc,
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
+        # Fill mel mask
+        fake_mels = torch.masked_fill(fake_mels, mels_key_padding_mask[:, :, None], 0.0)
+        gt_mels = torch.masked_fill(gt_mels, mels_key_padding_mask[:, :, None], 0.0)
 
 
-            loss_disc_all += loss_disc
+        with torch.autocast(device_type=audios.device.type, enabled=False):
+            loss_mel = F.l1_loss(gt_mels, fake_mels)
+            loss_adv, _ = self.generator_loss(y_d_hat_g)
+            loss_fm = self.feature_loss(fmap_r, fmap_g)
 
 
-        loss_disc_all /= len(self.discriminators)
+            loss_gen_all = loss_fm * 45 + loss_mel + loss_adv + encoded.loss
 
 
         self.log(
         self.log(
-            "train/discriminator/all",
-            loss_disc_all,
+            "train/generator/loss",
+            loss_gen_all,
             on_step=True,
             on_step=True,
             on_epoch=False,
             on_epoch=False,
             prog_bar=True,
             prog_bar=True,
             logger=True,
             logger=True,
             sync_dist=True,
             sync_dist=True,
         )
         )
-
-        return loss_disc_all
-
-    def training_step(self, batch, batch_idx):
-        optim_g, optim_d = self.optimizers()
-
-        audio, lengths = batch["audio"], batch["lengths"]
-        audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)
-
-        # Generator
-        optim_g.zero_grad()
-        loss_gen_all, audio, fake_audio = self.training_generator(audio, audio_mask)
-        self.manual_backward(loss_gen_all)
-
         self.log(
         self.log(
-            "train/generator/grad_norm",
-            grad_norm(self.generator.parameters()),
+            "train/generator/loss_mel",
+            loss_mel,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+        self.log(
+            "train/generator/loss_fm",
+            loss_fm,
             on_step=True,
             on_step=True,
             on_epoch=False,
             on_epoch=False,
             prog_bar=False,
             prog_bar=False,
             logger=True,
             logger=True,
             sync_dist=True,
             sync_dist=True,
         )
         )
-
-        self.clip_gradients(
-            optim_g, gradient_clip_val=1000, gradient_clip_algorithm="norm"
+        self.log(
+            "train/generator/loss_adv",
+            loss_adv,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+        self.log(
+            "train/generator/loss_vq",
+            encoded.loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
         )
         )
-        optim_g.step()
-
-        # Discriminator
-        assert fake_audio.shape == audio.shape
-
-        optim_d.zero_grad()
-        loss_disc_all = self.training_discriminator(audio, fake_audio)
-        self.manual_backward(loss_disc_all)
-
-        for key, disc in self.discriminators.items():
-            self.log(
-                f"train/discriminator/grad_norm_{key}",
-                grad_norm(disc.parameters()),
-                on_step=True,
-                on_epoch=False,
-                prog_bar=False,
-                logger=True,
-                sync_dist=True,
-            )
 
 
+        optim_g.zero_grad()
+        self.manual_backward(loss_gen_all)
         self.clip_gradients(
         self.clip_gradients(
-            optim_d, gradient_clip_val=1000, gradient_clip_algorithm="norm"
+            optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
         )
         )
-        optim_d.step()
+        optim_g.step()
 
 
         # Manual LR Scheduler
         # Manual LR Scheduler
         scheduler_g, scheduler_d = self.lr_schedulers()
         scheduler_g, scheduler_d = self.lr_schedulers()
@@ -276,25 +271,55 @@ class VQGAN(L.LightningModule):
         scheduler_d.step()
         scheduler_d.step()
 
 
     def validation_step(self, batch: Any, batch_idx: int):
     def validation_step(self, batch: Any, batch_idx: int):
-        audio, lengths = batch["audio"], batch["lengths"]
-        audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)
+        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+        features, feature_lengths = batch["features"], batch["feature_lengths"]
+
+        with torch.no_grad():
+            gt_mels = self.mel_transform(audios).transpose(1, 2)
+            key_padding_mask = sequence_mask(feature_lengths)
+            mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
+            audio_masks = sequence_mask(audio_lengths)
+
+            assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
+            gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
+            gt_mels = gt_mels[:, :gt_mel_length]
+            mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
+
+            assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
+            gt_feature_length = min(features.shape[1], key_padding_mask.shape[1])
+            features = features[:, :gt_feature_length]
+            key_padding_mask = key_padding_mask[:, :gt_feature_length]
 
 
         # Generator
         # Generator
-        fake_audio, _ = self.forward(audio, audio_mask)
-        assert fake_audio.shape == audio.shape
+        encoded = self.encoder(
+            x=features,
+            mels=gt_mels,
+            key_padding_mask=key_padding_mask,
+            mels_key_padding_mask=mels_key_padding_mask,
+        )
+
+        features = encoded.features
+        audios = audios[:, None, :]
 
 
-        # Apply mask
-        audio = audio * audio_mask
-        fake_audio = fake_audio * audio_mask
+        fake_audios = self.generator(features)
+        min_audio_length = min(audios.shape[-1], fake_audios.shape[-1])
 
 
-        # L1 Mel-Spectrogram Loss
-        audio_mel = self.mel_transforms.loss(audio.squeeze(1))
-        fake_audio_mel = self.mel_transforms.loss(fake_audio.squeeze(1))
-        loss_mel = F.l1_loss(audio_mel, fake_audio_mel)
+        audios = audios[:, :, :min_audio_length]
+        fake_audios = fake_audios[:, :, :min_audio_length]
+        audio_masks = audio_masks[:, None, :min_audio_length]
 
 
+        audio = torch.masked_fill(audios, audio_masks, 0.0)
+        fake_audios = torch.masked_fill(fake_audios, audio_masks, 0.0)
+        assert fake_audios.shape == audio.shape
+
+        fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
+        gt_mels = torch.masked_fill(gt_mels, mels_key_padding_mask[:, :, None], 0.0)
+        fake_mels = torch.masked_fill(fake_mels, mels_key_padding_mask[:, :, None], 0.0)
+
+        mel_loss = F.l1_loss(gt_mels, fake_mels)
         self.log(
         self.log(
-            "val/metrics/mel",
-            loss_mel,
+            "val/mel_loss",
+            mel_loss,
             on_step=False,
             on_step=False,
             on_epoch=True,
             on_epoch=True,
             prog_bar=True,
             prog_bar=True,
@@ -302,5 +327,61 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
             sync_dist=True,
         )
         )
 
 
-        # Report other metrics
-        self.report_val_metrics(fake_audio, audio, lengths)
+        for idx, (mel, gen_mel, audio, gen_audio, audio_len) in enumerate(
+            zip(
+                gt_mels.transpose(1, 2),
+                fake_mels.transpose(1, 2),
+                audios,
+                fake_audios,
+                audio_lengths,
+            )
+        ):
+            mel_len = audio_len // self.hop_length
+
+            image_mels = plot_mel(
+                [
+                    gen_mel[:, :mel_len],
+                    mel[:, :mel_len],
+                ],
+                ["Sampled Spectrogram", "Ground-Truth Spectrogram"],
+            )
+
+            if isinstance(self.logger, WandbLogger):
+                self.logger.experiment.log(
+                    {
+                        "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
+                        "wavs": [
+                            wandb.Audio(
+                                audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="gt",
+                            ),
+                            wandb.Audio(
+                                gen_audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="prediction",
+                            ),
+                        ],
+                    },
+                )
+
+            if isinstance(self.logger, TensorBoardLogger):
+                self.logger.experiment.add_figure(
+                    f"sample-{idx}/mels",
+                    image_mels,
+                    global_step=self.global_step,
+                )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/gt",
+                    audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/prediction",
+                    gen_audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
+
+            plt.close(image_mels)

+ 38 - 22
fish_speech/models/vqgan/modules.py

@@ -38,12 +38,14 @@ class VQEncoder(nn.Module):
         # Feature Encoder
         # Feature Encoder
         down_sample = 2 if input_downsample else 1
         down_sample = 2 if input_downsample else 1
 
 
-        self.vq_in = nn.Linear(in_channels * down_sample, in_channels)
+        self.vq_in = nn.Conv1d(
+            in_channels, in_channels, kernel_size=down_sample, stride=down_sample
+        )
         self.vq = VectorQuantization(
         self.vq = VectorQuantization(
             dim=in_channels,
             dim=in_channels,
             codebook_size=code_book_size,
             codebook_size=code_book_size,
             threshold_ema_dead_code=2,
             threshold_ema_dead_code=2,
-            kmeans_init=True,
+            kmeans_init=False,
             kmeans_iters=50,
             kmeans_iters=50,
         )
         )
 
 
@@ -78,7 +80,7 @@ class VQEncoder(nn.Module):
         )
         )
 
 
         # Final Mixer
         # Final Mixer
-        self.mixer_in = nn.ModuleList(
+        self.mixer_blocks = nn.ModuleList(
             [
             [
                 TransformerBlock(
                 TransformerBlock(
                     channels,
                     channels,
@@ -102,47 +104,61 @@ class VQEncoder(nn.Module):
             for p in self.vq_in.parameters():
             for p in self.vq_in.parameters():
                 p.requires_grad = False
                 p.requires_grad = False
 
 
-    def forward(self, x, mels, key_padding_mask=None):
+    def forward(
+        self, x, mels, key_padding_mask=None, mels_key_padding_mask=None
+    ) -> VQEncoderOutput:
         # x: (batch, seq_len, channels)
         # x: (batch, seq_len, channels)
-        # x: (batch, seq_len, 128)
-
-        if self.input_downsample and key_padding_mask is not None:
-            key_padding_mask = key_padding_mask[:, ::2]
+        # mels: (batch, seq_len, 128)
 
 
-        # Merge Channels
-        if self.input_downsample:
-            feature_0, feature_1 = x[:, ::2], x[:, 1::2]
-            min_len = min(feature_0.size(1), feature_1.size(1))
-            x = torch.cat([feature_0[:, :min_len], feature_1[:, :min_len]], dim=2)
+        assert key_padding_mask.size(1) == x.size(
+            1
+        ), f"key_padding_mask shape {key_padding_mask.size()} does not match features shape {features.size()}"
 
 
-        # Encode Features
-        features = self.vq_in(x)
-        assert key_padding_mask.size(1) == features.size(
+        assert mels_key_padding_mask.size(1) == mels.size(
             1
             1
-        ), f"key_padding_mask shape {key_padding_mask.size()} is not (batch_size, seq_len)"
+        ), f"mels_key_padding_mask shape {mels_key_padding_mask.size()} does not match mels shape {mels.size()}"
 
 
-        features, _, loss = self.vq(features, mask=~key_padding_mask)
+        # Encode Features
+        features = self.vq_in(x.transpose(1, 2))
+        features, _, loss = self.vq(features)
+        features = features.transpose(1, 2)
 
 
         if self.input_downsample:
         if self.input_downsample:
             features = F.interpolate(
             features = F.interpolate(
                 features.transpose(1, 2), scale_factor=2
                 features.transpose(1, 2), scale_factor=2
             ).transpose(1, 2)
             ).transpose(1, 2)
 
 
+        # Shape may change due to downsampling, let's cut it to the same size
+        if features.shape[1] != key_padding_mask.shape[1]:
+            assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
+            min_len = min(features.shape[1], key_padding_mask.shape[1])
+            features = features[:, :min_len]
+            key_padding_mask = key_padding_mask[:, :min_len]
+
         features = self.feature_in(features)
         features = self.feature_in(features)
         for block in self.feature_blocks:
         for block in self.feature_blocks:
             features = block(features, key_padding_mask=key_padding_mask)
             features = block(features, key_padding_mask=key_padding_mask)
 
 
         # Encode Speaker
         # Encode Speaker
-        speaker = self.speaker_in(x)
+        speaker = self.speaker_in(mels)
         speaker = torch.cat(
         speaker = torch.cat(
             [self.speaker_query.expand(speaker.shape[0], -1, -1), speaker], dim=1
             [self.speaker_query.expand(speaker.shape[0], -1, -1), speaker], dim=1
         )
         )
+        mels_key_padding_mask = torch.cat(
+            [
+                torch.ones(
+                    speaker.shape[0], 1, dtype=torch.bool, device=speaker.device
+                ),
+                mels_key_padding_mask,
+            ],
+            dim=1,
+        )
         for block in self.speaker_blocks:
         for block in self.speaker_blocks:
-            speaker = block(mels, key_padding_mask=key_padding_mask)
+            speaker = block(speaker, key_padding_mask=mels_key_padding_mask)
 
 
         # Mix
         # Mix
         x = features + speaker[:, :1]
         x = features + speaker[:, :1]
-        for block in self.mixer_in:
+        for block in self.mixer_blocks:
             x = block(x, key_padding_mask=key_padding_mask)
             x = block(x, key_padding_mask=key_padding_mask)
 
 
         return VQEncoderOutput(
         return VQEncoderOutput(
@@ -350,7 +366,7 @@ class RelativeAttention(nn.Module):
             assert key_padding_mask.size() == (
             assert key_padding_mask.size() == (
                 batch_size,
                 batch_size,
                 seq_len,
                 seq_len,
-            ), f"key_padding_mask shape {key_padding_mask.size()} is not (batch_size, seq_len)"
+            ), f"key_padding_mask shape {key_padding_mask.size()} does not match x shape {x.size()}"
             assert (
             assert (
                 key_padding_mask.dtype == torch.bool
                 key_padding_mask.dtype == torch.bool
             ), f"key_padding_mask dtype {key_padding_mask.dtype} is not bool"
             ), f"key_padding_mask dtype {key_padding_mask.dtype} is not bool"

+ 29 - 3
fish_speech/models/vqgan/utils.py

@@ -1,6 +1,8 @@
+import matplotlib
 import torch
 import torch
-import torch.utils.data
-from librosa.filters import mel as librosa_mel_fn
+from matplotlib import pyplot as plt
+
+matplotlib.use("Agg")
 
 
 
 
 def convert_pad_shape(pad_shape):
 def convert_pad_shape(pad_shape):
@@ -13,7 +15,7 @@ def sequence_mask(length, max_length=None):
     if max_length is None:
     if max_length is None:
         max_length = length.max()
         max_length = length.max()
     x = torch.arange(max_length, dtype=length.dtype, device=length.device)
     x = torch.arange(max_length, dtype=length.dtype, device=length.device)
-    return x.unsqueeze(0) < length.unsqueeze(1)
+    return x.unsqueeze(0) >= length.unsqueeze(1)
 
 
 
 
 def init_weights(m, mean=0.0, std=0.01):
 def init_weights(m, mean=0.0, std=0.01):
@@ -24,3 +26,27 @@ def init_weights(m, mean=0.0, std=0.01):
 
 
 def get_padding(kernel_size, dilation=1):
 def get_padding(kernel_size, dilation=1):
     return int((kernel_size * dilation - dilation) / 2)
     return int((kernel_size * dilation - dilation) / 2)
+
+
+def plot_mel(data, titles=None):
+    fig, axes = plt.subplots(len(data), 1, squeeze=False)
+
+    if titles is None:
+        titles = [None for i in range(len(data))]
+
+    plt.tight_layout()
+
+    for i in range(len(data)):
+        mel = data[i]
+
+        if isinstance(mel, torch.Tensor):
+            mel = mel.detach().cpu().numpy()
+
+        axes[i][0].imshow(mel, origin="lower")
+        axes[i][0].set_aspect(2.5, adjustable="box")
+        axes[i][0].set_ylim(0, mel.shape[0])
+        axes[i][0].set_title(titles[i], fontsize="medium")
+        axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
+        axes[i][0].set_anchor("W")
+
+    return fig

+ 0 - 29
fish_speech/utils/viz.py

@@ -1,29 +0,0 @@
-import matplotlib
-from matplotlib import pyplot as plt
-from torch import Tensor
-
-matplotlib.use("Agg")
-
-
-def plot_mel(data, titles=None):
-    fig, axes = plt.subplots(len(data), 1, squeeze=False)
-
-    if titles is None:
-        titles = [None for i in range(len(data))]
-
-    plt.tight_layout()
-
-    for i in range(len(data)):
-        mel = data[i]
-
-        if isinstance(mel, Tensor):
-            mel = mel.detach().cpu().numpy()
-
-        axes[i][0].imshow(mel, origin="lower")
-        axes[i][0].set_aspect(2.5, adjustable="box")
-        axes[i][0].set_ylim(0, mel.shape[0])
-        axes[i][0].set_title(titles[i], fontsize="medium")
-        axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
-        axes[i][0].set_anchor("W")
-
-    return fig

+ 2 - 0
pyproject.toml

@@ -30,6 +30,8 @@ dependencies = [
     "jieba",
     "jieba",
     "g2p_en",
     "g2p_en",
     "pyopenjtalk",
     "pyopenjtalk",
+    "wandb",
+    "tensorboard",
 ]
 ]
 
 
 [build-system]
 [build-system]