inference.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from pathlib import Path
  2. import click
  3. import hydra
  4. import numpy as np
  5. import soundfile as sf
  6. import torch
  7. import torchaudio
  8. from hydra import compose, initialize
  9. from hydra.utils import instantiate
  10. from lightning import LightningModule
  11. from loguru import logger
  12. from omegaconf import OmegaConf
  13. from fish_speech.utils.file import AUDIO_EXTENSIONS
  14. # register eval resolver
  15. OmegaConf.register_new_resolver("eval", eval)
  16. def load_model(config_name, checkpoint_path, device="cuda"):
  17. hydra.core.global_hydra.GlobalHydra.instance().clear()
  18. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  19. cfg = compose(config_name=config_name)
  20. model: LightningModule = instantiate(cfg.model)
  21. state_dict = torch.load(
  22. checkpoint_path,
  23. map_location=model.device,
  24. )
  25. if "state_dict" in state_dict:
  26. state_dict = state_dict["state_dict"]
  27. model.load_state_dict(state_dict, strict=False)
  28. model.eval()
  29. model.to(device)
  30. logger.info("Restored model from checkpoint")
  31. return model
  32. @torch.no_grad()
  33. @click.command()
  34. @click.option(
  35. "--input-path",
  36. "-i",
  37. default="test.wav",
  38. type=click.Path(exists=True, path_type=Path),
  39. )
  40. @click.option(
  41. "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
  42. )
  43. @click.option("--config-name", "-cfg", default="vqgan_pretrain")
  44. @click.option(
  45. "--checkpoint-path",
  46. "-ckpt",
  47. default="checkpoints/vq-gan-group-fsq-2x1024.pth",
  48. )
  49. @click.option(
  50. "--device",
  51. "-d",
  52. default="cuda",
  53. )
  54. def main(input_path, output_path, config_name, checkpoint_path, device):
  55. model = load_model(config_name, checkpoint_path, device=device)
  56. if input_path.suffix in AUDIO_EXTENSIONS:
  57. logger.info(f"Processing in-place reconstruction of {input_path}")
  58. # Load audio
  59. audio, sr = torchaudio.load(str(input_path))
  60. if audio.shape[0] > 1:
  61. audio = audio.mean(0, keepdim=True)
  62. audio = torchaudio.functional.resample(audio, sr, model.sampling_rate)
  63. audios = audio[None].to(model.device)
  64. logger.info(
  65. f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
  66. )
  67. # VQ Encoder
  68. audio_lengths = torch.tensor(
  69. [audios.shape[2]], device=model.device, dtype=torch.long
  70. )
  71. indices = model.encode(audios, audio_lengths)[0][0]
  72. logger.info(f"Generated indices of shape {indices.shape}")
  73. # Save indices
  74. np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
  75. elif input_path.suffix == ".npy":
  76. logger.info(f"Processing precomputed indices from {input_path}")
  77. indices = np.load(input_path)
  78. indices = torch.from_numpy(indices).to(model.device).long()
  79. assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
  80. else:
  81. raise ValueError(f"Unknown input type: {input_path}")
  82. # Restore
  83. feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
  84. fake_audios = model.decode(
  85. indices=indices[None], feature_lengths=feature_lengths, return_audios=True
  86. )
  87. audio_time = fake_audios.shape[-1] / model.sampling_rate
  88. logger.info(
  89. f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
  90. )
  91. # Save audio
  92. fake_audio = fake_audios[0, 0].float().cpu().numpy()
  93. sf.write(output_path, fake_audio, model.sampling_rate)
  94. logger.info(f"Saved audio to {output_path}")
  95. if __name__ == "__main__":
  96. main()