|
|
@@ -1,6 +1,5 @@
|
|
|
-import itertools
|
|
|
-from dataclasses import dataclass
|
|
|
-from typing import Any, Callable, Literal, Optional
|
|
|
+import math
|
|
|
+from typing import Any, Callable
|
|
|
|
|
|
import lightning as L
|
|
|
import torch
|
|
|
@@ -22,7 +21,7 @@ class VQGAN(L.LightningModule):
|
|
|
encoder: WaveNet,
|
|
|
quantizer: nn.Module,
|
|
|
decoder: WaveNet,
|
|
|
- reflow: nn.Module,
|
|
|
+ # reflow: nn.Module,
|
|
|
vocoder: nn.Module,
|
|
|
mel_transform: nn.Module,
|
|
|
weight_reflow: float = 1.0,
|
|
|
@@ -44,7 +43,7 @@ class VQGAN(L.LightningModule):
|
|
|
self.quantizer = quantizer
|
|
|
self.decoder = decoder
|
|
|
self.vocoder = vocoder
|
|
|
- self.reflow = reflow
|
|
|
+ # self.reflow = reflow
|
|
|
self.mel_transform = mel_transform
|
|
|
|
|
|
# Freeze vocoder
|
|
|
@@ -122,51 +121,21 @@ class VQGAN(L.LightningModule):
|
|
|
vq_recon_features = vq_result.z * mel_masks_float_conv
|
|
|
|
|
|
# VQ Decode
|
|
|
- gen_mel = self.decoder(vq_recon_features) * mel_masks_float_conv
|
|
|
+ gen_mel = (
|
|
|
+ self.decoder(
|
|
|
+ torch.randn_like(vq_recon_features) * mel_masks_float_conv,
|
|
|
+ condition=vq_recon_features,
|
|
|
+ )
|
|
|
+ * mel_masks_float_conv
|
|
|
+ )
|
|
|
|
|
|
# Mel Loss
|
|
|
loss_mel = (gen_mel - gt_mels).abs().mean(
|
|
|
dim=1, keepdim=True
|
|
|
).sum() / mel_masks_float_conv.sum()
|
|
|
|
|
|
- # Reflow, given x_1_aux, we want to reconstruct x_1
|
|
|
- x_1 = self.norm_spec(gt_mels)
|
|
|
- t = torch.rand(gt_mels.shape[0], device=gt_mels.device, dtype=torch.float32)
|
|
|
- t = torch.clamp(t, 1e-6, 1 - 1e-6) # Avoid 0 and 1
|
|
|
- x_0 = torch.randn_like(x_1)
|
|
|
-
|
|
|
- # X_t = t * X_1 + (1 - t) * X_0
|
|
|
- x_t = x_0 + t[:, None, None] * (x_1 - x_0)
|
|
|
-
|
|
|
- v_pred = self.reflow(
|
|
|
- x_t,
|
|
|
- 1000 * t,
|
|
|
- vq_recon_features.detach(), # Stop gradients, avoid reflow to destroy the VQ
|
|
|
- )
|
|
|
-
|
|
|
- # Log L2 loss with
|
|
|
- with torch.autocast(device_type=gt_mels.device.type, dtype=torch.float32):
|
|
|
- weights = (
|
|
|
- 0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
|
|
|
- )
|
|
|
- assert (
|
|
|
- torch.isnan(weights).any() == False
|
|
|
- and torch.isinf(weights).any() == False
|
|
|
- ), "Found NaN or Inf in weights."
|
|
|
-
|
|
|
- loss_reflow = weights[:, None, None] * F.mse_loss(
|
|
|
- x_1 - x_0, v_pred, reduction="none"
|
|
|
- )
|
|
|
- loss_reflow = (loss_reflow * mel_masks_float_conv).mean(
|
|
|
- dim=1
|
|
|
- ).sum() / mel_masks_float_conv.sum()
|
|
|
-
|
|
|
# Total loss
|
|
|
- loss = (
|
|
|
- self.weight_vq * loss_vq
|
|
|
- + self.weight_mel * loss_mel
|
|
|
- + self.weight_reflow * loss_reflow
|
|
|
- )
|
|
|
+ loss = self.weight_vq * loss_vq + self.weight_mel * loss_mel
|
|
|
|
|
|
# Log losses
|
|
|
self.log(
|
|
|
@@ -193,14 +162,6 @@ class VQGAN(L.LightningModule):
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
)
|
|
|
- self.log(
|
|
|
- "train/generator/loss_reflow",
|
|
|
- loss_reflow,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- )
|
|
|
|
|
|
return loss
|
|
|
|
|
|
@@ -223,7 +184,13 @@ class VQGAN(L.LightningModule):
|
|
|
vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
|
|
|
|
|
|
# VQ Decode
|
|
|
- gen_aux_mels = self.decoder(vq_recon_features) * mel_masks_float_conv
|
|
|
+ gen_aux_mels = (
|
|
|
+ self.decoder(
|
|
|
+ torch.randn_like(vq_recon_features) * mel_masks_float_conv,
|
|
|
+ condition=vq_recon_features,
|
|
|
+ )
|
|
|
+ * mel_masks_float_conv
|
|
|
+ )
|
|
|
loss_mel = (gen_aux_mels - gt_mels).abs().mean(
|
|
|
dim=1, keepdim=True
|
|
|
).sum() / mel_masks_float_conv.sum()
|
|
|
@@ -238,45 +205,8 @@ class VQGAN(L.LightningModule):
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
|
|
|
- # Reflow inference
|
|
|
- t_start = 0.0
|
|
|
-
|
|
|
- x_1 = self.norm_spec(gen_aux_mels)
|
|
|
- x_0 = torch.randn_like(x_1)
|
|
|
- gen_reflow_mels = (1 - t_start) * x_0 + t_start * x_1
|
|
|
-
|
|
|
- t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
|
|
|
- dt = (1.0 - t_start) / self.reflow_inference_steps
|
|
|
-
|
|
|
- for _ in range(self.reflow_inference_steps):
|
|
|
- gen_reflow_mels += (
|
|
|
- self.reflow(
|
|
|
- gen_reflow_mels,
|
|
|
- 1000 * t,
|
|
|
- vq_recon_features,
|
|
|
- )
|
|
|
- * dt
|
|
|
- )
|
|
|
- t += dt
|
|
|
-
|
|
|
- gen_reflow_mels = self.denorm_spec(gen_reflow_mels) * mel_masks_float_conv
|
|
|
- loss_reflow_mel = (gen_reflow_mels - gt_mels).abs().mean(
|
|
|
- dim=1, keepdim=True
|
|
|
- ).sum() / mel_masks_float_conv.sum()
|
|
|
-
|
|
|
- self.log(
|
|
|
- "val/loss_reflow_mel",
|
|
|
- loss_reflow_mel,
|
|
|
- on_step=False,
|
|
|
- on_epoch=True,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
-
|
|
|
recon_audios = self.vocoder(gt_mels)
|
|
|
gen_aux_audios = self.vocoder(gen_aux_mels)
|
|
|
- gen_reflow_audios = self.vocoder(gen_reflow_mels)
|
|
|
|
|
|
# only log the first batch
|
|
|
if batch_idx != 0:
|
|
|
@@ -285,36 +215,33 @@ class VQGAN(L.LightningModule):
|
|
|
for idx, (
|
|
|
gt_mel,
|
|
|
gen_aux_mel,
|
|
|
- gen_reflow_mel,
|
|
|
audio,
|
|
|
gen_aux_audio,
|
|
|
- gen_reflow_audio,
|
|
|
recon_audio,
|
|
|
audio_len,
|
|
|
) in enumerate(
|
|
|
zip(
|
|
|
gt_mels,
|
|
|
gen_aux_mels,
|
|
|
- gen_reflow_mels,
|
|
|
- audios.float(),
|
|
|
- gen_aux_audios.float(),
|
|
|
- gen_reflow_audios.float(),
|
|
|
- recon_audios.float(),
|
|
|
+ audios.cpu().float(),
|
|
|
+ gen_aux_audios.cpu().float(),
|
|
|
+ recon_audios.cpu().float(),
|
|
|
audio_lengths,
|
|
|
)
|
|
|
):
|
|
|
+ if idx > 4:
|
|
|
+ break
|
|
|
+
|
|
|
mel_len = audio_len // self.mel_transform.hop_length
|
|
|
|
|
|
image_mels = plot_mel(
|
|
|
[
|
|
|
gt_mel[:, :mel_len],
|
|
|
gen_aux_mel[:, :mel_len],
|
|
|
- gen_reflow_mel[:, :mel_len],
|
|
|
],
|
|
|
[
|
|
|
"Ground-Truth",
|
|
|
"Auxiliary",
|
|
|
- "Reflow",
|
|
|
],
|
|
|
)
|
|
|
|
|
|
@@ -333,11 +260,6 @@ class VQGAN(L.LightningModule):
|
|
|
sample_rate=self.sampling_rate,
|
|
|
caption="aux",
|
|
|
),
|
|
|
- wandb.Audio(
|
|
|
- gen_reflow_audio[0, :audio_len],
|
|
|
- sample_rate=self.sampling_rate,
|
|
|
- caption="reflow",
|
|
|
- ),
|
|
|
wandb.Audio(
|
|
|
recon_audio[0, :audio_len],
|
|
|
sample_rate=self.sampling_rate,
|
|
|
@@ -365,12 +287,6 @@ class VQGAN(L.LightningModule):
|
|
|
self.global_step,
|
|
|
sample_rate=self.sampling_rate,
|
|
|
)
|
|
|
- self.logger.experiment.add_audio(
|
|
|
- f"sample-{idx}/wavs/reflow",
|
|
|
- gen_reflow_audio[0, :audio_len],
|
|
|
- self.global_step,
|
|
|
- sample_rate=self.sampling_rate,
|
|
|
- )
|
|
|
self.logger.experiment.add_audio(
|
|
|
f"sample-{idx}/wavs/recon",
|
|
|
recon_audio[0, :audio_len],
|
|
|
@@ -379,3 +295,37 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
|
|
|
|
plt.close(image_mels)
|
|
|
+
|
|
|
+ def encode(self, audios, audio_lengths):
|
|
|
+ audios = audios.float()
|
|
|
+
|
|
|
+ gt_mels = self.mel_transform(audios)
|
|
|
+ mel_lengths = audio_lengths // self.mel_transform.hop_length
|
|
|
+ mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
|
|
+ mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
|
+ gt_mels = gt_mels * mel_masks_float_conv
|
|
|
+
|
|
|
+ # Encode
|
|
|
+ encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
|
|
|
+ feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
|
|
|
+
|
|
|
+ return self.quantizer.encode(encoded_features), feature_lengths
|
|
|
+
|
|
|
+ def decode(self, indices, feature_lengths, return_audios=False):
|
|
|
+ factor = math.prod(self.quantizer.downsample_factor)
|
|
|
+ mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
|
|
|
+ mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
|
+
|
|
|
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
|
|
|
+ gen_mel = (
|
|
|
+ self.decoder(
|
|
|
+ torch.randn_like(z) * mel_masks_float_conv,
|
|
|
+ condition=z,
|
|
|
+ )
|
|
|
+ * mel_masks_float_conv
|
|
|
+ )
|
|
|
+
|
|
|
+ if return_audios:
|
|
|
+ return self.vocoder(gen_mel)
|
|
|
+
|
|
|
+ return gen_mel
|