api_server.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from threading import Lock
  2. import pyrootutils
  3. import uvicorn
  4. from kui.asgi import (
  5. Depends,
  6. FactoryClass,
  7. HTTPException,
  8. HttpRoute,
  9. Kui,
  10. OpenAPI,
  11. Routes,
  12. )
  13. from kui.cors import CORSConfig
  14. from kui.openapi.specification import Info
  15. from kui.security import bearer_auth
  16. from loguru import logger
  17. from typing_extensions import Annotated
  18. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  19. from tools.server.api_utils import MsgPackRequest, parse_args
  20. from tools.server.exception_handler import ExceptionHandler
  21. from tools.server.model_manager import ModelManager
  22. from tools.server.views import routes
  23. class API(ExceptionHandler):
  24. def __init__(self):
  25. self.args = parse_args()
  26. self.routes = routes
  27. def api_auth(endpoint):
  28. async def verify(token: Annotated[str, Depends(bearer_auth)]):
  29. if token != self.args.api_key:
  30. raise HTTPException(401, None, "Invalid token")
  31. return await endpoint()
  32. async def passthrough():
  33. return await endpoint()
  34. if self.args.api_key is not None:
  35. return verify
  36. else:
  37. return passthrough
  38. self.openapi = OpenAPI(
  39. Info(
  40. {
  41. "title": "Fish Speech API",
  42. "version": "1.5.0",
  43. }
  44. ),
  45. ).routes
  46. # Initialize the app
  47. self.app = Kui(
  48. routes=self.routes + self.openapi[1:], # Remove the default route
  49. exception_handlers={
  50. HTTPException: self.http_exception_handler,
  51. Exception: self.other_exception_handler,
  52. },
  53. factory_class=FactoryClass(http=MsgPackRequest),
  54. cors_config=CORSConfig(),
  55. )
  56. # Add the state variables
  57. self.app.state.lock = Lock()
  58. self.app.state.device = self.args.device
  59. self.app.state.max_text_length = self.args.max_text_length
  60. # Associate the app with the model manager
  61. self.app.on_startup(self.initialize_app)
  62. async def initialize_app(self, app: Kui):
  63. # Make the ModelManager available to the views
  64. app.state.model_manager = ModelManager(
  65. mode=self.args.mode,
  66. device=self.args.device,
  67. half=self.args.half,
  68. compile=self.args.compile,
  69. asr_enabled=self.args.load_asr_model,
  70. llama_checkpoint_path=self.args.llama_checkpoint_path,
  71. decoder_checkpoint_path=self.args.decoder_checkpoint_path,
  72. decoder_config_name=self.args.decoder_config_name,
  73. )
  74. logger.info(f"Startup done, listening server at http://{self.args.listen}")
  75. # Each worker process created by Uvicorn has its own memory space,
  76. # meaning that models and variables are not shared between processes.
  77. # Therefore, any variables (like `llama_queue` or `decoder_model`)
  78. # will not be shared across workers.
  79. # Multi-threading for deep learning can cause issues, such as inconsistent
  80. # outputs if multiple threads access the same buffers simultaneously.
  81. # Instead, it's better to use multiprocessing or independent models per thread.
  82. if __name__ == "__main__":
  83. api = API()
  84. host, port = api.args.listen.split(":")
  85. uvicorn.run(
  86. api.app,
  87. host=host,
  88. port=int(port),
  89. workers=api.args.workers,
  90. log_level="info",
  91. )