Lengyue пре 2 година
родитељ
комит
fd879ec280
2 измењених фајлова са 24 додато и 245 уклоњено
  1. 2 2
      fish_speech/models/vqgan/__init__.py
  2. 22 243
      fish_speech/models/vqgan/lit_module.py

+ 2 - 2
fish_speech/models/vqgan/__init__.py

@@ -1,3 +1,3 @@
-from .lit_module import VQGAN, VQNaive
+from .lit_module import VQGAN
 
-__all__ = ["VQGAN", "VQNaive"]
+__all__ = ["VQGAN"]

+ 22 - 243
fish_speech/models/vqgan/lit_module.py

@@ -111,7 +111,8 @@ class VQGAN(L.LightningModule):
         if self.speaker_encoder is not None:
             components.append(self.speaker_encoder.parameters())
 
-        components.append(self.decoder.parameters())
+        if self.decoder is not None:
+            components.append(self.decoder.parameters())
 
         if self.freeze_hifigan is False:
             components.append(self.generator.parameters())
@@ -197,12 +198,16 @@ class VQGAN(L.LightningModule):
             torch.set_grad_enabled(True)
 
         # Sample mels
-        speaker_features = (
-            self.speaker_encoder(gt_mels, mel_masks)
-            if self.speaker_encoder is not None
-            else None
-        )
-        decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
+        if self.decoder is not None:
+            speaker_features = (
+                self.speaker_encoder(gt_mels, mel_masks)
+                if self.speaker_encoder is not None
+                else None
+            )
+            decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
+        else:
+            decoded_mels = text_features
+
         fake_audios = self.generator(decoded_mels)
 
         y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
@@ -350,12 +355,16 @@ class VQGAN(L.LightningModule):
         )
 
         # Sample mels
-        speaker_features = (
-            self.speaker_encoder(gt_mels, mel_masks)
-            if self.speaker_encoder is not None
-            else None
-        )
-        decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
+        if self.decoder is not None:
+            speaker_features = (
+                self.speaker_encoder(gt_mels, mel_masks)
+                if self.speaker_encoder is not None
+                else None
+            )
+            decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
+        else:
+            decoded_mels = text_features
+
         fake_audios = self.generator(decoded_mels)
 
         fake_mels = self.mel_transform(fake_audios.squeeze(1))
@@ -449,233 +458,3 @@ class VQGAN(L.LightningModule):
                 )
 
             plt.close(image_mels)
