api_server.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from threading import Lock
  2. import pyrootutils
  3. import uvicorn
  4. from kui.asgi import (
  5. Depends,
  6. FactoryClass,
  7. HTTPException,
  8. HttpRoute,
  9. Kui,
  10. OpenAPI,
  11. Routes,
  12. )
  13. from kui.security import bearer_auth
  14. from loguru import logger
  15. from typing_extensions import Annotated
  16. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  17. from tools.server.api_utils import MsgPackRequest, parse_args
  18. from tools.server.exception_handler import ExceptionHandler
  19. from tools.server.model_manager import ModelManager
  20. from tools.server.views import (
  21. ASRView,
  22. ChatView,
  23. HealthView,
  24. TTSView,
  25. VQGANDecodeView,
  26. VQGANEncodeView,
  27. )
  28. class API(ExceptionHandler):
  29. def __init__(self):
  30. self.args = parse_args()
  31. self.routes = [
  32. ("/v1/health", HealthView),
  33. ("/v1/vqgan/encode", VQGANEncodeView),
  34. ("/v1/vqgan/decode", VQGANDecodeView),
  35. ("/v1/asr", ASRView),
  36. ("/v1/tts", TTSView),
  37. ("/v1/chat", ChatView),
  38. ]
  39. def api_auth(endpoint):
  40. async def verify(token: Annotated[str, Depends(bearer_auth)]):
  41. if token != self.args.api_key:
  42. raise HTTPException(401, None, "Invalid token")
  43. return await endpoint()
  44. async def passthrough():
  45. return await endpoint()
  46. if self.args.api_key is not None:
  47. return verify
  48. else:
  49. return passthrough
  50. self.routes = Routes(
  51. [HttpRoute(path, view) for path, view in self.routes],
  52. http_middlewares=[api_auth],
  53. )
  54. self.openapi = OpenAPI(
  55. {
  56. "title": "Fish Speech API",
  57. "version": "1.5.0",
  58. },
  59. ).routes
  60. # Initialize the app
  61. self.app = Kui(
  62. routes=self.routes + self.openapi[1:], # Remove the default route
  63. exception_handlers={
  64. HTTPException: self.http_exception_handler,
  65. Exception: self.other_exception_handler,
  66. },
  67. factory_class=FactoryClass(http=MsgPackRequest),
  68. cors_config={},
  69. )
  70. # Add the state variables
  71. self.app.state.lock = Lock()
  72. self.app.state.device = self.args.device
  73. self.app.state.max_text_length = self.args.max_text_length
  74. # Associate the app with the model manager
  75. self.app.on_startup(self.initialize_app)
  76. async def initialize_app(self, app: Kui):
  77. # Make the ModelManager available to the views
  78. app.state.model_manager = ModelManager(
  79. mode=self.args.mode,
  80. device=self.args.device,
  81. half=self.args.half,
  82. compile=self.args.compile,
  83. asr_enabled=self.args.load_asr_model,
  84. llama_checkpoint_path=self.args.llama_checkpoint_path,
  85. decoder_checkpoint_path=self.args.decoder_checkpoint_path,
  86. decoder_config_name=self.args.decoder_config_name,
  87. )
  88. logger.info(f"Startup done, listening server at http://{self.args.listen}")
  89. # Each worker process created by Uvicorn has its own memory space,
  90. # meaning that models and variables are not shared between processes.
  91. # Therefore, any variables (like `llama_queue` or `decoder_model`)
  92. # will not be shared across workers.
  93. # Multi-threading for deep learning can cause issues, such as inconsistent
  94. # outputs if multiple threads access the same buffers simultaneously.
  95. # Instead, it's better to use multiprocessing or independent models per thread.
  96. if __name__ == "__main__":
  97. api = API()
  98. host, port = api.args.listen.split(":")
  99. uvicorn.run(
  100. api.app,
  101. host=host,
  102. port=int(port),
  103. workers=api.args.workers,
  104. log_level="info",
  105. )