api_server.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from threading import Lock
  2. import pyrootutils
  3. import uvicorn
  4. from kui.asgi import FactoryClass, HTTPException, HttpRoute, Kui, OpenAPI, Routes
  5. from loguru import logger
  6. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  7. from tools.server.api_utils import MsgPackRequest, parse_args
  8. from tools.server.exception_handler import ExceptionHandler
  9. from tools.server.model_manager import ModelManager
  10. from tools.server.views import (
  11. ASRView,
  12. ChatView,
  13. HealthView,
  14. TTSView,
  15. VQGANDecodeView,
  16. VQGANEncodeView,
  17. )
  18. class API(ExceptionHandler):
  19. def __init__(self):
  20. self.args = parse_args()
  21. self.routes = [
  22. ("/v1/health", HealthView),
  23. ("/v1/vqgan/encode", VQGANEncodeView),
  24. ("/v1/vqgan/decode", VQGANDecodeView),
  25. ("/v1/asr", ASRView),
  26. ("/v1/tts", TTSView),
  27. ("/v1/chat", ChatView),
  28. ]
  29. self.routes = Routes([HttpRoute(path, view) for path, view in self.routes])
  30. self.openapi = OpenAPI(
  31. {
  32. "title": "Fish Speech API",
  33. "version": "1.5.0",
  34. },
  35. ).routes
  36. # Initialize the app
  37. self.app = Kui(
  38. routes=self.routes + self.openapi[1:], # Remove the default route
  39. exception_handlers={
  40. HTTPException: self.http_exception_handler,
  41. Exception: self.other_exception_handler,
  42. },
  43. factory_class=FactoryClass(http=MsgPackRequest),
  44. cors_config={},
  45. )
  46. # Add the state variables
  47. self.app.state.lock = Lock()
  48. self.app.state.device = self.args.device
  49. self.app.state.max_text_length = self.args.max_text_length
  50. # Associate the app with the model manager
  51. self.app.on_startup(self.initialize_app)
  52. async def initialize_app(self, app: Kui):
  53. # Make the ModelManager available to the views
  54. app.state.model_manager = ModelManager(
  55. mode=self.args.mode,
  56. device=self.args.device,
  57. half=self.args.half,
  58. compile=self.args.compile,
  59. asr_enabled=self.args.load_asr_model,
  60. llama_checkpoint_path=self.args.llama_checkpoint_path,
  61. decoder_checkpoint_path=self.args.decoder_checkpoint_path,
  62. decoder_config_name=self.args.decoder_config_name,
  63. )
  64. logger.info(f"Startup done, listening server at http://{self.args.listen}")
  65. # Each worker process created by Uvicorn has its own memory space,
  66. # meaning that models and variables are not shared between processes.
  67. # Therefore, any variables (like `llama_queue` or `decoder_model`)
  68. # will not be shared across workers.
  69. # Multi-threading for deep learning can cause issues, such as inconsistent
  70. # outputs if multiple threads access the same buffers simultaneously.
  71. # Instead, it's better to use multiprocessing or independent models per thread.
  72. if __name__ == "__main__":
  73. api = API()
  74. host, port = api.args.listen.split(":")
  75. uvicorn.run(
  76. api.app,
  77. host=host,
  78. port=int(port),
  79. workers=api.args.workers,
  80. log_level="info",
  81. )