inference.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from pathlib import Path
  2. import click
  3. import librosa
  4. import numpy as np
  5. import soundfile as sf
  6. import torch
  7. from hydra import compose, initialize
  8. from hydra.utils import instantiate
  9. from lightning import LightningModule
  10. from loguru import logger
  11. from omegaconf import OmegaConf
  12. from fish_speech.utils.file import AUDIO_EXTENSIONS
  13. # register eval resolver
  14. OmegaConf.register_new_resolver("eval", eval)
  15. def load_model(config_name, checkpoint_path, device="cuda"):
  16. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  17. cfg = compose(config_name=config_name)
  18. model: LightningModule = instantiate(cfg.model)
  19. state_dict = torch.load(
  20. checkpoint_path,
  21. map_location=model.device,
  22. )
  23. if "state_dict" in state_dict:
  24. state_dict = state_dict["state_dict"]
  25. model.load_state_dict(state_dict, strict=False)
  26. model.eval()
  27. model.to(device)
  28. logger.info("Restored model from checkpoint")
  29. return model
  30. @torch.no_grad()
  31. @click.command()
  32. @click.option(
  33. "--input-path",
  34. "-i",
  35. default="test.wav",
  36. type=click.Path(exists=True, path_type=Path),
  37. )
  38. @click.option(
  39. "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
  40. )
  41. @click.option("--config-name", "-cfg", default="vqgan_pretrain")
  42. @click.option(
  43. "--checkpoint-path",
  44. "-ckpt",
  45. default="checkpoints/vq-gan-group-fsq-2x1024.pth",
  46. )
  47. @click.option(
  48. "--device",
  49. "-d",
  50. default="cuda",
  51. )
  52. def main(input_path, output_path, config_name, checkpoint_path, device):
  53. model = load_model(config_name, checkpoint_path, device=device)
  54. if input_path.suffix in AUDIO_EXTENSIONS:
  55. logger.info(f"Processing in-place reconstruction of {input_path}")
  56. # Load audio
  57. audio, _ = librosa.load(
  58. input_path,
  59. sr=model.sampling_rate,
  60. mono=True,
  61. )
  62. audios = torch.from_numpy(audio).to(model.device)[None, None, :]
  63. logger.info(
  64. f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
  65. )
  66. # VQ Encoder
  67. audio_lengths = torch.tensor(
  68. [audios.shape[2]], device=model.device, dtype=torch.long
  69. )
  70. indices = model.encode(audios, audio_lengths)[0][0]
  71. logger.info(f"Generated indices of shape {indices.shape}")
  72. # Save indices
  73. np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
  74. elif input_path.suffix == ".npy":
  75. logger.info(f"Processing precomputed indices from {input_path}")
  76. indices = np.load(input_path)
  77. indices = torch.from_numpy(indices).to(model.device).long()
  78. assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
  79. else:
  80. raise ValueError(f"Unknown input type: {input_path}")
  81. # Restore
  82. feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
  83. fake_audios = model.decode(
  84. indices=indices[None], feature_lengths=feature_lengths, return_audios=True
  85. )
  86. audio_time = fake_audios.shape[-1] / model.sampling_rate
  87. logger.info(
  88. f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
  89. )
  90. # Save audio
  91. fake_audio = fake_audios[0, 0].float().cpu().numpy()
  92. sf.write(output_path, fake_audio, model.sampling_rate)
  93. logger.info(f"Saved audio to {output_path}")
  94. if __name__ == "__main__":
  95. main()