|
|
@@ -28,6 +28,7 @@ class VQDiffusion(L.LightningModule):
|
|
|
optimizer: Callable,
|
|
|
lr_scheduler: Callable,
|
|
|
mel_transform: nn.Module,
|
|
|
+ feature_mel_transform: nn.Module,
|
|
|
vq_encoder: VQEncoder,
|
|
|
speaker_encoder: SpeakerEncoder,
|
|
|
text_encoder: TextEncoder,
|
|
|
@@ -44,6 +45,7 @@ class VQDiffusion(L.LightningModule):
|
|
|
|
|
|
# Generator and discriminators
|
|
|
self.mel_transform = mel_transform
|
|
|
+ self.feature_mel_transform = feature_mel_transform
|
|
|
self.noise_scheduler_train = DDIMScheduler(num_train_timesteps=1000)
|
|
|
self.noise_scheduler_infer = UniPCMultistepScheduler(num_train_timesteps=1000)
|
|
|
|
|
|
@@ -91,26 +93,30 @@ class VQDiffusion(L.LightningModule):
|
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
- gt_mels = self.mel_transform(audios)
|
|
|
+ gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
+ features = self.feature_mel_transform(
|
|
|
+ audios, sample_rate=self.sampling_rate
|
|
|
+ )
|
|
|
|
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
|
-
|
|
|
+ feature_lengths = audio_lengths // self.hop_length // 2
|
|
|
feature_masks = torch.unsqueeze(
|
|
|
- sequence_mask(feature_lengths, features.shape[1]), 1
|
|
|
+ sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
).to(gt_mels.dtype)
|
|
|
mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
gt_mels.dtype
|
|
|
)
|
|
|
|
|
|
speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
- # vq_features, vq_loss = self.vq_encoder(features, feature_masks)
|
|
|
-
|
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
|
- text_features = self.text_encoder(features, feature_masks, g=speaker_features)
|
|
|
+ text_features = self.text_encoder(features, feature_masks)
|
|
|
+ text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
|
|
|
text_features = F.interpolate(
|
|
|
text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
)
|
|
|
|
|
|
+ text_features = text_features + speaker_features
|
|
|
+
|
|
|
# Sample noise that we'll add to the images
|
|
|
normalized_gt_mels = self.normalize_mels(gt_mels)
|
|
|
noise = torch.randn_like(normalized_gt_mels)
|
|
|
@@ -147,17 +153,17 @@ class VQDiffusion(L.LightningModule):
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
|
|
|
- # self.log(
|
|
|
- # "train/vq_loss",
|
|
|
- # vq_loss,
|
|
|
- # on_step=True,
|
|
|
- # on_epoch=False,
|
|
|
- # prog_bar=True,
|
|
|
- # logger=True,
|
|
|
- # sync_dist=True,
|
|
|
- # )
|
|
|
+ self.log(
|
|
|
+ "train/vq_loss",
|
|
|
+ vq_loss,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=True,
|
|
|
+ logger=True,
|
|
|
+ sync_dist=True,
|
|
|
+ )
|
|
|
|
|
|
- return noise_loss # + vq_loss
|
|
|
+ return noise_loss + vq_loss
|
|
|
|
|
|
def validation_step(self, batch: Any, batch_idx: int):
|
|
|
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
@@ -166,25 +172,29 @@ class VQDiffusion(L.LightningModule):
|
|
|
audios = audios.float()
|
|
|
# features = features.float().mT
|
|
|
audios = audios[:, None, :]
|
|
|
- gt_mels = self.mel_transform(audios)
|
|
|
- mel_lengths = audio_lengths // self.hop_length
|
|
|
+ gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
+ features = self.feature_mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
|
|
|
+ mel_lengths = audio_lengths // self.hop_length
|
|
|
+ feature_lengths = audio_lengths // self.hop_length // 2
|
|
|
feature_masks = torch.unsqueeze(
|
|
|
- sequence_mask(feature_lengths, features.shape[1]), 1
|
|
|
+ sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
).to(gt_mels.dtype)
|
|
|
mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
gt_mels.dtype
|
|
|
)
|
|
|
|
|
|
speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
- # vq_features, vq_loss = self.vq_encoder(features, feature_masks)
|
|
|
|
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
|
- text_features = self.text_encoder(features, feature_masks, g=speaker_features)
|
|
|
+ text_features = self.text_encoder(features, feature_masks)
|
|
|
+ text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
|
|
|
text_features = F.interpolate(
|
|
|
text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
)
|
|
|
|
|
|
+ text_features = text_features + speaker_features
|
|
|
+
|
|
|
# Begin sampling
|
|
|
sampled_mels = torch.randn_like(gt_mels)
|
|
|
self.noise_scheduler_infer.set_timesteps(100)
|