Lengyue 2 anni fa
parent
commit
3fb677a04c

+ 2 - 3
fish_speech/configs/vqgan_pretrain.yaml

@@ -20,7 +20,6 @@ hop_length: 640
 num_mels: 128
 n_fft: 2048
 win_length: 2048
-segment_size: 128
 
 # Dataset Configuration
 train_dataset:
@@ -28,7 +27,7 @@ train_dataset:
   filelist: data/vq_train_filelist.txt
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
-  slice_frames: ${segment_size}
+  slice_frames: 128
 
 val_dataset:
   _target_: fish_speech.datasets.vqgan.VQGANDataset
@@ -58,7 +57,7 @@ model:
   generator:
     _target_: fish_speech.models.vqgan.modules.models.SynthesizerTrn
     spec_channels: 1025
-    segment_size: 20480
+    segment_size: 32
     inter_channels: 192
     hidden_channels: 192
     filter_channels: 768

+ 10 - 3
fish_speech/models/vqgan/lit_module.py

@@ -18,7 +18,7 @@ from fish_speech.models.vqgan.losses import (
     generator_loss,
     kl_loss,
 )
-from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
+from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
 
 
 @dataclass
@@ -134,6 +134,13 @@ class VQGAN(L.LightningModule):
             quantized,
         ) = self.generator(gt_specs, spec_lengths)
 
+        gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
+        spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
+        audios = slice_segments(
+            audios,
+            ids_slice * self.hop_length,
+            self.generator.segment_size * self.hop_length,
+        )
         fake_mels = self.mel_transform(fake_audios.squeeze(1))
 
         assert (
@@ -223,8 +230,8 @@ class VQGAN(L.LightningModule):
         loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
 
         self.log(
-            "train/generator/loss",
-            loss,
+            "train/generator/loss_kl",
+            loss_kl,
             on_step=True,
             on_epoch=False,
             prog_bar=False,