|
@@ -1,5 +1,5 @@
|
|
|
import itertools
|
|
import itertools
|
|
|
-from typing import Any, Callable
|
|
|
|
|
|
|
+from typing import Any, Callable, Literal
|
|
|
|
|
|
|
|
import lightning as L
|
|
import lightning as L
|
|
|
import torch
|
|
import torch
|
|
@@ -47,12 +47,15 @@ class VQGAN(L.LightningModule):
|
|
|
segment_size: int = 20480,
|
|
segment_size: int = 20480,
|
|
|
hop_length: int = 640,
|
|
hop_length: int = 640,
|
|
|
sample_rate: int = 32000,
|
|
sample_rate: int = 32000,
|
|
|
- freeze_hifigan: bool = False,
|
|
|
|
|
- freeze_vq: bool = False,
|
|
|
|
|
|
|
+ mode: Literal["pretrain-stage1", "pretrain-stage2", "finetune"] = "finetune",
|
|
|
speaker_encoder: SpeakerEncoder = None,
|
|
speaker_encoder: SpeakerEncoder = None,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
+ # pretrain-stage1: vq use gt mel as target, hifigan use gt mel as input
|
|
|
|
|
+ # pretrain-stage2: end-to-end training, use gt mel as hifi gan target
|
|
|
|
|
+ # finetune: end-to-end training, use gt mel as hifi gan target but freeze vq
|
|
|
|
|
+
|
|
|
# Model parameters
|
|
# Model parameters
|
|
|
self.optimizer_builder = optimizer
|
|
self.optimizer_builder = optimizer
|
|
|
self.lr_scheduler_builder = lr_scheduler
|
|
self.lr_scheduler_builder = lr_scheduler
|
|
@@ -71,22 +74,13 @@ class VQGAN(L.LightningModule):
|
|
|
self.segment_size = segment_size
|
|
self.segment_size = segment_size
|
|
|
self.hop_length = hop_length
|
|
self.hop_length = hop_length
|
|
|
self.sampling_rate = sample_rate
|
|
self.sampling_rate = sample_rate
|
|
|
- self.freeze_hifigan = freeze_hifigan
|
|
|
|
|
- self.freeze_vq = freeze_vq
|
|
|
|
|
|
|
+ self.mode = mode
|
|
|
|
|
|
|
|
# Disable automatic optimization
|
|
# Disable automatic optimization
|
|
|
self.automatic_optimization = False
|
|
self.automatic_optimization = False
|
|
|
|
|
|
|
|
- # Stage 1: Train the VQ only
|
|
|
|
|
- if self.freeze_hifigan:
|
|
|
|
|
- for p in self.discriminator.parameters():
|
|
|
|
|
- p.requires_grad = False
|
|
|
|
|
-
|
|
|
|
|
- for p in self.generator.parameters():
|
|
|
|
|
- p.requires_grad = False
|
|
|
|
|
-
|
|
|
|
|
- # Stage 2: Train the HifiGAN + Decoder + Generator
|
|
|
|
|
- if freeze_vq:
|
|
|
|
|
|
|
+ # Finetune: Train the VQ only
|
|
|
|
|
+ if self.mode == "finetune":
|
|
|
for p in self.vq_encoder.parameters():
|
|
for p in self.vq_encoder.parameters():
|
|
|
p.requires_grad = False
|
|
p.requires_grad = False
|
|
|
|
|
|
|
@@ -99,7 +93,7 @@ class VQGAN(L.LightningModule):
|
|
|
def configure_optimizers(self):
|
|
def configure_optimizers(self):
|
|
|
# Need two optimizers and two schedulers
|
|
# Need two optimizers and two schedulers
|
|
|
components = []
|
|
components = []
|
|
|
- if self.freeze_vq is False:
|
|
|
|
|
|
|
+ if self.mode != "finetune":
|
|
|
components.extend(
|
|
components.extend(
|
|
|
[
|
|
[
|
|
|
self.downsample.parameters(),
|
|
self.downsample.parameters(),
|
|
@@ -114,9 +108,7 @@ class VQGAN(L.LightningModule):
|
|
|
if self.decoder is not None:
|
|
if self.decoder is not None:
|
|
|
components.append(self.decoder.parameters())
|
|
components.append(self.decoder.parameters())
|
|
|
|
|
|
|
|
- if self.freeze_hifigan is False:
|
|
|
|
|
- components.append(self.generator.parameters())
|
|
|
|
|
-
|
|
|
|
|
|
|
+ components.append(self.generator.parameters())
|
|
|
optimizer_generator = self.optimizer_builder(itertools.chain(*components))
|
|
optimizer_generator = self.optimizer_builder(itertools.chain(*components))
|
|
|
optimizer_discriminator = self.optimizer_builder(
|
|
optimizer_discriminator = self.optimizer_builder(
|
|
|
self.discriminator.parameters()
|
|
self.discriminator.parameters()
|
|
@@ -157,7 +149,7 @@ class VQGAN(L.LightningModule):
|
|
|
audios, sample_rate=self.sampling_rate
|
|
audios, sample_rate=self.sampling_rate
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- if self.freeze_vq:
|
|
|
|
|
|
|
+ if self.mode == "finetune":
|
|
|
# Disable gradient computation for VQ
|
|
# Disable gradient computation for VQ
|
|
|
torch.set_grad_enabled(False)
|
|
torch.set_grad_enabled(False)
|
|
|
self.vq_encoder.eval()
|
|
self.vq_encoder.eval()
|
|
@@ -183,9 +175,7 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
|
text_features = self.mel_encoder(features, feature_masks)
|
|
text_features = self.mel_encoder(features, feature_masks)
|
|
|
- text_features, _, loss_vq = self.vq_encoder(
|
|
|
|
|
- text_features, feature_masks, freeze_codebook=self.freeze_vq
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ text_features, _, loss_vq = self.vq_encoder(text_features, feature_masks)
|
|
|
text_features = F.interpolate(
|
|
text_features = F.interpolate(
|
|
|
text_features, size=gt_mels.shape[2], mode="nearest"
|
|
text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
)
|
|
)
|
|
@@ -193,7 +183,7 @@ class VQGAN(L.LightningModule):
|
|
|
if loss_vq.ndim > 1:
|
|
if loss_vq.ndim > 1:
|
|
|
loss_vq = loss_vq.mean()
|
|
loss_vq = loss_vq.mean()
|
|
|
|
|
|
|
|
- if self.freeze_vq:
|
|
|
|
|
|
|
+ if self.mode == "finetune":
|
|
|
# Enable gradient computation
|
|
# Enable gradient computation
|
|
|
torch.set_grad_enabled(True)
|
|
torch.set_grad_enabled(True)
|
|
|
|
|
|
|
@@ -208,55 +198,69 @@ class VQGAN(L.LightningModule):
|
|
|
else:
|
|
else:
|
|
|
decoded_mels = text_features
|
|
decoded_mels = text_features
|
|
|
|
|
|
|
|
- fake_audios = self.generator(decoded_mels)
|
|
|
|
|
-
|
|
|
|
|
- y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
|
|
|
-
|
|
|
|
|
- y, ids_slice = rand_slice_segments(audios, audio_lengths, self.segment_size)
|
|
|
|
|
- y_hat = slice_segments(fake_audios, ids_slice, self.segment_size)
|
|
|
|
|
|
|
+ input_mels = gt_mels if self.mode == "pretrain-stage1" else decoded_mels
|
|
|
|
|
+ if self.segment_size is not None:
|
|
|
|
|
+ audios, ids_slice = rand_slice_segments(
|
|
|
|
|
+ audios, audio_lengths, self.segment_size
|
|
|
|
|
+ )
|
|
|
|
|
+ input_mels = slice_segments(
|
|
|
|
|
+ input_mels,
|
|
|
|
|
+ ids_slice // self.hop_length,
|
|
|
|
|
+ self.segment_size // self.hop_length,
|
|
|
|
|
+ )
|
|
|
|
|
+ gen_mel_masks = slice_segments(
|
|
|
|
|
+ mel_masks,
|
|
|
|
|
+ ids_slice // self.hop_length,
|
|
|
|
|
+ self.segment_size // self.hop_length,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- assert y.shape == y_hat.shape, f"{y.shape} != {y_hat.shape}"
|
|
|
|
|
|
|
+ fake_audios = self.generator(input_mels)
|
|
|
|
|
+ fake_audio_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
|
|
|
+ assert (
|
|
|
|
|
+ audios.shape == fake_audios.shape
|
|
|
|
|
+ ), f"{audios.shape} != {fake_audios.shape}"
|
|
|
|
|
|
|
|
- # Since we don't want to update the discriminator, we skip the backward pass
|
|
|
|
|
- if self.freeze_hifigan is False:
|
|
|
|
|
- # Discriminator
|
|
|
|
|
- y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
|
|
|
|
|
|
|
+ # Discriminator
|
|
|
|
|
+ y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios.detach())
|
|
|
|
|
|
|
|
- with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
|
|
- loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
|
|
|
|
|
|
+ with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
|
|
+ loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
|
|
|
|
|
|
|
- self.log(
|
|
|
|
|
- "train/discriminator/loss",
|
|
|
|
|
- loss_disc_all,
|
|
|
|
|
- on_step=True,
|
|
|
|
|
- on_epoch=False,
|
|
|
|
|
- prog_bar=True,
|
|
|
|
|
- logger=True,
|
|
|
|
|
- sync_dist=True,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ self.log(
|
|
|
|
|
+ "train/discriminator/loss",
|
|
|
|
|
+ loss_disc_all,
|
|
|
|
|
+ on_step=True,
|
|
|
|
|
+ on_epoch=False,
|
|
|
|
|
+ prog_bar=True,
|
|
|
|
|
+ logger=True,
|
|
|
|
|
+ sync_dist=True,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- optim_d.zero_grad()
|
|
|
|
|
- self.manual_backward(loss_disc_all)
|
|
|
|
|
- self.clip_gradients(
|
|
|
|
|
- optim_d, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
|
|
|
|
|
- )
|
|
|
|
|
- optim_d.step()
|
|
|
|
|
|
|
+ optim_d.zero_grad()
|
|
|
|
|
+ self.manual_backward(loss_disc_all)
|
|
|
|
|
+ self.clip_gradients(
|
|
|
|
|
+ optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
|
|
+ )
|
|
|
|
|
+ optim_d.step()
|
|
|
|
|
|
|
|
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
|
|
|
|
|
|
|
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(audios, fake_audios)
|
|
|
|
|
|
|
|
with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
|
- loss_mel = F.l1_loss(gt_mels * mel_masks, y_hat_mels * mel_masks)
|
|
|
|
|
|
|
+ loss_mel = F.l1_loss(
|
|
|
|
|
+ input_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
|
|
|
|
|
+ )
|
|
|
loss_adv, _ = generator_loss(y_d_hat_g)
|
|
loss_adv, _ = generator_loss(y_d_hat_g)
|
|
|
loss_fm = feature_loss(fmap_r, fmap_g)
|
|
loss_fm = feature_loss(fmap_r, fmap_g)
|
|
|
|
|
|
|
|
- if self.freeze_hifigan is True:
|
|
|
|
|
- loss_gen_all = loss_decoded_mel + loss_vq
|
|
|
|
|
|
|
+ if self.mode == "pretrain-stage1":
|
|
|
|
|
+ loss_vq_all = loss_decoded_mel + loss_vq
|
|
|
|
|
+ loss_gen_all = loss_mel * 45 + loss_fm + loss_adv
|
|
|
else:
|
|
else:
|
|
|
loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
|
|
loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
|
|
|
|
|
|
|
|
self.log(
|
|
self.log(
|
|
|
- "train/generator/loss",
|
|
|
|
|
|
|
+ "train/generator/loss_gen_all",
|
|
|
loss_gen_all,
|
|
loss_gen_all,
|
|
|
on_step=True,
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
on_epoch=False,
|
|
@@ -264,6 +268,18 @@ class VQGAN(L.LightningModule):
|
|
|
logger=True,
|
|
logger=True,
|
|
|
sync_dist=True,
|
|
sync_dist=True,
|
|
|
)
|
|
)
|
|
|
|
|
+
|
|
|
|
|
+ if self.mode == "pretrain-stage1":
|
|
|
|
|
+ self.log(
|
|
|
|
|
+ "train/generator/loss_vq_all",
|
|
|
|
|
+ loss_vq_all,
|
|
|
|
|
+ on_step=True,
|
|
|
|
|
+ on_epoch=False,
|
|
|
|
|
+ prog_bar=True,
|
|
|
|
|
+ logger=True,
|
|
|
|
|
+ sync_dist=True,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
self.log(
|
|
self.log(
|
|
|
"train/generator/loss_decoded_mel",
|
|
"train/generator/loss_decoded_mel",
|
|
|
loss_decoded_mel,
|
|
loss_decoded_mel,
|
|
@@ -311,9 +327,14 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
optim_g.zero_grad()
|
|
optim_g.zero_grad()
|
|
|
|
|
+
|
|
|
|
|
+ # Only backpropagate loss_vq_all in pretrain-stage1
|
|
|
|
|
+ if self.mode == "pretrain-stage1":
|
|
|
|
|
+ self.manual_backward(loss_vq_all)
|
|
|
|
|
+
|
|
|
self.manual_backward(loss_gen_all)
|
|
self.manual_backward(loss_gen_all)
|
|
|
self.clip_gradients(
|
|
self.clip_gradients(
|
|
|
- optim_g, gradient_clip_val=1.0, gradient_clip_algorithm="norm"
|
|
|
|
|
|
|
+ optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
)
|
|
)
|
|
|
optim_g.step()
|
|
optim_g.step()
|
|
|
|
|
|