فهرست منبع

weight mel loss at different resolution

Lengyue 2 سال پیش
والد
کامیت
4709094b9c
1فایلهای تغییر یافته به همراه22 افزوده شده و 2 حذف شده
  1. 22 2
      fish_speech/models/vqgan/lit_module.py

+ 22 - 2
fish_speech/models/vqgan/lit_module.py

@@ -174,8 +174,20 @@ class VQGAN(L.LightningModule):
         )
         optim_d.step()
 
-        # Mel Loss
-        loss_mel = avg_with_mask((gen_mel - gt_mels).abs(), mel_masks_float_conv)
+        # Mel Loss, applying l1, using a weighted sum
+        mel_distance = (
+            gen_mel - gt_mels
+        ).abs()  # * 0.5 + self.ssim(gen_mel, gt_mels) * 0.5
+        loss_mel_low_freq = avg_with_mask(mel_distance[:, :40, :], mel_masks_float_conv)
+        loss_mel_mid_freq = avg_with_mask(
+            mel_distance[:, 40:70, :], mel_masks_float_conv
+        )
+        loss_mel_high_freq = avg_with_mask(
+            mel_distance[:, 70:, :], mel_masks_float_conv
+        )
+        loss_mel = (
+            loss_mel_low_freq * 0.6 + loss_mel_mid_freq * 0.3 + loss_mel_high_freq * 0.1
+        )
 
         # Adversarial Loss
         fake_logits = self.discriminator(gen_mel)
@@ -221,6 +233,14 @@ class VQGAN(L.LightningModule):
             prog_bar=False,
             logger=True,
         )
+        self.log(
+            "train/generator/loss_speaker_id",
+            loss_speaker_id,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+        )
 
         # Generator backward
         optim_g.zero_grad()