|
|
@@ -25,7 +25,8 @@ class VQGAN(L.LightningModule):
|
|
|
decoder: WaveNet,
|
|
|
discriminator: Discriminator,
|
|
|
vocoder: nn.Module,
|
|
|
- mel_transform: nn.Module,
|
|
|
+ encode_mel_transform: nn.Module,
|
|
|
+ gt_mel_transform: nn.Module,
|
|
|
weight_adv: float = 1.0,
|
|
|
weight_vq: float = 1.0,
|
|
|
weight_mel: float = 1.0,
|
|
|
@@ -44,7 +45,11 @@ class VQGAN(L.LightningModule):
|
|
|
self.decoder = decoder
|
|
|
self.vocoder = vocoder
|
|
|
self.discriminator = discriminator
|
|
|
- self.mel_transform = mel_transform
|
|
|
+ self.encode_mel_transform = encode_mel_transform
|
|
|
+ self.gt_mel_transform = gt_mel_transform
|
|
|
+
|
|
|
+ # A simple linear layer to project quality to condition channels
|
|
|
+ self.quality_projection = nn.Linear(1, 768)
|
|
|
|
|
|
# Freeze vocoder
|
|
|
for param in self.vocoder.parameters():
|
|
|
@@ -84,6 +89,7 @@ class VQGAN(L.LightningModule):
|
|
|
self.encoder.parameters(),
|
|
|
self.quantizer.parameters(),
|
|
|
self.decoder.parameters(),
|
|
|
+ self.quality_projection.parameters(),
|
|
|
)
|
|
|
)
|
|
|
optimizer_discriminator = self.optimizer_builder(
|
|
|
@@ -121,20 +127,27 @@ class VQGAN(L.LightningModule):
|
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
- gt_mels = self.mel_transform(audios)
|
|
|
+ encoded_mels = self.encode_mel_transform(audios)
|
|
|
+ gt_mels = self.gt_mel_transform(audios)
|
|
|
+ quality = ((gt_mels.mean(-1) > -8).sum(-1) - 90) / 10
|
|
|
+ quality = quality.unsqueeze(-1)
|
|
|
|
|
|
- mel_lengths = audio_lengths // self.mel_transform.hop_length
|
|
|
+ mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
|
|
|
mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
|
|
mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
|
gt_mels = gt_mels * mel_masks_float_conv
|
|
|
+ encoded_mels = encoded_mels * mel_masks_float_conv
|
|
|
|
|
|
# Encode
|
|
|
- encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
|
|
|
+ encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
|
|
|
|
|
|
# Quantize
|
|
|
vq_result = self.quantizer(encoded_features)
|
|
|
loss_vq = getattr("vq_result", "loss", 0.0)
|
|
|
vq_recon_features = vq_result.z * mel_masks_float_conv
|
|
|
+ vq_recon_features = (
|
|
|
+ vq_recon_features + self.quality_projection(quality)[:, :, None]
|
|
|
+ )
|
|
|
|
|
|
# VQ Decode
|
|
|
gen_mel = (
|
|
|
@@ -233,14 +246,6 @@ class VQGAN(L.LightningModule):
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
)
|
|
|
- self.log(
|
|
|
- "train/generator/loss_speaker_id",
|
|
|
- loss_speaker_id,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- )
|
|
|
|
|
|
# Generator backward
|
|
|
optim_g.zero_grad()
|
|
|
@@ -260,17 +265,29 @@ class VQGAN(L.LightningModule):
|
|
|
audios = audios.float()
|
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
- gt_mels = self.mel_transform(audios)
|
|
|
- mel_lengths = audio_lengths // self.mel_transform.hop_length
|
|
|
+ encoded_mels = self.encode_mel_transform(audios)
|
|
|
+ gt_mels = self.gt_mel_transform(audios)
|
|
|
+
|
|
|
+ mel_lengths = audio_lengths // self.gt_mel_transform.hop_length
|
|
|
mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
|
|
mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
|
gt_mels = gt_mels * mel_masks_float_conv
|
|
|
+ encoded_mels = encoded_mels * mel_masks_float_conv
|
|
|
|
|
|
# Encode
|
|
|
- encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
|
|
|
+ encoded_features = self.encoder(encoded_mels) * mel_masks_float_conv
|
|
|
|
|
|
# Quantize
|
|
|
vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
|
|
|
+ vq_recon_features = (
|
|
|
+ vq_recon_features
|
|
|
+ + self.quality_projection(
|
|
|
+ torch.ones(
|
|
|
+ vq_recon_features.shape[0], 1, device=vq_recon_features.device
|
|
|
+ )
|
|
|
+ * 2
|
|
|
+ )[:, :, None]
|
|
|
+ )
|
|
|
|
|
|
# VQ Decode
|
|
|
gen_aux_mels = (
|
|
|
@@ -319,7 +336,7 @@ class VQGAN(L.LightningModule):
|
|
|
if idx > 4:
|
|
|
break
|
|
|
|
|
|
- mel_len = audio_len // self.mel_transform.hop_length
|
|
|
+ mel_len = audio_len // self.gt_mel_transform.hop_length
|
|
|
|
|
|
image_mels = plot_mel(
|
|
|
[
|
|
|
@@ -386,14 +403,14 @@ class VQGAN(L.LightningModule):
|
|
|
def encode(self, audios, audio_lengths):
|
|
|
audios = audios.float()
|
|
|
|
|
|
- gt_mels = self.mel_transform(audios)
|
|
|
- mel_lengths = audio_lengths // self.mel_transform.hop_length
|
|
|
- mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
|
|
+ mels = self.encode_mel_transform(audios)
|
|
|
+ mel_lengths = audio_lengths // self.encode_mel_transform.hop_length
|
|
|
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
|
|
|
mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
|
- gt_mels = gt_mels * mel_masks_float_conv
|
|
|
+ mels = mels * mel_masks_float_conv
|
|
|
|
|
|
# Encode
|
|
|
- encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
|
|
|
+ encoded_features = self.encoder(mels) * mel_masks_float_conv
|
|
|
feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
|
|
|
|
|
|
return self.quantizer.encode(encoded_features), feature_lengths
|
|
|
@@ -404,6 +421,13 @@ class VQGAN(L.LightningModule):
|
|
|
mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
|
|
|
|
z = self.quantizer.decode(indices) * mel_masks_float_conv
|
|
|
+ z = (
|
|
|
+ z
|
|
|
+ + self.quality_projection(torch.ones(z.shape[0], 1, device=z.device) * 2)[
|
|
|
+ :, :, None
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
gen_mel = (
|
|
|
self.decoder(
|
|
|
torch.randn_like(z) * mel_masks_float_conv,
|