|
@@ -1,70 +1,41 @@
|
|
|
from typing import Any, Callable
|
|
from typing import Any, Callable
|
|
|
|
|
|
|
|
|
|
+import lightning as L
|
|
|
import torch
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
|
-from fish_vocoder.models.vocoder import VocoderModel
|
|
|
|
|
-from fish_vocoder.modules.losses.stft import MultiResolutionSTFTLoss
|
|
|
|
|
-from fish_vocoder.utils.grad_norm import grad_norm
|
|
|
|
|
-from fish_vocoder.utils.mask import sequence_mask
|
|
|
|
|
from torch import nn
|
|
from torch import nn
|
|
|
from torch.utils.checkpoint import checkpoint as gradient_checkpointing
|
|
from torch.utils.checkpoint import checkpoint as gradient_checkpointing
|
|
|
|
|
|
|
|
|
|
|
|
|
-class GANModel(VocoderModel):
|
|
|
|
|
|
|
+class VQGAN(L.LightningModule):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
|
- sampling_rate: int,
|
|
|
|
|
- n_fft: int,
|
|
|
|
|
- hop_length: int,
|
|
|
|
|
- win_length: int,
|
|
|
|
|
- num_mels: int,
|
|
|
|
|
optimizer: Callable,
|
|
optimizer: Callable,
|
|
|
lr_scheduler: Callable,
|
|
lr_scheduler: Callable,
|
|
|
- mel_transforms: nn.ModuleDict,
|
|
|
|
|
|
|
+ encoder: nn.Module,
|
|
|
generator: nn.Module,
|
|
generator: nn.Module,
|
|
|
- discriminators: nn.ModuleDict,
|
|
|
|
|
- multi_resolution_stft_loss: MultiResolutionSTFTLoss,
|
|
|
|
|
- num_frames: int,
|
|
|
|
|
- crop_length: int | None = None,
|
|
|
|
|
- checkpointing: bool = False,
|
|
|
|
|
- feature_matching: bool = False,
|
|
|
|
|
|
|
+ discriminator: nn.Module,
|
|
|
|
|
+ mel_transform: nn.Module,
|
|
|
|
|
+ segment_size: int = 20480,
|
|
|
):
|
|
):
|
|
|
- super().__init__(
|
|
|
|
|
- sampling_rate=sampling_rate,
|
|
|
|
|
- n_fft=n_fft,
|
|
|
|
|
- hop_length=hop_length,
|
|
|
|
|
- win_length=win_length,
|
|
|
|
|
- num_mels=num_mels,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
|
|
|
# Model parameters
|
|
# Model parameters
|
|
|
self.optimizer_builder = optimizer
|
|
self.optimizer_builder = optimizer
|
|
|
self.lr_scheduler_builder = lr_scheduler
|
|
self.lr_scheduler_builder = lr_scheduler
|
|
|
|
|
|
|
|
- # Spectrogram transforms
|
|
|
|
|
- self.mel_transforms = mel_transforms
|
|
|
|
|
-
|
|
|
|
|
# Generator and discriminators
|
|
# Generator and discriminators
|
|
|
# Compile generator so that snake can save memory
|
|
# Compile generator so that snake can save memory
|
|
|
self.generator = generator
|
|
self.generator = generator
|
|
|
- self.discriminators = discriminators
|
|
|
|
|
-
|
|
|
|
|
- # Loss
|
|
|
|
|
- self.multi_resolution_stft_loss = multi_resolution_stft_loss
|
|
|
|
|
|
|
+ self.discriminator = discriminator
|
|
|
|
|
+ self.mel_transform = mel_transform
|
|
|
|
|
|
|
|
# Crop length for saving memory
|
|
# Crop length for saving memory
|
|
|
- self.num_frames = num_frames
|
|
|
|
|
- self.crop_length = crop_length
|
|
|
|
|
|
|
+ self.segment_size = segment_size
|
|
|
|
|
|
|
|
# Disable automatic optimization
|
|
# Disable automatic optimization
|
|
|
self.automatic_optimization = False
|
|
self.automatic_optimization = False
|
|
|
|
|
|
|
|
- # Gradient checkpointing
|
|
|
|
|
- self.checkpointing = checkpointing
|
|
|
|
|
-
|
|
|
|
|
- # Feature matching
|
|
|
|
|
- self.feature_matching = feature_matching
|
|
|
|
|
-
|
|
|
|
|
def configure_optimizers(self):
|
|
def configure_optimizers(self):
|
|
|
# Need two optimizers and two schedulers
|
|
# Need two optimizers and two schedulers
|
|
|
optimizer_generator = self.optimizer_builder(self.generator.parameters())
|
|
optimizer_generator = self.optimizer_builder(self.generator.parameters())
|
|
@@ -95,12 +66,7 @@ class GANModel(VocoderModel):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
def training_generator(self, audio, audio_mask):
|
|
def training_generator(self, audio, audio_mask):
|
|
|
- if self.training and self.checkpointing:
|
|
|
|
|
- fake_audio, base_loss = gradient_checkpointing(
|
|
|
|
|
- self.forward, audio, audio_mask, use_reentrant=False
|
|
|
|
|
- )
|
|
|
|
|
- else:
|
|
|
|
|
- fake_audio, base_loss = self.forward(audio, audio_mask)
|
|
|
|
|
|
|
+ # fake_audio, base_loss = self.forward(audio, audio_mask)
|
|
|
|
|
|
|
|
assert fake_audio.shape == audio.shape
|
|
assert fake_audio.shape == audio.shape
|
|
|
|
|
|
|
@@ -308,14 +274,6 @@ class GANModel(VocoderModel):
|
|
|
scheduler_g.step()
|
|
scheduler_g.step()
|
|
|
scheduler_d.step()
|
|
scheduler_d.step()
|
|
|
|
|
|
|
|
- def forward(self, audio, mask=None, input_spec=None):
|
|
|
|
|
- if input_spec is None:
|
|
|
|
|
- input_spec = self.mel_transforms.input(audio.squeeze(1))
|
|
|
|
|
-
|
|
|
|
|
- fake_audio = self.generator(input_spec)
|
|
|
|
|
-
|
|
|
|
|
- return fake_audio, 0
|
|
|
|
|
-
|
|
|
|
|
def validation_step(self, batch: Any, batch_idx: int):
|
|
def validation_step(self, batch: Any, batch_idx: int):
|
|
|
audio, lengths = batch["audio"], batch["lengths"]
|
|
audio, lengths = batch["audio"], batch["lengths"]
|
|
|
audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)
|
|
audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)
|