infer_vq.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import librosa
  2. import numpy as np
  3. import soundfile as sf
  4. import torch
  5. from hydra import compose, initialize
  6. from hydra.utils import instantiate
  7. from lightning import LightningModule
  8. from loguru import logger
  9. from omegaconf import OmegaConf
  10. # register eval resolver
  11. OmegaConf.register_new_resolver("eval", eval)
  12. @torch.no_grad()
  13. @torch.autocast(device_type="cuda", enabled=True)
  14. def main():
  15. with initialize(version_base="1.3", config_path="../fish_speech/configs"):
  16. cfg = compose(config_name="vq_naive_40hz")
  17. model: LightningModule = instantiate(cfg.model)
  18. state_dict = torch.load(
  19. "results/vq_naive_40hz/checkpoints/step_000675000.ckpt",
  20. # "results/vq_naive_25hz/checkpoints/step_000100000.ckpt",
  21. map_location=model.device,
  22. )["state_dict"]
  23. model.load_state_dict(state_dict, strict=True)
  24. model.eval()
  25. model.cuda()
  26. logger.info("Restored model from checkpoint")
  27. # Load audio
  28. audio = librosa.load("record1.wav", sr=model.sampling_rate, mono=True)[0]
  29. audios = torch.from_numpy(audio).to(model.device)[None, None, :]
  30. logger.info(
  31. f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
  32. )
  33. # VQ Encoder
  34. audio_lengths = torch.tensor(
  35. [audios.shape[2]], device=model.device, dtype=torch.long
  36. )
  37. mel_masks, gt_mels, text_features, indices, loss_vq = model.vq_encode(
  38. audios, audio_lengths
  39. )
  40. logger.info(
  41. f"VQ Encoded, indices: {indices.shape} equavilent to "
  42. + f"{1/(audios.shape[2] / model.sampling_rate / indices.shape[1]):.2f} Hz"
  43. )
  44. # VQ Decoder
  45. audioa = librosa.load(
  46. "data/AiShell/wav/train/S0121/BAC009S0121W0125.wav",
  47. sr=model.sampling_rate,
  48. mono=True,
  49. )[0]
  50. audioa = torch.from_numpy(audioa).to(model.device)[None, None, :]
  51. mel = model.mel_transform(audioa)
  52. mel1_masks = torch.ones([mel.shape[0], 1, mel.shape[2]], device=model.device)
  53. speaker_features = model.speaker_encoder(mel, mel1_masks)
  54. speaker_features = model.speaker_encoder(gt_mels, mel_masks)
  55. speaker_features = torch.zeros_like(speaker_features)
  56. decoded_mels = model.vq_decode(text_features, speaker_features, gt_mels, mel_masks)
  57. fake_audios = model.vocoder(decoded_mels)
  58. # Save audio
  59. fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
  60. sf.write("fake.wav", fake_audio, model.sampling_rate)
  61. if __name__ == "__main__":
  62. main()