|
|
@@ -9,6 +9,7 @@ import wandb
|
|
|
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
|
|
from matplotlib import pyplot as plt
|
|
|
from torch import nn
|
|
|
+from torch.utils.checkpoint import checkpoint as gradient_checkpoint
|
|
|
|
|
|
from fish_speech.models.vqgan.losses import (
|
|
|
MultiResolutionSTFTLoss,
|
|
|
@@ -16,19 +17,9 @@ from fish_speech.models.vqgan.losses import (
|
|
|
feature_loss,
|
|
|
generator_loss,
|
|
|
)
|
|
|
-from fish_speech.models.vqgan.modules.balancer import Balancer
|
|
|
-from fish_speech.models.vqgan.modules.decoder import Generator
|
|
|
-from fish_speech.models.vqgan.modules.encoders import (
|
|
|
- ConvDownSampler,
|
|
|
- TextEncoder,
|
|
|
- VQEncoder,
|
|
|
-)
|
|
|
-from fish_speech.models.vqgan.utils import (
|
|
|
- plot_mel,
|
|
|
- rand_slice_segments,
|
|
|
- sequence_mask,
|
|
|
- slice_segments,
|
|
|
-)
|
|
|
+from fish_speech.models.vqgan.modules.convnext import ConvNeXt
|
|
|
+from fish_speech.models.vqgan.modules.encoders import VQEncoder
|
|
|
+from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
@@ -41,9 +32,7 @@ class VQEncodeResult:
|
|
|
|
|
|
@dataclass
|
|
|
class VQDecodeResult:
|
|
|
- audios: torch.Tensor
|
|
|
mels: torch.Tensor
|
|
|
- mel_lengths: torch.Tensor
|
|
|
|
|
|
|
|
|
class VQGAN(L.LightningModule):
|
|
|
@@ -51,19 +40,15 @@ class VQGAN(L.LightningModule):
|
|
|
self,
|
|
|
optimizer: Callable,
|
|
|
lr_scheduler: Callable,
|
|
|
- downsample: ConvDownSampler,
|
|
|
- vq_encoder: VQEncoder,
|
|
|
- mel_encoder: TextEncoder,
|
|
|
- decoder: TextEncoder,
|
|
|
- generator: Generator,
|
|
|
- discriminators: nn.ModuleDict,
|
|
|
+ encoder: ConvNeXt,
|
|
|
+ vq: VQEncoder,
|
|
|
+ decoder: ConvNeXt,
|
|
|
+ generator: nn.Module,
|
|
|
+ discriminator: ConvNeXt,
|
|
|
mel_transform: nn.Module,
|
|
|
- segment_size: int = 20480,
|
|
|
hop_length: int = 640,
|
|
|
sample_rate: int = 32000,
|
|
|
- mode: Literal["pretrain", "finetune"] = "finetune",
|
|
|
freeze_discriminator: bool = False,
|
|
|
- multi_resolution_stft_loss: Optional[MultiResolutionSTFTLoss] = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
@@ -74,68 +59,41 @@ class VQGAN(L.LightningModule):
|
|
|
self.optimizer_builder = optimizer
|
|
|
self.lr_scheduler_builder = lr_scheduler
|
|
|
|
|
|
- # Generator and discriminators
|
|
|
- self.downsample = downsample
|
|
|
- self.vq_encoder = vq_encoder
|
|
|
- self.mel_encoder = mel_encoder
|
|
|
+ # Generator and discriminator
|
|
|
+ self.encoder = encoder
|
|
|
+ self.vq = vq
|
|
|
self.decoder = decoder
|
|
|
self.generator = generator
|
|
|
- self.discriminators = discriminators
|
|
|
+ self.discriminator = discriminator
|
|
|
self.mel_transform = mel_transform
|
|
|
self.freeze_discriminator = freeze_discriminator
|
|
|
|
|
|
# Crop length for saving memory
|
|
|
- self.segment_size = segment_size
|
|
|
self.hop_length = hop_length
|
|
|
self.sampling_rate = sample_rate
|
|
|
- self.mode = mode
|
|
|
|
|
|
# Disable automatic optimization
|
|
|
self.automatic_optimization = False
|
|
|
|
|
|
- # Finetune: Train the VQ only
|
|
|
- if self.mode == "finetune":
|
|
|
- for p in self.vq_encoder.parameters():
|
|
|
- p.requires_grad = False
|
|
|
-
|
|
|
- for p in self.mel_encoder.parameters():
|
|
|
- p.requires_grad = False
|
|
|
-
|
|
|
- for p in self.downsample.parameters():
|
|
|
- p.requires_grad = False
|
|
|
-
|
|
|
if self.freeze_discriminator:
|
|
|
- for p in self.discriminators.parameters():
|
|
|
+ for p in self.discriminator.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
|
- # Losses
|
|
|
- self.multi_resolution_stft_loss = multi_resolution_stft_loss
|
|
|
- loss_dict = {
|
|
|
- "mel": 1,
|
|
|
- "adv": 1,
|
|
|
- "fm": 1,
|
|
|
- }
|
|
|
-
|
|
|
- if self.multi_resolution_stft_loss is not None:
|
|
|
- loss_dict["stft"] = 1
|
|
|
-
|
|
|
- self.balancer = Balancer(loss_dict)
|
|
|
+ # Freeze generator
|
|
|
+ for p in self.generator.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
# Need two optimizers and two schedulers
|
|
|
- components = [
|
|
|
- self.downsample.parameters(),
|
|
|
- self.vq_encoder.parameters(),
|
|
|
- self.mel_encoder.parameters(),
|
|
|
- ]
|
|
|
-
|
|
|
- if self.decoder is not None:
|
|
|
- components.append(self.decoder.parameters())
|
|
|
-
|
|
|
- components.append(self.generator.parameters())
|
|
|
- optimizer_generator = self.optimizer_builder(itertools.chain(*components))
|
|
|
+ optimizer_generator = self.optimizer_builder(
|
|
|
+ itertools.chain(
|
|
|
+ self.encoder.parameters(),
|
|
|
+ self.vq.parameters(),
|
|
|
+ self.decoder.parameters(),
|
|
|
+ )
|
|
|
+ )
|
|
|
optimizer_discriminator = self.optimizer_builder(
|
|
|
- self.discriminators.parameters()
|
|
|
+ self.discriminator.parameters()
|
|
|
)
|
|
|
|
|
|
lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
|
|
|
@@ -171,13 +129,6 @@ class VQGAN(L.LightningModule):
|
|
|
with torch.no_grad():
|
|
|
gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
|
|
|
- if self.mode == "finetune":
|
|
|
- # Disable gradient computation for VQ
|
|
|
- torch.set_grad_enabled(False)
|
|
|
- self.vq_encoder.eval()
|
|
|
- self.mel_encoder.eval()
|
|
|
- self.downsample.eval()
|
|
|
-
|
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
|
mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
gt_mels.dtype
|
|
|
@@ -189,186 +140,80 @@ class VQGAN(L.LightningModule):
|
|
|
if loss_vq.ndim > 1:
|
|
|
loss_vq = loss_vq.mean()
|
|
|
|
|
|
- if self.mode == "finetune":
|
|
|
- # Enable gradient computation
|
|
|
- torch.set_grad_enabled(True)
|
|
|
-
|
|
|
- decoded = self.decode(
|
|
|
- indices=vq_result.indices if self.mode == "finetune" else None,
|
|
|
- features=vq_result.features if self.mode == "pretrain" else None,
|
|
|
+ decoded_mels = self.decode(
|
|
|
+ indices=None,
|
|
|
+ features=vq_result.features,
|
|
|
audio_lengths=audio_lengths,
|
|
|
- mel_only=True,
|
|
|
- )
|
|
|
- decoded_mels = decoded.mels
|
|
|
- input_mels = gt_mels if self.mode == "pretrain" else decoded_mels
|
|
|
+ ).mels
|
|
|
|
|
|
- if self.segment_size is not None:
|
|
|
- audios, ids_slice = rand_slice_segments(
|
|
|
- audios, audio_lengths, self.segment_size
|
|
|
- )
|
|
|
- input_mels = slice_segments(
|
|
|
- input_mels,
|
|
|
- ids_slice // self.hop_length,
|
|
|
- self.segment_size // self.hop_length,
|
|
|
- )
|
|
|
- sliced_gt_mels = slice_segments(
|
|
|
- gt_mels,
|
|
|
- ids_slice // self.hop_length,
|
|
|
- self.segment_size // self.hop_length,
|
|
|
- )
|
|
|
- gen_mel_masks = slice_segments(
|
|
|
- mel_masks,
|
|
|
- ids_slice // self.hop_length,
|
|
|
- self.segment_size // self.hop_length,
|
|
|
- )
|
|
|
- else:
|
|
|
- sliced_gt_mels = gt_mels
|
|
|
- gen_mel_masks = mel_masks
|
|
|
+ with torch.no_grad():
|
|
|
+ with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
+ fake_audios = self.generator(decoded_mels.float())
|
|
|
|
|
|
- fake_audios = self.generator(input_mels)
|
|
|
- fake_audio_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
|
assert (
|
|
|
audios.shape == fake_audios.shape
|
|
|
), f"{audios.shape} != {fake_audios.shape}"
|
|
|
|
|
|
- # Multi-Resolution STFT Loss
|
|
|
- if self.multi_resolution_stft_loss is not None:
|
|
|
- with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
- sc_loss, mag_loss = self.multi_resolution_stft_loss(
|
|
|
- fake_audios.squeeze(1).float(), audios.squeeze(1).float()
|
|
|
- )
|
|
|
- loss_stft = sc_loss + mag_loss
|
|
|
-
|
|
|
# Discriminator
|
|
|
if self.freeze_discriminator is False:
|
|
|
- loss_disc_all = []
|
|
|
-
|
|
|
- for key, disc in self.discriminators.items():
|
|
|
- scores, _ = disc(audios)
|
|
|
- score_fakes, _ = disc(fake_audios.detach())
|
|
|
-
|
|
|
- with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
- loss_disc, _, _ = discriminator_loss(scores, score_fakes)
|
|
|
-
|
|
|
- self.log(
|
|
|
- f"train/discriminator/{key}",
|
|
|
- loss_disc,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
+ scores = self.discriminator(gt_mels)
|
|
|
+ score_fakes = self.discriminator(decoded_mels.detach())
|
|
|
|
|
|
- loss_disc_all.append(loss_disc)
|
|
|
-
|
|
|
- loss_disc_all = torch.stack(loss_disc_all).mean()
|
|
|
+ with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
+ loss_disc, _, _ = discriminator_loss([scores], [score_fakes])
|
|
|
|
|
|
self.log(
|
|
|
- "train/discriminator/loss",
|
|
|
- loss_disc_all,
|
|
|
+ f"train/discriminator/loss",
|
|
|
+ loss_disc,
|
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
|
- prog_bar=True,
|
|
|
+ prog_bar=False,
|
|
|
logger=True,
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
|
|
|
optim_d.zero_grad()
|
|
|
- self.manual_backward(loss_disc_all)
|
|
|
+ self.manual_backward(loss_disc)
|
|
|
self.clip_gradients(
|
|
|
optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
)
|
|
|
optim_d.step()
|
|
|
|
|
|
# Adv Loss
|
|
|
- loss_adv_all = []
|
|
|
- loss_fm_all = []
|
|
|
-
|
|
|
- for key, disc in self.discriminators.items():
|
|
|
- score_fakes, feat_fake = disc(fake_audios)
|
|
|
-
|
|
|
- # Adversarial Loss
|
|
|
- with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
- loss_fake, _ = generator_loss(score_fakes)
|
|
|
-
|
|
|
- self.log(
|
|
|
- f"train/generator/adv_{key}",
|
|
|
- loss_fake,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
-
|
|
|
- loss_adv_all.append(loss_fake)
|
|
|
-
|
|
|
- # Feature Matching Loss
|
|
|
- _, feat_real = disc(audios)
|
|
|
-
|
|
|
- with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
- loss_fm = feature_loss(feat_real, feat_fake)
|
|
|
-
|
|
|
- self.log(
|
|
|
- f"train/generator/adv_fm_{key}",
|
|
|
- loss_fm,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
-
|
|
|
- loss_fm_all.append(loss_fm)
|
|
|
-
|
|
|
- loss_adv_all = torch.stack(loss_adv_all).mean()
|
|
|
- loss_fm_all = torch.stack(loss_fm_all).mean()
|
|
|
+ score_fakes = self.discriminator(decoded_mels)
|
|
|
|
|
|
+ # Adversarial Loss
|
|
|
with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
- loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
|
- loss_mel = F.l1_loss(
|
|
|
- sliced_gt_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
|
|
|
- )
|
|
|
-
|
|
|
- loss_dict = {
|
|
|
- "mel": loss_mel,
|
|
|
- "adv": loss_adv_all,
|
|
|
- "fm": loss_fm_all,
|
|
|
- }
|
|
|
+ loss_adv, _ = generator_loss([score_fakes])
|
|
|
|
|
|
- if self.multi_resolution_stft_loss is not None:
|
|
|
- loss_dict["stft"] = loss_stft
|
|
|
+ self.log(
|
|
|
+ f"train/generator/adv",
|
|
|
+ loss_adv,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ sync_dist=True,
|
|
|
+ )
|
|
|
|
|
|
- generator_out_grad = self.balancer.compute(
|
|
|
- loss_dict,
|
|
|
- fake_audios,
|
|
|
- )
|
|
|
+ # Feature Matching Loss
|
|
|
+ score_gts = self.discriminator(gt_mels)
|
|
|
|
|
|
- if self.mode == "pretrain":
|
|
|
- loss_vq_all = loss_decoded_mel + loss_vq
|
|
|
+ with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
+ loss_fm = feature_loss([score_gts], [score_fakes])
|
|
|
|
|
|
- # Loss vq and loss decoded mel are only used in pretrain stage
|
|
|
- if self.mode == "pretrain":
|
|
|
- self.log(
|
|
|
- "train/generator/loss_vq",
|
|
|
- loss_vq,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
+ self.log(
|
|
|
+ f"train/generator/adv_fm",
|
|
|
+ loss_fm,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ sync_dist=True,
|
|
|
+ )
|
|
|
|
|
|
- self.log(
|
|
|
- "train/generator/loss_decoded_mel",
|
|
|
- loss_decoded_mel,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
+ with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
+ loss_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
|
|
|
|
self.log(
|
|
|
"train/generator/loss_mel",
|
|
|
@@ -380,29 +225,20 @@ class VQGAN(L.LightningModule):
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
|
|
|
- if self.multi_resolution_stft_loss is not None:
|
|
|
- self.log(
|
|
|
- "train/generator/loss_stft",
|
|
|
- loss_stft,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
-
|
|
|
self.log(
|
|
|
- "train/generator/loss_fm_all",
|
|
|
- loss_fm_all,
|
|
|
+ "train/generator/loss_vq",
|
|
|
+ loss_vq,
|
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
+
|
|
|
+ loss = loss_mel * 20 + loss_vq + loss_adv + loss_fm
|
|
|
self.log(
|
|
|
- "train/generator/loss_adv_all",
|
|
|
- loss_adv_all,
|
|
|
+ "train/generator/loss",
|
|
|
+ loss,
|
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
|
prog_bar=False,
|
|
|
@@ -412,11 +248,7 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
optim_g.zero_grad()
|
|
|
|
|
|
- # Only backpropagate loss_vq_all in pretrain stage
|
|
|
- if self.mode == "pretrain":
|
|
|
- self.manual_backward(loss_vq_all, retain_graph=True)
|
|
|
-
|
|
|
- self.manual_backward(fake_audios, gradient=generator_out_grad)
|
|
|
+ self.manual_backward(loss)
|
|
|
self.clip_gradients(
|
|
|
optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
)
|
|
|
@@ -440,19 +272,11 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
|
|
|
|
vq_result = self.encode(audios, audio_lengths)
|
|
|
- decoded = self.decode(
|
|
|
+ decoded_mels = self.decode(
|
|
|
indices=vq_result.indices,
|
|
|
audio_lengths=audio_lengths,
|
|
|
- mel_only=self.mode == "pretrain",
|
|
|
- )
|
|
|
-
|
|
|
- decoded_mels = decoded.mels
|
|
|
-
|
|
|
- # Use gt mel as input for pretrain
|
|
|
- if self.mode == "pretrain":
|
|
|
- fake_audios = self.generator(gt_mels)
|
|
|
- else:
|
|
|
- fake_audios = decoded.audios
|
|
|
+ ).mels
|
|
|
+ fake_audios = self.generator(decoded_mels)
|
|
|
|
|
|
fake_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
|
|
|
|
@@ -557,21 +381,25 @@ class VQGAN(L.LightningModule):
|
|
|
with torch.no_grad():
|
|
|
features = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
|
|
|
- if self.downsample is not None:
|
|
|
- features = self.downsample(features)
|
|
|
-
|
|
|
feature_lengths = (
|
|
|
audio_lengths
|
|
|
/ self.hop_length
|
|
|
- / (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
+ # / self.vq.downsample
|
|
|
).long()
|
|
|
|
|
|
+ # print(features.shape, feature_lengths.shape, torch.max(feature_lengths))
|
|
|
+
|
|
|
feature_masks = torch.unsqueeze(
|
|
|
sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
).to(features.dtype)
|
|
|
|
|
|
- text_features = self.mel_encoder(features, feature_masks)
|
|
|
- vq_features, indices, loss = self.vq_encoder(text_features, feature_masks)
|
|
|
+ features = (
|
|
|
+ gradient_checkpoint(
|
|
|
+ self.encoder, features, feature_masks, use_reentrant=False
|
|
|
+ )
|
|
|
+ * feature_masks
|
|
|
+ )
|
|
|
+ vq_features, indices, loss = self.vq(features, feature_masks)
|
|
|
|
|
|
return VQEncodeResult(
|
|
|
features=vq_features,
|
|
|
@@ -581,18 +409,13 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
|
|
|
|
def calculate_audio_lengths(self, feature_lengths):
|
|
|
- return (
|
|
|
- feature_lengths
|
|
|
- * self.hop_length
|
|
|
- * (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
- )
|
|
|
+ return feature_lengths * self.hop_length * self.vq.downsample
|
|
|
|
|
|
def decode(
|
|
|
self,
|
|
|
indices=None,
|
|
|
features=None,
|
|
|
audio_lengths=None,
|
|
|
- mel_only=False,
|
|
|
feature_lengths=None,
|
|
|
):
|
|
|
assert (
|
|
|
@@ -611,26 +434,11 @@ class VQGAN(L.LightningModule):
|
|
|
).float()
|
|
|
|
|
|
if indices is not None:
|
|
|
- features = self.vq_encoder.decode(indices)
|
|
|
-
|
|
|
- features = F.interpolate(features, size=mel_masks.shape[2], mode="nearest")
|
|
|
+ features = self.vq.decode(indices)
|
|
|
|
|
|
# Sample mels
|
|
|
- if self.decoder is not None:
|
|
|
- decoded_mels = self.decoder(features, mel_masks)
|
|
|
- else:
|
|
|
- decoded_mels = features
|
|
|
-
|
|
|
- if mel_only:
|
|
|
- return VQDecodeResult(
|
|
|
- audios=None,
|
|
|
- mels=decoded_mels,
|
|
|
- mel_lengths=mel_lengths,
|
|
|
- )
|
|
|
+ decoded = gradient_checkpoint(self.decoder, features, use_reentrant=False)
|
|
|
|
|
|
- fake_audios = self.generator(decoded_mels)
|
|
|
return VQDecodeResult(
|
|
|
- audios=fake_audios,
|
|
|
- mels=decoded_mels,
|
|
|
- mel_lengths=mel_lengths,
|
|
|
+ mels=decoded,
|
|
|
)
|