|
|
@@ -9,15 +9,7 @@ import wandb
|
|
|
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
|
|
from matplotlib import pyplot as plt
|
|
|
from torch import nn
|
|
|
-from torch.utils.checkpoint import checkpoint as gradient_checkpoint
|
|
|
-
|
|
|
-from fish_speech.models.vqgan.losses import (
|
|
|
- MultiResolutionSTFTLoss,
|
|
|
- discriminator_loss,
|
|
|
- feature_loss,
|
|
|
- generator_loss,
|
|
|
- kl_loss,
|
|
|
-)
|
|
|
+
|
|
|
from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
|
|
|
|
|
|
|
|
|
@@ -40,17 +32,16 @@ class VQGAN(L.LightningModule):
|
|
|
self,
|
|
|
optimizer: Callable,
|
|
|
lr_scheduler: Callable,
|
|
|
- generator: nn.Module,
|
|
|
- discriminator: nn.Module,
|
|
|
+ encoder: nn.Module,
|
|
|
+ quantizer: nn.Module,
|
|
|
+ aux_decoder: nn.Module,
|
|
|
+ reflow: nn.Module,
|
|
|
+ vocoder: nn.Module,
|
|
|
mel_transform: nn.Module,
|
|
|
- spec_transform: nn.Module,
|
|
|
- hop_length: int = 640,
|
|
|
- sample_rate: int = 32000,
|
|
|
- freeze_discriminator: bool = False,
|
|
|
- weight_mel: float = 45,
|
|
|
- weight_kl: float = 0.1,
|
|
|
+ weight_reflow: float = 1.0,
|
|
|
weight_vq: float = 1.0,
|
|
|
- weight_aux_mel: float = 20.0,
|
|
|
+ weight_aux_mel: float = 1.0,
|
|
|
+ sampling_rate: int = 44100,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
@@ -58,62 +49,54 @@ class VQGAN(L.LightningModule):
|
|
|
self.optimizer_builder = optimizer
|
|
|
self.lr_scheduler_builder = lr_scheduler
|
|
|
|
|
|
- # Generator and discriminator
|
|
|
- self.generator = generator
|
|
|
- self.discriminator = discriminator
|
|
|
+ # Modules
|
|
|
+ self.encoder = encoder
|
|
|
+ self.quantizer = quantizer
|
|
|
+ self.aux_decoder = aux_decoder
|
|
|
+ self.reflow = reflow
|
|
|
self.mel_transform = mel_transform
|
|
|
- self.spec_transform = spec_transform
|
|
|
- self.freeze_discriminator = freeze_discriminator
|
|
|
+ self.vocoder = vocoder
|
|
|
+
|
|
|
+ # Freeze vocoder
|
|
|
+ for param in self.vocoder.parameters():
|
|
|
+ param.requires_grad = False
|
|
|
|
|
|
# Loss weights
|
|
|
- self.weight_mel = weight_mel
|
|
|
- self.weight_kl = weight_kl
|
|
|
+ self.weight_reflow = weight_reflow
|
|
|
self.weight_vq = weight_vq
|
|
|
self.weight_aux_mel = weight_aux_mel
|
|
|
|
|
|
- # Other parameters
|
|
|
- self.hop_length = hop_length
|
|
|
- self.sampling_rate = sample_rate
|
|
|
-
|
|
|
- # Disable automatic optimization
|
|
|
- self.automatic_optimization = False
|
|
|
+ self.spec_min = -12
|
|
|
+ self.spec_max = 3
|
|
|
+ self.sampling_rate = sampling_rate
|
|
|
|
|
|
- if self.freeze_discriminator:
|
|
|
- for p in self.discriminator.parameters():
|
|
|
- p.requires_grad = False
|
|
|
+ def on_save_checkpoint(self, checkpoint):
|
|
|
+ # Do not save vocoder
|
|
|
+ state_dict = checkpoint["state_dict"]
|
|
|
+ for name in list(state_dict.keys()):
|
|
|
+ if "vocoder" in name:
|
|
|
+ state_dict.pop(name)
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
# Need two optimizers and two schedulers
|
|
|
- optimizer_generator = self.optimizer_builder(self.generator.parameters())
|
|
|
- optimizer_discriminator = self.optimizer_builder(
|
|
|
- self.discriminator.parameters()
|
|
|
- )
|
|
|
-
|
|
|
- lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
|
|
|
- lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
|
|
|
-
|
|
|
- return (
|
|
|
- {
|
|
|
- "optimizer": optimizer_generator,
|
|
|
- "lr_scheduler": {
|
|
|
- "scheduler": lr_scheduler_generator,
|
|
|
- "interval": "step",
|
|
|
- "name": "optimizer/generator",
|
|
|
- },
|
|
|
+ optimizer = self.optimizer_builder(self.parameters())
|
|
|
+ lr_scheduler = self.lr_scheduler_builder(optimizer)
|
|
|
+
|
|
|
+ return {
|
|
|
+ "optimizer": optimizer,
|
|
|
+ "lr_scheduler": {
|
|
|
+ "scheduler": lr_scheduler,
|
|
|
+ "interval": "step",
|
|
|
},
|
|
|
- {
|
|
|
- "optimizer": optimizer_discriminator,
|
|
|
- "lr_scheduler": {
|
|
|
- "scheduler": lr_scheduler_discriminator,
|
|
|
- "interval": "step",
|
|
|
- "name": "optimizer/discriminator",
|
|
|
- },
|
|
|
- },
|
|
|
- )
|
|
|
+ }
|
|
|
|
|
|
- def training_step(self, batch, batch_idx):
|
|
|
- optim_g, optim_d = self.optimizers()
|
|
|
+ def norm_spec(self, x):
|
|
|
+ return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
|
|
|
|
|
+ def denorm_spec(self, x):
|
|
|
+ return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
|
|
+
|
|
|
+ def training_step(self, batch, batch_idx):
|
|
|
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
|
|
|
audios = audios.float()
|
|
|
@@ -121,173 +104,84 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
with torch.no_grad():
|
|
|
gt_mels = self.mel_transform(audios)
|
|
|
- gt_specs = self.spec_transform(audios)
|
|
|
-
|
|
|
- spec_lengths = audio_lengths // self.hop_length
|
|
|
- spec_masks = torch.unsqueeze(
|
|
|
- sequence_mask(spec_lengths, gt_mels.shape[2]), 1
|
|
|
- ).to(gt_mels.dtype)
|
|
|
- (
|
|
|
- fake_audios,
|
|
|
- ids_slice,
|
|
|
- y_mask,
|
|
|
- y_mask,
|
|
|
- (z, z_p, m_p, logs_p, m_q, logs_q),
|
|
|
- loss_vq,
|
|
|
- decoded_aux_mels,
|
|
|
- ) = self.generator(gt_specs, spec_lengths)
|
|
|
|
|
|
- gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
|
|
|
- decoded_aux_mels = slice_segments(
|
|
|
- decoded_aux_mels, ids_slice, self.generator.segment_size
|
|
|
- )
|
|
|
- spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
|
|
|
- audios = slice_segments(
|
|
|
- audios,
|
|
|
- ids_slice * self.hop_length,
|
|
|
- self.generator.segment_size * self.hop_length,
|
|
|
- )
|
|
|
- fake_mels = self.mel_transform(fake_audios.squeeze(1))
|
|
|
+ mel_lengths = audio_lengths // self.mel_transform.hop_length
|
|
|
+ mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
|
|
+ mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
|
|
|
|
- assert (
|
|
|
- audios.shape == fake_audios.shape
|
|
|
- ), f"{audios.shape} != {fake_audios.shape}"
|
|
|
-
|
|
|
- # Discriminator
|
|
|
- if self.freeze_discriminator is False:
|
|
|
- y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(
|
|
|
- audios, fake_audios.detach()
|
|
|
- )
|
|
|
+ # Encode
|
|
|
+ encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
|
|
|
|
|
|
- with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
- loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
|
|
-
|
|
|
- self.log(
|
|
|
- f"train/discriminator/loss",
|
|
|
- loss_disc,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
+ # Quantize
|
|
|
+ vq_result = self.quantizer(encoded_features)
|
|
|
+ loss_vq = getattr("vq_result", "loss", 0.0)
|
|
|
+ vq_recon_features = vq_result.z * mel_masks_float_conv
|
|
|
|
|
|
- optim_d.zero_grad()
|
|
|
- self.manual_backward(loss_disc)
|
|
|
- self.clip_gradients(
|
|
|
- optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
- )
|
|
|
- optim_d.step()
|
|
|
+ # VQ Decode
|
|
|
+ aux_mel = self.aux_decoder(vq_recon_features)
|
|
|
+ loss_aux_mel = F.l1_loss(
|
|
|
+ aux_mel * mel_masks_float_conv, gt_mels * mel_masks_float_conv
|
|
|
+ )
|
|
|
|
|
|
- # Adv Loss
|
|
|
- y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios)
|
|
|
+ # Reflow
|
|
|
+ x_1 = self.norm_spec(gt_mels.mT)
|
|
|
+ t = torch.rand(gt_mels.shape[0], device=gt_mels.device)
|
|
|
+ x_0 = torch.randn_like(x_1)
|
|
|
|
|
|
- # Adversarial Loss
|
|
|
- with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
- loss_adv, _ = generator_loss(y_d_hat_g)
|
|
|
+ # X_t = t * X_1 + (1 - t) * X_0
|
|
|
+ x_t = x_0 + t[:, None, None] * (x_1 - x_0)
|
|
|
|
|
|
- self.log(
|
|
|
- f"train/generator/adv",
|
|
|
- loss_adv,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
+ v_pred = self.reflow(
|
|
|
+ x_t,
|
|
|
+ 1000 * t,
|
|
|
+ condition=vq_recon_features.mT,
|
|
|
+ self_mask=mel_masks,
|
|
|
)
|
|
|
|
|
|
- with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
- loss_fm = feature_loss(y_d_hat_r, y_d_hat_g)
|
|
|
-
|
|
|
- self.log(
|
|
|
- f"train/generator/adv_fm",
|
|
|
- loss_fm,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
+ # Log L2 loss with
|
|
|
+ weights = 0.398942 / t / (1 - t) * torch.exp(-0.5 * torch.log(t / (1 - t)) ** 2)
|
|
|
+ loss_reflow = weights[:, None, None] * F.mse_loss(
|
|
|
+ x_1 - x_0, v_pred, reduction="none"
|
|
|
)
|
|
|
+ loss_reflow = (loss_reflow * mel_masks_float_conv.mT).mean()
|
|
|
|
|
|
- with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
- loss_mel = F.l1_loss(gt_mels * spec_masks, fake_mels * spec_masks)
|
|
|
- loss_aux_mel = F.l1_loss(
|
|
|
- gt_mels * spec_masks, decoded_aux_mels * spec_masks
|
|
|
- )
|
|
|
-
|
|
|
- self.log(
|
|
|
- "train/generator/loss_mel",
|
|
|
- loss_mel,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
+ # Total loss
|
|
|
+ loss = (
|
|
|
+ self.weight_vq * loss_vq
|
|
|
+ + self.weight_aux_mel * loss_aux_mel
|
|
|
+ + self.weight_reflow * loss_reflow
|
|
|
)
|
|
|
|
|
|
+ # Log losses
|
|
|
self.log(
|
|
|
- "train/generator/loss_aux_mel",
|
|
|
- loss_aux_mel,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
+ "train/loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True
|
|
|
)
|
|
|
-
|
|
|
self.log(
|
|
|
- "train/generator/loss_vq",
|
|
|
+ "train/loss_vq",
|
|
|
loss_vq,
|
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
- sync_dist=True,
|
|
|
)
|
|
|
-
|
|
|
- loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
|
|
|
-
|
|
|
self.log(
|
|
|
- "train/generator/loss_kl",
|
|
|
- loss_kl,
|
|
|
+ "train/loss_aux_mel",
|
|
|
+ loss_aux_mel,
|
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
-
|
|
|
- loss = (
|
|
|
- loss_mel * self.weight_mel
|
|
|
- + loss_aux_mel * self.weight_aux_mel
|
|
|
- + loss_vq * self.weight_vq
|
|
|
- + loss_kl * self.weight_kl
|
|
|
- + loss_adv
|
|
|
- + loss_fm
|
|
|
)
|
|
|
self.log(
|
|
|
- "train/generator/loss",
|
|
|
- loss,
|
|
|
+ "train/loss_reflow",
|
|
|
+ loss_reflow,
|
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
-
|
|
|
- # Backward
|
|
|
- optim_g.zero_grad()
|
|
|
-
|
|
|
- self.manual_backward(loss)
|
|
|
- self.clip_gradients(
|
|
|
- optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
)
|
|
|
- optim_g.step()
|
|
|
|
|
|
- # Manual LR Scheduler
|
|
|
- scheduler_g, scheduler_d = self.lr_schedulers()
|
|
|
- scheduler_g.step()
|
|
|
- scheduler_d.step()
|
|
|
+ return loss
|
|
|
|
|
|
def validation_step(self, batch: Any, batch_idx: int):
|
|
|
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
@@ -296,32 +190,25 @@ class VQGAN(L.LightningModule):
|
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
gt_mels = self.mel_transform(audios)
|
|
|
- gt_specs = self.spec_transform(audios)
|
|
|
- spec_lengths = audio_lengths // self.hop_length
|
|
|
- spec_masks = torch.unsqueeze(
|
|
|
- sequence_mask(spec_lengths, gt_mels.shape[2]), 1
|
|
|
- ).to(gt_mels.dtype)
|
|
|
-
|
|
|
- prior_audios, _, _ = self.generator.infer(gt_specs, spec_lengths)
|
|
|
- posterior_audios, _, _ = self.generator.infer_posterior(gt_specs, spec_lengths)
|
|
|
- prior_mels = self.mel_transform(prior_audios.squeeze(1))
|
|
|
- posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
|
|
|
-
|
|
|
- min_mel_length = min(
|
|
|
- gt_mels.shape[-1], prior_mels.shape[-1], posterior_mels.shape[-1]
|
|
|
- )
|
|
|
- gt_mels = gt_mels[:, :, :min_mel_length]
|
|
|
- prior_mels = prior_mels[:, :, :min_mel_length]
|
|
|
- posterior_mels = posterior_mels[:, :, :min_mel_length]
|
|
|
+ mel_lengths = audio_lengths // self.mel_transform.hop_length
|
|
|
+ mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
|
|
+ mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
|
+
|
|
|
+ # Encode
|
|
|
+ encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
|
|
|
|
|
|
- prior_mel_loss = F.l1_loss(gt_mels * spec_masks, prior_mels * spec_masks)
|
|
|
- posterior_mel_loss = F.l1_loss(
|
|
|
- gt_mels * spec_masks, posterior_mels * spec_masks
|
|
|
+ # Quantize
|
|
|
+ vq_result = self.quantizer(encoded_features)
|
|
|
+
|
|
|
+ # VQ Decode
|
|
|
+ aux_mels = self.aux_decoder(vq_result.z)
|
|
|
+ loss_aux_mel = F.l1_loss(
|
|
|
+ aux_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
|
|
|
)
|
|
|
|
|
|
self.log(
|
|
|
- "val/prior_mel_loss",
|
|
|
- prior_mel_loss,
|
|
|
+ "val/loss_aux_mel",
|
|
|
+ loss_aux_mel,
|
|
|
on_step=False,
|
|
|
on_epoch=True,
|
|
|
prog_bar=False,
|
|
|
@@ -329,9 +216,33 @@ class VQGAN(L.LightningModule):
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
|
|
|
+ # Reflow inference
|
|
|
+ t_start = 0.0
|
|
|
+ infer_step = 20
|
|
|
+ gen_mels = torch.randn(gt_mels.shape, device=gt_mels.device).mT
|
|
|
+ t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
|
|
|
+ dt = (1.0 - t_start) / infer_step
|
|
|
+
|
|
|
+ for _ in range(infer_step):
|
|
|
+ gen_mels += (
|
|
|
+ self.reflow(
|
|
|
+ gen_mels,
|
|
|
+ 1000 * t,
|
|
|
+ condition=vq_result.z.mT,
|
|
|
+ self_mask=mel_masks,
|
|
|
+ )
|
|
|
+ * dt
|
|
|
+ )
|
|
|
+ t += dt
|
|
|
+
|
|
|
+ gen_mels = self.denorm_spec(gen_mels).mT
|
|
|
+ loss_recon_reflow = F.l1_loss(
|
|
|
+ gen_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
|
|
|
+ )
|
|
|
+
|
|
|
self.log(
|
|
|
- "val/posterior_mel_loss",
|
|
|
- posterior_mel_loss,
|
|
|
+ "val/loss_recon_reflow",
|
|
|
+ loss_recon_reflow,
|
|
|
on_step=False,
|
|
|
on_epoch=True,
|
|
|
prog_bar=False,
|
|
|
@@ -339,41 +250,47 @@ class VQGAN(L.LightningModule):
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
|
|
|
+ gen_audios = self.vocoder(gen_mels)
|
|
|
+ recon_audios = self.vocoder(gt_mels)
|
|
|
+ aux_audios = self.vocoder(aux_mels)
|
|
|
+
|
|
|
# only log the first batch
|
|
|
if batch_idx != 0:
|
|
|
return
|
|
|
|
|
|
for idx, (
|
|
|
- mel,
|
|
|
- prior_mel,
|
|
|
- posterior_mel,
|
|
|
+ gt_mel,
|
|
|
+ reflow_mel,
|
|
|
+ aux_mel,
|
|
|
audio,
|
|
|
- prior_audio,
|
|
|
- posterior_audio,
|
|
|
+ reflow_audio,
|
|
|
+ aux_audio,
|
|
|
+ recon_audio,
|
|
|
audio_len,
|
|
|
) in enumerate(
|
|
|
zip(
|
|
|
gt_mels,
|
|
|
- prior_mels,
|
|
|
- posterior_mels,
|
|
|
- audios.detach().float(),
|
|
|
- prior_audios.detach().float(),
|
|
|
- posterior_audios.detach().float(),
|
|
|
+ gen_mels,
|
|
|
+ aux_mels,
|
|
|
+ audios.float(),
|
|
|
+ gen_audios.float(),
|
|
|
+ aux_audios.float(),
|
|
|
+ recon_audios.float(),
|
|
|
audio_lengths,
|
|
|
)
|
|
|
):
|
|
|
- mel_len = audio_len // self.hop_length
|
|
|
+ mel_len = audio_len // self.mel_transform.hop_length
|
|
|
|
|
|
image_mels = plot_mel(
|
|
|
[
|
|
|
- prior_mel[:, :mel_len],
|
|
|
- posterior_mel[:, :mel_len],
|
|
|
- mel[:, :mel_len],
|
|
|
+ gt_mel[:, :mel_len],
|
|
|
+ reflow_mel[:, :mel_len],
|
|
|
+ aux_mel[:, :mel_len],
|
|
|
],
|
|
|
[
|
|
|
- "Prior (VQ)",
|
|
|
- "Posterior (Reconstruction)",
|
|
|
"Ground-Truth",
|
|
|
+ "Reflow",
|
|
|
+ "Aux",
|
|
|
],
|
|
|
)
|
|
|
|
|
|
@@ -388,14 +305,19 @@ class VQGAN(L.LightningModule):
|
|
|
caption="gt",
|
|
|
),
|
|
|
wandb.Audio(
|
|
|
- prior_audio[0, :audio_len],
|
|
|
+ reflow_audio[0, :audio_len],
|
|
|
sample_rate=self.sampling_rate,
|
|
|
- caption="prior",
|
|
|
+ caption="reflow",
|
|
|
),
|
|
|
wandb.Audio(
|
|
|
- posterior_audio[0, :audio_len],
|
|
|
+ aux_audio[0, :audio_len],
|
|
|
sample_rate=self.sampling_rate,
|
|
|
- caption="posterior",
|
|
|
+ caption="aux",
|
|
|
+ ),
|
|
|
+ wandb.Audio(
|
|
|
+ recon_audio[0, :audio_len],
|
|
|
+ sample_rate=self.sampling_rate,
|
|
|
+ caption="recon",
|
|
|
),
|
|
|
],
|
|
|
},
|
|
|
@@ -414,91 +336,22 @@ class VQGAN(L.LightningModule):
|
|
|
sample_rate=self.sampling_rate,
|
|
|
)
|
|
|
self.logger.experiment.add_audio(
|
|
|
- f"sample-{idx}/wavs/prior",
|
|
|
- prior_audio[0, :audio_len],
|
|
|
+ f"sample-{idx}/wavs/reflow",
|
|
|
+ reflow_audio[0, :audio_len],
|
|
|
self.global_step,
|
|
|
sample_rate=self.sampling_rate,
|
|
|
)
|
|
|
self.logger.experiment.add_audio(
|
|
|
- f"sample-{idx}/wavs/posterior",
|
|
|
- posterior_audio[0, :audio_len],
|
|
|
+ f"sample-{idx}/wavs/aux",
|
|
|
+ aux_audio[0, :audio_len],
|
|
|
+ self.global_step,
|
|
|
+ sample_rate=self.sampling_rate,
|
|
|
+ )
|
|
|
+ self.logger.experiment.add_audio(
|
|
|
+ f"sample-{idx}/wavs/recon",
|
|
|
+ recon_audio[0, :audio_len],
|
|
|
self.global_step,
|
|
|
sample_rate=self.sampling_rate,
|
|
|
)
|
|
|
|
|
|
plt.close(image_mels)
|
|
|
-
|
|
|
- # def encode(self, audios, audio_lengths=None):
|
|
|
- # if audio_lengths is None:
|
|
|
- # audio_lengths = torch.tensor(
|
|
|
- # [audios.shape[-1]] * audios.shape[0],
|
|
|
- # device=audios.device,
|
|
|
- # dtype=torch.long,
|
|
|
- # )
|
|
|
-
|
|
|
- # with torch.no_grad():
|
|
|
- # features = self.mel_transform(audios, sample_rate=self.sampling_rate)
|
|
|
-
|
|
|
- # feature_lengths = (
|
|
|
- # audio_lengths
|
|
|
- # / self.hop_length
|
|
|
- # # / self.vq.downsample
|
|
|
- # ).long()
|
|
|
-
|
|
|
- # # print(features.shape, feature_lengths.shape, torch.max(feature_lengths))
|
|
|
-
|
|
|
- # feature_masks = torch.unsqueeze(
|
|
|
- # sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
- # ).to(features.dtype)
|
|
|
-
|
|
|
- # features = (
|
|
|
- # gradient_checkpoint(
|
|
|
- # self.encoder, features, feature_masks, use_reentrant=False
|
|
|
- # )
|
|
|
- # * feature_masks
|
|
|
- # )
|
|
|
- # vq_features, indices, loss = self.vq(features, feature_masks)
|
|
|
-
|
|
|
- # return VQEncodeResult(
|
|
|
- # features=vq_features,
|
|
|
- # indices=indices,
|
|
|
- # loss=loss,
|
|
|
- # feature_lengths=feature_lengths,
|
|
|
- # )
|
|
|
-
|
|
|
- # def calculate_audio_lengths(self, feature_lengths):
|
|
|
- # return feature_lengths * self.hop_length * self.vq.downsample
|
|
|
-
|
|
|
- # def decode(
|
|
|
- # self,
|
|
|
- # indices=None,
|
|
|
- # features=None,
|
|
|
- # audio_lengths=None,
|
|
|
- # feature_lengths=None,
|
|
|
- # return_audios=False,
|
|
|
- # ):
|
|
|
- # assert (
|
|
|
- # indices is not None or features is not None
|
|
|
- # ), "indices or features must be provided"
|
|
|
- # assert (
|
|
|
- # feature_lengths is not None or audio_lengths is not None
|
|
|
- # ), "feature_lengths or audio_lengths must be provided"
|
|
|
-
|
|
|
- # if audio_lengths is None:
|
|
|
- # audio_lengths = self.calculate_audio_lengths(feature_lengths)
|
|
|
-
|
|
|
- # mel_lengths = audio_lengths // self.hop_length
|
|
|
- # mel_masks = torch.unsqueeze(
|
|
|
- # sequence_mask(mel_lengths, torch.max(mel_lengths)), 1
|
|
|
- # ).float()
|
|
|
-
|
|
|
- # if indices is not None:
|
|
|
- # features = self.vq.decode(indices)
|
|
|
-
|
|
|
- # # Sample mels
|
|
|
- # decoded = gradient_checkpoint(self.decoder, features, use_reentrant=False)
|
|
|
-
|
|
|
- # return VQDecodeResult(
|
|
|
- # mels=decoded,
|
|
|
- # audios=self.generator(decoded) if return_audios else None,
|
|
|
- # )
|