inference.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from pathlib import Path
  2. import click
  3. import librosa
  4. import numpy as np
  5. import soundfile as sf
  6. import torch
  7. import torch.nn.functional as F
  8. from einops import rearrange
  9. from hydra import compose, initialize
  10. from hydra.utils import instantiate
  11. from lightning import LightningModule
  12. from loguru import logger
  13. from omegaconf import OmegaConf
  14. from fish_speech.models.vqgan.utils import sequence_mask
  15. from fish_speech.utils.file import AUDIO_EXTENSIONS
  16. # register eval resolver
  17. OmegaConf.register_new_resolver("eval", eval)
  18. @torch.no_grad()
  19. @torch.autocast(device_type="cuda", enabled=True)
  20. @click.command()
  21. @click.option(
  22. "--input-path",
  23. "-i",
  24. default="test.wav",
  25. type=click.Path(exists=True, path_type=Path),
  26. )
  27. @click.option(
  28. "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
  29. )
  30. @click.option("--config-name", "-cfg", default="vqgan_pretrain")
  31. @click.option(
  32. "--checkpoint-path",
  33. "-ckpt",
  34. default="checkpoints/vq-gan-group-fsq-8x1024-wn-20x768-30kh.pth",
  35. )
  36. def main(input_path, output_path, config_name, checkpoint_path):
  37. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  38. cfg = compose(config_name=config_name)
  39. model: LightningModule = instantiate(cfg.model)
  40. state_dict = torch.load(
  41. checkpoint_path,
  42. map_location=model.device,
  43. )
  44. if "state_dict" in state_dict:
  45. state_dict = state_dict["state_dict"]
  46. model.load_state_dict(state_dict, strict=False)
  47. model.eval()
  48. model.cuda()
  49. logger.info("Restored model from checkpoint")
  50. if input_path.suffix in AUDIO_EXTENSIONS:
  51. logger.info(f"Processing in-place reconstruction of {input_path}")
  52. # Load audio
  53. audio, _ = librosa.load(
  54. input_path,
  55. sr=model.sampling_rate,
  56. mono=True,
  57. )
  58. audios = torch.from_numpy(audio).to(model.device)[None, None, :]
  59. logger.info(
  60. f"Loaded audio with {audios.shape[2] / model.sampling_rate:.2f} seconds"
  61. )
  62. # VQ Encoder
  63. audio_lengths = torch.tensor(
  64. [audios.shape[2]], device=model.device, dtype=torch.long
  65. )
  66. indices = model.encode(audios, audio_lengths)[0][0]
  67. logger.info(f"Generated indices of shape {indices.shape}")
  68. # Save indices
  69. np.save(output_path.with_suffix(".npy"), indices.cpu().numpy())
  70. elif input_path.suffix == ".npy":
  71. logger.info(f"Processing precomputed indices from {input_path}")
  72. indices = np.load(input_path)
  73. indices = torch.from_numpy(indices).to(model.device).long()
  74. assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
  75. else:
  76. raise ValueError(f"Unknown input type: {input_path}")
  77. # random destroy 10% of indices
  78. # mask = torch.rand_like(indices, dtype=torch.float) > 0.9
  79. # indices[mask] = torch.randint(0, 1000, mask.shape, device=indices.device, dtype=indices.dtype)[mask]
  80. # Restore
  81. feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
  82. fake_audios = model.decode(
  83. indices=indices[None], feature_lengths=feature_lengths, return_audios=True
  84. )
  85. audio_time = fake_audios.shape[-1] / model.sampling_rate
  86. logger.info(
  87. f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
  88. )
  89. # Save audio
  90. fake_audio = fake_audios[0, 0].cpu().numpy().astype(np.float32)
  91. sf.write("fake.wav", fake_audio, model.sampling_rate)
  92. logger.info(f"Saved audio to {output_path}")
  93. if __name__ == "__main__":
  94. main()