Lengyue 2 лет назад
Родитель
Сommit
e70d2c0373
1 измененных файлов с 35 добавлено и 23 удалено
  1. 35 23
      tools/infer_vq.py

+ 35 - 23
tools/infer_vq.py

@@ -2,12 +2,15 @@ import librosa
 import numpy as np
 import numpy as np
 import soundfile as sf
 import soundfile as sf
 import torch
 import torch
+import torch.nn.functional as F
 from hydra import compose, initialize
 from hydra import compose, initialize
 from hydra.utils import instantiate
 from hydra.utils import instantiate
 from lightning import LightningModule
 from lightning import LightningModule
 from loguru import logger
 from loguru import logger
 from omegaconf import OmegaConf
 from omegaconf import OmegaConf
 
 
+from fish_speech.models.vqgan.utils import sequence_mask
+
 # register eval resolver
 # register eval resolver
 OmegaConf.register_new_resolver("eval", eval)
 OmegaConf.register_new_resolver("eval", eval)
 
 
@@ -16,12 +19,11 @@ OmegaConf.register_new_resolver("eval", eval)
 @torch.autocast(device_type="cuda", enabled=True)
 @torch.autocast(device_type="cuda", enabled=True)
 def main():
 def main():
     with initialize(version_base="1.3", config_path="../fish_speech/configs"):
     with initialize(version_base="1.3", config_path="../fish_speech/configs"):
-        cfg = compose(config_name="vq_naive_40hz")
+        cfg = compose(config_name="vqgan")
 
 
     model: LightningModule = instantiate(cfg.model)
     model: LightningModule = instantiate(cfg.model)
     state_dict = torch.load(
     state_dict = torch.load(
-        "results/vq_naive_40hz/checkpoints/step_000675000.ckpt",
-        # "results/vq_naive_25hz/checkpoints/step_000100000.ckpt",
+        "results/vqgan/checkpoints/step_000110000.ckpt",
         map_location=model.device,
         map_location=model.device,
     )["state_dict"]
     )["state_dict"]
     model.load_state_dict(state_dict, strict=True)
     model.load_state_dict(state_dict, strict=True)
@@ -30,7 +32,7 @@ def main():
     logger.info("Restored model from checkpoint")
     logger.info("Restored model from checkpoint")
 
 
     # Load audio
     # Load audio
-    audio = librosa.load("record1.wav", sr=model.sampling_rate, mono=True)[0]
+    audio = librosa.load("0.wav", sr=model.sampling_rate, mono=True)[0]
     audios = torch.from_numpy(audio).to(model.device)[None, None, :]
     audios = torch.from_numpy(audio).to(model.device)[None, None, :]
     logger.info(
     logger.info(
         f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
         f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
@@ -40,30 +42,40 @@ def main():
     audio_lengths = torch.tensor(
     audio_lengths = torch.tensor(
         [audios.shape[2]], device=model.device, dtype=torch.long
         [audios.shape[2]], device=model.device, dtype=torch.long
     )
     )
-    mel_masks, gt_mels, text_features, indices, loss_vq = model.vq_encode(
-        audios, audio_lengths
+
+    features = gt_mels = model.mel_transform(audios, sample_rate=model.sampling_rate)
+
+    if model.downsample is not None:
+        features = model.downsample(features)
+
+    mel_lengths = audio_lengths // model.hop_length
+    feature_lengths = (
+        audio_lengths
+        / model.hop_length
+        / (model.downsample.total_strides if model.downsample is not None else 1)
+    ).long()
+
+    feature_masks = torch.unsqueeze(
+        sequence_mask(feature_lengths, features.shape[2]), 1
+    ).to(gt_mels.dtype)
+    mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
+        gt_mels.dtype
     )
     )
+
+    # vq_features is 50 hz, need to convert to true mel size
+    text_features = model.mel_encoder(features, feature_masks)
+    text_features, indices, _ = model.vq_encoder(text_features, feature_masks)
+
     logger.info(
     logger.info(
         f"VQ Encoded, indices: {indices.shape} equavilent to "
         f"VQ Encoded, indices: {indices.shape} equavilent to "
-        + f"{1/(audios.shape[2] / model.sampling_rate / indices.shape[1]):.2f} Hz"
+        + f"{1/(audios.shape[2] / model.sampling_rate / indices.shape[2]):.2f} Hz"
     )
     )
 
 
-    # VQ Decoder
-    audioa = librosa.load(
-        "data/AiShell/wav/train/S0121/BAC009S0121W0125.wav",
-        sr=model.sampling_rate,
-        mono=True,
-    )[0]
-    audioa = torch.from_numpy(audioa).to(model.device)[None, None, :]
-    mel = model.mel_transform(audioa)
-    mel1_masks = torch.ones([mel.shape[0], 1, mel.shape[2]], device=model.device)
-
-    speaker_features = model.speaker_encoder(mel, mel1_masks)
-
-    speaker_features = model.speaker_encoder(gt_mels, mel_masks)
-    speaker_features = torch.zeros_like(speaker_features)
-    decoded_mels = model.vq_decode(text_features, speaker_features, gt_mels, mel_masks)
-    fake_audios = model.vocoder(decoded_mels)
+    text_features = F.interpolate(text_features, size=gt_mels.shape[2], mode="nearest")
+
+    # Sample mels
+    decoded_mels = model.decoder(text_features, mel_masks)
+    fake_audios = model.generator(decoded_mels)
 
 
     # Save audio
     # Save audio
     fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
     fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)