|
|
@@ -1,10 +1,16 @@
|
|
|
+import itertools
|
|
|
from typing import Any, Callable
|
|
|
|
|
|
import lightning as L
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
+import wandb
|
|
|
+from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
|
|
+from matplotlib import pyplot as plt
|
|
|
from torch import nn
|
|
|
-from torch.utils.checkpoint import checkpoint as gradient_checkpointing
|
|
|
+
|
|
|
+from fish_speech.models.vqgan.modules import EnsembleDiscriminator, Generator, VQEncoder
|
|
|
+from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
|
|
|
|
|
|
|
|
|
class VQGAN(L.LightningModule):
|
|
|
@@ -12,11 +18,13 @@ class VQGAN(L.LightningModule):
|
|
|
self,
|
|
|
optimizer: Callable,
|
|
|
lr_scheduler: Callable,
|
|
|
- encoder: nn.Module,
|
|
|
- generator: nn.Module,
|
|
|
- discriminator: nn.Module,
|
|
|
+ encoder: VQEncoder,
|
|
|
+ generator: Generator,
|
|
|
+ discriminator: EnsembleDiscriminator,
|
|
|
mel_transform: nn.Module,
|
|
|
segment_size: int = 20480,
|
|
|
+ hop_length: int = 640,
|
|
|
+ sample_rate: int = 32000,
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
@@ -33,15 +41,19 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
# Crop length for saving memory
|
|
|
self.segment_size = segment_size
|
|
|
+ self.hop_length = hop_length
|
|
|
+ self.sampling_rate = sample_rate
|
|
|
|
|
|
# Disable automatic optimization
|
|
|
self.automatic_optimization = False
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
# Need two optimizers and two schedulers
|
|
|
- optimizer_generator = self.optimizer_builder(self.generator.parameters())
|
|
|
+ optimizer_generator = self.optimizer_builder(
|
|
|
+ itertools.chain(self.encoder.parameters(), self.generator.parameters())
|
|
|
+ )
|
|
|
optimizer_discriminator = self.optimizer_builder(
|
|
|
- self.discriminators.parameters()
|
|
|
+ self.discriminator.parameters()
|
|
|
)
|
|
|
|
|
|
lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
|
|
|
@@ -66,109 +78,112 @@ class VQGAN(L.LightningModule):
|
|
|
},
|
|
|
)
|
|
|
|
|
|
- def training_generator(self, audio, audio_mask):
|
|
|
- # fake_audio, base_loss = self.forward(audio, audio_mask)
|
|
|
+ @staticmethod
|
|
|
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
|
|
+ loss = 0
|
|
|
+ r_losses = []
|
|
|
+ g_losses = []
|
|
|
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
|
|
+ dr = dr.float()
|
|
|
+ dg = dg.float()
|
|
|
+ r_loss = torch.mean((1 - dr) ** 2)
|
|
|
+ g_loss = torch.mean(dg**2)
|
|
|
+ loss += r_loss + g_loss
|
|
|
+ r_losses.append(r_loss.item())
|
|
|
+ g_losses.append(g_loss.item())
|
|
|
+
|
|
|
+ return loss, r_losses, g_losses
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def generator_loss(disc_outputs):
|
|
|
+ loss = 0
|
|
|
+ gen_losses = []
|
|
|
+ for dg in disc_outputs:
|
|
|
+ dg = dg.float()
|
|
|
+ l = torch.mean((1 - dg) ** 2)
|
|
|
+ gen_losses.append(l)
|
|
|
+ loss += l
|
|
|
+
|
|
|
+ return loss, gen_losses
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def feature_loss(fmap_r, fmap_g):
|
|
|
+ loss = 0
|
|
|
+ for dr, dg in zip(fmap_r, fmap_g):
|
|
|
+ for rl, gl in zip(dr, dg):
|
|
|
+ rl = rl.float().detach()
|
|
|
+ gl = gl.float()
|
|
|
+ loss += torch.mean(torch.abs(rl - gl))
|
|
|
+
|
|
|
+ return loss * 2
|
|
|
|
|
|
- assert fake_audio.shape == audio.shape
|
|
|
+ def training_step(self, batch, batch_idx):
|
|
|
+ optim_g, optim_d = self.optimizers()
|
|
|
|
|
|
- # Apply mask
|
|
|
- audio = audio * audio_mask
|
|
|
- fake_audio = fake_audio * audio_mask
|
|
|
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
+ features, feature_lengths = batch["features"], batch["feature_lengths"]
|
|
|
|
|
|
- # Multi-Resolution STFT Loss
|
|
|
- sc_loss, mag_loss = self.multi_resolution_stft_loss(
|
|
|
- fake_audio.squeeze(1), audio.squeeze(1)
|
|
|
- )
|
|
|
- loss_stft = sc_loss + mag_loss
|
|
|
+ with torch.no_grad():
|
|
|
+ gt_mels = self.mel_transform(audios).transpose(1, 2)
|
|
|
+ key_padding_mask = sequence_mask(feature_lengths)
|
|
|
+ mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
|
|
|
|
|
|
- self.log(
|
|
|
- "train/generator/stft",
|
|
|
- loss_stft,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=True,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
+ assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
|
|
|
+ gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
|
|
|
+ gt_mels = gt_mels[:, :gt_mel_length]
|
|
|
+ mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
|
|
|
|
|
|
- # L1 Mel-Spectrogram Loss
|
|
|
- # This is not used in backpropagation currently
|
|
|
- audio_mel = self.mel_transforms.loss(audio.squeeze(1))
|
|
|
- fake_audio_mel = self.mel_transforms.loss(fake_audio.squeeze(1))
|
|
|
- loss_mel = F.l1_loss(audio_mel, fake_audio_mel)
|
|
|
+ assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
|
|
|
+ gt_feature_length = min(features.shape[1], key_padding_mask.shape[1])
|
|
|
+ features = features[:, :gt_feature_length]
|
|
|
+ key_padding_mask = key_padding_mask[:, :gt_feature_length]
|
|
|
|
|
|
- self.log(
|
|
|
- "train/generator/mel",
|
|
|
- loss_mel,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=True,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
+ # Generator
|
|
|
+ encoded = self.encoder(
|
|
|
+ x=features,
|
|
|
+ mels=gt_mels,
|
|
|
+ key_padding_mask=key_padding_mask,
|
|
|
+ mels_key_padding_mask=mels_key_padding_mask,
|
|
|
)
|
|
|
|
|
|
- # Now, we need to reduce the length of the audio to save memory
|
|
|
- if self.crop_length is not None and audio.shape[2] > self.crop_length:
|
|
|
- slice_idx = torch.randint(0, audio.shape[-1] - self.crop_length, (1,))
|
|
|
+ features = encoded.features
|
|
|
+ audios = audios[:, None, :]
|
|
|
|
|
|
- audio = audio[..., slice_idx : slice_idx + self.crop_length]
|
|
|
- fake_audio = fake_audio[..., slice_idx : slice_idx + self.crop_length]
|
|
|
- audio_mask = audio_mask[..., slice_idx : slice_idx + self.crop_length]
|
|
|
-
|
|
|
- assert audio.shape == fake_audio.shape == audio_mask.shape
|
|
|
-
|
|
|
- # Adv Loss
|
|
|
- loss_adv_all = 0
|
|
|
-
|
|
|
- for key, disc in self.discriminators.items():
|
|
|
- score_fakes, feat_fake = disc(fake_audio)
|
|
|
-
|
|
|
- # Adversarial Loss
|
|
|
- score_fakes = torch.cat(score_fakes, dim=1)
|
|
|
- loss_fake = torch.mean((1 - score_fakes) ** 2)
|
|
|
-
|
|
|
- self.log(
|
|
|
- f"train/generator/adv_{key}",
|
|
|
- loss_fake,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
+ # Get slice of audio
|
|
|
+ if audios.shape[-1] > self.segment_size:
|
|
|
+ start = torch.randint(
|
|
|
+ 0, audios.shape[-1] - self.segment_size, (1,), device=audios.device
|
|
|
+ ).item()
|
|
|
+ start = start // self.hop_length * self.hop_length
|
|
|
|
|
|
- loss_adv_all += loss_fake
|
|
|
+ audios = audios[:, :, start : start + self.segment_size]
|
|
|
+ audio_masks = sequence_mask(audio_lengths)[
|
|
|
+ :, None, start : start + self.segment_size
|
|
|
+ ]
|
|
|
|
|
|
- if self.feature_matching is False:
|
|
|
- continue
|
|
|
+ mel_start = start // self.hop_length
|
|
|
+ mel_size = self.segment_size // self.hop_length
|
|
|
+ gt_mels = gt_mels[:, mel_start : mel_start + mel_size]
|
|
|
+ mels_key_padding_mask = mels_key_padding_mask[
|
|
|
+ :, mel_start : mel_start + mel_size
|
|
|
+ ]
|
|
|
|
|
|
- # Feature Matching Loss
|
|
|
- _, feat_real = disc(audio)
|
|
|
- loss_fm = 0
|
|
|
- for dr, dg in zip(feat_real, feat_fake):
|
|
|
- for rl, gl in zip(dr, dg):
|
|
|
- loss_fm += F.l1_loss(rl, gl)
|
|
|
+ features = features[:, :, mel_start : mel_start + mel_size]
|
|
|
|
|
|
- loss_fm /= len(feat_real)
|
|
|
+ fake_audios = self.generator(features)
|
|
|
+ audio = torch.masked_fill(audios, audio_masks, 0.0)
|
|
|
+ fake_audios = torch.masked_fill(fake_audios, audio_masks, 0.0)
|
|
|
+ assert fake_audios.shape == audio.shape
|
|
|
|
|
|
- self.log(
|
|
|
- f"train/generator/adv_fm_{key}",
|
|
|
- loss_fm,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
-
|
|
|
- loss_adv_all += loss_fm
|
|
|
+ # Discriminator
|
|
|
+ y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audio, fake_audios.detach())
|
|
|
|
|
|
- loss_adv_all /= len(self.discriminators)
|
|
|
- loss_gen_all = base_loss + loss_stft * 2.5 + loss_mel * 45 + loss_adv_all
|
|
|
+ with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
+ loss_disc_all, _, _ = self.discriminator_loss(y_d_hat_r, y_d_hat_g)
|
|
|
|
|
|
self.log(
|
|
|
- "train/generator/all",
|
|
|
- loss_gen_all,
|
|
|
+ "train/discriminator/loss",
|
|
|
+ loss_disc_all,
|
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
|
prog_bar=True,
|
|
|
@@ -176,99 +191,79 @@ class VQGAN(L.LightningModule):
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
|
|
|
- return loss_gen_all, audio, fake_audio
|
|
|
+ optim_d.zero_grad()
|
|
|
+ self.manual_backward(loss_disc_all)
|
|
|
+ self.clip_gradients(
|
|
|
+ optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
+ )
|
|
|
+ optim_d.step()
|
|
|
|
|
|
- def training_discriminator(self, audio, fake_audio):
|
|
|
- loss_disc_all = 0
|
|
|
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.discriminator(audios, fake_audios)
|
|
|
+ fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
|
|
|
|
|
|
- for key, disc in self.discriminators.items():
|
|
|
- if self.training and self.checkpointing:
|
|
|
- scores, _ = gradient_checkpointing(disc, audio, use_reentrant=False)
|
|
|
- score_fakes, _ = gradient_checkpointing(
|
|
|
- disc, fake_audio.detach(), use_reentrant=False
|
|
|
- )
|
|
|
- else:
|
|
|
- scores, _ = disc(audio)
|
|
|
- score_fakes, _ = disc(fake_audio.detach())
|
|
|
-
|
|
|
- scores = torch.cat(scores, dim=1)
|
|
|
- score_fakes = torch.cat(score_fakes, dim=1)
|
|
|
- loss_disc = torch.mean((scores - 1) ** 2) + torch.mean((score_fakes) ** 2)
|
|
|
-
|
|
|
- self.log(
|
|
|
- f"train/discriminator/{key}",
|
|
|
- loss_disc,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
+ # Fill mel mask
|
|
|
+ fake_mels = torch.masked_fill(fake_mels, mels_key_padding_mask[:, :, None], 0.0)
|
|
|
+ gt_mels = torch.masked_fill(gt_mels, mels_key_padding_mask[:, :, None], 0.0)
|
|
|
|
|
|
- loss_disc_all += loss_disc
|
|
|
+ with torch.autocast(device_type=audios.device.type, enabled=False):
|
|
|
+ loss_mel = F.l1_loss(gt_mels, fake_mels)
|
|
|
+ loss_adv, _ = self.generator_loss(y_d_hat_g)
|
|
|
+ loss_fm = self.feature_loss(fmap_r, fmap_g)
|
|
|
|
|
|
- loss_disc_all /= len(self.discriminators)
|
|
|
+ loss_gen_all = loss_fm * 45 + loss_mel + loss_adv + encoded.loss
|
|
|
|
|
|
self.log(
|
|
|
- "train/discriminator/all",
|
|
|
- loss_disc_all,
|
|
|
+ "train/generator/loss",
|
|
|
+ loss_gen_all,
|
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
|
prog_bar=True,
|
|
|
logger=True,
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
-
|
|
|
- return loss_disc_all
|
|
|
-
|
|
|
- def training_step(self, batch, batch_idx):
|
|
|
- optim_g, optim_d = self.optimizers()
|
|
|
-
|
|
|
- audio, lengths = batch["audio"], batch["lengths"]
|
|
|
- audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)
|
|
|
-
|
|
|
- # Generator
|
|
|
- optim_g.zero_grad()
|
|
|
- loss_gen_all, audio, fake_audio = self.training_generator(audio, audio_mask)
|
|
|
- self.manual_backward(loss_gen_all)
|
|
|
-
|
|
|
self.log(
|
|
|
- "train/generator/grad_norm",
|
|
|
- grad_norm(self.generator.parameters()),
|
|
|
+ "train/generator/loss_mel",
|
|
|
+ loss_mel,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ sync_dist=True,
|
|
|
+ )
|
|
|
+ self.log(
|
|
|
+ "train/generator/loss_fm",
|
|
|
+ loss_fm,
|
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
-
|
|
|
- self.clip_gradients(
|
|
|
- optim_g, gradient_clip_val=1000, gradient_clip_algorithm="norm"
|
|
|
+ self.log(
|
|
|
+ "train/generator/loss_adv",
|
|
|
+ loss_adv,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ sync_dist=True,
|
|
|
+ )
|
|
|
+ self.log(
|
|
|
+ "train/generator/loss_vq",
|
|
|
+ encoded.loss,
|
|
|
+ on_step=True,
|
|
|
+ on_epoch=False,
|
|
|
+ prog_bar=False,
|
|
|
+ logger=True,
|
|
|
+ sync_dist=True,
|
|
|
)
|
|
|
- optim_g.step()
|
|
|
-
|
|
|
- # Discriminator
|
|
|
- assert fake_audio.shape == audio.shape
|
|
|
-
|
|
|
- optim_d.zero_grad()
|
|
|
- loss_disc_all = self.training_discriminator(audio, fake_audio)
|
|
|
- self.manual_backward(loss_disc_all)
|
|
|
-
|
|
|
- for key, disc in self.discriminators.items():
|
|
|
- self.log(
|
|
|
- f"train/discriminator/grad_norm_{key}",
|
|
|
- grad_norm(disc.parameters()),
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=False,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
|
|
|
+ optim_g.zero_grad()
|
|
|
+ self.manual_backward(loss_gen_all)
|
|
|
self.clip_gradients(
|
|
|
- optim_d, gradient_clip_val=1000, gradient_clip_algorithm="norm"
|
|
|
+ optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
|
|
|
)
|
|
|
- optim_d.step()
|
|
|
+ optim_g.step()
|
|
|
|
|
|
# Manual LR Scheduler
|
|
|
scheduler_g, scheduler_d = self.lr_schedulers()
|
|
|
@@ -276,25 +271,55 @@ class VQGAN(L.LightningModule):
|
|
|
scheduler_d.step()
|
|
|
|
|
|
def validation_step(self, batch: Any, batch_idx: int):
|
|
|
- audio, lengths = batch["audio"], batch["lengths"]
|
|
|
- audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)
|
|
|
+ audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
+ features, feature_lengths = batch["features"], batch["feature_lengths"]
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ gt_mels = self.mel_transform(audios).transpose(1, 2)
|
|
|
+ key_padding_mask = sequence_mask(feature_lengths)
|
|
|
+ mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
|
|
|
+ audio_masks = sequence_mask(audio_lengths)
|
|
|
+
|
|
|
+ assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
|
|
|
+ gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
|
|
|
+ gt_mels = gt_mels[:, :gt_mel_length]
|
|
|
+ mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
|
|
|
+
|
|
|
+ assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
|
|
|
+ gt_feature_length = min(features.shape[1], key_padding_mask.shape[1])
|
|
|
+ features = features[:, :gt_feature_length]
|
|
|
+ key_padding_mask = key_padding_mask[:, :gt_feature_length]
|
|
|
|
|
|
# Generator
|
|
|
- fake_audio, _ = self.forward(audio, audio_mask)
|
|
|
- assert fake_audio.shape == audio.shape
|
|
|
+ encoded = self.encoder(
|
|
|
+ x=features,
|
|
|
+ mels=gt_mels,
|
|
|
+ key_padding_mask=key_padding_mask,
|
|
|
+ mels_key_padding_mask=mels_key_padding_mask,
|
|
|
+ )
|
|
|
+
|
|
|
+ features = encoded.features
|
|
|
+ audios = audios[:, None, :]
|
|
|
|
|
|
- # Apply mask
|
|
|
- audio = audio * audio_mask
|
|
|
- fake_audio = fake_audio * audio_mask
|
|
|
+ fake_audios = self.generator(features)
|
|
|
+ min_audio_length = min(audios.shape[-1], fake_audios.shape[-1])
|
|
|
|
|
|
- # L1 Mel-Spectrogram Loss
|
|
|
- audio_mel = self.mel_transforms.loss(audio.squeeze(1))
|
|
|
- fake_audio_mel = self.mel_transforms.loss(fake_audio.squeeze(1))
|
|
|
- loss_mel = F.l1_loss(audio_mel, fake_audio_mel)
|
|
|
+ audios = audios[:, :, :min_audio_length]
|
|
|
+ fake_audios = fake_audios[:, :, :min_audio_length]
|
|
|
+ audio_masks = audio_masks[:, None, :min_audio_length]
|
|
|
|
|
|
+ audio = torch.masked_fill(audios, audio_masks, 0.0)
|
|
|
+ fake_audios = torch.masked_fill(fake_audios, audio_masks, 0.0)
|
|
|
+ assert fake_audios.shape == audio.shape
|
|
|
+
|
|
|
+ fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
|
|
|
+ gt_mels = torch.masked_fill(gt_mels, mels_key_padding_mask[:, :, None], 0.0)
|
|
|
+ fake_mels = torch.masked_fill(fake_mels, mels_key_padding_mask[:, :, None], 0.0)
|
|
|
+
|
|
|
+ mel_loss = F.l1_loss(gt_mels, fake_mels)
|
|
|
self.log(
|
|
|
- "val/metrics/mel",
|
|
|
- loss_mel,
|
|
|
+ "val/mel_loss",
|
|
|
+ mel_loss,
|
|
|
on_step=False,
|
|
|
on_epoch=True,
|
|
|
prog_bar=True,
|
|
|
@@ -302,5 +327,61 @@ class VQGAN(L.LightningModule):
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
|
|
|
- # Report other metrics
|
|
|
- self.report_val_metrics(fake_audio, audio, lengths)
|
|
|
+ for idx, (mel, gen_mel, audio, gen_audio, audio_len) in enumerate(
|
|
|
+ zip(
|
|
|
+ gt_mels.transpose(1, 2),
|
|
|
+ fake_mels.transpose(1, 2),
|
|
|
+ audios,
|
|
|
+ fake_audios,
|
|
|
+ audio_lengths,
|
|
|
+ )
|
|
|
+ ):
|
|
|
+ mel_len = audio_len // self.hop_length
|
|
|
+
|
|
|
+ image_mels = plot_mel(
|
|
|
+ [
|
|
|
+ gen_mel[:, :mel_len],
|
|
|
+ mel[:, :mel_len],
|
|
|
+ ],
|
|
|
+ ["Sampled Spectrogram", "Ground-Truth Spectrogram"],
|
|
|
+ )
|
|
|
+
|
|
|
+ if isinstance(self.logger, WandbLogger):
|
|
|
+ self.logger.experiment.log(
|
|
|
+ {
|
|
|
+ "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
|
|
|
+ "wavs": [
|
|
|
+ wandb.Audio(
|
|
|
+ audio[0, :audio_len],
|
|
|
+ sample_rate=self.sampling_rate,
|
|
|
+ caption="gt",
|
|
|
+ ),
|
|
|
+ wandb.Audio(
|
|
|
+ gen_audio[0, :audio_len],
|
|
|
+ sample_rate=self.sampling_rate,
|
|
|
+ caption="prediction",
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ if isinstance(self.logger, TensorBoardLogger):
|
|
|
+ self.logger.experiment.add_figure(
|
|
|
+ f"sample-{idx}/mels",
|
|
|
+ image_mels,
|
|
|
+ global_step=self.global_step,
|
|
|
+ )
|
|
|
+ self.logger.experiment.add_audio(
|
|
|
+ f"sample-{idx}/wavs/gt",
|
|
|
+ audio[0, :audio_len],
|
|
|
+ self.global_step,
|
|
|
+ sample_rate=self.sampling_rate,
|
|
|
+ )
|
|
|
+ self.logger.experiment.add_audio(
|
|
|
+ f"sample-{idx}/wavs/prediction",
|
|
|
+ gen_audio[0, :audio_len],
|
|
|
+ self.global_step,
|
|
|
+ sample_rate=self.sampling_rate,
|
|
|
+ )
|
|
|
+
|
|
|
+ plt.close(image_mels)
|