-
-
-class VQNaive(L.LightningModule):
-    def __init__(
-        self,
-        optimizer: Callable,
-        lr_scheduler: Callable,
-        downsample: ConvDownSampler,
-        vq_encoder: VQEncoder,
-        speaker_encoder: SpeakerEncoder,
-        mel_encoder: TextEncoder,
-        decoder: TextEncoder,
-        mel_transform: nn.Module,
-        hop_length: int = 640,
-        sample_rate: int = 32000,
-        vocoder: Generator = None,
-    ):
-        super().__init__()
-
-        # Model parameters
-        self.optimizer_builder = optimizer
-        self.lr_scheduler_builder = lr_scheduler
-
-        # Generator and discriminators
-        self.downsample = downsample
-        self.vq_encoder = vq_encoder
-        self.speaker_encoder = speaker_encoder
-        self.mel_encoder = mel_encoder
-        self.decoder = decoder
-        self.mel_transform = mel_transform
-
-        # Crop length for saving memory
-        self.hop_length = hop_length
-        self.sampling_rate = sample_rate
-
-        # Vocoder
-        self.vocoder = vocoder
-
-        for p in self.vocoder.parameters():
-            p.requires_grad = False
-
-    def configure_optimizers(self):
-        optimizer = self.optimizer_builder(self.parameters())
-        lr_scheduler = self.lr_scheduler_builder(optimizer)
-
-        return {
-            "optimizer": optimizer,
-            "lr_scheduler": {
-                "scheduler": lr_scheduler,
-                "interval": "step",
-            },
-        }
-
-    def vq_encode(self, audios, audio_lengths):
-        with torch.no_grad():
-            features = gt_mels = self.mel_transform(
-                audios, sample_rate=self.sampling_rate
-            )
-
-        if self.downsample is not None:
-            features = self.downsample(features)
-
-        mel_lengths = audio_lengths // self.hop_length
-        feature_lengths = (
-            audio_lengths
-            / self.hop_length
-            / (self.downsample.total_strides if self.downsample is not None else 1)
-        ).long()
-
-        feature_masks = torch.unsqueeze(
-            sequence_mask(feature_lengths, features.shape[2]), 1
-        ).to(gt_mels.dtype)
-        mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
-            gt_mels.dtype
-        )
-
-        # vq_features is 50 hz, need to convert to true mel size
-        text_features = self.mel_encoder(features, feature_masks)
-        text_features, indices, loss_vq = self.vq_encoder(text_features, feature_masks)
-
-        return mel_masks, gt_mels, text_features, indices, loss_vq
-
-    def vq_decode(self, text_features, speaker_features, gt_mels, mel_masks):
-        text_features = F.interpolate(
-            text_features, size=gt_mels.shape[2], mode="nearest"
-        )
-
-        decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
-
-        return decoded_mels
-
-    def training_step(self, batch, batch_idx):
-        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-
-        audios = audios.float()
-        audios = audios[:, None, :]
-
-        mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
-            audios, audio_lengths
-        )
-        speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-        decoded_mels = self.vq_decode(
-            text_features, speaker_features, gt_mels, mel_masks
-        )
-        loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
-        loss = loss_mel + loss_vq
-
-        self.log(
-            "train/generator/loss",
-            loss,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
-        self.log(
-            "train/loss_mel",
-            loss_mel,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-        self.log(
-            "train/generator/loss_vq",
-            loss_vq,
-            on_step=True,
-            on_epoch=False,
-            prog_bar=False,
-            logger=True,
-            sync_dist=True,
-        )
-
-        return loss
-
-    def validation_step(self, batch: Any, batch_idx: int):
-        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
-
-        audios = audios.float()
-        audios = audios[:, None, :]
-
-        mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
-            audios, audio_lengths
-        )
-        speaker_features = self.speaker_encoder(gt_mels, mel_masks)
-        decoded_mels = self.vq_decode(
-            text_features, speaker_features, gt_mels, mel_masks
-        )
-        fake_audios = self.vocoder(decoded_mels)
-
-        mel_loss = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
-        self.log(
-            "val/mel_loss",
-            mel_loss,
-            on_step=False,
-            on_epoch=True,
-            prog_bar=True,
-            logger=True,
-            sync_dist=True,
-        )
-
-        for idx, (
-            mel,
-            decoded_mel,
-            audio,
-            gen_audio,
-            audio_len,
-        ) in enumerate(
-            zip(
-                gt_mels,
-                decoded_mels,
-                audios.detach().float(),
-                fake_audios.detach().float(),
-                audio_lengths,
-            )
-        ):
-            mel_len = audio_len // self.hop_length
-
-            image_mels = plot_mel(
-                [
-                    decoded_mel[:, :mel_len],
-                    mel[:, :mel_len],
-                ],
-                [
-                    "Generated",
-                    "Ground-Truth",
-                ],
-            )
-
-            if isinstance(self.logger, WandbLogger):
-                self.logger.experiment.log(
-                    {
-                        "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
-                        "wavs": [
-                            wandb.Audio(
-                                audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="gt",
-                            ),
-                            wandb.Audio(
-                                gen_audio[0, :audio_len],
-                                sample_rate=self.sampling_rate,
-                                caption="prediction",
-                            ),
-                        ],
-                    },
-                )
-
-            if isinstance(self.logger, TensorBoardLogger):
-                self.logger.experiment.add_figure(
-                    f"sample-{idx}/mels",
-                    image_mels,
-                    global_step=self.global_step,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/gt",
-                    audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-                self.logger.experiment.add_audio(
-                    f"sample-{idx}/wavs/prediction",
-                    gen_audio[0, :audio_len],
-                    self.global_step,
-                    sample_rate=self.sampling_rate,
-                )
-
-            plt.close(image_mels)