|
|
@@ -72,6 +72,7 @@ class VQGAN(L.LightningModule):
|
|
|
self.hop_length = hop_length
|
|
|
self.sampling_rate = sample_rate
|
|
|
self.freeze_hifigan = freeze_hifigan
|
|
|
+ self.freeze_vq = freeze_vq
|
|
|
|
|
|
# Disable automatic optimization
|
|
|
self.automatic_optimization = False
|
|
|
@@ -164,7 +165,9 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
|
text_features = self.mel_encoder(features, feature_masks)
|
|
|
- text_features, _, loss_vq = self.vq_encoder(text_features, feature_masks)
|
|
|
+ text_features, _, loss_vq = self.vq_encoder(
|
|
|
+ text_features, feature_masks, freeze_codebook=self.freeze_vq
|
|
|
+ )
|
|
|
text_features = F.interpolate(
|
|
|
text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
)
|