|
|
@@ -111,7 +111,8 @@ class VQGAN(L.LightningModule):
|
|
|
if self.speaker_encoder is not None:
|
|
|
components.append(self.speaker_encoder.parameters())
|
|
|
|
|
|
- components.append(self.decoder.parameters())
|
|
|
+ if self.decoder is not None:
|
|
|
+ components.append(self.decoder.parameters())
|
|
|
|
|
|
if self.freeze_hifigan is False:
|
|
|
components.append(self.generator.parameters())
|
|
|
@@ -197,12 +198,16 @@ class VQGAN(L.LightningModule):
|
|
|
torch.set_grad_enabled(True)
|
|
|
|
|
|
# Sample mels
|
|
|
- speaker_features = (
|
|
|
- self.speaker_encoder(gt_mels, mel_masks)
|
|
|
- if self.speaker_encoder is not None
|
|
|
- else None
|
|
|
- )
|
|
|
- decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
+ if self.decoder is not None:
|
|
|
+ speaker_features = (
|
|
|
+ self.speaker_encoder(gt_mels, mel_masks)
|
|
|
+ if self.speaker_encoder is not None
|
|
|
+ else None
|
|
|
+ )
|
|
|
+ decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
+ else:
|
|
|
+ decoded_mels = text_features
|
|
|
+
|
|
|
fake_audios = self.generator(decoded_mels)
|
|
|
|
|
|
y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
|
@@ -350,12 +355,16 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
|
|
|
|
# Sample mels
|
|
|
- speaker_features = (
|
|
|
- self.speaker_encoder(gt_mels, mel_masks)
|
|
|
- if self.speaker_encoder is not None
|
|
|
- else None
|
|
|
- )
|
|
|
- decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
+ if self.decoder is not None:
|
|
|
+ speaker_features = (
|
|
|
+ self.speaker_encoder(gt_mels, mel_masks)
|
|
|
+ if self.speaker_encoder is not None
|
|
|
+ else None
|
|
|
+ )
|
|
|
+ decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
+ else:
|
|
|
+ decoded_mels = text_features
|
|
|
+
|
|
|
fake_audios = self.generator(decoded_mels)
|
|
|
|
|
|
fake_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
|
@@ -449,233 +458,3 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
|
|
|
|
plt.close(image_mels)
|
|
|
-
|
|
|
-
|
|
|
-class VQNaive(L.LightningModule):
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- optimizer: Callable,
|
|
|
- lr_scheduler: Callable,
|
|
|
- downsample: ConvDownSampler,
|
|
|
- vq_encoder: VQEncoder,
|
|
|
- speaker_encoder: SpeakerEncoder,
|
|
|
- mel_encoder: TextEncoder,
|
|
|
- decoder: TextEncoder,
|
|
|
- mel_transform: nn.Module,
|
|
|
- hop_length: int = 640,
|
|
|
- sample_rate: int = 32000,
|
|
|
- vocoder: Generator = None,
|
|
|
- ):
|
|
|
- super().__init__()
|
|
|
-
|
|
|
- # Model parameters
|
|
|
- self.optimizer_builder = optimizer
|
|
|
- self.lr_scheduler_builder = lr_scheduler
|
|
|
-
|
|
|
- # Generator and discriminators
|
|
|
- self.downsample = downsample
|
|
|
- self.vq_encoder = vq_encoder
|
|
|
- self.speaker_encoder = speaker_encoder
|
|
|
- self.mel_encoder = mel_encoder
|
|
|
- self.decoder = decoder
|
|
|
- self.mel_transform = mel_transform
|
|
|
-
|
|
|
- # Crop length for saving memory
|
|
|
- self.hop_length = hop_length
|
|
|
- self.sampling_rate = sample_rate
|
|
|
-
|
|
|
- # Vocoder
|
|
|
- self.vocoder = vocoder
|
|
|
-
|
|
|
- for p in self.vocoder.parameters():
|
|
|
- p.requires_grad = False
|
|
|
-
|
|
|
- def configure_optimizers(self):
|
|
|
- optimizer = self.optimizer_builder(self.parameters())
|
|
|
- lr_scheduler = self.lr_scheduler_builder(optimizer)
|
|
|
-
|
|
|
- return {
|
|
|
- "optimizer": optimizer,
|
|
|
- "lr_scheduler": {
|
|
|
- "scheduler": lr_scheduler,
|
|
|
- "interval": "step",
|
|
|
- },
|
|
|
- }
|
|
|
-
|
|
|
- def vq_encode(self, audios, audio_lengths):
|
|
|
- with torch.no_grad():
|
|
|
- features = gt_mels = self.mel_transform(
|
|
|
- audios, sample_rate=self.sampling_rate
|
|
|
- )
|
|
|
-
|
|
|
- if self.downsample is not None:
|
|
|
- features = self.downsample(features)
|
|
|
-
|
|
|
- mel_lengths = audio_lengths // self.hop_length
|
|
|
- feature_lengths = (
|
|
|
- audio_lengths
|
|
|
- / self.hop_length
|
|
|
- / (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
- ).long()
|
|
|
-
|
|
|
- feature_masks = torch.unsqueeze(
|
|
|
- sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
- ).to(gt_mels.dtype)
|
|
|
- mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
- gt_mels.dtype
|
|
|
- )
|
|
|
-
|
|
|
- # vq_features is 50 hz, need to convert to true mel size
|
|
|
- text_features = self.mel_encoder(features, feature_masks)
|
|
|
- text_features, indices, loss_vq = self.vq_encoder(text_features, feature_masks)
|
|
|
-
|
|
|
- return mel_masks, gt_mels, text_features, indices, loss_vq
|
|
|
-
|
|
|
- def vq_decode(self, text_features, speaker_features, gt_mels, mel_masks):
|
|
|
- text_features = F.interpolate(
|
|
|
- text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
- )
|
|
|
-
|
|
|
- decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
-
|
|
|
- return decoded_mels
|
|
|
-
|
|
|
- def training_step(self, batch, batch_idx):
|
|
|
- audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
-
|
|
|
- audios = audios.float()
|
|
|
- audios = audios[:, None, :]
|
|
|
-
|
|
|
- mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
|
|
|
- audios, audio_lengths
|
|
|
- )
|
|
|
- speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
- decoded_mels = self.vq_decode(
|
|
|
- text_features, speaker_features, gt_mels, mel_masks
|
|
|
- )
|
|
|
- loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
|
- loss = loss_mel + loss_vq
|
|
|
-
|
|
|
- self.log(
|
|
|
- "train/generator/loss",
|
|
|
- loss,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=True,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
- self.log(
|
|
|
- "train/loss_mel",
|
|
|
- loss_mel,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
- self.log(
|
|
|
- "train/generator/loss_vq",
|
|
|
- loss_vq,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
-
|
|
|
- return loss
|
|
|
-
|
|
|
- def validation_step(self, batch: Any, batch_idx: int):
|
|
|
- audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
-
|
|
|
- audios = audios.float()
|
|
|
- audios = audios[:, None, :]
|
|
|
-
|
|
|
- mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
|
|
|
- audios, audio_lengths
|
|
|
- )
|
|
|
- speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
- decoded_mels = self.vq_decode(
|
|
|
- text_features, speaker_features, gt_mels, mel_masks
|
|
|
- )
|
|
|
- fake_audios = self.vocoder(decoded_mels)
|
|
|
-
|
|
|
- mel_loss = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
|
- self.log(
|
|
|
- "val/mel_loss",
|
|
|
- mel_loss,
|
|
|
- on_step=False,
|
|
|
- on_epoch=True,
|
|
|
- prog_bar=True,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
-
|
|
|
- for idx, (
|
|
|
- mel,
|
|
|
- decoded_mel,
|
|
|
- audio,
|
|
|
- gen_audio,
|
|
|
- audio_len,
|
|
|
- ) in enumerate(
|
|
|
- zip(
|
|
|
- gt_mels,
|
|
|
- decoded_mels,
|
|
|
- audios.detach().float(),
|
|
|
- fake_audios.detach().float(),
|
|
|
- audio_lengths,
|
|
|
- )
|
|
|
- ):
|
|
|
- mel_len = audio_len // self.hop_length
|
|
|
-
|
|
|
- image_mels = plot_mel(
|
|
|
- [
|
|
|
- decoded_mel[:, :mel_len],
|
|
|
- mel[:, :mel_len],
|
|
|
- ],
|
|
|
- [
|
|
|
- "Generated",
|
|
|
- "Ground-Truth",
|
|
|
- ],
|
|
|
- )
|
|
|
-
|
|
|
- if isinstance(self.logger, WandbLogger):
|
|
|
- self.logger.experiment.log(
|
|
|
- {
|
|
|
- "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
|
|
|
- "wavs": [
|
|
|
- wandb.Audio(
|
|
|
- audio[0, :audio_len],
|
|
|
- sample_rate=self.sampling_rate,
|
|
|
- caption="gt",
|
|
|
- ),
|
|
|
- wandb.Audio(
|
|
|
- gen_audio[0, :audio_len],
|
|
|
- sample_rate=self.sampling_rate,
|
|
|
- caption="prediction",
|
|
|
- ),
|
|
|
- ],
|
|
|
- },
|
|
|
- )
|
|
|
-
|
|
|
- if isinstance(self.logger, TensorBoardLogger):
|
|
|
- self.logger.experiment.add_figure(
|
|
|
- f"sample-{idx}/mels",
|
|
|
- image_mels,
|
|
|
- global_step=self.global_step,
|
|
|
- )
|
|
|
- self.logger.experiment.add_audio(
|
|
|
- f"sample-{idx}/wavs/gt",
|
|
|
- audio[0, :audio_len],
|
|
|
- self.global_step,
|
|
|
- sample_rate=self.sampling_rate,
|
|
|
- )
|
|
|
- self.logger.experiment.add_audio(
|
|
|
- f"sample-{idx}/wavs/prediction",
|
|
|
- gen_audio[0, :audio_len],
|
|
|
- self.global_step,
|
|
|
- sample_rate=self.sampling_rate,
|
|
|
- )
|
|
|
-
|
|
|
- plt.close(image_mels)
|