|
|
@@ -40,12 +40,11 @@ class VQGAN(L.LightningModule):
|
|
|
downsample: ConvDownSampler,
|
|
|
vq_encoder: VQEncoder,
|
|
|
speaker_encoder: SpeakerEncoder,
|
|
|
- text_encoder: TextEncoder,
|
|
|
+ mel_encoder: TextEncoder,
|
|
|
decoder: TextEncoder,
|
|
|
generator: Generator,
|
|
|
discriminator: EnsembleDiscriminator,
|
|
|
mel_transform: nn.Module,
|
|
|
- feature_mel_transform: nn.Module,
|
|
|
segment_size: int = 20480,
|
|
|
hop_length: int = 640,
|
|
|
sample_rate: int = 32000,
|
|
|
@@ -61,13 +60,11 @@ class VQGAN(L.LightningModule):
|
|
|
# Generator and discriminators
|
|
|
self.downsample = downsample
|
|
|
self.vq_encoder = vq_encoder
|
|
|
- self.speaker_encoder = speaker_encoder
|
|
|
- self.text_encoder = text_encoder
|
|
|
+ self.mel_encoder = mel_encoder
|
|
|
self.decoder = decoder
|
|
|
self.generator = generator
|
|
|
self.discriminator = discriminator
|
|
|
self.mel_transform = mel_transform
|
|
|
- self.feature_mel_transform = feature_mel_transform
|
|
|
|
|
|
# Crop length for saving memory
|
|
|
self.segment_size = segment_size
|
|
|
@@ -91,7 +88,7 @@ class VQGAN(L.LightningModule):
|
|
|
for p in self.vq_encoder.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
|
- for p in self.text_encoder.parameters():
|
|
|
+ for p in self.mel_encoder.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
|
for p in self.downsample.parameters():
|
|
|
@@ -103,8 +100,7 @@ class VQGAN(L.LightningModule):
|
|
|
itertools.chain(
|
|
|
self.downsample.parameters(),
|
|
|
self.vq_encoder.parameters(),
|
|
|
- self.speaker_encoder.parameters(),
|
|
|
- self.text_encoder.parameters(),
|
|
|
+ self.mel_encoder.parameters(),
|
|
|
self.decoder.parameters(),
|
|
|
self.generator.parameters(),
|
|
|
)
|
|
|
@@ -144,8 +140,7 @@ class VQGAN(L.LightningModule):
|
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
- gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
- features = self.feature_mel_transform(
|
|
|
+ features = gt_mels = self.mel_transform(
|
|
|
audios, sample_rate=self.sampling_rate
|
|
|
)
|
|
|
|
|
|
@@ -155,9 +150,7 @@ class VQGAN(L.LightningModule):
|
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
|
feature_lengths = (
|
|
|
audio_lengths
|
|
|
- / self.sampling_rate
|
|
|
- * self.feature_mel_transform.sample_rate
|
|
|
- / self.feature_mel_transform.hop_length
|
|
|
+ / self.hop_length
|
|
|
/ (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
).long()
|
|
|
|
|
|
@@ -168,17 +161,15 @@ class VQGAN(L.LightningModule):
|
|
|
gt_mels.dtype
|
|
|
)
|
|
|
|
|
|
- speaker_features = self.speaker_encoder(features, feature_masks)
|
|
|
-
|
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
|
- text_features = self.text_encoder(features, feature_masks)
|
|
|
+ text_features = self.mel_encoder(features, feature_masks)
|
|
|
text_features, loss_vq = self.vq_encoder(text_features, feature_masks)
|
|
|
text_features = F.interpolate(
|
|
|
text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
)
|
|
|
|
|
|
# Sample mels
|
|
|
- decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
+ decoded_mels = self.decoder(text_features, mel_masks)
|
|
|
fake_audios = self.generator(decoded_mels)
|
|
|
|
|
|
y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
|
@@ -299,8 +290,7 @@ class VQGAN(L.LightningModule):
|
|
|
audios = audios.float()
|
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
- gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
- features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
+ features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
|
|
|
if self.downsample is not None:
|
|
|
features = self.downsample(features)
|
|
|
@@ -308,9 +298,7 @@ class VQGAN(L.LightningModule):
|
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
|
feature_lengths = (
|
|
|
audio_lengths
|
|
|
- / self.sampling_rate
|
|
|
- * self.feature_mel_transform.sample_rate
|
|
|
- / self.feature_mel_transform.hop_length
|
|
|
+ / self.hop_length
|
|
|
/ (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
).long()
|
|
|
|
|
|
@@ -321,17 +309,15 @@ class VQGAN(L.LightningModule):
|
|
|
gt_mels.dtype
|
|
|
)
|
|
|
|
|
|
- speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
-
|
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
|
- text_features = self.text_encoder(features, feature_masks)
|
|
|
+ text_features = self.mel_encoder(features, feature_masks)
|
|
|
text_features, _ = self.vq_encoder(text_features, feature_masks)
|
|
|
text_features = F.interpolate(
|
|
|
text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
)
|
|
|
|
|
|
# Sample mels
|
|
|
- decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
+ decoded_mels = self.decoder(text_features, mel_masks)
|
|
|
fake_audios = self.generator(decoded_mels)
|
|
|
|
|
|
fake_mels = self.mel_transform(fake_audios.squeeze(1))
|