Sfoglia il codice sorgente

Fix speaker encoder

Lengyue 2 anni fa
parent
commit
77371cc16b

+ 1 - 1
fish_speech/models/vqgan/lit_module.py

@@ -105,7 +105,7 @@ class VQGAN(L.LightningModule):
             (z_q_audio, z_p),
             (m_p_text, logs_p_text),
             (m_q, logs_q),
-        ) = self.generator(features, feature_lengths, gt_mels, feature_lengths)
+        ) = self.generator(features, feature_lengths, gt_mels)
 
         y_hat_mel = self.mel_transform(y_hat.squeeze(1))
         y_mel = slice_segments(gt_mels, ids_slice, self.segment_size // self.hop_length)

+ 1 - 1
fish_speech/models/vqgan/modules/encoders.py

@@ -202,6 +202,6 @@ class SpeakerEncoder(nn.Module):
         for block in self.blocks:
             x = block(x, x, x, key_padding_mask=x_mask)[0]
 
-        x = self.out_proj(x[:, 0])
+        x = self.out_proj(x[:, :1, :]).mT
 
         return x