| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- import os
- from argparse import ArgumentParser
- from pathlib import Path
- import pyrootutils
- import torch
- from loguru import logger
- pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
- from fish_speech.inference_engine import TTSInferenceEngine
- from fish_speech.models.dac.inference import load_model as load_decoder_model
- from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
- from fish_speech.utils.schema import ServeTTSRequest
- from tools.webui import build_app
- from tools.webui.inference import get_inference_wrapper
- # Make einx happy
- os.environ["EINX_FILTER_TRACEBACK"] = "false"
- def parse_args():
- parser = ArgumentParser()
- parser.add_argument(
- "--llama-checkpoint-path",
- type=Path,
- default="checkpoints/openaudio-s1-mini",
- )
- parser.add_argument(
- "--decoder-checkpoint-path",
- type=Path,
- default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
- )
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
- parser.add_argument("--device", type=str, default="cuda")
- parser.add_argument("--half", action="store_true")
- parser.add_argument("--compile", action="store_true")
- parser.add_argument("--max-gradio-length", type=int, default=0)
- parser.add_argument("--theme", type=str, default="light")
- return parser.parse_args()
- if __name__ == "__main__":
- args = parse_args()
- args.precision = torch.half if args.half else torch.bfloat16
- # Check if MPS or CUDA is available
- if torch.backends.mps.is_available():
- args.device = "mps"
- logger.info("mps is available, running on mps.")
- elif not torch.cuda.is_available():
- logger.info("CUDA is not available, running on CPU.")
- args.device = "cpu"
- logger.info("Loading Llama model...")
- llama_queue = launch_thread_safe_queue(
- checkpoint_path=args.llama_checkpoint_path,
- device=args.device,
- precision=args.precision,
- compile=args.compile,
- )
- logger.info("Loading VQ-GAN model...")
- decoder_model = load_decoder_model(
- config_name=args.decoder_config_name,
- checkpoint_path=args.decoder_checkpoint_path,
- device=args.device,
- )
- logger.info("Decoder model loaded, warming up...")
- # Create the inference engine
- inference_engine = TTSInferenceEngine(
- llama_queue=llama_queue,
- decoder_model=decoder_model,
- compile=args.compile,
- precision=args.precision,
- )
- # Dry run to check if the model is loaded correctly and avoid the first-time latency
- list(
- inference_engine.inference(
- ServeTTSRequest(
- text="Hello world.",
- references=[],
- reference_id=None,
- max_new_tokens=1024,
- chunk_length=200,
- top_p=0.7,
- repetition_penalty=1.5,
- temperature=0.7,
- format="wav",
- )
- )
- )
- logger.info("Warming up done, launching the web UI...")
- # Get the inference function with the immutable arguments
- inference_fct = get_inference_wrapper(inference_engine)
- app = build_app(inference_fct, args.theme)
- app.launch(show_api=True)
|