infer_vq.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import librosa
  2. import numpy as np
  3. import soundfile as sf
  4. import torch
  5. import torch.nn.functional as F
  6. from hydra import compose, initialize
  7. from hydra.utils import instantiate
  8. from lightning import LightningModule
  9. from loguru import logger
  10. from omegaconf import OmegaConf
  11. from fish_speech.models.vqgan.utils import sequence_mask
  12. # register eval resolver
  13. OmegaConf.register_new_resolver("eval", eval)
  14. @torch.no_grad()
  15. @torch.autocast(device_type="cuda", enabled=True)
  16. def main():
  17. with initialize(version_base="1.3", config_path="../fish_speech/configs"):
  18. cfg = compose(config_name="vqgan")
  19. model: LightningModule = instantiate(cfg.model)
  20. state_dict = torch.load(
  21. "results/vqgan/checkpoints/step_000110000.ckpt",
  22. map_location=model.device,
  23. )["state_dict"]
  24. model.load_state_dict(state_dict, strict=True)
  25. model.eval()
  26. model.cuda()
  27. logger.info("Restored model from checkpoint")
  28. # Load audio
  29. audio = librosa.load("0.wav", sr=model.sampling_rate, mono=True)[0]
  30. audios = torch.from_numpy(audio).to(model.device)[None, None, :]
  31. logger.info(
  32. f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
  33. )
  34. # VQ Encoder
  35. audio_lengths = torch.tensor(
  36. [audios.shape[2]], device=model.device, dtype=torch.long
  37. )
  38. features = gt_mels = model.mel_transform(audios, sample_rate=model.sampling_rate)
  39. if model.downsample is not None:
  40. features = model.downsample(features)
  41. mel_lengths = audio_lengths // model.hop_length
  42. feature_lengths = (
  43. audio_lengths
  44. / model.hop_length
  45. / (model.downsample.total_strides if model.downsample is not None else 1)
  46. ).long()
  47. feature_masks = torch.unsqueeze(
  48. sequence_mask(feature_lengths, features.shape[2]), 1
  49. ).to(gt_mels.dtype)
  50. mel_masks = torch.unsqueeze(sequence_mask(mel_lengths, gt_mels.shape[2]), 1).to(
  51. gt_mels.dtype
  52. )
  53. # vq_features is 50 hz, need to convert to true mel size
  54. text_features = model.mel_encoder(features, feature_masks)
  55. text_features, indices, _ = model.vq_encoder(text_features, feature_masks)
  56. logger.info(
  57. f"VQ Encoded, indices: {indices.shape} equivalent to "
  58. + f"{1/(audios.shape[2] / model.sampling_rate / indices.shape[2]):.2f} Hz"
  59. )
  60. text_features = F.interpolate(text_features, size=gt_mels.shape[2], mode="nearest")
  61. # Sample mels
  62. decoded_mels = model.decoder(text_features, mel_masks)
  63. fake_audios = model.generator(decoded_mels)
  64. # Save audio
  65. fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
  66. sf.write("fake.wav", fake_audio, model.sampling_rate)
  67. if __name__ == "__main__":
  68. main()