inference.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from pathlib import Path
  2. import click
  3. import hydra
  4. import numpy as np
  5. import soundfile as sf
  6. import torch
  7. import torchaudio
  8. from hydra import compose, initialize
  9. from hydra.utils import instantiate
  10. from loguru import logger
  11. from omegaconf import OmegaConf
  12. from tools.file import AUDIO_EXTENSIONS
  13. # register eval resolver
  14. OmegaConf.register_new_resolver("eval", eval)
  15. def load_model(config_name, checkpoint_path, device="cuda"):
  16. hydra.core.global_hydra.GlobalHydra.instance().clear()
  17. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  18. cfg = compose(config_name=config_name)
  19. model = instantiate(cfg)
  20. state_dict = torch.load(
  21. checkpoint_path, map_location=device, mmap=True, weights_only=True
  22. )
  23. if "state_dict" in state_dict:
  24. state_dict = state_dict["state_dict"]
  25. if any("generator" in k for k in state_dict):
  26. state_dict = {
  27. k.replace("generator.", ""): v
  28. for k, v in state_dict.items()
  29. if "generator." in k
  30. }
  31. result = model.load_state_dict(state_dict, strict=False, assign=True)
  32. model.eval()
  33. model.to(device)
  34. logger.info(f"Loaded model: {result}")
  35. return model
  36. @torch.no_grad()
  37. @click.command()
  38. @click.option(
  39. "--input-path",
  40. "-i",
  41. default="test.wav",
  42. type=click.Path(exists=True, path_type=Path),
  43. )
  44. @click.option(
  45. "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
  46. )
  47. @click.option("--config-name", default="firefly_gan_vq")
  48. @click.option(
  49. "--checkpoint-path",
  50. default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
  51. )
  52. @click.option(
  53. "--device",
  54. "-d",
  55. default="cuda",
  56. )
  57. def main(input_path, output_path, config_name, checkpoint_path, device):
  58. model = load_model(config_name, checkpoint_path, device=device)
  59. if input_path.suffix in AUDIO_EXTENSIONS:
  60. logger.info(f"Processing in-place reconstruction of {input_path}")
  61. # Load audio
  62. audio, sr = torchaudio.load(str(input_path))
  63. if audio.shape[0] > 1:
  64. audio = audio.mean(0, keepdim=True)
  65. audio = torchaudio.functional.resample(
  66. audio, sr, model.spec_transform.sample_rate
  67. )
  68. audios = audio[None].to(device)
  69. logger.info(
  70. f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
  71. )
  72. # VQ Encoder
  73. audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
  74. indices = model.encode(audios, audio_lengths)[0][0]
  75. logger.info(f"Generated indices of shape {indices.shape}")
  76. # Save indices
  77. np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
  78. elif input_path.suffix == ".npy":
  79. logger.info(f"Processing precomputed indices from {input_path}")
  80. indices = np.load(input_path)
  81. indices = torch.from_numpy(indices).to(device).long()
  82. assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
  83. else:
  84. raise ValueError(f"Unknown input type: {input_path}")
  85. # Restore
  86. feature_lengths = torch.tensor([indices.shape[1]], device=device)
  87. fake_audios, _ = model.decode(
  88. indices=indices[None], feature_lengths=feature_lengths
  89. )
  90. audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
  91. logger.info(
  92. f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
  93. )
  94. # Save audio
  95. fake_audio = fake_audios[0, 0].float().cpu().numpy()
  96. sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
  97. logger.info(f"Saved audio to {output_path}")
  98. if __name__ == "__main__":
  99. main()