run_webui.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import os
  2. from argparse import ArgumentParser
  3. from pathlib import Path
  4. import pyrootutils
  5. import torch
  6. from loguru import logger
  7. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  8. from fish_speech.inference_engine import TTSInferenceEngine
  9. from fish_speech.models.dac.inference import load_model as load_decoder_model
  10. from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
  11. from fish_speech.utils.schema import ServeTTSRequest
  12. from tools.webui import build_app
  13. from tools.webui.inference import get_inference_wrapper
  14. # Make einx happy
  15. os.environ["EINX_FILTER_TRACEBACK"] = "false"
  16. def parse_args():
  17. parser = ArgumentParser()
  18. parser.add_argument(
  19. "--llama-checkpoint-path",
  20. type=Path,
  21. default="checkpoints/openaudio-s1-mini",
  22. )
  23. parser.add_argument(
  24. "--decoder-checkpoint-path",
  25. type=Path,
  26. default="checkpoints/openaudio-s1-mini/codec.pth",
  27. )
  28. parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
  29. parser.add_argument("--device", type=str, default="cuda")
  30. parser.add_argument("--half", action="store_true")
  31. parser.add_argument("--compile", action="store_true")
  32. parser.add_argument("--max-gradio-length", type=int, default=0)
  33. parser.add_argument("--theme", type=str, default="light")
  34. return parser.parse_args()
  35. if __name__ == "__main__":
  36. args = parse_args()
  37. args.precision = torch.half if args.half else torch.bfloat16
  38. # Check if MPS or CUDA is available
  39. if torch.backends.mps.is_available():
  40. args.device = "mps"
  41. logger.info("mps is available, running on mps.")
  42. elif torch.xpu.is_available():
  43. args.device = "xpu"
  44. logger.info("XPU is available, running on XPU.")
  45. elif not torch.cuda.is_available():
  46. logger.info("CUDA is not available, running on CPU.")
  47. args.device = "cpu"
  48. logger.info("Loading Llama model...")
  49. llama_queue = launch_thread_safe_queue(
  50. checkpoint_path=args.llama_checkpoint_path,
  51. device=args.device,
  52. precision=args.precision,
  53. compile=args.compile,
  54. )
  55. logger.info("Loading VQ-GAN model...")
  56. decoder_model = load_decoder_model(
  57. config_name=args.decoder_config_name,
  58. checkpoint_path=args.decoder_checkpoint_path,
  59. device=args.device,
  60. )
  61. logger.info("Decoder model loaded, warming up...")
  62. # Create the inference engine
  63. inference_engine = TTSInferenceEngine(
  64. llama_queue=llama_queue,
  65. decoder_model=decoder_model,
  66. compile=args.compile,
  67. precision=args.precision,
  68. )
  69. # Dry run to check if the model is loaded correctly and avoid the first-time latency
  70. list(
  71. inference_engine.inference(
  72. ServeTTSRequest(
  73. text="Hello world.",
  74. references=[],
  75. reference_id=None,
  76. max_new_tokens=1024,
  77. chunk_length=200,
  78. top_p=0.7,
  79. repetition_penalty=1.5,
  80. temperature=0.7,
  81. format="wav",
  82. )
  83. )
  84. )
  85. logger.info("Warming up done, launching the web UI...")
  86. # Get the inference function with the immutable arguments
  87. inference_fct = get_inference_wrapper(inference_engine)
  88. app = build_app(inference_fct, args.theme)
  89. app.launch(show_api=True)