|
|
@@ -30,7 +30,6 @@ class VQGAN(L.LightningModule):
|
|
|
weight_mel: float = 1.0,
|
|
|
sampling_rate: int = 44100,
|
|
|
freeze_encoder: bool = False,
|
|
|
- reflow_use_shallow: bool = False,
|
|
|
reflow_inference_steps: int = 10,
|
|
|
reflow_inference_start_t: float = 0.5,
|
|
|
):
|
|
|
@@ -61,7 +60,6 @@ class VQGAN(L.LightningModule):
|
|
|
self.spec_min = -12
|
|
|
self.spec_max = 3
|
|
|
self.sampling_rate = sampling_rate
|
|
|
- self.reflow_use_shallow = reflow_use_shallow
|
|
|
self.reflow_inference_steps = reflow_inference_steps
|
|
|
self.reflow_inference_start_t = reflow_inference_start_t
|
|
|
|
|
|
@@ -133,18 +131,12 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
# Reflow, given x_1_aux, we want to reconstruct x_1
|
|
|
x_1 = self.norm_spec(gt_mels)
|
|
|
-
|
|
|
- if self.reflow_use_shallow:
|
|
|
- # Detach the gradient of the generated mel
|
|
|
- x_1_aux = self.norm_spec(gen_mel.detach())
|
|
|
- else:
|
|
|
- x_1_aux = x_1
|
|
|
-
|
|
|
- t = torch.rand(gt_mels.shape[0], device=gt_mels.device)
|
|
|
+ t = torch.rand(gt_mels.shape[0], device=gt_mels.device, dtype=torch.float32)
|
|
|
+ t = torch.clamp(t, 1e-6, 1 - 1e-6) # Avoid 0 and 1
|
|
|
x_0 = torch.randn_like(x_1)
|
|
|
|
|
|
# X_t = t * X_1 + (1 - t) * X_0
|
|
|
- x_t = x_0 + t[:, None, None] * (x_1_aux - x_0)
|
|
|
+ x_t = x_0 + t[:, None, None] * (x_1 - x_0)
|
|
|
|
|
|
v_pred = self.reflow(
|
|
|
x_t,
|
|
|
@@ -153,13 +145,21 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
|
|
|
|
# Log L2 loss with
|
|
|
- weights = 0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
|
|
|
- loss_reflow = weights[:, None, None] * F.mse_loss(
|
|
|
- x_1 - x_0, v_pred, reduction="none"
|
|
|
- )
|
|
|
- loss_reflow = (loss_reflow * mel_masks_float_conv).mean(
|
|
|
- dim=1
|
|
|
- ).sum() / mel_masks_float_conv.sum()
|
|
|
+ with torch.autocast(device_type=gt_mels.device.type, dtype=torch.float32):
|
|
|
+ weights = (
|
|
|
+ 0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
|
|
|
+ )
|
|
|
+ assert (
|
|
|
+ torch.isnan(weights).any() == False
|
|
|
+ and torch.isinf(weights).any() == False
|
|
|
+ ), "Found NaN or Inf in weights."
|
|
|
+
|
|
|
+ loss_reflow = weights[:, None, None] * F.mse_loss(
|
|
|
+ x_1 - x_0, v_pred, reduction="none"
|
|
|
+ )
|
|
|
+ loss_reflow = (loss_reflow * mel_masks_float_conv).mean(
|
|
|
+ dim=1
|
|
|
+ ).sum() / mel_masks_float_conv.sum()
|
|
|
|
|
|
# Total loss
|
|
|
loss = (
|
|
|
@@ -239,7 +239,7 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
|
|
|
|
# Reflow inference
|
|
|
- t_start = self.reflow_inference_start_t if self.reflow_use_shallow else 0.0
|
|
|
+ t_start = 0.0
|
|
|
|
|
|
x_1 = self.norm_spec(gen_aux_mels)
|
|
|
x_0 = torch.randn_like(x_1)
|