|
|
@@ -39,7 +39,6 @@ class VQGAN(L.LightningModule):
|
|
|
lr_scheduler: Callable,
|
|
|
downsample: ConvDownSampler,
|
|
|
vq_encoder: VQEncoder,
|
|
|
- speaker_encoder: SpeakerEncoder,
|
|
|
mel_encoder: TextEncoder,
|
|
|
decoder: TextEncoder,
|
|
|
generator: Generator,
|
|
|
@@ -163,7 +162,7 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
|
text_features = self.mel_encoder(features, feature_masks)
|
|
|
- text_features, loss_vq = self.vq_encoder(text_features, feature_masks)
|
|
|
+ text_features, _, loss_vq = self.vq_encoder(text_features, feature_masks)
|
|
|
text_features = F.interpolate(
|
|
|
text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
)
|
|
|
@@ -311,7 +310,7 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
|
text_features = self.mel_encoder(features, feature_masks)
|
|
|
- text_features, _ = self.vq_encoder(text_features, feature_masks)
|
|
|
+ text_features, _, _ = self.vq_encoder(text_features, feature_masks)
|
|
|
text_features = F.interpolate(
|
|
|
text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
)
|