| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- import json
- import multiprocessing
- import os
- import re
- from argparse import Namespace
- from threading import Lock
- import pyrootutils
- import uvicorn
- from kui.asgi import (
- Depends,
- FactoryClass,
- HTTPException,
- HttpRoute,
- Kui,
- OpenAPI,
- Routes,
- )
- from kui.cors import CORSConfig
- from kui.openapi.specification import Info
- from kui.security import bearer_auth
- from loguru import logger
- from typing_extensions import Annotated
- pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
- from tools.server.api_utils import MsgPackRequest, parse_args
- from tools.server.exception_handler import ExceptionHandler
- from tools.server.model_manager import ModelManager
- from tools.server.views import routes
- ENV_ARGS_KEY = "FISH_API_SERVER_ARGS"
- class API(ExceptionHandler):
- def __init__(self, args: Namespace | None = None):
- self.args = args or parse_args()
- def api_auth(endpoint):
- async def verify(token: Annotated[str, Depends(bearer_auth)]):
- if token != self.args.api_key:
- raise HTTPException(401, None, "Invalid token")
- return await endpoint()
- async def passthrough():
- return await endpoint()
- if self.args.api_key is not None:
- return verify
- else:
- return passthrough
- self.routes = Routes(
- routes, # keep existing routes
- http_middlewares=[api_auth], # apply api_auth middleware
- )
- # OpenAPIの設定
- # self.openapi = OpenAPI(
- # Info(
- # {
- # "title": "Fish Speech API",
- # "version": "1.5.0",
- # }
- # ),
- # ).routes
- # Initialize the app
- self.app = Kui(
- routes=self.routes,
- exception_handlers={
- HTTPException: self.http_exception_handler,
- Exception: self.other_exception_handler,
- },
- factory_class=FactoryClass(http=MsgPackRequest),
- cors_config=CORSConfig(),
- )
- # Add the state variables
- self.app.state.lock = Lock()
- self.app.state.device = self.args.device
- self.app.state.max_text_length = self.args.max_text_length
- # Associate the app with the model manager
- self.app.on_startup(self.initialize_app)
- async def initialize_app(self, app: Kui):
- # Make the ModelManager available to the views
- app.state.model_manager = ModelManager(
- mode=self.args.mode,
- device=self.args.device,
- half=self.args.half,
- compile=self.args.compile,
- llama_checkpoint_path=self.args.llama_checkpoint_path,
- decoder_checkpoint_path=self.args.decoder_checkpoint_path,
- decoder_config_name=self.args.decoder_config_name,
- )
- logger.info(f"self.args.mode={self.args.mode}")
- logger.info(f"self.args.device={self.args.device}")
- logger.info(f"self.args.half={self.args.half}")
- logger.info(f"self.args.compile={self.args.compile}")
- logger.info(f"self.args.llama_checkpoint_path={self.args.llama_checkpoint_path}")
- logger.info(f"self.args.decoder_checkpoint_path={self.args.decoder_checkpoint_path}")
- logger.info(f"self.args.decoder_config_name={self.args.decoder_config_name}")
- logger.info(f"Startup done, listening server at http://{self.args.listen}")
- def create_app():
- args_env = os.environ.get(ENV_ARGS_KEY)
- args = None
- if args_env:
- try:
- args = Namespace(**json.loads(args_env))
- except Exception as exc:
- logger.warning(f"Failed to load args from {ENV_ARGS_KEY}: {exc}")
- return API(args=args).app
- # Each worker process created by Uvicorn has its own memory space,
- # meaning that models and variables are not shared between processes.
- # Therefore, any variables (like `llama_queue` or `decoder_model`)
- # will not be shared across workers.
- # Multi-threading for deep learning can cause issues, such as inconsistent
- # outputs if multiple threads access the same buffers simultaneously.
- # Instead, it's better to use multiprocessing or independent models per thread.
- if __name__ == "__main__":
- multiprocessing.set_start_method("spawn", force=True)
- args = parse_args()
- os.environ[ENV_ARGS_KEY] = json.dumps(vars(args))
- # IPv6 address format is [xxxx:xxxx::xxxx]:port
- match = re.search(r"\[([^\]]+)\]:(\d+)$", args.listen)
- if match:
- host, port = match.groups() # IPv6
- else:
- host, port = args.listen.split(":") # IPv4
- uvicorn.run(
- "tools.api_server:create_app",
- host=host,
- port=int(port),
- workers=args.workers,
- log_level="info",
- factory=True,
- )
|