|
|
@@ -18,7 +18,7 @@ from fish_speech.models.vqgan.losses import (
|
|
|
generator_loss,
|
|
|
kl_loss,
|
|
|
)
|
|
|
-from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
|
|
|
+from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
@@ -134,6 +134,13 @@ class VQGAN(L.LightningModule):
|
|
|
quantized,
|
|
|
) = self.generator(gt_specs, spec_lengths)
|
|
|
|
|
|
+ gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
|
|
|
+ spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
|
|
|
+ audios = slice_segments(
|
|
|
+ audios,
|
|
|
+ ids_slice * self.hop_length,
|
|
|
+ self.generator.segment_size * self.hop_length,
|
|
|
+ )
|
|
|
fake_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
|
|
|
|
assert (
|
|
|
@@ -223,8 +230,8 @@ class VQGAN(L.LightningModule):
|
|
|
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
|
|
|
|
|
|
self.log(
|
|
|
- "train/generator/loss",
|
|
|
- loss,
|
|
|
+ "train/generator/loss_kl",
|
|
|
+ loss_kl,
|
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
|
prog_bar=False,
|