|
|
@@ -16,9 +16,20 @@ from fish_speech.models.vqgan.losses import (
|
|
|
generator_loss,
|
|
|
kl_loss,
|
|
|
)
|
|
|
+from fish_speech.models.vqgan.modules.decoder import Generator
|
|
|
from fish_speech.models.vqgan.modules.discriminator import EnsembleDiscriminator
|
|
|
-from fish_speech.models.vqgan.modules.models import SynthesizerTrn
|
|
|
-from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
|
|
|
+from fish_speech.models.vqgan.modules.encoders import (
|
|
|
+ ConvDownSampler,
|
|
|
+ SpeakerEncoder,
|
|
|
+ TextEncoder,
|
|
|
+ VQEncoder,
|
|
|
+)
|
|
|
+from fish_speech.models.vqgan.utils import (
|
|
|
+ plot_mel,
|
|
|
+ rand_slice_segments,
|
|
|
+ sequence_mask,
|
|
|
+ slice_segments,
|
|
|
+)
|
|
|
|
|
|
|
|
|
class VQGAN(L.LightningModule):
|
|
|
@@ -26,12 +37,18 @@ class VQGAN(L.LightningModule):
|
|
|
self,
|
|
|
optimizer: Callable,
|
|
|
lr_scheduler: Callable,
|
|
|
- generator: SynthesizerTrn,
|
|
|
+ downsample: ConvDownSampler,
|
|
|
+ vq_encoder: VQEncoder,
|
|
|
+ speaker_encoder: SpeakerEncoder,
|
|
|
+ text_encoder: TextEncoder,
|
|
|
+ decoder: TextEncoder,
|
|
|
+ generator: Generator,
|
|
|
discriminator: EnsembleDiscriminator,
|
|
|
mel_transform: nn.Module,
|
|
|
segment_size: int = 20480,
|
|
|
hop_length: int = 640,
|
|
|
sample_rate: int = 32000,
|
|
|
+ freeze_hifigan: bool = False,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
@@ -40,6 +57,11 @@ class VQGAN(L.LightningModule):
|
|
|
self.lr_scheduler_builder = lr_scheduler
|
|
|
|
|
|
# Generator and discriminators
|
|
|
+ self.downsample = downsample
|
|
|
+ self.vq_encoder = vq_encoder
|
|
|
+ self.speaker_encoder = speaker_encoder
|
|
|
+ self.text_encoder = text_encoder
|
|
|
+ self.decoder = decoder
|
|
|
self.generator = generator
|
|
|
self.discriminator = discriminator
|
|
|
self.mel_transform = mel_transform
|
|
|
@@ -48,13 +70,31 @@ class VQGAN(L.LightningModule):
|
|
|
self.segment_size = segment_size
|
|
|
self.hop_length = hop_length
|
|
|
self.sampling_rate = sample_rate
|
|
|
+ self.freeze_hifigan = freeze_hifigan
|
|
|
|
|
|
# Disable automatic optimization
|
|
|
self.automatic_optimization = False
|
|
|
|
|
|
+ # Stage 1: Train the VQ only
|
|
|
+ if self.freeze_hifigan:
|
|
|
+ for p in self.discriminator.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
+
|
|
|
+ for p in self.generator.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
+
|
|
|
def configure_optimizers(self):
|
|
|
# Need two optimizers and two schedulers
|
|
|
- optimizer_generator = self.optimizer_builder(self.generator.parameters())
|
|
|
+ optimizer_generator = self.optimizer_builder(
|
|
|
+ itertools.chain(
|
|
|
+ self.downsample.parameters(),
|
|
|
+ self.vq_encoder.parameters(),
|
|
|
+ self.speaker_encoder.parameters(),
|
|
|
+ self.text_encoder.parameters(),
|
|
|
+ self.decoder.parameters(),
|
|
|
+ self.generator.parameters(),
|
|
|
+ )
|
|
|
+ )
|
|
|
optimizer_discriminator = self.optimizer_builder(
|
|
|
self.discriminator.parameters()
|
|
|
)
|
|
|
@@ -85,30 +125,49 @@ class VQGAN(L.LightningModule):
|
|
|
optim_g, optim_d = self.optimizers()
|
|
|
|
|
|
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
- features, feature_lengths = batch["features"], batch["feature_lengths"]
|
|
|
- audios = audios[:, None, :]
|
|
|
|
|
|
audios = audios.float()
|
|
|
- # features = features.long()
|
|
|
+ audios = audios[:, None, :]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
gt_mels = self.mel_transform(audios)
|
|
|
- gt_mels = gt_mels[:, :, : features.shape[1]]
|
|
|
-
|
|
|
- (
|
|
|
- y_hat,
|
|
|
- ids_slice,
|
|
|
- x_mask,
|
|
|
- y_mask,
|
|
|
- (z_q, z_p),
|
|
|
- (m_p, logs_p),
|
|
|
- (m_q, logs_q),
|
|
|
- # vq_loss,
|
|
|
- ) = self.generator(features, feature_lengths, gt_mels)
|
|
|
-
|
|
|
- y_hat_mel = self.mel_transform(y_hat.squeeze(1))
|
|
|
- y_mel = slice_segments(gt_mels, ids_slice, self.segment_size // self.hop_length)
|
|
|
- y = slice_segments(audios, ids_slice * self.hop_length, self.segment_size)
|
|
|
+
|
|
|
+ if self.downsample is not None:
|
|
|
+ features = self.downsample(gt_mels)
|
|
|
+
|
|
|
+ mel_lengths = audio_lengths // self.hop_length
|
|
|
+ feature_lengths = (
|
|
|
+ audio_lengths
|
|
|
+ / self.hop_length
|
|
|
+ / (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
+ ).long()
|
|
|
+
|
|
|
+ feature_masks = torch.unsqueeze(
|
|
|
+ sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
+ ).to(gt_mels.dtype)
|
|
|
+ mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
+ gt_mels.dtype
|
|
|
+ )
|
|
|
+
|
|
|
+ speaker_features = self.speaker_encoder(features, feature_masks)
|
|
|
+
|
|
|
+ # vq_features is 50 hz, need to convert to true mel size
|
|
|
+ text_features = self.text_encoder(features, feature_masks)
|
|
|
+ text_features, loss_vq = self.vq_encoder(text_features, feature_masks)
|
|
|
+ text_features = F.interpolate(
|
|
|
+ text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Sample mels
|
|
|
+ decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
+ fake_audios = self.generator(decoded_mels)
|
|
|
+
|
|
|
+ y_hat_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
|
+
|
|
|
+ y, ids_slice = rand_slice_segments(audios, audio_lengths, self.segment_size)
|
|
|
+ y_hat = slice_segments(fake_audios, ids_slice, self.segment_size)
|
|
|
+
|
|
|
+ assert y.shape == y_hat.shape, f"{y.shape} != {y_hat.shape}"
|
|
|
|
|
|
# Discriminator
|
|
|
y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(y, y_hat.detach())
|
|
|
@@ -126,41 +185,29 @@ class VQGAN(L.LightningModule):
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
|
|
|
- optim_d.zero_grad()
|
|
|
- self.manual_backward(loss_disc_all)
|
|
|
- self.clip_gradients(
|
|
|
- optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
- )
|
|
|
- optim_d.step()
|
|
|
+ # Since we don't want to update the discriminator, we skip the backward pass
|
|
|
+ if self.freeze_hifigan is False:
|
|
|
+ optim_d.zero_grad()
|
|
|
+ self.manual_backward(loss_disc_all)
|
|
|
+ self.clip_gradients(
|
|
|
+ optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
+ )
|
|
|
+ optim_d.step()
|
|
|
|
|
|
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(y, y_hat)
|
|
|
|
|
|
with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
- loss_mel = F.l1_loss(y_mel, y_hat_mel)
|
|
|
+ loss_decoded_mel = F.l1_loss(gt_mels, decoded_mels)
|
|
|
+ loss_mel = F.l1_loss(gt_mels, y_hat_mels)
|
|
|
loss_adv, _ = generator_loss(y_d_hat_g)
|
|
|
loss_fm = feature_loss(fmap_r, fmap_g)
|
|
|
- loss_kl = kl_loss(
|
|
|
- z_p=z_p,
|
|
|
- logs_q=logs_q,
|
|
|
- m_p=m_p,
|
|
|
- logs_p=logs_p,
|
|
|
- z_mask=x_mask,
|
|
|
- )
|
|
|
|
|
|
- # Cyclical kl loss
|
|
|
- # then 500 steps linear: 0.1
|
|
|
- # then 500 steps 0.1
|
|
|
- # then go back to 0
|
|
|
+ mel_loss_weight = 25 if self.freeze_hifigan is True else 45
|
|
|
|
|
|
- if self.global_step < 100000:
|
|
|
- beta = 1e-6
|
|
|
- else:
|
|
|
- beta = self.global_step % 1000
|
|
|
- beta = min(beta, 500) / 500 * 0.1 + 1e-6
|
|
|
+ loss_gen_all = loss_mel * mel_loss_weight + loss_fm + loss_adv + loss_vq
|
|
|
|
|
|
- loss_gen_all = (
|
|
|
- loss_mel * 45 + loss_fm + loss_adv + loss_kl * beta
|
|
|
- ) # + vq_loss
|
|
|
+ if self.freeze_hifigan is True:
|
|
|
+ loss_gen_all += loss_decoded_mel * mel_loss_weight
|
|
|
|
|
|
self.log(
|
|
|
"train/generator/loss",
|
|
|
@@ -171,6 +218,15 @@ class VQGAN(L.LightningModule):
|
|
|
logger=True,
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
+ self.log(
|
|
|
+ "train/generator/loss_decoded_mel",
|
|
|
+ loss_decoded_mel,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ sync_dist=True,
|
|
|
+ )
|
|
|
self.log(
|
|
|
"train/generator/loss_mel",
|
|
|
loss_mel,
|
|
|
@@ -199,23 +255,14 @@ class VQGAN(L.LightningModule):
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
self.log(
|
|
|
- "train/generator/loss_kl",
|
|
|
- loss_kl,
|
|
|
+ "train/generator/loss_vq",
|
|
|
+ loss_vq,
|
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
- # self.log(
|
|
|
- # "train/generator/loss_vq",
|
|
|
- # vq_loss,
|
|
|
- # on_step=True,
|
|
|
- # on_epoch=False,
|
|
|
- # prog_bar=False,
|
|
|
- # logger=True,
|
|
|
- # sync_dist=True,
|
|
|
- # )
|
|
|
|
|
|
optim_g.zero_grad()
|
|
|
self.manual_backward(loss_gen_all)
|
|
|
@@ -231,25 +278,50 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
def validation_step(self, batch: Any, batch_idx: int):
|
|
|
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
- features, feature_lengths = batch["features"], batch["feature_lengths"]
|
|
|
|
|
|
audios = audios.float()
|
|
|
- # features = features.float()
|
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
gt_mels = self.mel_transform(audios)
|
|
|
- gt_mels = gt_mels[:, :, : features.shape[1]]
|
|
|
|
|
|
- fake_audios = self.generator.infer(features, feature_lengths, gt_mels)
|
|
|
- posterior_audios = self.generator.reconstruct(gt_mels, feature_lengths)
|
|
|
+ if self.downsample is not None:
|
|
|
+ features = self.downsample(gt_mels)
|
|
|
+
|
|
|
+ mel_lengths = audio_lengths // self.hop_length
|
|
|
+ feature_lengths = (
|
|
|
+ audio_lengths
|
|
|
+ / self.hop_length
|
|
|
+ / (self.downsample.total_strides if self.downsample is not None else 1)
|
|
|
+ ).long()
|
|
|
+
|
|
|
+ feature_masks = torch.unsqueeze(
|
|
|
+ sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
+ ).to(gt_mels.dtype)
|
|
|
+ mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
+ gt_mels.dtype
|
|
|
+ )
|
|
|
+
|
|
|
+ speaker_features = self.speaker_encoder(features, feature_masks)
|
|
|
+
|
|
|
+ # vq_features is 50 hz, need to convert to true mel size
|
|
|
+ text_features = self.text_encoder(features, feature_masks)
|
|
|
+ text_features, vq_loss = self.vq_encoder(text_features, feature_masks)
|
|
|
+ text_features = F.interpolate(
|
|
|
+ text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Sample mels
|
|
|
+ decoded_mels = self.decoder(text_features, mel_masks, g=speaker_features)
|
|
|
+ fake_audios = self.generator(decoded_mels)
|
|
|
|
|
|
fake_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
|
- posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
|
|
|
|
|
|
- min_mel_length = min(gt_mels.shape[-1], fake_mels.shape[-1])
|
|
|
+ min_mel_length = min(
|
|
|
+ decoded_mels.shape[-1], gt_mels.shape[-1], fake_mels.shape[-1]
|
|
|
+ )
|
|
|
+ decoded_mels = decoded_mels[:, :, :min_mel_length]
|
|
|
gt_mels = gt_mels[:, :, :min_mel_length]
|
|
|
fake_mels = fake_mels[:, :, :min_mel_length]
|
|
|
- posterior_mels = posterior_mels[:, :, :min_mel_length]
|
|
|
|
|
|
mel_loss = F.l1_loss(gt_mels, fake_mels)
|
|
|
self.log(
|
|
|
@@ -265,19 +337,17 @@ class VQGAN(L.LightningModule):
|
|
|
for idx, (
|
|
|
mel,
|
|
|
gen_mel,
|
|
|
- post_mel,
|
|
|
+ decode_mel,
|
|
|
audio,
|
|
|
gen_audio,
|
|
|
- post_audio,
|
|
|
audio_len,
|
|
|
) in enumerate(
|
|
|
zip(
|
|
|
gt_mels,
|
|
|
fake_mels,
|
|
|
- posterior_mels,
|
|
|
- audios,
|
|
|
- fake_audios,
|
|
|
- posterior_audios,
|
|
|
+ decoded_mels,
|
|
|
+ audios.detach().float(),
|
|
|
+ fake_audios.detach().float(),
|
|
|
audio_lengths,
|
|
|
)
|
|
|
):
|
|
|
@@ -286,13 +356,13 @@ class VQGAN(L.LightningModule):
|
|
|
image_mels = plot_mel(
|
|
|
[
|
|
|
gen_mel[:, :mel_len],
|
|
|
- post_mel[:, :mel_len],
|
|
|
+ decode_mel[:, :mel_len],
|
|
|
mel[:, :mel_len],
|
|
|
],
|
|
|
[
|
|
|
- "Generated Spectrogram",
|
|
|
- "Posterior Spectrogram",
|
|
|
- "Ground-Truth Spectrogram",
|
|
|
+ "Generated",
|
|
|
+ "Decoded",
|
|
|
+ "Ground-Truth",
|
|
|
],
|
|
|
)
|
|
|
|
|
|
@@ -311,11 +381,6 @@ class VQGAN(L.LightningModule):
|
|
|
sample_rate=self.sampling_rate,
|
|
|
caption="prediction",
|
|
|
),
|
|
|
- wandb.Audio(
|
|
|
- post_audio[0, :audio_len],
|
|
|
- sample_rate=self.sampling_rate,
|
|
|
- caption="posterior",
|
|
|
- ),
|
|
|
],
|
|
|
},
|
|
|
)
|
|
|
@@ -338,11 +403,5 @@ class VQGAN(L.LightningModule):
|
|
|
self.global_step,
|
|
|
sample_rate=self.sampling_rate,
|
|
|
)
|
|
|
- self.logger.experiment.add_audio(
|
|
|
- f"sample-{idx}/wavs/posterior",
|
|
|
- post_audio[0, :audio_len],
|
|
|
- self.global_step,
|
|
|
- sample_rate=self.sampling_rate,
|
|
|
- )
|
|
|
|
|
|
plt.close(image_mels)
|