|
@@ -10,21 +10,8 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
|
|
from matplotlib import pyplot as plt
|
|
from matplotlib import pyplot as plt
|
|
|
from torch import nn
|
|
from torch import nn
|
|
|
|
|
|
|
|
-from fish_speech.models.vqgan.utils import plot_mel, sequence_mask, slice_segments
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-@dataclass
|
|
|
|
|
-class VQEncodeResult:
|
|
|
|
|
- features: torch.Tensor
|
|
|
|
|
- indices: torch.Tensor
|
|
|
|
|
- loss: torch.Tensor
|
|
|
|
|
- feature_lengths: torch.Tensor
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-@dataclass
|
|
|
|
|
-class VQDecodeResult:
|
|
|
|
|
- mels: torch.Tensor
|
|
|
|
|
- audios: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
+from fish_speech.models.vqgan.modules.wavenet import WaveNet
|
|
|
|
|
+from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
class VQGAN(L.LightningModule):
|
|
class VQGAN(L.LightningModule):
|
|
@@ -32,16 +19,20 @@ class VQGAN(L.LightningModule):
|
|
|
self,
|
|
self,
|
|
|
optimizer: Callable,
|
|
optimizer: Callable,
|
|
|
lr_scheduler: Callable,
|
|
lr_scheduler: Callable,
|
|
|
- encoder: nn.Module,
|
|
|
|
|
|
|
+ encoder: WaveNet,
|
|
|
quantizer: nn.Module,
|
|
quantizer: nn.Module,
|
|
|
- aux_decoder: nn.Module,
|
|
|
|
|
|
|
+ decoder: WaveNet,
|
|
|
reflow: nn.Module,
|
|
reflow: nn.Module,
|
|
|
vocoder: nn.Module,
|
|
vocoder: nn.Module,
|
|
|
mel_transform: nn.Module,
|
|
mel_transform: nn.Module,
|
|
|
weight_reflow: float = 1.0,
|
|
weight_reflow: float = 1.0,
|
|
|
weight_vq: float = 1.0,
|
|
weight_vq: float = 1.0,
|
|
|
- weight_aux_mel: float = 1.0,
|
|
|
|
|
|
|
+ weight_mel: float = 1.0,
|
|
|
sampling_rate: int = 44100,
|
|
sampling_rate: int = 44100,
|
|
|
|
|
+ freeze_encoder: bool = False,
|
|
|
|
|
+ reflow_use_shallow: bool = False,
|
|
|
|
|
+ reflow_inference_steps: int = 10,
|
|
|
|
|
+ reflow_inference_start_t: float = 0.5,
|
|
|
):
|
|
):
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
|
|
|
@@ -52,10 +43,10 @@ class VQGAN(L.LightningModule):
|
|
|
# Modules
|
|
# Modules
|
|
|
self.encoder = encoder
|
|
self.encoder = encoder
|
|
|
self.quantizer = quantizer
|
|
self.quantizer = quantizer
|
|
|
- self.aux_decoder = aux_decoder
|
|
|
|
|
|
|
+ self.decoder = decoder
|
|
|
|
|
+ self.vocoder = vocoder
|
|
|
self.reflow = reflow
|
|
self.reflow = reflow
|
|
|
self.mel_transform = mel_transform
|
|
self.mel_transform = mel_transform
|
|
|
- self.vocoder = vocoder
|
|
|
|
|
|
|
|
|
|
# Freeze vocoder
|
|
# Freeze vocoder
|
|
|
for param in self.vocoder.parameters():
|
|
for param in self.vocoder.parameters():
|
|
@@ -64,13 +55,27 @@ class VQGAN(L.LightningModule):
|
|
|
# Loss weights
|
|
# Loss weights
|
|
|
self.weight_reflow = weight_reflow
|
|
self.weight_reflow = weight_reflow
|
|
|
self.weight_vq = weight_vq
|
|
self.weight_vq = weight_vq
|
|
|
- self.weight_aux_mel = weight_aux_mel
|
|
|
|
|
|
|
+ self.weight_mel = weight_mel
|
|
|
|
|
|
|
|
|
|
+ # Other parameters
|
|
|
self.spec_min = -12
|
|
self.spec_min = -12
|
|
|
self.spec_max = 3
|
|
self.spec_max = 3
|
|
|
self.sampling_rate = sampling_rate
|
|
self.sampling_rate = sampling_rate
|
|
|
|
|
+ self.reflow_use_shallow = reflow_use_shallow
|
|
|
|
|
+ self.reflow_inference_steps = reflow_inference_steps
|
|
|
|
|
+ self.reflow_inference_start_t = reflow_inference_start_t
|
|
|
|
|
+
|
|
|
|
|
+ # Disable strict loading
|
|
|
self.strict_loading = False
|
|
self.strict_loading = False
|
|
|
|
|
|
|
|
|
|
+ # If encoder is frozen
|
|
|
|
|
+ if freeze_encoder:
|
|
|
|
|
+ for param in self.encoder.parameters():
|
|
|
|
|
+ param.requires_grad = False
|
|
|
|
|
+
|
|
|
|
|
+ for param in self.quantizer.parameters():
|
|
|
|
|
+ param.requires_grad = False
|
|
|
|
|
+
|
|
|
def on_save_checkpoint(self, checkpoint):
|
|
def on_save_checkpoint(self, checkpoint):
|
|
|
# Do not save vocoder
|
|
# Do not save vocoder
|
|
|
state_dict = checkpoint["state_dict"]
|
|
state_dict = checkpoint["state_dict"]
|
|
@@ -79,7 +84,6 @@ class VQGAN(L.LightningModule):
|
|
|
state_dict.pop(name)
|
|
state_dict.pop(name)
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
def configure_optimizers(self):
|
|
|
- # Need two optimizers and two schedulers
|
|
|
|
|
optimizer = self.optimizer_builder(self.parameters())
|
|
optimizer = self.optimizer_builder(self.parameters())
|
|
|
lr_scheduler = self.lr_scheduler_builder(optimizer)
|
|
lr_scheduler = self.lr_scheduler_builder(optimizer)
|
|
|
|
|
|
|
@@ -97,7 +101,6 @@ class VQGAN(L.LightningModule):
|
|
|
def denorm_spec(self, x):
|
|
def denorm_spec(self, x):
|
|
|
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
|
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
|
|
|
|
|
|
|
- # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
def training_step(self, batch, batch_idx):
|
|
|
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
|
|
|
|
@@ -110,6 +113,7 @@ class VQGAN(L.LightningModule):
|
|
|
mel_lengths = audio_lengths // self.mel_transform.hop_length
|
|
mel_lengths = audio_lengths // self.mel_transform.hop_length
|
|
|
mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
|
mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
|
|
mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
|
|
|
+ gt_mels = gt_mels * mel_masks_float_conv
|
|
|
|
|
|
|
|
# Encode
|
|
# Encode
|
|
|
encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
|
|
encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
|
|
@@ -120,25 +124,31 @@ class VQGAN(L.LightningModule):
|
|
|
vq_recon_features = vq_result.z * mel_masks_float_conv
|
|
vq_recon_features = vq_result.z * mel_masks_float_conv
|
|
|
|
|
|
|
|
# VQ Decode
|
|
# VQ Decode
|
|
|
- aux_mel = self.aux_decoder(vq_recon_features)
|
|
|
|
|
- loss_aux_mel = F.l1_loss(
|
|
|
|
|
- aux_mel * mel_masks_float_conv, gt_mels * mel_masks_float_conv
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ gen_mel = self.decoder(vq_recon_features) * mel_masks_float_conv
|
|
|
|
|
|
|
|
- # Reflow
|
|
|
|
|
|
|
+ # Mel Loss
|
|
|
|
|
+ loss_mel = (gen_mel - gt_mels).abs().mean(
|
|
|
|
|
+ dim=1, keepdim=True
|
|
|
|
|
+ ).sum() / mel_masks_float_conv.sum()
|
|
|
|
|
+
|
|
|
|
|
+ # Reflow, given x_1_aux, we want to reconstruct x_1
|
|
|
x_1 = self.norm_spec(gt_mels)
|
|
x_1 = self.norm_spec(gt_mels)
|
|
|
|
|
+
|
|
|
|
|
+ if self.reflow_use_shallow:
|
|
|
|
|
+ x_1_aux = self.norm_spec(gen_mel)
|
|
|
|
|
+ else:
|
|
|
|
|
+ x_1_aux = x_1
|
|
|
|
|
+
|
|
|
t = torch.rand(gt_mels.shape[0], device=gt_mels.device)
|
|
t = torch.rand(gt_mels.shape[0], device=gt_mels.device)
|
|
|
x_0 = torch.randn_like(x_1)
|
|
x_0 = torch.randn_like(x_1)
|
|
|
|
|
|
|
|
# X_t = t * X_1 + (1 - t) * X_0
|
|
# X_t = t * X_1 + (1 - t) * X_0
|
|
|
- x_t = x_0 + t[:, None, None] * (x_1 - x_0)
|
|
|
|
|
|
|
+ x_t = x_0 + t[:, None, None] * (x_1_aux - x_0)
|
|
|
|
|
|
|
|
v_pred = self.reflow(
|
|
v_pred = self.reflow(
|
|
|
x_t,
|
|
x_t,
|
|
|
1000 * t,
|
|
1000 * t,
|
|
|
- vq_recon_features, # .detach()
|
|
|
|
|
- x_masks=mel_masks_float_conv,
|
|
|
|
|
- cond_masks=mel_masks_float_conv,
|
|
|
|
|
|
|
+ vq_recon_features,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# Log L2 loss with
|
|
# Log L2 loss with
|
|
@@ -146,21 +156,28 @@ class VQGAN(L.LightningModule):
|
|
|
loss_reflow = weights[:, None, None] * F.mse_loss(
|
|
loss_reflow = weights[:, None, None] * F.mse_loss(
|
|
|
x_1 - x_0, v_pred, reduction="none"
|
|
x_1 - x_0, v_pred, reduction="none"
|
|
|
)
|
|
)
|
|
|
- loss_reflow = (loss_reflow * mel_masks_float_conv).mean()
|
|
|
|
|
|
|
+ loss_reflow = (loss_reflow * mel_masks_float_conv).mean(
|
|
|
|
|
+ dim=1
|
|
|
|
|
+ ).sum() / mel_masks_float_conv.sum()
|
|
|
|
|
|
|
|
# Total loss
|
|
# Total loss
|
|
|
loss = (
|
|
loss = (
|
|
|
self.weight_vq * loss_vq
|
|
self.weight_vq * loss_vq
|
|
|
- + self.weight_aux_mel * loss_aux_mel
|
|
|
|
|
|
|
+ + self.weight_mel * loss_mel
|
|
|
+ self.weight_reflow * loss_reflow
|
|
+ self.weight_reflow * loss_reflow
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# Log losses
|
|
# Log losses
|
|
|
self.log(
|
|
self.log(
|
|
|
- "train/loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True
|
|
|
|
|
|
|
+ "train/generator/loss",
|
|
|
|
|
+ loss,
|
|
|
|
|
+ on_step=True,
|
|
|
|
|
+ on_epoch=False,
|
|
|
|
|
+ prog_bar=True,
|
|
|
|
|
+ logger=True,
|
|
|
)
|
|
)
|
|
|
self.log(
|
|
self.log(
|
|
|
- "train/loss_vq",
|
|
|
|
|
|
|
+ "train/generator/loss_vq",
|
|
|
loss_vq,
|
|
loss_vq,
|
|
|
on_step=True,
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
on_epoch=False,
|
|
@@ -168,15 +185,15 @@ class VQGAN(L.LightningModule):
|
|
|
logger=True,
|
|
logger=True,
|
|
|
)
|
|
)
|
|
|
self.log(
|
|
self.log(
|
|
|
- "train/loss_aux_mel",
|
|
|
|
|
- loss_aux_mel,
|
|
|
|
|
|
|
+ "train/generator/loss_mel",
|
|
|
|
|
+ loss_mel,
|
|
|
on_step=True,
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
on_epoch=False,
|
|
|
prog_bar=False,
|
|
prog_bar=False,
|
|
|
logger=True,
|
|
logger=True,
|
|
|
)
|
|
)
|
|
|
self.log(
|
|
self.log(
|
|
|
- "train/loss_reflow",
|
|
|
|
|
|
|
+ "train/generator/loss_reflow",
|
|
|
loss_reflow,
|
|
loss_reflow,
|
|
|
on_step=True,
|
|
on_step=True,
|
|
|
on_epoch=False,
|
|
on_epoch=False,
|
|
@@ -196,22 +213,23 @@ class VQGAN(L.LightningModule):
|
|
|
mel_lengths = audio_lengths // self.mel_transform.hop_length
|
|
mel_lengths = audio_lengths // self.mel_transform.hop_length
|
|
|
mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
|
mel_masks = sequence_mask(mel_lengths, gt_mels.shape[2])
|
|
|
mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
mel_masks_float_conv = mel_masks[:, None, :].float()
|
|
|
|
|
+ gt_mels = gt_mels * mel_masks_float_conv
|
|
|
|
|
|
|
|
# Encode
|
|
# Encode
|
|
|
encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
|
|
encoded_features = self.encoder(gt_mels) * mel_masks_float_conv
|
|
|
|
|
|
|
|
# Quantize
|
|
# Quantize
|
|
|
- vq_result = self.quantizer(encoded_features)
|
|
|
|
|
|
|
+ vq_recon_features = self.quantizer(encoded_features).z * mel_masks_float_conv
|
|
|
|
|
|
|
|
# VQ Decode
|
|
# VQ Decode
|
|
|
- aux_mels = self.aux_decoder(vq_result.z)
|
|
|
|
|
- loss_aux_mel = F.l1_loss(
|
|
|
|
|
- aux_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ gen_aux_mels = self.decoder(vq_recon_features) * mel_masks_float_conv
|
|
|
|
|
+ loss_mel = (gen_aux_mels - gt_mels).abs().mean(
|
|
|
|
|
+ dim=1, keepdim=True
|
|
|
|
|
+ ).sum() / mel_masks_float_conv.sum()
|
|
|
|
|
|
|
|
self.log(
|
|
self.log(
|
|
|
- "val/loss_aux_mel",
|
|
|
|
|
- loss_aux_mel,
|
|
|
|
|
|
|
+ "val/loss_mel",
|
|
|
|
|
+ loss_mel,
|
|
|
on_step=False,
|
|
on_step=False,
|
|
|
on_epoch=True,
|
|
on_epoch=True,
|
|
|
prog_bar=False,
|
|
prog_bar=False,
|
|
@@ -220,37 +238,34 @@ class VQGAN(L.LightningModule):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# Reflow inference
|
|
# Reflow inference
|
|
|
- t_start = 0.0
|
|
|
|
|
- infer_step = 10
|
|
|
|
|
|
|
+ t_start = self.reflow_inference_start_t if self.reflow_use_shallow else 0.0
|
|
|
|
|
|
|
|
- x_1 = self.norm_spec(aux_mels)
|
|
|
|
|
|
|
+ x_1 = self.norm_spec(gen_aux_mels)
|
|
|
x_0 = torch.randn_like(x_1)
|
|
x_0 = torch.randn_like(x_1)
|
|
|
- gen_mels = (1 - t_start) * x_0 + t_start * x_1
|
|
|
|
|
|
|
+ gen_reflow_mels = (1 - t_start) * x_0 + t_start * x_1
|
|
|
|
|
|
|
|
t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
|
|
t = torch.zeros(gt_mels.shape[0], device=gt_mels.device)
|
|
|
- dt = (1.0 - t_start) / infer_step
|
|
|
|
|
|
|
+ dt = (1.0 - t_start) / self.reflow_inference_steps
|
|
|
|
|
|
|
|
- for _ in range(infer_step):
|
|
|
|
|
- gen_mels += (
|
|
|
|
|
|
|
+ for _ in range(self.reflow_inference_steps):
|
|
|
|
|
+ gen_reflow_mels += (
|
|
|
self.reflow(
|
|
self.reflow(
|
|
|
- gen_mels,
|
|
|
|
|
|
|
+ gen_reflow_mels,
|
|
|
1000 * t,
|
|
1000 * t,
|
|
|
- vq_result.z,
|
|
|
|
|
- x_masks=mel_masks_float_conv,
|
|
|
|
|
- cond_masks=mel_masks_float_conv,
|
|
|
|
|
|
|
+ vq_recon_features,
|
|
|
)
|
|
)
|
|
|
* dt
|
|
* dt
|
|
|
)
|
|
)
|
|
|
t += dt
|
|
t += dt
|
|
|
|
|
|
|
|
- gen_mels = self.denorm_spec(gen_mels)
|
|
|
|
|
- loss_recon_reflow = F.l1_loss(
|
|
|
|
|
- gen_mels * mel_masks_float_conv, gt_mels * mel_masks_float_conv
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ gen_reflow_mels = self.denorm_spec(gen_reflow_mels) * mel_masks_float_conv
|
|
|
|
|
+ loss_reflow_mel = (gen_reflow_mels - gt_mels).abs().mean(
|
|
|
|
|
+ dim=1, keepdim=True
|
|
|
|
|
+ ).sum() / mel_masks_float_conv.sum()
|
|
|
|
|
|
|
|
self.log(
|
|
self.log(
|
|
|
- "val/loss_recon_reflow",
|
|
|
|
|
- loss_recon_reflow,
|
|
|
|
|
|
|
+ "val/loss_reflow_mel",
|
|
|
|
|
+ loss_reflow_mel,
|
|
|
on_step=False,
|
|
on_step=False,
|
|
|
on_epoch=True,
|
|
on_epoch=True,
|
|
|
prog_bar=False,
|
|
prog_bar=False,
|
|
@@ -258,9 +273,9 @@ class VQGAN(L.LightningModule):
|
|
|
sync_dist=True,
|
|
sync_dist=True,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- gen_audios = self.vocoder(gen_mels)
|
|
|
|
|
recon_audios = self.vocoder(gt_mels)
|
|
recon_audios = self.vocoder(gt_mels)
|
|
|
- aux_audios = self.vocoder(aux_mels)
|
|
|
|
|
|
|
+ gen_aux_audios = self.vocoder(gen_aux_mels)
|
|
|
|
|
+ gen_reflow_audios = self.vocoder(gen_reflow_mels)
|
|
|
|
|
|
|
|
# only log the first batch
|
|
# only log the first batch
|
|
|
if batch_idx != 0:
|
|
if batch_idx != 0:
|
|
@@ -268,21 +283,21 @@ class VQGAN(L.LightningModule):
|
|
|
|
|
|
|
|
for idx, (
|
|
for idx, (
|
|
|
gt_mel,
|
|
gt_mel,
|
|
|
- reflow_mel,
|
|
|
|
|
- aux_mel,
|
|
|
|
|
|
|
+ gen_aux_mel,
|
|
|
|
|
+ gen_reflow_mel,
|
|
|
audio,
|
|
audio,
|
|
|
- reflow_audio,
|
|
|
|
|
- aux_audio,
|
|
|
|
|
|
|
+ gen_aux_audio,
|
|
|
|
|
+ gen_reflow_audio,
|
|
|
recon_audio,
|
|
recon_audio,
|
|
|
audio_len,
|
|
audio_len,
|
|
|
) in enumerate(
|
|
) in enumerate(
|
|
|
zip(
|
|
zip(
|
|
|
gt_mels,
|
|
gt_mels,
|
|
|
- gen_mels,
|
|
|
|
|
- aux_mels,
|
|
|
|
|
|
|
+ gen_aux_mels,
|
|
|
|
|
+ gen_reflow_mels,
|
|
|
audios.float(),
|
|
audios.float(),
|
|
|
- gen_audios.float(),
|
|
|
|
|
- aux_audios.float(),
|
|
|
|
|
|
|
+ gen_aux_audios.float(),
|
|
|
|
|
+ gen_reflow_audios.float(),
|
|
|
recon_audios.float(),
|
|
recon_audios.float(),
|
|
|
audio_lengths,
|
|
audio_lengths,
|
|
|
)
|
|
)
|
|
@@ -292,13 +307,13 @@ class VQGAN(L.LightningModule):
|
|
|
image_mels = plot_mel(
|
|
image_mels = plot_mel(
|
|
|
[
|
|
[
|
|
|
gt_mel[:, :mel_len],
|
|
gt_mel[:, :mel_len],
|
|
|
- reflow_mel[:, :mel_len],
|
|
|
|
|
- aux_mel[:, :mel_len],
|
|
|
|
|
|
|
+ gen_aux_mel[:, :mel_len],
|
|
|
|
|
+ gen_reflow_mel[:, :mel_len],
|
|
|
],
|
|
],
|
|
|
[
|
|
[
|
|
|
"Ground-Truth",
|
|
"Ground-Truth",
|
|
|
|
|
+ "Auxiliary",
|
|
|
"Reflow",
|
|
"Reflow",
|
|
|
- "Aux",
|
|
|
|
|
],
|
|
],
|
|
|
)
|
|
)
|
|
|
|
|
|
|
@@ -313,14 +328,14 @@ class VQGAN(L.LightningModule):
|
|
|
caption="gt",
|
|
caption="gt",
|
|
|
),
|
|
),
|
|
|
wandb.Audio(
|
|
wandb.Audio(
|
|
|
- reflow_audio[0, :audio_len],
|
|
|
|
|
|
|
+ gen_aux_audio[0, :audio_len],
|
|
|
sample_rate=self.sampling_rate,
|
|
sample_rate=self.sampling_rate,
|
|
|
- caption="reflow",
|
|
|
|
|
|
|
+ caption="aux",
|
|
|
),
|
|
),
|
|
|
wandb.Audio(
|
|
wandb.Audio(
|
|
|
- aux_audio[0, :audio_len],
|
|
|
|
|
|
|
+ gen_reflow_audio[0, :audio_len],
|
|
|
sample_rate=self.sampling_rate,
|
|
sample_rate=self.sampling_rate,
|
|
|
- caption="aux",
|
|
|
|
|
|
|
+ caption="reflow",
|
|
|
),
|
|
),
|
|
|
wandb.Audio(
|
|
wandb.Audio(
|
|
|
recon_audio[0, :audio_len],
|
|
recon_audio[0, :audio_len],
|
|
@@ -344,14 +359,14 @@ class VQGAN(L.LightningModule):
|
|
|
sample_rate=self.sampling_rate,
|
|
sample_rate=self.sampling_rate,
|
|
|
)
|
|
)
|
|
|
self.logger.experiment.add_audio(
|
|
self.logger.experiment.add_audio(
|
|
|
- f"sample-{idx}/wavs/reflow",
|
|
|
|
|
- reflow_audio[0, :audio_len],
|
|
|
|
|
|
|
+ f"sample-{idx}/wavs/gen",
|
|
|
|
|
+ gen_aux_audio[0, :audio_len],
|
|
|
self.global_step,
|
|
self.global_step,
|
|
|
sample_rate=self.sampling_rate,
|
|
sample_rate=self.sampling_rate,
|
|
|
)
|
|
)
|
|
|
self.logger.experiment.add_audio(
|
|
self.logger.experiment.add_audio(
|
|
|
- f"sample-{idx}/wavs/aux",
|
|
|
|
|
- aux_audio[0, :audio_len],
|
|
|
|
|
|
|
+ f"sample-{idx}/wavs/reflow",
|
|
|
|
|
+ gen_reflow_audio[0, :audio_len],
|
|
|
self.global_step,
|
|
self.global_step,
|
|
|
sample_rate=self.sampling_rate,
|
|
sample_rate=self.sampling_rate,
|
|
|
)
|
|
)
|