|
@@ -208,11 +208,19 @@ class VQGAN(L.LightningModule):
|
|
|
ids_slice // self.hop_length,
|
|
ids_slice // self.hop_length,
|
|
|
self.segment_size // self.hop_length,
|
|
self.segment_size // self.hop_length,
|
|
|
)
|
|
)
|
|
|
|
|
+ sliced_gt_mels = slice_segments(
|
|
|
|
|
+ gt_mels,
|
|
|
|
|
+ ids_slice // self.hop_length,
|
|
|
|
|
+ self.segment_size // self.hop_length,
|
|
|
|
|
+ )
|
|
|
gen_mel_masks = slice_segments(
|
|
gen_mel_masks = slice_segments(
|
|
|
mel_masks,
|
|
mel_masks,
|
|
|
ids_slice // self.hop_length,
|
|
ids_slice // self.hop_length,
|
|
|
self.segment_size // self.hop_length,
|
|
self.segment_size // self.hop_length,
|
|
|
)
|
|
)
|
|
|
|
|
+ else:
|
|
|
|
|
+ sliced_gt_mels = gt_mels
|
|
|
|
|
+ gen_mel_masks = mel_masks
|
|
|
|
|
|
|
|
fake_audios = self.generator(input_mels)
|
|
fake_audios = self.generator(input_mels)
|
|
|
fake_audio_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
fake_audio_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
@@ -248,7 +256,7 @@ class VQGAN(L.LightningModule):
|
|
|
with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
|
loss_mel = F.l1_loss(
|
|
loss_mel = F.l1_loss(
|
|
|
- input_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
|
|
|
|
|
|
|
+ sliced_gt_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
|
|
|
)
|
|
)
|
|
|
loss_adv, _ = generator_loss(y_d_hat_g)
|
|
loss_adv, _ = generator_loss(y_d_hat_g)
|
|
|
loss_fm = feature_loss(fmap_r, fmap_g)
|
|
loss_fm = feature_loss(fmap_r, fmap_g)
|