|
@@ -1,5 +1,6 @@
|
|
|
import itertools
|
|
import itertools
|
|
|
-from typing import Any, Callable, Literal
|
|
|
|
|
|
|
+from dataclasses import dataclass
|
|
|
|
|
+from typing import Any, Callable, Literal, Optional
|
|
|
|
|
|
|
|
import lightning as L
|
|
import lightning as L
|
|
|
import torch
|
|
import torch
|
|
@@ -8,19 +9,17 @@ import wandb
|
|
|
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
|
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
|
|
from matplotlib import pyplot as plt
|
|
from matplotlib import pyplot as plt
|
|
|
from torch import nn
|
|
from torch import nn
|
|
|
-from vector_quantize_pytorch import VectorQuantize
|
|
|
|
|
|
|
|
|
|
from fish_speech.models.vqgan.losses import (
|
|
from fish_speech.models.vqgan.losses import (
|
|
|
|
|
+ MultiResolutionSTFTLoss,
|
|
|
discriminator_loss,
|
|
discriminator_loss,
|
|
|
feature_loss,
|
|
feature_loss,
|
|
|
generator_loss,
|
|
generator_loss,
|
|
|
- kl_loss,
|
|
|
|
|
)
|
|
)
|
|
|
|
|
+from fish_speech.models.vqgan.modules.balancer import Balancer
|
|
|
from fish_speech.models.vqgan.modules.decoder import Generator
|
|
from fish_speech.models.vqgan.modules.decoder import Generator
|
|
|
-from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
|
|
|
|
|
from fish_speech.models.vqgan.modules.encoders import (
|
|
from fish_speech.models.vqgan.modules.encoders import (
|
|
|
ConvDownSampler,
|
|
ConvDownSampler,
|
|
|
- SpeakerEncoder,
|
|
|
|
|
TextEncoder,
|
|
TextEncoder,
|
|
|
VQEncoder,
|
|
VQEncoder,
|
|
|
)
|
|
)
|
|
@@ -32,6 +31,21 @@ from fish_speech.models.vqgan.utils import (
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+@dataclass
|
|
|
|
|
+class VQEncodeResult:
|
|
|
|
|
+ features: torch.Tensor
|
|
|
|
|
+ indices: torch.Tensor
|
|
|
|
|
+ loss: torch.Tensor
|
|
|
|
|
+ feature_lengths: torch.Tensor
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@dataclass
|
|
|
|
|
+class VQDecodeResult:
|
|
|
|
|
+ audios: torch.Tensor
|
|
|
|
|
+ mels: torch.Tensor
|
|
|
|
|
+ mel_lengths: torch.Tensor
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
class VQGAN(L.LightningModule):
|
|
class VQGAN(L.LightningModule):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
@@ -42,18 +56,18 @@ class VQGAN(L.LightningModule):
|
|
|
mel_encoder: TextEncoder,
|
|
mel_encoder: TextEncoder,
|
|
|
decoder: TextEncoder,
|
|
decoder: TextEncoder,
|
|
|
generator: Generator,
|
|
generator: Generator,
|
|
|
- discriminator: EnsembleDiscriminator,
|
|
|
|
|
|
|
+ discriminators: nn.ModuleDict,
|
|
|
mel_transform: nn.Module,
|
|
mel_transform: nn.Module,
|
|
|
segment_size: int = 20480,
|
|
segment_size: int = 20480,
|
|
|
hop_length: int = 640,
|
|
hop_length: int = 640,
|
|
|
sample_rate: int = 32000,
|
|
sample_rate: int = 32000,
|
|
|
- mode: Literal["pretrain-stage1", "pretrain-stage2", "finetune"] = "finetune",
|
|
|
|
|
- speaker_encoder: SpeakerEncoder = None,
|
|
|
|
|
|
|
+ mode: Literal["pretrain", "finetune"] = "finetune",
|
|
|
|
|
+ freeze_discriminator: bool = False,
|
|
|
|
|
+ multi_resolution_stft_loss: Optional[MultiResolutionSTFTLoss] = None,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
|
- # pretrain-stage1: vq use gt mel as target, hifigan use gt mel as input
|
|
|
|
|
- # pretrain-stage2: end-to-end training, use gt mel as hifi gan target
|
|
|
|
|
|
|
+ # pretrain: vq use gt mel as target, hifigan use gt mel as input
|
|
|
# finetune: end-to-end training, use gt mel as hifi gan target but freeze vq
|
|
# finetune: end-to-end training, use gt mel as hifi gan target but freeze vq
|
|
|
|
|
|
|
|
# Model parameters
|
|
# Model parameters
|
|
@@ -64,11 +78,11 @@ class VQGAN(L.LightningModule):
|
|
|
self.downsample = downsample
|
|
self.downsample = downsample
|
|
|
self.vq_encoder = vq_encoder
|
|
self.vq_encoder = vq_encoder
|
|
|
self.mel_encoder = mel_encoder
|
|
self.mel_encoder = mel_encoder
|
|
|
- self.speaker_encoder = speaker_encoder
|
|
|
|
|
self.decoder = decoder
|
|
self.decoder = decoder
|
|
|
self.generator = generator
|
|
self.generator = generator
|
|
|
- self.discriminator = discriminator
|
|
|
|
|
|
|
+ self.discriminators = discriminators
|
|
|
self.mel_transform = mel_transform
|
|
self.mel_transform = mel_transform
|
|
|
|
|
+ self.freeze_discriminator = freeze_discriminator
|
|
|
|
|
|
|
|
# Crop length for saving memory
|
|
# Crop length for saving memory
|
|
|
self.segment_size = segment_size
|
|
self.segment_size = segment_size
|
|
@@ -90,20 +104,30 @@ class VQGAN(L.LightningModule):
|
|
|
for p in self.downsample.parameters():
|
|
for p in self.downsample.parameters():
|
|
|
p.requires_grad = False
|
|
p.requires_grad = False
|
|
|
|
|
|
|
|
|
|
+ if self.freeze_discriminator:
|
|
|
|
|
+ for p in self.discriminators.parameters():
|
|
|
|
|
+ p.requires_grad = False
|
|
|
|
|
+
|
|
|
|
|
+ # Losses
|
|
|
|
|
+ self.multi_resolution_stft_loss = multi_resolution_stft_loss
|
|
|
|
|
+ loss_dict = {
|
|
|
|
|
+ "mel": 1,
|
|
|
|
|
+ "adv": 1,
|
|
|
|
|
+ "fm": 1,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if self.multi_resolution_stft_loss is not None:
|
|
|
|
|
+ loss_dict["stft"] = 1
|
|
|
|
|
+
|
|
|
|
|
+ self.balancer = Balancer(loss_dict)
|
|
|
|
|
+
|
|
|
def configure_optimizers(self):
|
|
def configure_optimizers(self):
|
|
|
# Need two optimizers and two schedulers
|
|
# Need two optimizers and two schedulers
|
|
|
- components = []
|
|
|
|
|
- if self.mode != "finetune":
|
|
|
|
|
- components.extend(
|
|
|
|
|
- [
|
|
|
|
|
- self.downsample.parameters(),
|
|
|
|
|
- self.vq_encoder.parameters(),
|
|
|
|
|
- self.mel_encoder.parameters(),
|
|
|
|
|
- ]
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- if self.speaker_encoder is not None:
|
|
|
|
|
- components.append(self.speaker_encoder.parameters())
|
|
|
|
|
|
|
+ components = [
|
|
|
|
|
+ self.downsample.parameters(),
|
|
|
|
|
+ self.vq_encoder.parameters(),
|
|
|
|
|
+ self.mel_encoder.parameters(),
|
|
|
|
|
+ ]
|
|
|
|
|
|
|
|
if self.decoder is not None:
|
|
if self.decoder is not None:
|
|
|
components.append(self.decoder.parameters())
|
|
components.append(self.decoder.parameters())
|
|
@@ -111,7 +135,7 @@ class VQGAN(L.LightningModule):
|
|
|
components.append(self.generator.parameters())
|
|
components.append(self.generator.parameters())
|
|
|
optimizer_generator = self.optimizer_builder(itertools.chain(*components))
|
|
optimizer_generator = self.optimizer_builder(itertools.chain(*components))
|
|
|
optimizer_discriminator = self.optimizer_builder(
|
|
optimizer_discriminator = self.optimizer_builder(
|
|
|
- self.discriminator.parameters()
|
|
|
|
|
|
|
+ self.discriminators.parameters()
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
|
|
lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
|
|
@@ -145,9 +169,7 @@ class VQGAN(L.LightningModule):
|
|
|
audios = audios[:, None, :]
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
|
- features = gt_mels = self.mel_transform(
|
|
|
|
|
- audios, sample_rate=self.sampling_rate
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
|
|
|
|
|
if self.mode == "finetune":
|
|
if self.mode == "finetune":
|
|
|
# Disable gradient computation for VQ
|
|
# Disable gradient computation for VQ
|
|
@@ -156,29 +178,13 @@ class VQGAN(L.LightningModule):
|
|
|
self.mel_encoder.eval()
|
|
self.mel_encoder.eval()
|
|
|
self.downsample.eval()
|
|
self.downsample.eval()
|
|
|
|
|
|
|
|
- if self.downsample is not None:
|
|
|
|
|
- features = self.downsample(features)
|
|
|
|
|
-
|
|
|
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
|
- feature_lengths = (
|
|
|
|
|
- audio_lengths
|
|
|
|
|
- / self.hop_length
|
|
|
|
|
- / (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
|
|
- ).long()
|
|
|
|
|
-
|
|
|
|
|
- feature_masks = torch.unsqueeze(
|
|
|
|
|
- sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
|
|
- ).to(gt_mels.dtype)
|
|
|
|
|
mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
gt_mels.dtype
|
|
gt_mels.dtype
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # vq_features is 50 hz, need to convert to true mel size
|
|
|
|
|
- text_features = self.mel_encoder(features, feature_masks)
|
|
|
|
|
- text_features, _, loss_vq = self.vq_encoder(text_features, feature_masks)
|
|
|
|
|
- text_features = F.interpolate(
|
|
|
|
|
- text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ vq_result = self.encode(audios, audio_lengths)
|
|
|
|
|
+ loss_vq = vq_result.loss
|
|
|
|
|
|
|
|
if loss_vq.ndim > 1:
|
|
if loss_vq.ndim > 1:
|
|
|
loss_vq = loss_vq.mean()
|
|
loss_vq = loss_vq.mean()
|
|
@@ -187,18 +193,15 @@ class VQGAN(L.LightningModule):
|
|
|
# Enable gradient computation
|
|
# Enable gradient computation
|
|
|
torch.set_grad_enabled(True)
|
|
torch.set_grad_enabled(True)
|
|
|
|
|
|
|
|
- # Sample mels
|
|
|
|
|
- if self.decoder is not None:
|
|
|
|
|
- speaker_features = (
|
|
|
|
|
- self.speaker_encoder(gt_mels, mel_masks)
|
|
|
|
|
- if self.speaker_encoder is not None
|
|
|
|
|
- else None
|
|
|
|
|
- )
|
|
|
|
|
- decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
|
|
- else:
|
|
|
|
|
- decoded_mels = text_features
|
|
|
|
|
|
|
+ decoded = self.decode(
|
|
|
|
|
+ indices=vq_result.indices if self.mode == "finetune" else None,
|
|
|
|
|
+ features=vq_result.features if self.mode == "pretrain" else None,
|
|
|
|
|
+ audio_lengths=audio_lengths,
|
|
|
|
|
+ mel_only=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ decoded_mels = decoded.mels
|
|
|
|
|
+ input_mels = gt_mels if self.mode == "pretrain" else decoded_mels
|
|
|
|
|
|
|
|
- input_mels = gt_mels if self.mode == "pretrain-stage1" else decoded_mels
|
|
|
|
|
if self.segment_size is not None:
|
|
if self.segment_size is not None:
|
|
|
audios, ids_slice = rand_slice_segments(
|
|
audios, ids_slice = rand_slice_segments(
|
|
|
audios, audio_lengths, self.segment_size
|
|
audios, audio_lengths, self.segment_size
|
|
@@ -228,75 +231,145 @@ class VQGAN(L.LightningModule):
|
|
|
audios.shape == fake_audios.shape
|
|
audios.shape == fake_audios.shape
|
|
|
), f"{audios.shape} != {fake_audios.shape}"
|
|
), f"{audios.shape} != {fake_audios.shape}"
|
|
|
|
|
|
|
|
|
|
+ # Multi-Resolution STFT Loss
|
|
|
|
|
+ if self.multi_resolution_stft_loss is not None:
|
|
|
|
|
+ with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
|
|
+ sc_loss, mag_loss = self.multi_resolution_stft_loss(
|
|
|
|
|
+ fake_audios.squeeze(1).float(), audios.squeeze(1).float()
|
|
|
|
|
+ )
|
|
|
|
|
+ loss_stft = sc_loss + mag_loss
|
|
|
|
|
+
|
|
|
# Discriminator
|
|
# Discriminator
|
|
|
- y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios.detach())
|
|
|
|
|
|
|
+ if self.freeze_discriminator is False:
|
|
|
|
|
+ loss_disc_all = []
|
|
|
|
|
+
|
|
|
|
|
+ for key, disc in self.discriminators.items():
|
|
|
|
|
+ scores, _ = disc(audios)
|
|
|
|
|
+ score_fakes, _ = disc(fake_audios.detach())
|
|
|
|
|
+
|
|
|
|
|
+ with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
|
|
+ loss_disc, _, _ = discriminator_loss(scores, score_fakes)
|
|
|
|
|
+
|
|
|
|
|
+ self.log(
|
|
|
|
|
+ f"train/discriminator/{key}",
|
|
|
|
|
+ loss_disc,
|
|
|
|
|
+ on_step=True,
|
|
|
|
|
+ on_epoch=False,
|
|
|
|
|
+ prog_bar=False,
|
|
|
|
|
+ logger=True,
|
|
|
|
|
+ sync_dist=True,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
|
|
- loss_disc_all, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
|
|
|
|
|
|
+ loss_disc_all.append(loss_disc)
|
|
|
|
|
|
|
|
- self.log(
|
|
|
|
|
- "train/discriminator/loss",
|
|
|
|
|
- loss_disc_all,
|
|
|
|
|
- on_step=True,
|
|
|
|
|
- on_epoch=False,
|
|
|
|
|
- prog_bar=True,
|
|
|
|
|
- logger=True,
|
|
|
|
|
- sync_dist=True,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ loss_disc_all = torch.stack(loss_disc_all).mean()
|
|
|
|
|
|
|
|
- optim_d.zero_grad()
|
|
|
|
|
- self.manual_backward(loss_disc_all)
|
|
|
|
|
- self.clip_gradients(
|
|
|
|
|
- optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
|
|
- )
|
|
|
|
|
- optim_d.step()
|
|
|
|
|
|
|
+ self.log(
|
|
|
|
|
+ "train/discriminator/loss",
|
|
|
|
|
+ loss_disc_all,
|
|
|
|
|
+ on_step=True,
|
|
|
|
|
+ on_epoch=False,
|
|
|
|
|
+ prog_bar=True,
|
|
|
|
|
+ logger=True,
|
|
|
|
|
+ sync_dist=True,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ optim_d.zero_grad()
|
|
|
|
|
+ self.manual_backward(loss_disc_all)
|
|
|
|
|
+ self.clip_gradients(
|
|
|
|
|
+ optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
|
|
+ )
|
|
|
|
|
+ optim_d.step()
|
|
|
|
|
+
|
|
|
|
|
+ # Adv Loss
|
|
|
|
|
+ loss_adv_all = []
|
|
|
|
|
+ loss_fm_all = []
|
|
|
|
|
+
|
|
|
|
|
+ for key, disc in self.discriminators.items():
|
|
|
|
|
+ score_fakes, feat_fake = disc(fake_audios)
|
|
|
|
|
+
|
|
|
|
|
+ # Adversarial Loss
|
|
|
|
|
+ with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
|
|
+ loss_fake, _ = generator_loss(score_fakes)
|
|
|
|
|
+
|
|
|
|
|
+ self.log(
|
|
|
|
|
+ f"train/generator/adv_{key}",
|
|
|
|
|
+ loss_fake,
|
|
|
|
|
+ on_step=True,
|
|
|
|
|
+ on_epoch=False,
|
|
|
|
|
+ prog_bar=False,
|
|
|
|
|
+ logger=True,
|
|
|
|
|
+ sync_dist=True,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ loss_adv_all.append(loss_fake)
|
|
|
|
|
+
|
|
|
|
|
+ # Feature Matching Loss
|
|
|
|
|
+ _, feat_real = disc(audios)
|
|
|
|
|
+
|
|
|
|
|
+ with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
|
|
+ loss_fm = feature_loss(feat_real, feat_fake)
|
|
|
|
|
|
|
|
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(audios, fake_audios)
|
|
|
|
|
|
|
+ self.log(
|
|
|
|
|
+ f"train/generator/adv_fm_{key}",
|
|
|
|
|
+ loss_fm,
|
|
|
|
|
+ on_step=True,
|
|
|
|
|
+ on_epoch=False,
|
|
|
|
|
+ prog_bar=False,
|
|
|
|
|
+ logger=True,
|
|
|
|
|
+ sync_dist=True,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ loss_fm_all.append(loss_fm)
|
|
|
|
|
+
|
|
|
|
|
+ loss_adv_all = torch.stack(loss_adv_all).mean()
|
|
|
|
|
+ loss_fm_all = torch.stack(loss_fm_all).mean()
|
|
|
|
|
|
|
|
with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
loss_decoded_mel = F.l1_loss(gt_mels * mel_masks, decoded_mels * mel_masks)
|
|
|
loss_mel = F.l1_loss(
|
|
loss_mel = F.l1_loss(
|
|
|
sliced_gt_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
|
|
sliced_gt_mels * gen_mel_masks, fake_audio_mels * gen_mel_masks
|
|
|
)
|
|
)
|
|
|
- loss_adv, _ = generator_loss(y_d_hat_g)
|
|
|
|
|
- loss_fm = feature_loss(fmap_r, fmap_g)
|
|
|
|
|
|
|
|
|
|
- if self.mode == "pretrain-stage1":
|
|
|
|
|
|
|
+ loss_dict = {
|
|
|
|
|
+ "mel": loss_mel,
|
|
|
|
|
+ "adv": loss_adv_all,
|
|
|
|
|
+ "fm": loss_fm_all,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if self.multi_resolution_stft_loss is not None:
|
|
|
|
|
+ loss_dict["stft"] = loss_stft
|
|
|
|
|
+
|
|
|
|
|
+ generator_out_grad = self.balancer.compute(
|
|
|
|
|
+ loss_dict,
|
|
|
|
|
+ fake_audios,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if self.mode == "pretrain":
|
|
|
loss_vq_all = loss_decoded_mel + loss_vq
|
|
loss_vq_all = loss_decoded_mel + loss_vq
|
|
|
- loss_gen_all = loss_mel * 45 + loss_fm + loss_adv
|
|
|
|
|
- else:
|
|
|
|
|
- loss_gen_all = loss_mel * 45 + loss_vq * 45 + loss_fm + loss_adv
|
|
|
|
|
|
|
|
|
|
- self.log(
|
|
|
|
|
- "train/generator/loss_gen_all",
|
|
|
|
|
- loss_gen_all,
|
|
|
|
|
- on_step=True,
|
|
|
|
|
- on_epoch=False,
|
|
|
|
|
- prog_bar=True,
|
|
|
|
|
- logger=True,
|
|
|
|
|
- sync_dist=True,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # Loss vq and loss decoded mel are only used in pretrain stage
|
|
|
|
|
+ if self.mode == "pretrain":
|
|
|
|
|
+ self.log(
|
|
|
|
|
+ "train/generator/loss_vq",
|
|
|
|
|
+ loss_vq,
|
|
|
|
|
+ on_step=True,
|
|
|
|
|
+ on_epoch=False,
|
|
|
|
|
+ prog_bar=False,
|
|
|
|
|
+ logger=True,
|
|
|
|
|
+ sync_dist=True,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- if self.mode == "pretrain-stage1":
|
|
|
|
|
self.log(
|
|
self.log(
|
|
|
- "train/generator/loss_vq_all",
|
|
|
|
|
- loss_vq_all,
|
|
|
|
|
|
|
+ "train/generator/loss_decoded_mel",
|
|
|
|
|
+ loss_decoded_mel,
|
|
|
on_step=True,
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
on_epoch=False,
|
|
|
- prog_bar=True,
|
|
|
|
|
|
|
+ prog_bar=False,
|
|
|
logger=True,
|
|
logger=True,
|
|
|
sync_dist=True,
|
|
sync_dist=True,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- self.log(
|
|
|
|
|
- "train/generator/loss_decoded_mel",
|
|
|
|
|
- loss_decoded_mel,
|
|
|
|
|
- on_step=True,
|
|
|
|
|
- on_epoch=False,
|
|
|
|
|
- prog_bar=False,
|
|
|
|
|
- logger=True,
|
|
|
|
|
- sync_dist=True,
|
|
|
|
|
- )
|
|
|
|
|
self.log(
|
|
self.log(
|
|
|
"train/generator/loss_mel",
|
|
"train/generator/loss_mel",
|
|
|
loss_mel,
|
|
loss_mel,
|
|
@@ -306,18 +379,21 @@ class VQGAN(L.LightningModule):
|
|
|
logger=True,
|
|
logger=True,
|
|
|
sync_dist=True,
|
|
sync_dist=True,
|
|
|
)
|
|
)
|
|
|
|
|
+
|
|
|
|
|
+ if self.multi_resolution_stft_loss is not None:
|
|
|
|
|
+ self.log(
|
|
|
|
|
+ "train/generator/loss_stft",
|
|
|
|
|
+ loss_stft,
|
|
|
|
|
+ on_step=True,
|
|
|
|
|
+ on_epoch=False,
|
|
|
|
|
+ prog_bar=False,
|
|
|
|
|
+ logger=True,
|
|
|
|
|
+ sync_dist=True,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
self.log(
|
|
self.log(
|
|
|
- "train/generator/loss_fm",
|
|
|
|
|
- loss_fm,
|
|
|
|
|
- on_step=True,
|
|
|
|
|
- on_epoch=False,
|
|
|
|
|
- prog_bar=False,
|
|
|
|
|
- logger=True,
|
|
|
|
|
- sync_dist=True,
|
|
|
|
|
- )
|
|
|
|
|
- self.log(
|
|
|
|
|
- "train/generator/loss_adv",
|
|
|
|
|
- loss_adv,
|
|
|
|
|
|
|
+ "train/generator/loss_fm_all",
|
|
|
|
|
+ loss_fm_all,
|
|
|
on_step=True,
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
on_epoch=False,
|
|
|
prog_bar=False,
|
|
prog_bar=False,
|
|
@@ -325,8 +401,8 @@ class VQGAN(L.LightningModule):
|
|
|
sync_dist=True,
|
|
sync_dist=True,
|
|
|
)
|
|
)
|
|
|
self.log(
|
|
self.log(
|
|
|
- "train/generator/loss_vq",
|
|
|
|
|
- loss_vq,
|
|
|
|
|
|
|
+ "train/generator/loss_adv_all",
|
|
|
|
|
+ loss_adv_all,
|
|
|
on_step=True,
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
on_epoch=False,
|
|
|
prog_bar=False,
|
|
prog_bar=False,
|
|
@@ -336,11 +412,11 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
|
|
optim_g.zero_grad()
|
|
optim_g.zero_grad()
|
|
|
|
|
|
|
|
- # Only backpropagate loss_vq_all in pretrain-stage1
|
|
|
|
|
- if self.mode == "pretrain-stage1":
|
|
|
|
|
- self.manual_backward(loss_vq_all)
|
|
|
|
|
|
|
+ # Only backpropagate loss_vq_all in pretrain stage
|
|
|
|
|
+ if self.mode == "pretrain":
|
|
|
|
|
+ self.manual_backward(loss_vq_all, retain_graph=True)
|
|
|
|
|
|
|
|
- self.manual_backward(loss_gen_all)
|
|
|
|
|
|
|
+ self.manual_backward(fake_audios, gradient=generator_out_grad)
|
|
|
self.clip_gradients(
|
|
self.clip_gradients(
|
|
|
optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
)
|
|
)
|
|
@@ -357,44 +433,26 @@ class VQGAN(L.LightningModule):
|
|
|
audios = audios.float()
|
|
audios = audios.float()
|
|
|
audios = audios[:, None, :]
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
|
|
- features = gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
|
|
-
|
|
|
|
|
- if self.downsample is not None:
|
|
|
|
|
- features = self.downsample(features)
|
|
|
|
|
-
|
|
|
|
|
|
|
+ gt_mels = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
|
- feature_lengths = (
|
|
|
|
|
- audio_lengths
|
|
|
|
|
- / self.hop_length
|
|
|
|
|
- / (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
|
|
- ).long()
|
|
|
|
|
-
|
|
|
|
|
- feature_masks = torch.unsqueeze(
|
|
|
|
|
- sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
|
|
- ).to(gt_mels.dtype)
|
|
|
|
|
mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
gt_mels.dtype
|
|
gt_mels.dtype
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # vq_features is 50 hz, need to convert to true mel size
|
|
|
|
|
- text_features = self.mel_encoder(features, feature_masks)
|
|
|
|
|
- text_features, _, _ = self.vq_encoder(text_features, feature_masks)
|
|
|
|
|
- text_features = F.interpolate(
|
|
|
|
|
- text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
|
|
|
|
+ vq_result = self.encode(audios, audio_lengths)
|
|
|
|
|
+ decoded = self.decode(
|
|
|
|
|
+ indices=vq_result.indices,
|
|
|
|
|
+ audio_lengths=audio_lengths,
|
|
|
|
|
+ mel_only=self.mode == "pretrain",
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # Sample mels
|
|
|
|
|
- if self.decoder is not None:
|
|
|
|
|
- speaker_features = (
|
|
|
|
|
- self.speaker_encoder(gt_mels, mel_masks)
|
|
|
|
|
- if self.speaker_encoder is not None
|
|
|
|
|
- else None
|
|
|
|
|
- )
|
|
|
|
|
- decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
|
|
- else:
|
|
|
|
|
- decoded_mels = text_features
|
|
|
|
|
|
|
+ decoded_mels = decoded.mels
|
|
|
|
|
|
|
|
- fake_audios = self.generator(decoded_mels)
|
|
|
|
|
|
|
+ # Use gt mel as input for pretrain
|
|
|
|
|
+ if self.mode == "pretrain":
|
|
|
|
|
+ fake_audios = self.generator(gt_mels)
|
|
|
|
|
+ else:
|
|
|
|
|
+ fake_audios = decoded.audios
|
|
|
|
|
|
|
|
fake_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
fake_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
|
|
|
|
|
@@ -487,3 +545,92 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
plt.close(image_mels)
|
|
plt.close(image_mels)
|
|
|
|
|
+
|
|
|
|
|
+ def encode(self, audios, audio_lengths=None):
|
|
|
|
|
+ if audio_lengths is None:
|
|
|
|
|
+ audio_lengths = torch.tensor(
|
|
|
|
|
+ [audios.shape[-1]] * audios.shape[0],
|
|
|
|
|
+ device=audios.device,
|
|
|
|
|
+ dtype=torch.long,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ with torch.no_grad():
|
|
|
|
|
+ features = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
|
|
+
|
|
|
|
|
+ if self.downsample is not None:
|
|
|
|
|
+ features = self.downsample(features)
|
|
|
|
|
+
|
|
|
|
|
+ feature_lengths = (
|
|
|
|
|
+ audio_lengths
|
|
|
|
|
+ / self.hop_length
|
|
|
|
|
+ / (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
|
|
+ ).long()
|
|
|
|
|
+
|
|
|
|
|
+ feature_masks = torch.unsqueeze(
|
|
|
|
|
+ sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
|
|
+ ).to(features.dtype)
|
|
|
|
|
+
|
|
|
|
|
+ text_features = self.mel_encoder(features, feature_masks)
|
|
|
|
|
+ vq_features, indices, loss = self.vq_encoder(text_features, feature_masks)
|
|
|
|
|
+
|
|
|
|
|
+ return VQEncodeResult(
|
|
|
|
|
+ features=vq_features,
|
|
|
|
|
+ indices=indices,
|
|
|
|
|
+ loss=loss,
|
|
|
|
|
+ feature_lengths=feature_lengths,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def calculate_audio_lengths(self, feature_lengths):
|
|
|
|
|
+ return (
|
|
|
|
|
+ feature_lengths
|
|
|
|
|
+ * self.hop_length
|
|
|
|
|
+ * (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def decode(
|
|
|
|
|
+ self,
|
|
|
|
|
+ indices=None,
|
|
|
|
|
+ features=None,
|
|
|
|
|
+ audio_lengths=None,
|
|
|
|
|
+ mel_only=False,
|
|
|
|
|
+ feature_lengths=None,
|
|
|
|
|
+ ):
|
|
|
|
|
+ assert (
|
|
|
|
|
+ indices is not None or features is not None
|
|
|
|
|
+ ), "indices or features must be provided"
|
|
|
|
|
+ assert (
|
|
|
|
|
+ feature_lengths is not None or audio_lengths is not None
|
|
|
|
|
+ ), "feature_lengths or audio_lengths must be provided"
|
|
|
|
|
+
|
|
|
|
|
+ if audio_lengths is None:
|
|
|
|
|
+ audio_lengths = self.calculate_audio_lengths(feature_lengths)
|
|
|
|
|
+
|
|
|
|
|
+ mel_lengths = audio_lengths // self.hop_length
|
|
|
|
|
+ mel_masks = torch.unsqueeze(
|
|
|
|
|
+ sequence_mask(mel_lengths, torch.max(mel_lengths)), 1
|
|
|
|
|
+ ).float()
|
|
|
|
|
+
|
|
|
|
|
+ if indices is not None:
|
|
|
|
|
+ features = self.vq_encoder.decode(indices)
|
|
|
|
|
+
|
|
|
|
|
+ features = F.interpolate(features, size=mel_masks.shape[2], mode="nearest")
|
|
|
|
|
+
|
|
|
|
|
+ # Sample mels
|
|
|
|
|
+ if self.decoder is not None:
|
|
|
|
|
+ decoded_mels = self.decoder(features, mel_masks)
|
|
|
|
|
+ else:
|
|
|
|
|
+ decoded_mels = features
|
|
|
|
|
+
|
|
|
|
|
+ if mel_only:
|
|
|
|
|
+ return VQDecodeResult(
|
|
|
|
|
+ audios=None,
|
|
|
|
|
+ mels=decoded_mels,
|
|
|
|
|
+ mel_lengths=mel_lengths,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ fake_audios = self.generator(decoded_mels)
|
|
|
|
|
+ return VQDecodeResult(
|
|
|
|
|
+ audios=fake_audios,
|
|
|
|
|
+ mels=decoded_mels,
|
|
|
|
|
+ mel_lengths=mel_lengths,
|
|
|
|
|
+ )
|