api_server.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import re
  2. from threading import Lock
  3. import pyrootutils
  4. import uvicorn
  5. from kui.asgi import (
  6. Depends,
  7. FactoryClass,
  8. HTTPException,
  9. HttpRoute,
  10. Kui,
  11. OpenAPI,
  12. Routes,
  13. )
  14. from kui.cors import CORSConfig
  15. from kui.openapi.specification import Info
  16. from kui.security import bearer_auth
  17. from loguru import logger
  18. from typing_extensions import Annotated
  19. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  20. from tools.server.api_utils import MsgPackRequest, parse_args
  21. from tools.server.exception_handler import ExceptionHandler
  22. from tools.server.model_manager import ModelManager
  23. from tools.server.views import routes
  24. class API(ExceptionHandler):
  25. def __init__(self):
  26. self.args = parse_args()
  27. self.routes = routes
  28. def api_auth(endpoint):
  29. async def verify(token: Annotated[str, Depends(bearer_auth)]):
  30. if token != self.args.api_key:
  31. raise HTTPException(401, None, "Invalid token")
  32. return await endpoint()
  33. async def passthrough():
  34. return await endpoint()
  35. if self.args.api_key is not None:
  36. return verify
  37. else:
  38. return passthrough
  39. self.openapi = OpenAPI(
  40. Info(
  41. {
  42. "title": "Fish Speech API",
  43. "version": "1.5.0",
  44. }
  45. ),
  46. ).routes
  47. # Initialize the app
  48. self.app = Kui(
  49. routes=self.routes + self.openapi[1:], # Remove the default route
  50. exception_handlers={
  51. HTTPException: self.http_exception_handler,
  52. Exception: self.other_exception_handler,
  53. },
  54. factory_class=FactoryClass(http=MsgPackRequest),
  55. cors_config=CORSConfig(),
  56. )
  57. # Add the state variables
  58. self.app.state.lock = Lock()
  59. self.app.state.device = self.args.device
  60. self.app.state.max_text_length = self.args.max_text_length
  61. # Associate the app with the model manager
  62. self.app.on_startup(self.initialize_app)
  63. async def initialize_app(self, app: Kui):
  64. # Make the ModelManager available to the views
  65. app.state.model_manager = ModelManager(
  66. mode=self.args.mode,
  67. device=self.args.device,
  68. half=self.args.half,
  69. compile=self.args.compile,
  70. asr_enabled=self.args.load_asr_model,
  71. llama_checkpoint_path=self.args.llama_checkpoint_path,
  72. decoder_checkpoint_path=self.args.decoder_checkpoint_path,
  73. decoder_config_name=self.args.decoder_config_name,
  74. )
  75. logger.info(f"Startup done, listening server at http://{self.args.listen}")
  76. # Each worker process created by Uvicorn has its own memory space,
  77. # meaning that models and variables are not shared between processes.
  78. # Therefore, any variables (like `llama_queue` or `decoder_model`)
  79. # will not be shared across workers.
  80. # Multi-threading for deep learning can cause issues, such as inconsistent
  81. # outputs if multiple threads access the same buffers simultaneously.
  82. # Instead, it's better to use multiprocessing or independent models per thread.
  83. if __name__ == "__main__":
  84. api = API()
  85. # IPv6 address format is [xxxx:xxxx::xxxx]:port
  86. match = re.search(r"\[([^\]]+)\]:(\d+)$", api.args.listen)
  87. if match:
  88. host, port = match.groups() # IPv6
  89. else:
  90. host, port = api.args.listen.split(":") # IPv4
  91. uvicorn.run(
  92. api.app,
  93. host=host,
  94. port=int(port),
  95. workers=api.args.workers,
  96. log_level="info",
  97. )