|
|
@@ -478,12 +478,7 @@ class VQNaive(L.LightningModule):
|
|
|
},
|
|
|
}
|
|
|
|
|
|
- def training_step(self, batch, batch_idx):
|
|
|
- audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
-
|
|
|
- audios = audios.float()
|
|
|
- audios = audios[:, None, :]
|
|
|
-
|
|
|
+ def vq_encode(self, audios, audio_lengths):
|
|
|
with torch.no_grad():
|
|
|
features = gt_mels = self.mel_transform(
|
|
|
audios, sample_rate=self.sampling_rate
|
|
|
@@ -506,17 +501,34 @@ class VQNaive(L.LightningModule):
|
|
|
gt_mels.dtype
|
|
|
)
|
|
|
|
|
|
- speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
-
|
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
|
text_features = self.mel_encoder(features, feature_masks)
|
|
|
- text_features, loss_vq = self.vq_encoder(text_features, feature_masks)
|
|
|
+ text_features, indices, loss_vq = self.vq_encoder(text_features, feature_masks)
|
|
|
+
|
|
|
+ return mel_masks, gt_mels, text_features, indices, loss_vq
|
|
|
+
|
|
|
+ def vq_decode(self, text_features, speaker_features, gt_mels, mel_masks):
|
|
|
text_features = F.interpolate(
|
|
|
text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
)
|
|
|
|
|
|
- # Sample mels
|
|
|
decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
+
|
|
|
+ return decoded_mels
|
|
|
+
|
|
|
+ def training_step(self, batch, batch_idx):
|
|
|
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
+
|
|
|
+ audios = audios.float()
|
|
|
+ audios = audios[:, None, :]
|
|
|
+
|
|
|
+ mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
|
|
|
+ audios, audio_lengths
|
|
|
+ )
|
|
|
+ speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
+ decoded_mels = self.vq_decode(
|
|
|
+ text_features, speaker_features, gt_mels, mel_masks
|
|
|
+ )
|
|
|
loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
|
loss = loss_mel + loss_vq
|
|
|
|
|
|
@@ -556,36 +568,13 @@ class VQNaive(L.LightningModule):
|
|
|
audios = audios.float()
|
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
- features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
-
|
|
|
- if self.downsample is not None:
|
|
|
- features = self.downsample(features)
|
|
|
-
|
|
|
- mel_lengths = audio_lengths // self.hop_length
|
|
|
- feature_lengths = (
|
|
|
- audio_lengths
|
|
|
- / self.hop_length
|
|
|
- / (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
- ).long()
|
|
|
-
|
|
|
- feature_masks = torch.unsqueeze(
|
|
|
- sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
- ).to(gt_mels.dtype)
|
|
|
- mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
- gt_mels.dtype
|
|
|
+ mel_masks, gt_mels, text_features, indices, loss_vq = self.vq_encode(
|
|
|
+ audios, audio_lengths
|
|
|
)
|
|
|
-
|
|
|
speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
-
|
|
|
- # vq_features is 50 hz, need to convert to true mel size
|
|
|
- text_features = self.mel_encoder(features, feature_masks)
|
|
|
- text_features, _ = self.vq_encoder(text_features, feature_masks)
|
|
|
- text_features = F.interpolate(
|
|
|
- text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
+ decoded_mels = self.vq_decode(
|
|
|
+ text_features, speaker_features, gt_mels, mel_masks
|
|
|
)
|
|
|
-
|
|
|
- # Sample mels
|
|
|
- decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
fake_audios = self.vocoder(decoded_mels)
|
|
|
|
|
|
mel_loss = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|