|
@@ -2,12 +2,15 @@ import librosa
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import soundfile as sf
|
|
import soundfile as sf
|
|
|
import torch
|
|
import torch
|
|
|
|
|
+import torch.nn.functional as F
|
|
|
from hydra import compose, initialize
|
|
from hydra import compose, initialize
|
|
|
from hydra.utils import instantiate
|
|
from hydra.utils import instantiate
|
|
|
from lightning import LightningModule
|
|
from lightning import LightningModule
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
from omegaconf import OmegaConf
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
|
|
|
|
+from fish_speech.models.vqgan.utils import sequence_mask
|
|
|
|
|
+
|
|
|
# register eval resolver
|
|
# register eval resolver
|
|
|
OmegaConf.register_new_resolver("eval", eval)
|
|
OmegaConf.register_new_resolver("eval", eval)
|
|
|
|
|
|
|
@@ -16,12 +19,11 @@ OmegaConf.register_new_resolver("eval", eval)
|
|
|
@torch.autocast(device_type="cuda", enabled=True)
|
|
@torch.autocast(device_type="cuda", enabled=True)
|
|
|
def main():
|
|
def main():
|
|
|
with initialize(version_base="1.3", config_path="../fish_speech/configs"):
|
|
with initialize(version_base="1.3", config_path="../fish_speech/configs"):
|
|
|
- cfg = compose(config_name="vq_naive_40hz")
|
|
|
|
|
|
|
+ cfg = compose(config_name="vqgan")
|
|
|
|
|
|
|
|
model: LightningModule = instantiate(cfg.model)
|
|
model: LightningModule = instantiate(cfg.model)
|
|
|
state_dict = torch.load(
|
|
state_dict = torch.load(
|
|
|
- "results/vq_naive_40hz/checkpoints/step_000675000.ckpt",
|
|
|
|
|
- # "results/vq_naive_25hz/checkpoints/step_000100000.ckpt",
|
|
|
|
|
|
|
+ "results/vqgan/checkpoints/step_000110000.ckpt",
|
|
|
map_location=model.device,
|
|
map_location=model.device,
|
|
|
)["state_dict"]
|
|
)["state_dict"]
|
|
|
model.load_state_dict(state_dict, strict=True)
|
|
model.load_state_dict(state_dict, strict=True)
|
|
@@ -30,7 +32,7 @@ def main():
|
|
|
logger.info("Restored model from checkpoint")
|
|
logger.info("Restored model from checkpoint")
|
|
|
|
|
|
|
|
# Load audio
|
|
# Load audio
|
|
|
- audio = librosa.load("record1.wav", sr=model.sampling_rate, mono=True)[0]
|
|
|
|
|
|
|
+ audio = librosa.load("0.wav", sr=model.sampling_rate, mono=True)[0]
|
|
|
audios = torch.from_numpy(audio).to(model.device)[None, None, :]
|
|
audios = torch.from_numpy(audio).to(model.device)[None, None, :]
|
|
|
logger.info(
|
|
logger.info(
|
|
|
f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
|
|
f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
|
|
@@ -40,30 +42,40 @@ def main():
|
|
|
audio_lengths = torch.tensor(
|
|
audio_lengths = torch.tensor(
|
|
|
[audios.shape[2]], device=model.device, dtype=torch.long
|
|
[audios.shape[2]], device=model.device, dtype=torch.long
|
|
|
)
|
|
)
|
|
|
- mel_masks, gt_mels, text_features, indices, loss_vq = model.vq_encode(
|
|
|
|
|
- audios, audio_lengths
|
|
|
|
|
|
|
+
|
|
|
|
|
+ features = gt_mels = model.mel_transform(audios, sample_rate=model.sampling_rate)
|
|
|
|
|
+
|
|
|
|
|
+ if model.downsample is not None:
|
|
|
|
|
+ features = model.downsample(features)
|
|
|
|
|
+
|
|
|
|
|
+ mel_lengths = audio_lengths // model.hop_length
|
|
|
|
|
+ feature_lengths = (
|
|
|
|
|
+ audio_lengths
|
|
|
|
|
+ / model.hop_length
|
|
|
|
|
+ / (model.downsample.total_strides if model.downsample is not None else 1)
|
|
|
|
|
+ ).long()
|
|
|
|
|
+
|
|
|
|
|
+ feature_masks = torch.unsqueeze(
|
|
|
|
|
+ sequence_mask(feature_lengths, features.shape[2]), 1
|
|
|
|
|
+ ).to(gt_mels.dtype)
|
|
|
|
|
+ mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
|
|
|
|
|
+ gt_mels.dtype
|
|
|
)
|
|
)
|
|
|
|
|
+
|
|
|
|
|
+ # vq_features is 50 hz, need to convert to true mel size
|
|
|
|
|
+ text_features = model.mel_encoder(features, feature_masks)
|
|
|
|
|
+ text_features, indices, _ = model.vq_encoder(text_features, feature_masks)
|
|
|
|
|
+
|
|
|
logger.info(
|
|
logger.info(
|
|
|
f"VQ Encoded, indices: {indices.shape} equavilent to "
|
|
f"VQ Encoded, indices: {indices.shape} equavilent to "
|
|
|
- + f"{1/(audios.shape[2] / model.sampling_rate / indices.shape[1]):.2f} Hz"
|
|
|
|
|
|
|
+ + f"{1/(audios.shape[2] / model.sampling_rate / indices.shape[2]):.2f} Hz"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # VQ Decoder
|
|
|
|
|
- audioa = librosa.load(
|
|
|
|
|
- "data/AiShell/wav/train/S0121/BAC009S0121W0125.wav",
|
|
|
|
|
- sr=model.sampling_rate,
|
|
|
|
|
- mono=True,
|
|
|
|
|
- )[0]
|
|
|
|
|
- audioa = torch.from_numpy(audioa).to(model.device)[None, None, :]
|
|
|
|
|
- mel = model.mel_transform(audioa)
|
|
|
|
|
- mel1_masks = torch.ones([mel.shape[0], 1, mel.shape[2]], device=model.device)
|
|
|
|
|
-
|
|
|
|
|
- speaker_features = model.speaker_encoder(mel, mel1_masks)
|
|
|
|
|
-
|
|
|
|
|
- speaker_features = model.speaker_encoder(gt_mels, mel_masks)
|
|
|
|
|
- speaker_features = torch.zeros_like(speaker_features)
|
|
|
|
|
- decoded_mels = model.vq_decode(text_features, speaker_features, gt_mels, mel_masks)
|
|
|
|
|
- fake_audios = model.vocoder(decoded_mels)
|
|
|
|
|
|
|
+ text_features = F.interpolate(text_features, size=gt_mels.shape[2], mode="nearest")
|
|
|
|
|
+
|
|
|
|
|
+ # Sample mels
|
|
|
|
|
+ decoded_mels = model.decoder(text_features, mel_masks)
|
|
|
|
|
+ fake_audios = model.generator(decoded_mels)
|
|
|
|
|
|
|
|
# Save audio
|
|
# Save audio
|
|
|
fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
|
|
fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
|