api_server.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import json
  2. import multiprocessing
  3. import os
  4. import re
  5. from argparse import Namespace
  6. from threading import Lock
  7. import pyrootutils
  8. import uvicorn
  9. from kui.asgi import (
  10. Depends,
  11. FactoryClass,
  12. HTTPException,
  13. HttpRoute,
  14. Kui,
  15. OpenAPI,
  16. Routes,
  17. )
  18. from kui.cors import CORSConfig
  19. from kui.openapi.specification import Info
  20. from kui.security import bearer_auth
  21. from loguru import logger
  22. from typing_extensions import Annotated
  23. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  24. from tools.server.api_utils import MsgPackRequest, parse_args
  25. from tools.server.exception_handler import ExceptionHandler
  26. from tools.server.model_manager import ModelManager
  27. from tools.server.views import routes
  28. ENV_ARGS_KEY = "FISH_API_SERVER_ARGS"
  29. class API(ExceptionHandler):
  30. def __init__(self, args: Namespace | None = None):
  31. self.args = args or parse_args()
  32. def api_auth(endpoint):
  33. async def verify(token: Annotated[str, Depends(bearer_auth)]):
  34. if token != self.args.api_key:
  35. raise HTTPException(401, None, "Invalid token")
  36. return await endpoint()
  37. async def passthrough():
  38. return await endpoint()
  39. if self.args.api_key is not None:
  40. return verify
  41. else:
  42. return passthrough
  43. self.routes = Routes(
  44. routes, # keep existing routes
  45. http_middlewares=[api_auth], # apply api_auth middleware
  46. )
  47. # OpenAPIの設定
  48. self.openapi = OpenAPI(
  49. Info(
  50. {
  51. "title": "Fish Speech API",
  52. "version": "1.5.0",
  53. }
  54. ),
  55. ).routes
  56. # Initialize the app
  57. self.app = Kui(
  58. routes=self.routes + self.openapi[1:], # Remove the default route
  59. exception_handlers={
  60. HTTPException: self.http_exception_handler,
  61. Exception: self.other_exception_handler,
  62. },
  63. factory_class=FactoryClass(http=MsgPackRequest),
  64. cors_config=CORSConfig(),
  65. )
  66. # Add the state variables
  67. self.app.state.lock = Lock()
  68. self.app.state.device = self.args.device
  69. self.app.state.max_text_length = self.args.max_text_length
  70. # Associate the app with the model manager
  71. self.app.on_startup(self.initialize_app)
  72. async def initialize_app(self, app: Kui):
  73. # Make the ModelManager available to the views
  74. app.state.model_manager = ModelManager(
  75. mode=self.args.mode,
  76. device=self.args.device,
  77. half=self.args.half,
  78. compile=self.args.compile,
  79. llama_checkpoint_path=self.args.llama_checkpoint_path,
  80. decoder_checkpoint_path=self.args.decoder_checkpoint_path,
  81. decoder_config_name=self.args.decoder_config_name,
  82. )
  83. logger.info(f"self.args.mode={self.args.mode}")
  84. logger.info(f"self.args.device={self.args.device}")
  85. logger.info(f"self.args.half={self.args.half}")
  86. logger.info(f"self.args.compile={self.args.compile}")
  87. logger.info(f"self.args.llama_checkpoint_path={self.args.llama_checkpoint_path}")
  88. logger.info(f"self.args.decoder_checkpoint_path={self.args.decoder_checkpoint_path}")
  89. logger.info(f"self.args.decoder_config_name={self.args.decoder_config_name}")
  90. logger.info(f"Startup done, listening server at http://{self.args.listen}")
  91. def create_app():
  92. args_env = os.environ.get(ENV_ARGS_KEY)
  93. args = None
  94. if args_env:
  95. try:
  96. args = Namespace(**json.loads(args_env))
  97. except Exception as exc:
  98. logger.warning(f"Failed to load args from {ENV_ARGS_KEY}: {exc}")
  99. return API(args=args).app
  100. # Each worker process created by Uvicorn has its own memory space,
  101. # meaning that models and variables are not shared between processes.
  102. # Therefore, any variables (like `llama_queue` or `decoder_model`)
  103. # will not be shared across workers.
  104. # Multi-threading for deep learning can cause issues, such as inconsistent
  105. # outputs if multiple threads access the same buffers simultaneously.
  106. # Instead, it's better to use multiprocessing or independent models per thread.
  107. if __name__ == "__main__":
  108. multiprocessing.set_start_method("spawn", force=True)
  109. args = parse_args()
  110. os.environ[ENV_ARGS_KEY] = json.dumps(vars(args))
  111. # IPv6 address format is [xxxx:xxxx::xxxx]:port
  112. match = re.search(r"\[([^\]]+)\]:(\d+)$", args.listen)
  113. if match:
  114. host, port = match.groups() # IPv6
  115. else:
  116. host, port = args.listen.split(":") # IPv4
  117. uvicorn.run(
  118. "tools.api_server:create_app",
  119. host=host,
  120. port=int(port),
  121. workers=args.workers,
  122. log_level="info",
  123. factory=True,
  124. )