|
|
@@ -45,10 +45,12 @@ class VQGAN(L.LightningModule):
|
|
|
generator: Generator,
|
|
|
discriminator: EnsembleDiscriminator,
|
|
|
mel_transform: nn.Module,
|
|
|
+ feature_mel_transform: nn.Module,
|
|
|
segment_size: int = 20480,
|
|
|
hop_length: int = 640,
|
|
|
sample_rate: int = 32000,
|
|
|
freeze_hifigan: bool = False,
|
|
|
+ freeze_vq: bool = False,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
@@ -65,6 +67,7 @@ class VQGAN(L.LightningModule):
|
|
|
self.generator = generator
|
|
|
self.discriminator = discriminator
|
|
|
self.mel_transform = mel_transform
|
|
|
+ self.feature_mel_transform = feature_mel_transform
|
|
|
|
|
|
# Crop length for saving memory
|
|
|
self.segment_size = segment_size
|
|
|
@@ -83,6 +86,17 @@ class VQGAN(L.LightningModule):
|
|
|
for p in self.generator.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
|
+ # Stage 2: Train the HifiGAN + Decoder + Generator
|
|
|
+ if freeze_vq:
|
|
|
+ for p in self.vq_encoder.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
+
|
|
|
+ for p in self.text_encoder.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
+
|
|
|
+ for p in self.downsample.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
+
|
|
|
def configure_optimizers(self):
|
|
|
# Need two optimizers and two schedulers
|
|
|
optimizer_generator = self.optimizer_builder(
|
|
|
@@ -130,15 +144,20 @@ class VQGAN(L.LightningModule):
|
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
- gt_mels = self.mel_transform(audios)
|
|
|
+ gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
+ features = self.feature_mel_transform(
|
|
|
+ audios, sample_rate=self.sampling_rate
|
|
|
+ )
|
|
|
|
|
|
if self.downsample is not None:
|
|
|
- features = self.downsample(gt_mels)
|
|
|
+ features = self.downsample(features)
|
|
|
|
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
|
feature_lengths = (
|
|
|
audio_lengths
|
|
|
- / self.hop_length
|
|
|
+ / self.sampling_rate
|
|
|
+ * self.feature_mel_transform.sample_rate
|
|
|
+ / self.feature_mel_transform.hop_length
|
|
|
/ (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
).long()
|
|
|
|
|
|
@@ -169,45 +188,43 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
assert y.shape == y_hat.shape, f"{y.shape} != {y_hat.shape}"
|
|
|
|
|
|
- # Discriminator
|
|
|
- y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
|
|
|
-
|
|
|
- with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
- loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
|
|
-
|
|
|
- self.log(
|
|
|
- "train/discriminator/loss",
|
|
|
- loss_disc_all,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=True,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
-
|
|
|
# Since we don't want to update the discriminator, we skip the backward pass
|
|
|
if self.freeze_hifigan is False:
|
|
|
+ # Discriminator
|
|
|
+ y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
|
|
|
+
|
|
|
+ with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
+ loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
|
|
+
|
|
|
+ self.log(
|
|
|
+ "train/discriminator/loss",
|
|
|
+ loss_disc_all,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=True,
|
|
|
+ logger=True,
|
|
|
+ sync_dist=True,
|
|
|
+ )
|
|
|
+
|
|
|
optim_d.zero_grad()
|
|
|
self.manual_backward(loss_disc_all)
|
|
|
self.clip_gradients(
|
|
|
- optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
+ optim_d, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
|
|
|
)
|
|
|
optim_d.step()
|
|
|
|
|
|
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
|
|
|
|
|
|
with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
- loss_decoded_mel = F.l1_loss(gt_mels, decoded_mels)
|
|
|
- loss_mel = F.l1_loss(gt_mels, y_hat_mels)
|
|
|
+ loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
|
+ loss_mel = F.l1_loss(gt_mels * mel_masks, y_hat_mels * mel_masks)
|
|
|
loss_adv, _ = generator_loss(y_d_hat_g)
|
|
|
loss_fm = feature_loss(fmap_r, fmap_g)
|
|
|
|
|
|
- mel_loss_weight = 25 if self.freeze_hifigan is True else 45
|
|
|
-
|
|
|
- loss_gen_all = loss_mel * mel_loss_weight + loss_fm + loss_adv + loss_vq
|
|
|
-
|
|
|
if self.freeze_hifigan is True:
|
|
|
- loss_gen_all += loss_decoded_mel * mel_loss_weight
|
|
|
+ loss_gen_all = loss_decoded_mel + loss_vq
|
|
|
+ else:
|
|
|
+ loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
|
|
|
|
|
|
self.log(
|
|
|
"train/generator/loss",
|
|
|
@@ -267,7 +284,7 @@ class VQGAN(L.LightningModule):
|
|
|
optim_g.zero_grad()
|
|
|
self.manual_backward(loss_gen_all)
|
|
|
self.clip_gradients(
|
|
|
- optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
+ optim_g, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
|
|
|
)
|
|
|
optim_g.step()
|
|
|
|
|
|
@@ -282,15 +299,18 @@ class VQGAN(L.LightningModule):
|
|
|
audios = audios.float()
|
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
- gt_mels = self.mel_transform(audios)
|
|
|
+ gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
+ features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
|
|
|
if self.downsample is not None:
|
|
|
- features = self.downsample(gt_mels)
|
|
|
+ features = self.downsample(features)
|
|
|
|
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
|
feature_lengths = (
|
|
|
audio_lengths
|
|
|
- / self.hop_length
|
|
|
+ / self.sampling_rate
|
|
|
+ * self.feature_mel_transform.sample_rate
|
|
|
+ / self.feature_mel_transform.hop_length
|
|
|
/ (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
).long()
|
|
|
|
|
|
@@ -301,11 +321,11 @@ class VQGAN(L.LightningModule):
|
|
|
gt_mels.dtype
|
|
|
)
|
|
|
|
|
|
- speaker_features = self.speaker_encoder(features, feature_masks)
|
|
|
+ speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
|
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
|
text_features = self.text_encoder(features, feature_masks)
|
|
|
- text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
|
|
|
+ text_features, _ = self.vq_encoder(text_features, feature_masks)
|
|
|
text_features = F.interpolate(
|
|
|
text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
)
|
|
|
@@ -323,7 +343,7 @@ class VQGAN(L.LightningModule):
|
|
|
gt_mels = gt_mels[:, :, :min_mel_length]
|
|
|
fake_mels = fake_mels[:, :, :min_mel_length]
|
|
|
|
|
|
- mel_loss = F.l1_loss(gt_mels, fake_mels)
|
|
|
+ mel_loss = F.l1_loss(gt_mels * mel_masks, fake_mels * mel_masks)
|
|
|
self.log(
|
|
|
"val/mel_loss",
|
|
|
mel_loss,
|
|
|
@@ -405,3 +425,244 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
|
|
|
|
plt.close(image_mels)
|
|
|
+
|
|
|
+
|
|
|
+class VQNaive(L.LightningModule):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ optimizer: Callable,
|
|
|
+ lr_scheduler: Callable,
|
|
|
+ downsample: ConvDownSampler,
|
|
|
+ vq_encoder: VQEncoder,
|
|
|
+ speaker_encoder: SpeakerEncoder,
|
|
|
+ mel_encoder: TextEncoder,
|
|
|
+ decoder: TextEncoder,
|
|
|
+ mel_transform: nn.Module,
|
|
|
+ hop_length: int = 640,
|
|
|
+ sample_rate: int = 32000,
|
|
|
+ vocoder: Generator = None,
|
|
|
+ ):
|
|
|
+ super().__init__()
|
|
|
+
|
|
|
+ # Model parameters
|
|
|
+ self.optimizer_builder = optimizer
|
|
|
+ self.lr_scheduler_builder = lr_scheduler
|
|
|
+
|
|
|
+ # Generator and discriminators
|
|
|
+ self.downsample = downsample
|
|
|
+ self.vq_encoder = vq_encoder
|
|
|
+ self.speaker_encoder = speaker_encoder
|
|
|
+ self.mel_encoder = mel_encoder
|
|
|
+ self.decoder = decoder
|
|
|
+ self.mel_transform = mel_transform
|
|
|
+
|
|
|
+ # Crop length for saving memory
|
|
|
+ self.hop_length = hop_length
|
|
|
+ self.sampling_rate = sample_rate
|
|
|
+
|
|
|
+ # Vocoder
|
|
|
+ self.vocoder = vocoder
|
|
|
+
|
|
|
+ for p in self.vocoder.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
+
|
|
|
+ def configure_optimizers(self):
|
|
|
+ optimizer = self.optimizer_builder(self.parameters())
|
|
|
+ lr_scheduler = self.lr_scheduler_builder(optimizer)
|
|
|
+
|
|
|
+ return {
|
|
|
+ "optimizer": optimizer,
|
|
|
+ "lr_scheduler": {
|
|
|
+ "scheduler": lr_scheduler,
|
|
|
+ "interval": "step",
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ def training_step(self, batch, batch_idx):
|
|
|
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
+
|
|
|
+ audios = audios.float()
|
|
|
+ audios = audios[:, None, :]
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ features = gt_mels = self.mel_transform(
|
|
|
+ audios, sample_rate=self.sampling_rate
|
|
|
+ )
|
|
|
+
|
|
|
+ if self.downsample is not None:
|
|
|
+ features = self.downsample(features)
|
|
|
+
|
|
|
+ mel_lengths = audio_lengths // self.hop_length
|
|
|
+ feature_lengths = (
|
|
|
+ audio_lengths
|
|
|
+ / self.hop_length
|
|
|
+ / (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
+ ).long()
|
|
|
+
|
|
|
+ feature_masks = torch.unsqueeze(
|
|
|
+ sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
+ ).to(gt_mels.dtype)
|
|
|
+ mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
+ gt_mels.dtype
|
|
|
+ )
|
|
|
+
|
|
|
+ speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
+
|
|
|
+ # vq_features is 50 hz, need to convert to true mel size
|
|
|
+ text_features = self.mel_encoder(features, feature_masks)
|
|
|
+ text_features, loss_vq = self.vq_encoder(text_features, feature_masks)
|
|
|
+ text_features = F.interpolate(
|
|
|
+ text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Sample mels
|
|
|
+ decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
+ loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
|
+ loss = loss_mel + loss_vq
|
|
|
+
|
|
|
+ self.log(
|
|
|
+ "train/generator/loss",
|
|
|
+ loss,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=True,
|
|
|
+ logger=True,
|
|
|
+ sync_dist=True,
|
|
|
+ )
|
|
|
+ self.log(
|
|
|
+ "train/loss_mel",
|
|
|
+ loss_mel,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ sync_dist=True,
|
|
|
+ )
|
|
|
+ self.log(
|
|
|
+ "train/generator/loss_vq",
|
|
|
+ loss_vq,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ sync_dist=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ return loss
|
|
|
+
|
|
|
+ def validation_step(self, batch: Any, batch_idx: int):
|
|
|
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
+
|
|
|
+ audios = audios.float()
|
|
|
+ audios = audios[:, None, :]
|
|
|
+
|
|
|
+ features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
+
|
|
|
+ if self.downsample is not None:
|
|
|
+ features = self.downsample(features)
|
|
|
+
|
|
|
+ mel_lengths = audio_lengths // self.hop_length
|
|
|
+ feature_lengths = (
|
|
|
+ audio_lengths
|
|
|
+ / self.hop_length
|
|
|
+ / (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
+ ).long()
|
|
|
+
|
|
|
+ feature_masks = torch.unsqueeze(
|
|
|
+ sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
+ ).to(gt_mels.dtype)
|
|
|
+ mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
+ gt_mels.dtype
|
|
|
+ )
|
|
|
+
|
|
|
+ speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
+
|
|
|
+ # vq_features is 50 hz, need to convert to true mel size
|
|
|
+ text_features = self.mel_encoder(features, feature_masks)
|
|
|
+ text_features, _ = self.vq_encoder(text_features, feature_masks)
|
|
|
+ text_features = F.interpolate(
|
|
|
+ text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Sample mels
|
|
|
+ decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
+ fake_audios = self.vocoder(decoded_mels)
|
|
|
+
|
|
|
+ mel_loss = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
|
+ self.log(
|
|
|
+ "val/mel_loss",
|
|
|
+ mel_loss,
|
|
|
+ on_step=False,
|
|
|
+ on_epoch=True,
|
|
|
+ prog_bar=True,
|
|
|
+ logger=True,
|
|
|
+ sync_dist=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ for idx, (
|
|
|
+ mel,
|
|
|
+ decoded_mel,
|
|
|
+ audio,
|
|
|
+ gen_audio,
|
|
|
+ audio_len,
|
|
|
+ ) in enumerate(
|
|
|
+ zip(
|
|
|
+ gt_mels,
|
|
|
+ decoded_mels,
|
|
|
+ audios.detach().float(),
|
|
|
+ fake_audios.detach().float(),
|
|
|
+ audio_lengths,
|
|
|
+ )
|
|
|
+ ):
|
|
|
+ mel_len = audio_len // self.hop_length
|
|
|
+
|
|
|
+ image_mels = plot_mel(
|
|
|
+ [
|
|
|
+ decoded_mel[:, :mel_len],
|
|
|
+ mel[:, :mel_len],
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ "Generated",
|
|
|
+ "Ground-Truth",
|
|
|
+ ],
|
|
|
+ )
|
|
|
+
|
|
|
+ if isinstance(self.logger, WandbLogger):
|
|
|
+ self.logger.experiment.log(
|
|
|
+ {
|
|
|
+ "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
|
|
|
+ "wavs": [
|
|
|
+ wandb.Audio(
|
|
|
+ audio[0, :audio_len],
|
|
|
+ sample_rate=self.sampling_rate,
|
|
|
+ caption="gt",
|
|
|
+ ),
|
|
|
+ wandb.Audio(
|
|
|
+ gen_audio[0, :audio_len],
|
|
|
+ sample_rate=self.sampling_rate,
|
|
|
+ caption="prediction",
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ if isinstance(self.logger, TensorBoardLogger):
|
|
|
+ self.logger.experiment.add_figure(
|
|
|
+ f"sample-{idx}/mels",
|
|
|
+ image_mels,
|
|
|
+ global_step=self.global_step,
|
|
|
+ )
|
|
|
+ self.logger.experiment.add_audio(
|
|
|
+ f"sample-{idx}/wavs/gt",
|
|
|
+ audio[0, :audio_len],
|
|
|
+ self.global_step,
|
|
|
+ sample_rate=self.sampling_rate,
|
|
|
+ )
|
|
|
+ self.logger.experiment.add_audio(
|
|
|
+ f"sample-{idx}/wavs/prediction",
|
|
|
+ gen_audio[0, :audio_len],
|
|
|
+ self.global_step,
|
|
|
+ sample_rate=self.sampling_rate,
|
|
|
+ )
|
|
|
+
|
|
|
+ plt.close(image_mels)
|