api_server.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import re
  2. from threading import Lock
  3. import pyrootutils
  4. import uvicorn
  5. from kui.asgi import (
  6. Depends,
  7. FactoryClass,
  8. HTTPException,
  9. HttpRoute,
  10. Kui,
  11. OpenAPI,
  12. Routes,
  13. )
  14. from kui.cors import CORSConfig
  15. from kui.openapi.specification import Info
  16. from kui.security import bearer_auth
  17. from loguru import logger
  18. from typing_extensions import Annotated
  19. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  20. from tools.server.api_utils import MsgPackRequest, parse_args
  21. from tools.server.exception_handler import ExceptionHandler
  22. from tools.server.model_manager import ModelManager
  23. from tools.server.views import routes
  24. class API(ExceptionHandler):
  25. def __init__(self):
  26. self.args = parse_args()
  27. def api_auth(endpoint):
  28. async def verify(token: Annotated[str, Depends(bearer_auth)]):
  29. if token != self.args.api_key:
  30. raise HTTPException(401, None, "Invalid token")
  31. return await endpoint()
  32. async def passthrough():
  33. return await endpoint()
  34. if self.args.api_key is not None:
  35. return verify
  36. else:
  37. return passthrough
  38. self.routes = Routes(
  39. routes, # keep existing routes
  40. http_middlewares=[api_auth], # apply api_auth middleware
  41. )
  42. # OpenAPIの設定
  43. self.openapi = OpenAPI(
  44. Info(
  45. {
  46. "title": "Fish Speech API",
  47. "version": "1.5.0",
  48. }
  49. ),
  50. ).routes
  51. # Initialize the app
  52. self.app = Kui(
  53. routes=self.routes + self.openapi[1:], # Remove the default route
  54. exception_handlers={
  55. HTTPException: self.http_exception_handler,
  56. Exception: self.other_exception_handler,
  57. },
  58. factory_class=FactoryClass(http=MsgPackRequest),
  59. cors_config=CORSConfig(),
  60. )
  61. # Add the state variables
  62. self.app.state.lock = Lock()
  63. self.app.state.device = self.args.device
  64. self.app.state.max_text_length = self.args.max_text_length
  65. # Associate the app with the model manager
  66. self.app.on_startup(self.initialize_app)
  67. async def initialize_app(self, app: Kui):
  68. # Make the ModelManager available to the views
  69. app.state.model_manager = ModelManager(
  70. mode=self.args.mode,
  71. device=self.args.device,
  72. half=self.args.half,
  73. compile=self.args.compile,
  74. asr_enabled=self.args.load_asr_model,
  75. llama_checkpoint_path=self.args.llama_checkpoint_path,
  76. decoder_checkpoint_path=self.args.decoder_checkpoint_path,
  77. decoder_config_name=self.args.decoder_config_name,
  78. )
  79. logger.info(f"Startup done, listening server at http://{self.args.listen}")
  80. # Each worker process created by Uvicorn has its own memory space,
  81. # meaning that models and variables are not shared between processes.
  82. # Therefore, any variables (like `llama_queue` or `decoder_model`)
  83. # will not be shared across workers.
  84. # Multi-threading for deep learning can cause issues, such as inconsistent
  85. # outputs if multiple threads access the same buffers simultaneously.
  86. # Instead, it's better to use multiprocessing or independent models per thread.
  87. if __name__ == "__main__":
  88. api = API()
  89. # IPv6 address format is [xxxx:xxxx::xxxx]:port
  90. match = re.search(r"\[([^\]]+)\]:(\d+)$", api.args.listen)
  91. if match:
  92. host, port = match.groups() # IPv6
  93. else:
  94. host, port = api.args.listen.split(":") # IPv4
  95. uvicorn.run(
  96. api.app,
  97. host=host,
  98. port=int(port),
  99. workers=api.args.workers,
  100. log_level="info",
  101. )