|
|
@@ -11,6 +11,7 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
|
|
from matplotlib import pyplot as plt
|
|
|
from torch import nn
|
|
|
from tqdm import tqdm
|
|
|
+from transformers import HubertModel
|
|
|
|
|
|
from fish_speech.models.vq_diffusion.convnext_1d import ConvNext1DModel
|
|
|
from fish_speech.models.vqgan.modules.encoders import (
|
|
|
@@ -86,7 +87,7 @@ class VQDiffusion(L.LightningModule):
|
|
|
features, feature_lengths = batch["features"], batch["feature_lengths"]
|
|
|
|
|
|
audios = audios.float()
|
|
|
- features = features.float().mT
|
|
|
+ # features = features.float().mT
|
|
|
audios = audios[:, None, :]
|
|
|
|
|
|
with torch.no_grad():
|
|
|
@@ -95,18 +96,20 @@ class VQDiffusion(L.LightningModule):
|
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
|
|
|
|
feature_masks = torch.unsqueeze(
|
|
|
- sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
+ sequence_mask(feature_lengths, features.shape[1]), 1
|
|
|
).to(gt_mels.dtype)
|
|
|
mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
gt_mels.dtype
|
|
|
)
|
|
|
|
|
|
speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
- vq_features, vq_loss = self.vq_encoder(features, feature_masks)
|
|
|
+ # vq_features, vq_loss = self.vq_encoder(features, feature_masks)
|
|
|
|
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
|
- vq_features = F.interpolate(vq_features, size=gt_mels.shape[2], mode="nearest")
|
|
|
- text_features = self.text_encoder(vq_features, mel_masks, g=speaker_features)
|
|
|
+ text_features = self.text_encoder(features, feature_masks, g=speaker_features)
|
|
|
+ text_features = F.interpolate(
|
|
|
+ text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
+ )
|
|
|
|
|
|
# Sample noise that we'll add to the images
|
|
|
normalized_gt_mels = self.normalize_mels(gt_mels)
|
|
|
@@ -144,41 +147,43 @@ class VQDiffusion(L.LightningModule):
|
|
|
sync_dist=True,
|
|
|
)
|
|
|
|
|
|
- self.log(
|
|
|
- "train/vq_loss",
|
|
|
- vq_loss,
|
|
|
- on_step=True,
|
|
|
- on_epoch=False,
|
|
|
- prog_bar=True,
|
|
|
- logger=True,
|
|
|
- sync_dist=True,
|
|
|
- )
|
|
|
+ # self.log(
|
|
|
+ # "train/vq_loss",
|
|
|
+ # vq_loss,
|
|
|
+ # on_step=True,
|
|
|
+ # on_epoch=False,
|
|
|
+ # prog_bar=True,
|
|
|
+ # logger=True,
|
|
|
+ # sync_dist=True,
|
|
|
+ # )
|
|
|
|
|
|
- return noise_loss + vq_loss
|
|
|
+ return noise_loss # + vq_loss
|
|
|
|
|
|
def validation_step(self, batch: Any, batch_idx: int):
|
|
|
audios, audio_lengths = batch["audios"], batch["audio_lengths"]
|
|
|
features, feature_lengths = batch["features"], batch["feature_lengths"]
|
|
|
|
|
|
audios = audios.float()
|
|
|
- features = features.float().mT
|
|
|
+ # features = features.float().mT
|
|
|
audios = audios[:, None, :]
|
|
|
gt_mels = self.mel_transform(audios)
|
|
|
mel_lengths = audio_lengths // self.hop_length
|
|
|
|
|
|
feature_masks = torch.unsqueeze(
|
|
|
- sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
+ sequence_mask(feature_lengths, features.shape[1]), 1
|
|
|
).to(gt_mels.dtype)
|
|
|
mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
gt_mels.dtype
|
|
|
)
|
|
|
|
|
|
speaker_features = self.speaker_encoder(gt_mels, mel_masks)
|
|
|
- vq_features, _ = self.vq_encoder(features, feature_masks)
|
|
|
+ # vq_features, vq_loss = self.vq_encoder(features, feature_masks)
|
|
|
|
|
|
# vq_features is 50 hz, need to convert to true mel size
|
|
|
- vq_features = F.interpolate(vq_features, size=gt_mels.shape[2], mode="nearest")
|
|
|
- text_features = self.text_encoder(vq_features, mel_masks, g=speaker_features)
|
|
|
+ text_features = self.text_encoder(features, feature_masks, g=speaker_features)
|
|
|
+ text_features = F.interpolate(
|
|
|
+ text_features, size=gt_mels.shape[2], mode="nearest"
|
|
|
+ )
|
|
|
|
|
|
# Begin sampling
|
|
|
sampled_mels = torch.randn_like(gt_mels)
|