Lengyue 2 anni fa
parent
commit
894085acbb
1 ha cambiato i file con 9 aggiunte e 1 eliminazioni
  1. 9 1
      fish_speech/models/vqgan/lit_module.py

+ 9 - 1
fish_speech/models/vqgan/lit_module.py

@@ -208,11 +208,19 @@ class VQGAN(L.LightningModule):
                 ids_slice // self.hop_length,
                 self.segment_size // self.hop_length,
             )
+            sliced_gt_mels = slice_segments(
+                gt_mels,
+                ids_slice // self.hop_length,
+                self.segment_size // self.hop_length,
+            )
             gen_mel_masks = slice_segments(
                 mel_masks,
                 ids_slice // self.hop_length,
                 self.segment_size // self.hop_length,
             )
+        else:
+            sliced_gt_mels = gt_mels
+            gen_mel_masks = mel_masks
 
         fake_audios = self.generator(input_mels)
         fake_audio_mels = self.mel_transform(fake_audios.squeeze(1))
@@ -248,7 +256,7 @@ class VQGAN(L.LightningModule):
         with torch.autocast(device_type=audios.device.type, enabled=False):
             loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
             loss_mel = F.l1_loss(
-                input_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
+                sliced_gt_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
             )
             loss_adv, _ = generator_loss(y_d_hat_g)
             loss_fm = feature_loss(fmap_r, fmap_g)