|
|
@@ -11,6 +11,7 @@ from torch import nn
|
|
|
|
|
|
from fish_speech.models.vqgan.modules.wavenet import WaveNet
|
|
|
from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
|
|
|
+from fish_speech.models.vqgan.modules.discriminator import Discriminator
|
|
|
|
|
|
|
|
|
class VQGAN(L.LightningModule):
|
|
|
@@ -21,16 +22,14 @@ class VQGAN(L.LightningModule):
|
|
|
encoder: WaveNet,
|
|
|
quantizer: nn.Module,
|
|
|
decoder: WaveNet,
|
|
|
- # reflow: nn.Module,
|
|
|
+ discriminator: Discriminator,
|
|
|
vocoder: nn.Module,
|
|
|
mel_transform: nn.Module,
|
|
|
- weight_reflow: float = 1.0,
|
|
|
+ weight_adv: float = 1.0,
|
|
|
weight_vq: float = 1.0,
|
|
|
weight_mel: float = 1.0,
|
|
|
sampling_rate: int = 44100,
|
|
|
freeze_encoder: bool = False,
|
|
|
- reflow_inference_steps: int = 10,
|
|
|
- reflow_inference_start_t: float = 0.5,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
@@ -43,7 +42,7 @@ class VQGAN(L.LightningModule):
|
|
|
self.quantizer = quantizer
|
|
|
self.decoder = decoder
|
|
|
self.vocoder = vocoder
|
|
|
- # self.reflow = reflow
|
|
|
+ self.discriminator = discriminator
|
|
|
self.mel_transform = mel_transform
|
|
|
|
|
|
# Freeze vocoder
|
|
|
@@ -51,7 +50,7 @@ class VQGAN(L.LightningModule):
|
|
|
param.requires_grad = False
|
|
|
|
|
|
# Loss weights
|
|
|
- self.weight_reflow = weight_reflow
|
|
|
+ self.weight_adv = weight_adv
|
|
|
self.weight_vq = weight_vq
|
|
|
self.weight_mel = weight_mel
|
|
|
|
|
|
@@ -59,8 +58,6 @@ class VQGAN(L.LightningModule):
|
|
|
self.spec_min = -12
|
|
|
self.spec_max = 3
|
|
|
self.sampling_rate = sampling_rate
|
|
|
- self.reflow_inference_steps = reflow_inference_steps
|
|
|
- self.reflow_inference_start_t = reflow_inference_start_t
|
|
|
|
|
|
# Disable strict loading
|
|
|
self.strict_loading = False
|
|
|
@@ -72,6 +69,8 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
for param in self.quantizer.parameters():
|
|
|
param.requires_grad = False
|
|
|
+
|
|
|
+ self.automatic_optimization = False
|
|
|
|
|
|
def on_save_checkpoint(self, checkpoint):
|
|
|
# Do not save vocoder
|
|
|
@@ -81,16 +80,16 @@ class VQGAN(L.LightningModule):
|
|
|
state_dict.pop(name)
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
- optimizer = self.optimizer_builder(self.parameters())
|
|
|
- lr_scheduler = self.lr_scheduler_builder(optimizer)
|
|
|
-
|
|
|
- return {
|
|
|
- "optimizer": optimizer,
|
|
|
- "lr_scheduler": {
|
|
|
- "scheduler": lr_scheduler,
|
|
|
- "interval": "step",
|
|
|
- },
|
|
|
- }
|
|
|
+ # optimizer = self.optimizer_builder(self.parameters())
|
|
|
+ # lr_scheduler = self.lr_scheduler_builder(optimizer)
|
|
|
+
|
|
|
+ # return {
|
|
|
+ # "optimizer": optimizer,
|
|
|
+ # "lr_scheduler": {
|
|
|
+ # "scheduler": lr_scheduler,
|
|
|
+ # "interval": "step",
|
|
|
+ # },
|
|
|
+ # }
|
|
|
|
|
|
def norm_spec(self, x):
|
|
|
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|