@@ -105,7 +105,7 @@ class VQGAN(L.LightningModule):
(z_q_audio, z_p),
(m_p_text, logs_p_text),
(m_q, logs_q),
- ) = self.generator(features, feature_lengths, gt_mels, feature_lengths)
+ ) = self.generator(features, feature_lengths, gt_mels)
y_hat_mel = self.mel_transform(y_hat.squeeze(1))
y_mel = slice_segments(gt_mels, ids_slice, self.segment_size // self.hop_length)
@@ -202,6 +202,6 @@ class SpeakerEncoder(nn.Module):
for block in self.blocks:
x = block(x, x, x, key_padding_mask=x_mask)[0]
- x = self.out_proj(x[:, 0])
+ x = self.out_proj(x[:, :1, :]).mT
return x