|
|
@@ -1,3 +1,4 @@
|
|
|
+import itertools
|
|
|
import math
|
|
|
from typing import Any, Callable
|
|
|
|
|
|
@@ -9,9 +10,9 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
|
|
from matplotlib import pyplot as plt
|
|
|
from torch import nn
|
|
|
|
|
|
-from fish_speech.models.vqgan.modules.wavenet import WaveNet
|
|
|
-from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
|
|
|
from fish_speech.models.vqgan.modules.discriminator import Discriminator
|
|
|
+from fish_speech.models.vqgan.modules.wavenet import WaveNet
|
|
|
+from fish_speech.models.vqgan.utils import avg_with_mask, plot_mel, sequence_mask
|
|
|
|
|
|
|
|
|
class VQGAN(L.LightningModule):
|
|
|
@@ -55,8 +56,6 @@ class VQGAN(L.LightningModule):
|
|
|
self.weight_mel = weight_mel
|
|
|
|
|
|
# Other parameters
|
|
|
- self.spec_min = -12
|
|
|
- self.spec_max = 3
|
|
|
self.sampling_rate = sampling_rate
|
|
|
|
|
|
# Disable strict loading
|
|
|
@@ -69,7 +68,7 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
for param in self.quantizer.parameters():
|
|
|
param.requires_grad = False
|
|
|
-
|
|
|
+
|
|
|
self.automatic_optimization = False
|
|
|
|
|
|
def on_save_checkpoint(self, checkpoint):
|
|
|
@@ -80,24 +79,42 @@ class VQGAN(L.LightningModule):
|
|
|
state_dict.pop(name)
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
- # optimizer = self.optimizer_builder(self.parameters())
|
|
|
- # lr_scheduler = self.lr_scheduler_builder(optimizer)
|
|
|
-
|
|
|
- # return {
|
|
|
- # "optimizer": optimizer,
|
|
|
- # "lr_scheduler": {
|
|
|
- # "scheduler": lr_scheduler,
|
|
|
- # "interval": "step",
|
|
|
- # },
|
|
|
- # }
|
|
|
-
|
|
|
- def norm_spec(self, x):
|
|
|
- return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
|
|
+ optimizer_generator = self.optimizer_builder(
|
|
|
+ itertools.chain(
|
|
|
+ self.encoder.parameters(),
|
|
|
+ self.quantizer.parameters(),
|
|
|
+ self.decoder.parameters(),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ optimizer_discriminator = self.optimizer_builder(
|
|
|
+ self.discriminator.parameters()
|
|
|
+ )
|
|
|
|
|
|
- def denorm_spec(self, x):
|
|
|
- return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
|
|
+ lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
|
|
|
+ lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
|
|
|
+
|
|
|
+ return (
|
|
|
+ {
|
|
|
+ "optimizer": optimizer_generator,
|
|
|
+ "lr_scheduler": {
|
|
|
+ "scheduler": lr_scheduler_generator,
|
|
|
+ "interval": "step",
|
|
|
+ "name": "optimizer/generator",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "optimizer": optimizer_discriminator,
|
|
|
+ "lr_scheduler": {
|
|
|
+ "scheduler": lr_scheduler_discriminator,
|
|
|
+ "interval": "step",
|
|
|
+ "name": "optimizer/discriminator",
|
|
|
+ },
|
|
|
+ },
|
|
|
+ )
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
+ optim_g, optim_d = self.optimizers()
|
|
|
+
|
|
|
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
|
|
|
audios = audios.float()
|
|
|
@@ -128,13 +145,48 @@ class VQGAN(L.LightningModule):
|
|
|
* mel_masks_float_conv
|
|
|
)
|
|
|
|
|
|
+ # Discriminator
|
|
|
+ real_logits = self.discriminator(gt_mels)
|
|
|
+ fake_logits = self.discriminator(gen_mel.detach())
|
|
|
+ d_mask = F.interpolate(
|
|
|
+ mel_masks_float_conv, size=(real_logits.shape[2],), mode="nearest"
|
|
|
+ )
|
|
|
+
|
|
|
+ loss_real = avg_with_mask((real_logits - 1) ** 2, d_mask)
|
|
|
+ loss_fake = avg_with_mask(fake_logits**2, d_mask)
|
|
|
+
|
|
|
+ loss_d = loss_real + loss_fake
|
|
|
+
|
|
|
+ self.log(
|
|
|
+ "train/discriminator/loss",
|
|
|
+ loss_d,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=True,
|
|
|
+ logger=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Discriminator backward
|
|
|
+ optim_d.zero_grad()
|
|
|
+ self.manual_backward(loss_d)
|
|
|
+ self.clip_gradients(
|
|
|
+ optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
+ )
|
|
|
+ optim_d.step()
|
|
|
+
|
|
|
# Mel Loss
|
|
|
- loss_mel = (gen_mel - gt_mels).abs().mean(
|
|
|
- dim=1, keepdim=True
|
|
|
- ).sum() / mel_masks_float_conv.sum()
|
|
|
+ loss_mel = avg_with_mask((gen_mel - gt_mels).abs(), mel_masks_float_conv)
|
|
|
+
|
|
|
+ # Adversarial Loss
|
|
|
+ fake_logits = self.discriminator(gen_mel)
|
|
|
+ loss_adv = avg_with_mask((fake_logits - 1) ** 2, d_mask)
|
|
|
|
|
|
# Total loss
|
|
|
- loss = self.weight_vq * loss_vq + self.weight_mel * loss_mel
|
|
|
+ loss = (
|
|
|
+ self.weight_vq * loss_vq
|
|
|
+ + self.weight_mel * loss_mel
|
|
|
+ + self.weight_adv * loss_adv
|
|
|
+ )
|
|
|
|
|
|
# Log losses
|
|
|
self.log(
|
|
|
@@ -162,7 +214,17 @@ class VQGAN(L.LightningModule):
|
|
|
logger=True,
|
|
|
)
|
|
|
|
|
|
- return loss
|
|
|
+ # Generator backward
|
|
|
+ optim_g.zero_grad()
|
|
|
+ self.manual_backward(loss)
|
|
|
+ self.clip_gradients(
|
|
|
+ optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
+ )
|
|
|
+ optim_g.step()
|
|
|
+
|
|
|
+ scheduler_g, scheduler_d = self.lr_schedulers()
|
|
|
+ scheduler_g.step()
|
|
|
+ scheduler_d.step()
|
|
|
|
|
|
def validation_step(self, batch: Any, batch_idx: int):
|
|
|
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
@@ -190,9 +252,7 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
|
* mel_masks_float_conv
|
|
|
)
|
|
|
- loss_mel = (gen_aux_mels - gt_mels).abs().mean(
|
|
|
- dim=1, keepdim=True
|
|
|
- ).sum() / mel_masks_float_conv.sum()
|
|
|
+ loss_mel = avg_with_mask((gen_aux_mels - gt_mels).abs(), mel_masks_float_conv)
|
|
|
|
|
|
self.log(
|
|
|
"val/loss_mel",
|