|
|
@@ -174,8 +174,20 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
|
optim_d.step()
|
|
|
|
|
|
- # Mel Loss
|
|
|
- loss_mel = avg_with_mask((gen_mel - gt_mels).abs(), mel_masks_float_conv)
|
|
|
+ # Mel Loss, applying l1, using a weighted sum
|
|
|
+ mel_distance = (
|
|
|
+ gen_mel - gt_mels
|
|
|
+ ).abs() # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
|
|
|
+ loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
|
|
|
+ loss_mel_mid_freq = avg_with_mask(
|
|
|
+ mel_distance[:, 40:70, :], mel_masks_float_conv
|
|
|
+ )
|
|
|
+ loss_mel_high_freq = avg_with_mask(
|
|
|
+ mel_distance[:, 70:, :], mel_masks_float_conv
|
|
|
+ )
|
|
|
+ loss_mel = (
|
|
|
+ loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
|
|
|
+ )
|
|
|
|
|
|
# Adversarial Loss
|
|
|
fake_logits = self.discriminator(gen_mel)
|
|
|
@@ -221,6 +233,14 @@ class VQGAN(L.LightningModule):
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
)
|
|
|
+ self.log(
|
|
|
+ "train/generator/loss_speaker_id",
|
|
|
+ loss_speaker_id,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ )
|
|
|
|
|
|
# Generator backward
|
|
|
optim_g.zero_grad()
|