inference.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from pathlib import Path
  2. import click
  3. import hydra
  4. import librosa
  5. import numpy as np
  6. import soundfile as sf
  7. import torch
  8. from hydra import compose, initialize
  9. from hydra.utils import instantiate
  10. from lightning import LightningModule
  11. from loguru import logger
  12. from omegaconf import OmegaConf
  13. from transformers import AutoTokenizer
  14. from fish_speech.utils.file import AUDIO_EXTENSIONS
  15. # register eval resolver
  16. OmegaConf.register_new_resolver("eval", eval)
  17. def load_model(config_name, checkpoint_path, device="cuda"):
  18. hydra.core.global_hydra.GlobalHydra.instance().clear()
  19. with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
  20. cfg = compose(config_name=config_name)
  21. model: LightningModule = instantiate(cfg.model)
  22. state_dict = torch.load(
  23. checkpoint_path,
  24. map_location=model.device,
  25. )
  26. if "state_dict" in state_dict:
  27. state_dict = state_dict["state_dict"]
  28. model.load_state_dict(state_dict, strict=False)
  29. model.eval()
  30. model.to(device)
  31. logger.info("Restored model from checkpoint")
  32. return model
  33. @torch.no_grad()
  34. @click.command()
  35. @click.option(
  36. "--input-path",
  37. "-i",
  38. default="test.npy",
  39. type=click.Path(exists=True, path_type=Path),
  40. )
  41. @click.option(
  42. "--reference-path",
  43. "-r",
  44. type=click.Path(exists=True, path_type=Path),
  45. default=None,
  46. )
  47. @click.option(
  48. "--text",
  49. type=str,
  50. default="-",
  51. )
  52. @click.option(
  53. "--tokenizer",
  54. type=str,
  55. default="fishaudio/fish-speech-1",
  56. )
  57. @click.option(
  58. "--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
  59. )
  60. @click.option("--config-name", "-cfg", default="vits_decoder_finetune")
  61. @click.option(
  62. "--checkpoint-path",
  63. "-ckpt",
  64. default="checkpoints/vq-gan-group-fsq-2x1024.pth",
  65. )
  66. @click.option(
  67. "--device",
  68. "-d",
  69. default="cuda",
  70. )
  71. def main(
  72. input_path,
  73. reference_path,
  74. text,
  75. tokenizer,
  76. output_path,
  77. config_name,
  78. checkpoint_path,
  79. device,
  80. ):
  81. model = load_model(config_name, checkpoint_path, device=device)
  82. assert input_path.suffix == ".npy", f"Expected .npy file, got {input_path.suffix}"
  83. logger.info(f"Processing precomputed indices from {input_path}")
  84. indices = np.load(input_path)
  85. indices = torch.from_numpy(indices).to(model.device).long()
  86. assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
  87. # Extract reference audio
  88. if reference_path is not None:
  89. assert (
  90. reference_path.suffix in AUDIO_EXTENSIONS
  91. ), f"Expected audio file, got {reference_path.suffix}"
  92. reference_audio, sr = librosa.load(reference_path, sr=model.sampling_rate)
  93. reference_audio = torch.from_numpy(reference_audio).to(model.device).float()
  94. reference_spec = model.spec_transform(reference_audio[None])
  95. reference_embedding = model.generator.encode_ref(
  96. reference_spec,
  97. torch.tensor([reference_spec.shape[-1]], device=model.device),
  98. )
  99. logger.info(
  100. f"Loaded reference audio from {reference_path}, shape: {reference_audio.shape}"
  101. )
  102. else:
  103. reference_embedding = torch.zeros(
  104. 1, model.generator.gin_channels, 1, device=model.device
  105. )
  106. logger.info("No reference audio provided, use zero embedding")
  107. # Extract text
  108. tokenizer = AutoTokenizer.from_pretrained(tokenizer)
  109. encoded_text = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
  110. logger.info(f"Encoded text: {encoded_text.shape}")
  111. # Restore
  112. feature_lengths = torch.tensor([indices.shape[1]], device=model.device)
  113. quantized = model.generator.vq.indicies_to_vq_features(
  114. indices=indices[None], feature_lengths=feature_lengths
  115. )
  116. logger.info(f"Restored VQ features: {quantized.shape}")
  117. # Decode
  118. fake_audios = model.generator.decode(
  119. quantized,
  120. torch.tensor([quantized.shape[-1]], device=model.device),
  121. encoded_text,
  122. torch.tensor([encoded_text.shape[-1]], device=model.device),
  123. ge=reference_embedding,
  124. )
  125. logger.info(
  126. f"Generated audio: {fake_audios.shape}, equivalent to {fake_audios.shape[-1] / model.sampling_rate:.2f} seconds"
  127. )
  128. # Save audio
  129. fake_audio = fake_audios[0, 0].float().cpu().numpy()
  130. sf.write(output_path, fake_audio, model.sampling_rate)
  131. logger.info(f"Saved audio to {output_path}")
  132. if __name__ == "__main__":
  133. main()