|
@@ -9,7 +9,13 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
|
|
from matplotlib import pyplot as plt
|
|
from matplotlib import pyplot as plt
|
|
|
from torch import nn
|
|
from torch import nn
|
|
|
|
|
|
|
|
-from fish_speech.models.vqgan.modules import EnsembleDiscriminator, Generator, VQEncoder
|
|
|
|
|
|
|
+from fish_speech.models.vqgan.modules import (
|
|
|
|
|
+ EnsembleDiscriminator,
|
|
|
|
|
+ Generator,
|
|
|
|
|
+ PosteriorEncoder,
|
|
|
|
|
+ SemanticEncoder,
|
|
|
|
|
+ SpeakerEncoder,
|
|
|
|
|
+)
|
|
|
from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
|
|
from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
|
|
|
|
|
|
|
|
|
|
|
|
@@ -18,7 +24,10 @@ class VQGAN(L.LightningModule):
|
|
|
self,
|
|
self,
|
|
|
optimizer: Callable,
|
|
optimizer: Callable,
|
|
|
lr_scheduler: Callable,
|
|
lr_scheduler: Callable,
|
|
|
- encoder: VQEncoder,
|
|
|
|
|
|
|
+ semantic_encoder: SemanticEncoder,
|
|
|
|
|
+ posterior_encoder: PosteriorEncoder,
|
|
|
|
|
+ speaker_encoder: SpeakerEncoder,
|
|
|
|
|
+ # flow: nn.Module,
|
|
|
generator: Generator,
|
|
generator: Generator,
|
|
|
discriminator: EnsembleDiscriminator,
|
|
discriminator: EnsembleDiscriminator,
|
|
|
mel_transform: nn.Module,
|
|
mel_transform: nn.Module,
|
|
@@ -34,7 +43,10 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
|
|
# Generator and discriminators
|
|
# Generator and discriminators
|
|
|
# Compile generator so that snake can save memory
|
|
# Compile generator so that snake can save memory
|
|
|
- self.encoder = encoder
|
|
|
|
|
|
|
+ self.semantic_encoder = semantic_encoder
|
|
|
|
|
+ self.posterior_encoder = posterior_encoder
|
|
|
|
|
+ self.speaker_encoder = speaker_encoder
|
|
|
|
|
+ # self.flow = flow
|
|
|
self.generator = generator
|
|
self.generator = generator
|
|
|
self.discriminator = discriminator
|
|
self.discriminator = discriminator
|
|
|
self.mel_transform = mel_transform
|
|
self.mel_transform = mel_transform
|
|
@@ -50,7 +62,13 @@ class VQGAN(L.LightningModule):
|
|
|
def configure_optimizers(self):
|
|
def configure_optimizers(self):
|
|
|
# Need two optimizers and two schedulers
|
|
# Need two optimizers and two schedulers
|
|
|
optimizer_generator = self.optimizer_builder(
|
|
optimizer_generator = self.optimizer_builder(
|
|
|
- itertools.chain(self.encoder.parameters(), self.generator.parameters())
|
|
|
|
|
|
|
+ itertools.chain(
|
|
|
|
|
+ self.semantic_encoder.parameters(),
|
|
|
|
|
+ self.posterior_encoder.parameters(),
|
|
|
|
|
+ self.speaker_encoder.parameters(),
|
|
|
|
|
+ self.generator.parameters(),
|
|
|
|
|
+ # self.flow.parameters(),
|
|
|
|
|
+ )
|
|
|
)
|
|
)
|
|
|
optimizer_discriminator = self.optimizer_builder(
|
|
optimizer_discriminator = self.optimizer_builder(
|
|
|
self.discriminator.parameters()
|
|
self.discriminator.parameters()
|
|
@@ -117,6 +135,31 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
|
|
return loss * 2
|
|
return loss * 2
|
|
|
|
|
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def kl_loss(m_q, logs_q, m_p, logs_p, z_mask):
|
|
|
|
|
+ """
|
|
|
|
|
+ m_q, logs_q: [b, h, t_t]
|
|
|
|
|
+ m_p, logs_p: [b, h, t_t]
|
|
|
|
|
+ """
|
|
|
|
|
+ m_q = m_q.float()
|
|
|
|
|
+ logs_q = logs_q.float()
|
|
|
|
|
+ m_p = m_p.float()
|
|
|
|
|
+ logs_p = logs_p.float()
|
|
|
|
|
+ z_mask = z_mask.float()
|
|
|
|
|
+
|
|
|
|
|
+ kl = 0.5 * (
|
|
|
|
|
+ (m_q - m_p) ** 2 / torch.exp(logs_p)
|
|
|
|
|
+ + torch.exp(logs_q) / torch.exp(logs_p)
|
|
|
|
|
+ - 1
|
|
|
|
|
+ - logs_q
|
|
|
|
|
+ + logs_p
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ kl = torch.sum(kl * z_mask)
|
|
|
|
|
+ l = kl / torch.sum(z_mask)
|
|
|
|
|
+
|
|
|
|
|
+ return l
|
|
|
|
|
+
|
|
|
def training_step(self, batch, batch_idx):
|
|
def training_step(self, batch, batch_idx):
|
|
|
optim_g, optim_d = self.optimizers()
|
|
optim_g, optim_d = self.optimizers()
|
|
|
|
|
|
|
@@ -127,7 +170,8 @@ class VQGAN(L.LightningModule):
|
|
|
features = features.float()
|
|
features = features.float()
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
|
- gt_mels = self.mel_transform(audios).transpose(1, 2)
|
|
|
|
|
|
|
+ gt_mels, gt_specs = self.mel_transform(audios, return_linear=True)
|
|
|
|
|
+ gt_mels = gt_mels.transpose(1, 2)
|
|
|
key_padding_mask = sequence_mask(feature_lengths)
|
|
key_padding_mask = sequence_mask(feature_lengths)
|
|
|
mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
|
|
mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
|
|
|
audio_masks = sequence_mask(audio_lengths)[:, None]
|
|
audio_masks = sequence_mask(audio_lengths)[:, None]
|
|
@@ -135,6 +179,7 @@ class VQGAN(L.LightningModule):
|
|
|
assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
|
|
assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
|
|
|
gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
|
|
gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
|
|
|
gt_mels = gt_mels[:, :gt_mel_length]
|
|
gt_mels = gt_mels[:, :gt_mel_length]
|
|
|
|
|
+ gt_specs = gt_specs[:, :, :gt_mel_length]
|
|
|
mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
|
|
mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
|
|
|
|
|
|
|
|
assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
|
|
assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
|
|
@@ -144,39 +189,19 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
|
|
audios = audios[:, None, :]
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
|
|
- # # Get slice of audio
|
|
|
|
|
- # if audios.shape[-1] > self.segment_size:
|
|
|
|
|
- # start = torch.randint(
|
|
|
|
|
- # 0, audios.shape[-1] - self.segment_size, (1,), device=audios.device
|
|
|
|
|
- # ).item()
|
|
|
|
|
- # start = start // self.hop_length * self.hop_length
|
|
|
|
|
-
|
|
|
|
|
- # audios = audios[:, :, start : start + self.segment_size]
|
|
|
|
|
- # audio_masks = sequence_mask(audio_lengths)[
|
|
|
|
|
- # :, None, start : start + self.segment_size
|
|
|
|
|
- # ]
|
|
|
|
|
-
|
|
|
|
|
- # mel_start = start // self.hop_length
|
|
|
|
|
- # mel_size = self.segment_size // self.hop_length
|
|
|
|
|
- # gt_mels = gt_mels[:, mel_start : mel_start + mel_size]
|
|
|
|
|
- # mels_key_padding_mask = mels_key_padding_mask[
|
|
|
|
|
- # :, mel_start : mel_start + mel_size
|
|
|
|
|
- # ]
|
|
|
|
|
-
|
|
|
|
|
- # features = features[:, :, mel_start : mel_start + mel_size]
|
|
|
|
|
-
|
|
|
|
|
- # Generator
|
|
|
|
|
- encoded = self.encoder(
|
|
|
|
|
|
|
+ speaker = self.speaker_encoder(gt_mels, mels_key_padding_mask)[:, :, None]
|
|
|
|
|
+ prior = self.semantic_encoder(
|
|
|
x=features,
|
|
x=features,
|
|
|
- mels=gt_mels,
|
|
|
|
|
key_padding_mask=key_padding_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
|
- mels_key_padding_mask=mels_key_padding_mask,
|
|
|
|
|
|
|
+ g=speaker,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- features = encoded.features
|
|
|
|
|
- # features = self.naive_proj(features.transpose(1, 2))
|
|
|
|
|
-
|
|
|
|
|
- fake_audios = self.generator(features)
|
|
|
|
|
|
|
+ posterior_key_padding_mask = (~mels_key_padding_mask).float()[:, None]
|
|
|
|
|
+ posterior = self.posterior_encoder(
|
|
|
|
|
+ gt_specs, posterior_key_padding_mask, g=speaker
|
|
|
|
|
+ )
|
|
|
|
|
+ # z_p = self.flow(posterior.mean, posterior_key_padding_mask, g=speaker)
|
|
|
|
|
+ fake_audios = self.generator(posterior.z, g=speaker)
|
|
|
|
|
|
|
|
min_audio_length = min(audios.shape[-1], fake_audios.shape[-1])
|
|
min_audio_length = min(audios.shape[-1], fake_audios.shape[-1])
|
|
|
audios = audios[:, :, :min_audio_length]
|
|
audios = audios[:, :, :min_audio_length]
|
|
@@ -227,8 +252,15 @@ class VQGAN(L.LightningModule):
|
|
|
loss_mel = F.l1_loss(gt_mels, fake_mels)
|
|
loss_mel = F.l1_loss(gt_mels, fake_mels)
|
|
|
loss_adv, _ = self.generator_loss(y_d_hat_g)
|
|
loss_adv, _ = self.generator_loss(y_d_hat_g)
|
|
|
loss_fm = self.feature_loss(fmap_r, fmap_g)
|
|
loss_fm = self.feature_loss(fmap_r, fmap_g)
|
|
|
|
|
+ loss_kl = self.kl_loss(
|
|
|
|
|
+ posterior.mean,
|
|
|
|
|
+ posterior.logs,
|
|
|
|
|
+ prior.mean,
|
|
|
|
|
+ prior.logs,
|
|
|
|
|
+ posterior_key_padding_mask,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + encoded.loss
|
|
|
|
|
|
|
+ loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + prior.loss + loss_kl
|
|
|
|
|
|
|
|
self.log(
|
|
self.log(
|
|
|
"train/generator/loss",
|
|
"train/generator/loss",
|
|
@@ -266,9 +298,18 @@ class VQGAN(L.LightningModule):
|
|
|
logger=True,
|
|
logger=True,
|
|
|
sync_dist=True,
|
|
sync_dist=True,
|
|
|
)
|
|
)
|
|
|
|
|
+ self.log(
|
|
|
|
|
+ "train/generator/loss_kl",
|
|
|
|
|
+ loss_kl,
|
|
|
|
|
+ on_step=True,
|
|
|
|
|
+ on_epoch=False,
|
|
|
|
|
+ prog_bar=False,
|
|
|
|
|
+ logger=True,
|
|
|
|
|
+ sync_dist=True,
|
|
|
|
|
+ )
|
|
|
self.log(
|
|
self.log(
|
|
|
"train/generator/loss_vq",
|
|
"train/generator/loss_vq",
|
|
|
- encoded.loss,
|
|
|
|
|
|
|
+ prior.loss,
|
|
|
on_step=True,
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
on_epoch=False,
|
|
|
prog_bar=False,
|
|
prog_bar=False,
|
|
@@ -296,14 +337,15 @@ class VQGAN(L.LightningModule):
|
|
|
features = features.float()
|
|
features = features.float()
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
|
- gt_mels = self.mel_transform(audios).transpose(1, 2)
|
|
|
|
|
|
|
+ gt_mels, gt_specs = self.mel_transform(audios, return_linear=True)
|
|
|
|
|
+ gt_mels = gt_mels.transpose(1, 2)
|
|
|
key_padding_mask = sequence_mask(feature_lengths)
|
|
key_padding_mask = sequence_mask(feature_lengths)
|
|
|
mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
|
|
mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
|
|
|
- audio_masks = sequence_mask(audio_lengths)
|
|
|
|
|
|
|
|
|
|
assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
|
|
assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
|
|
|
gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
|
|
gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
|
|
|
gt_mels = gt_mels[:, :gt_mel_length]
|
|
gt_mels = gt_mels[:, :gt_mel_length]
|
|
|
|
|
+ gt_specs = gt_specs[:, :, :gt_mel_length]
|
|
|
mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
|
|
mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
|
|
|
|
|
|
|
|
assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
|
|
assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
|
|
@@ -312,37 +354,41 @@ class VQGAN(L.LightningModule):
|
|
|
key_padding_mask = key_padding_mask[:, :gt_feature_length]
|
|
key_padding_mask = key_padding_mask[:, :gt_feature_length]
|
|
|
|
|
|
|
|
# Generator
|
|
# Generator
|
|
|
- encoded = self.encoder(
|
|
|
|
|
|
|
+ # speaker: (B, C, 1)
|
|
|
|
|
+ speaker = self.speaker_encoder(gt_mels, mels_key_padding_mask)[:, :, None]
|
|
|
|
|
+ posterior_key_padding_mask = (~mels_key_padding_mask).float()[:, None]
|
|
|
|
|
+
|
|
|
|
|
+ z_gen = self.semantic_encoder(
|
|
|
x=features,
|
|
x=features,
|
|
|
- mels=gt_mels,
|
|
|
|
|
key_padding_mask=key_padding_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
|
- mels_key_padding_mask=mels_key_padding_mask,
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ g=speaker,
|
|
|
|
|
+ ).z
|
|
|
|
|
|
|
|
- # features = self.naive_proj(features.transpose(1, 2))
|
|
|
|
|
|
|
+ # z_gen = self.flow(z_gen, posterior_key_padding_mask, g=speaker, reverse=True)
|
|
|
|
|
|
|
|
- features = encoded.features
|
|
|
|
|
- audios = audios[:, None, :]
|
|
|
|
|
|
|
+ z_posterior = self.posterior_encoder(
|
|
|
|
|
+ gt_specs, posterior_key_padding_mask, g=speaker
|
|
|
|
|
+ ).mean
|
|
|
|
|
|
|
|
- fake_audios = self.generator(features)
|
|
|
|
|
- min_audio_length = min(audios.shape[-1], fake_audios.shape[-1])
|
|
|
|
|
|
|
+ audios = audios[:, None, :]
|
|
|
|
|
+ fake_audios = self.generator(z_gen, g=speaker)
|
|
|
|
|
+ posterior_audios = self.generator(z_posterior)
|
|
|
|
|
+ min_audio_length = min(
|
|
|
|
|
+ audios.shape[-1], fake_audios.shape[-1], posterior_audios.shape[-1]
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
audios = audios[:, :, :min_audio_length]
|
|
audios = audios[:, :, :min_audio_length]
|
|
|
fake_audios = fake_audios[:, :, :min_audio_length]
|
|
fake_audios = fake_audios[:, :, :min_audio_length]
|
|
|
- audio_masks = audio_masks[:, None, :min_audio_length]
|
|
|
|
|
-
|
|
|
|
|
- audio = torch.masked_fill(audios, audio_masks, 0.0)
|
|
|
|
|
- fake_audios = torch.masked_fill(fake_audios, audio_masks, 0.0)
|
|
|
|
|
- assert fake_audios.shape == audio.shape
|
|
|
|
|
|
|
+ posterior_audios = posterior_audios[:, :, :min_audio_length]
|
|
|
|
|
+ assert fake_audios.shape == audios.shape == posterior_audios.shape
|
|
|
|
|
|
|
|
fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
|
|
fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
|
|
|
|
|
+ posterior_mels = self.mel_transform(posterior_audios.squeeze(1)).transpose(1, 2)
|
|
|
|
|
+
|
|
|
min_mel_length = min(gt_mels.shape[1], fake_mels.shape[1])
|
|
min_mel_length = min(gt_mels.shape[1], fake_mels.shape[1])
|
|
|
gt_mels = gt_mels[:, :min_mel_length]
|
|
gt_mels = gt_mels[:, :min_mel_length]
|
|
|
fake_mels = fake_mels[:, :min_mel_length]
|
|
fake_mels = fake_mels[:, :min_mel_length]
|
|
|
- mels_key_padding_mask = mels_key_padding_mask[:, :min_mel_length]
|
|
|
|
|
-
|
|
|
|
|
- gt_mels = torch.masked_fill(gt_mels, mels_key_padding_mask[:, :, None], 0.0)
|
|
|
|
|
- fake_mels = torch.masked_fill(fake_mels, mels_key_padding_mask[:, :, None], 0.0)
|
|
|
|
|
|
|
+ posterior_mels = posterior_mels[:, :min_mel_length]
|
|
|
|
|
|
|
|
mel_loss = F.l1_loss(gt_mels, fake_mels)
|
|
mel_loss = F.l1_loss(gt_mels, fake_mels)
|
|
|
self.log(
|
|
self.log(
|
|
@@ -355,12 +401,22 @@ class VQGAN(L.LightningModule):
|
|
|
sync_dist=True,
|
|
sync_dist=True,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- for idx, (mel, gen_mel, audio, gen_audio, audio_len) in enumerate(
|
|
|
|
|
|
|
+ for idx, (
|
|
|
|
|
+ mel,
|
|
|
|
|
+ gen_mel,
|
|
|
|
|
+ post_mel,
|
|
|
|
|
+ audio,
|
|
|
|
|
+ gen_audio,
|
|
|
|
|
+ post_audio,
|
|
|
|
|
+ audio_len,
|
|
|
|
|
+ ) in enumerate(
|
|
|
zip(
|
|
zip(
|
|
|
gt_mels.transpose(1, 2),
|
|
gt_mels.transpose(1, 2),
|
|
|
fake_mels.transpose(1, 2),
|
|
fake_mels.transpose(1, 2),
|
|
|
|
|
+ posterior_mels.transpose(1, 2),
|
|
|
audios,
|
|
audios,
|
|
|
fake_audios,
|
|
fake_audios,
|
|
|
|
|
+ posterior_audios,
|
|
|
audio_lengths,
|
|
audio_lengths,
|
|
|
)
|
|
)
|
|
|
):
|
|
):
|
|
@@ -369,9 +425,14 @@ class VQGAN(L.LightningModule):
|
|
|
image_mels = plot_mel(
|
|
image_mels = plot_mel(
|
|
|
[
|
|
[
|
|
|
gen_mel[:, :mel_len],
|
|
gen_mel[:, :mel_len],
|
|
|
|
|
+ post_mel[:, :mel_len],
|
|
|
mel[:, :mel_len],
|
|
mel[:, :mel_len],
|
|
|
],
|
|
],
|
|
|
- ["Sampled Spectrogram", "Ground-Truth Spectrogram"],
|
|
|
|
|
|
|
+ [
|
|
|
|
|
+ "Generated Spectrogram",
|
|
|
|
|
+ "Posterior Spectrogram",
|
|
|
|
|
+ "Ground-Truth Spectrogram",
|
|
|
|
|
+ ],
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
if isinstance(self.logger, WandbLogger):
|
|
if isinstance(self.logger, WandbLogger):
|
|
@@ -389,6 +450,11 @@ class VQGAN(L.LightningModule):
|
|
|
sample_rate=self.sampling_rate,
|
|
sample_rate=self.sampling_rate,
|
|
|
caption="prediction",
|
|
caption="prediction",
|
|
|
),
|
|
),
|
|
|
|
|
+ wandb.Audio(
|
|
|
|
|
+ post_audio[0, :audio_len],
|
|
|
|
|
+ sample_rate=self.sampling_rate,
|
|
|
|
|
+ caption="posterior",
|
|
|
|
|
+ ),
|
|
|
],
|
|
],
|
|
|
},
|
|
},
|
|
|
)
|
|
)
|
|
@@ -411,5 +477,11 @@ class VQGAN(L.LightningModule):
|
|
|
self.global_step,
|
|
self.global_step,
|
|
|
sample_rate=self.sampling_rate,
|
|
sample_rate=self.sampling_rate,
|
|
|
)
|
|
)
|
|
|
|
|
+ self.logger.experiment.add_audio(
|
|
|
|
|
+ f"sample-{idx}/wavs/posterior",
|
|
|
|
|
+ post_audio[0, :audio_len],
|
|
|
|
|
+ self.global_step,
|
|
|
|
|
+ sample_rate=self.sampling_rate,
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
plt.close(image_mels)
|
|
plt.close(image_mels)
|