api_server.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import json
  2. import multiprocessing
  3. import os
  4. import re
  5. from argparse import Namespace
  6. from threading import Lock
  7. import pyrootutils
  8. import uvicorn
  9. from kui.asgi import (
  10. Depends,
  11. FactoryClass,
  12. HTTPException,
  13. Kui,
  14. Routes,
  15. )
  16. from kui.cors import CORSConfig
  17. from kui.security import bearer_auth
  18. from loguru import logger
  19. from typing_extensions import Annotated
  20. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  21. from tools.server.api_utils import MsgPackRequest, parse_args
  22. from tools.server.exception_handler import ExceptionHandler
  23. from tools.server.model_manager import ModelManager
  24. from tools.server.views import routes
  25. ENV_ARGS_KEY = "FISH_API_SERVER_ARGS"
  26. class API(ExceptionHandler):
  27. def __init__(self, args: Namespace | None = None):
  28. self.args = args or parse_args()
  29. def api_auth(endpoint):
  30. async def verify(token: Annotated[str, Depends(bearer_auth)]):
  31. if token != self.args.api_key:
  32. raise HTTPException(401, None, "Invalid token")
  33. return await endpoint()
  34. async def passthrough():
  35. return await endpoint()
  36. if self.args.api_key is not None:
  37. return verify
  38. else:
  39. return passthrough
  40. self.routes = Routes(
  41. routes, # keep existing routes
  42. http_middlewares=[api_auth], # apply api_auth middleware
  43. )
  44. # OpenAPIの設定
  45. # self.openapi = OpenAPI(
  46. # Info(
  47. # {
  48. # "title": "Fish Speech API",
  49. # "version": "1.5.0",
  50. # }
  51. # ),
  52. # ).routes
  53. # Initialize the app
  54. self.app = Kui(
  55. routes=self.routes,
  56. exception_handlers={
  57. HTTPException: self.http_exception_handler,
  58. Exception: self.other_exception_handler,
  59. },
  60. factory_class=FactoryClass(http=MsgPackRequest),
  61. cors_config=CORSConfig(),
  62. )
  63. # Add the state variables
  64. self.app.state.lock = Lock()
  65. self.app.state.device = self.args.device
  66. self.app.state.max_text_length = self.args.max_text_length
  67. # Args print
  68. self.args_print()
  69. # Environment print
  70. self.environment_print()
  71. # Associate the app with the model manager
  72. self.app.on_startup(self.initialize_app)
  73. async def initialize_app(self, app: Kui):
  74. # Make the ModelManager available to the views
  75. app.state.model_manager = ModelManager(
  76. mode=self.args.mode,
  77. device=self.args.device,
  78. half=self.args.half,
  79. compile=self.args.compile,
  80. llama_checkpoint_path=self.args.llama_checkpoint_path,
  81. decoder_checkpoint_path=self.args.decoder_checkpoint_path,
  82. decoder_config_name=self.args.decoder_config_name,
  83. )
  84. logger.info(f"Startup done, listening server at http://{self.args.listen}")
  85. def args_print(self):
  86. if not self.args:
  87. return
  88. logger.info("Loaded arguments:")
  89. for key, value in vars(self.args).items():
  90. logger.info(f" self.args.{key}: {value}")
  91. @staticmethod
  92. def environment_print():
  93. logger.info("environment:")
  94. for key in os.environ.keys():
  95. logger.info(f" env.{key}: {os.getenv(key)}")
  96. def create_app():
  97. args_env = os.environ.get(ENV_ARGS_KEY)
  98. args = None
  99. if args_env:
  100. try:
  101. args = Namespace(**json.loads(args_env))
  102. except Exception as exc:
  103. logger.warning(f"Failed to load args from {ENV_ARGS_KEY}: {exc}")
  104. return API(args=args).app
  105. # Each worker process created by Uvicorn has its own memory space,
  106. # meaning that models and variables are not shared between processes.
  107. # Therefore, any variables (like `llama_queue` or `decoder_model`)
  108. # will not be shared across workers.
  109. # Multi-threading for deep learning can cause issues, such as inconsistent
  110. # outputs if multiple threads access the same buffers simultaneously.
  111. # Instead, it's better to use multiprocessing or independent models per thread.
  112. if __name__ == "__main__":
  113. multiprocessing.set_start_method("spawn", force=True)
  114. args = parse_args()
  115. os.environ[ENV_ARGS_KEY] = json.dumps(vars(args))
  116. # IPv6 address format is [xxxx:xxxx::xxxx]:port
  117. match = re.search(r"\[([^\]]+)\]:(\d+)$", args.listen)
  118. if match:
  119. host, port = match.groups() # IPv6
  120. else:
  121. host, port = args.listen.split(":") # IPv4
  122. uvicorn.run(
  123. "tools.api_server:create_app",
  124. host=host,
  125. port=int(port),
  126. workers=args.workers,
  127. log_level="info",
  128. factory=True,
  129. )