api_server.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import json
  2. import multiprocessing
  3. import os
  4. import re
  5. from argparse import Namespace
  6. from threading import Lock
  7. import pyrootutils
  8. import uvicorn
  9. from kui.asgi import (
  10. Depends,
  11. FactoryClass,
  12. HTTPException,
  13. Kui,
  14. Routes,
  15. )
  16. from kui.cors import CORSConfig
  17. from kui.security import bearer_auth
  18. from loguru import logger
  19. from typing_extensions import Annotated
  20. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  21. from tools.server.api_utils import MsgPackRequest, parse_args
  22. from tools.server.exception_handler import ExceptionHandler
  23. from tools.server.model_manager import ModelManager
  24. from tools.server.views import routes
  25. ENV_ARGS_KEY = "FISH_API_SERVER_ARGS"
  26. class API(ExceptionHandler):
  27. def __init__(self, args: Namespace | None = None):
  28. self.args = args or parse_args()
  29. def api_auth(endpoint):
  30. async def verify(token: Annotated[str, Depends(bearer_auth)]):
  31. if token != self.args.api_key:
  32. raise HTTPException(401, None, "Invalid token")
  33. return await endpoint()
  34. async def passthrough():
  35. return await endpoint()
  36. if self.args.api_key is not None:
  37. return verify
  38. else:
  39. return passthrough
  40. self.routes = Routes(
  41. routes, # keep existing routes
  42. http_middlewares=[api_auth], # apply api_auth middleware
  43. )
  44. # OpenAPIの設定
  45. # self.openapi = OpenAPI(
  46. # Info(
  47. # {
  48. # "title": "Fish Speech API",
  49. # "version": "1.5.0",
  50. # }
  51. # ),
  52. # ).routes
  53. # Initialize the app
  54. self.app = Kui(
  55. routes=self.routes,
  56. exception_handlers={
  57. HTTPException: self.http_exception_handler,
  58. Exception: self.other_exception_handler,
  59. },
  60. factory_class=FactoryClass(http=MsgPackRequest),
  61. cors_config=CORSConfig(),
  62. )
  63. # Add the state variables
  64. self.app.state.lock = Lock()
  65. self.app.state.device = self.args.device
  66. self.app.state.max_text_length = self.args.max_text_length
  67. # Associate the app with the model manager
  68. self.app.on_startup(self.initialize_app)
  69. # Print args
  70. self.args_print()
  71. async def initialize_app(self, app: Kui):
  72. # Make the ModelManager available to the views
  73. app.state.model_manager = ModelManager(
  74. mode=self.args.mode,
  75. device=self.args.device,
  76. half=self.args.half,
  77. compile=self.args.compile,
  78. llama_checkpoint_path=self.args.llama_checkpoint_path,
  79. decoder_checkpoint_path=self.args.decoder_checkpoint_path,
  80. decoder_config_name=self.args.decoder_config_name,
  81. )
  82. logger.info(f"Startup done, listening server at http://{self.args.listen}")
  83. def args_print(self):
  84. if self.args:
  85. logger.info("Loaded arguments:")
  86. for key, value in vars(self.args).items():
  87. logger.info(f" self.args.{key}: {value}")
  88. logger.info("environment:")
  89. for key in os.environ.keys():
  90. logger.info(f" env.{key}: {os.getenv(key)}")
  91. def create_app():
  92. args_env = os.environ.get(ENV_ARGS_KEY)
  93. args = None
  94. if args_env:
  95. try:
  96. args = Namespace(**json.loads(args_env))
  97. except Exception as exc:
  98. logger.warning(f"Failed to load args from {ENV_ARGS_KEY}: {exc}")
  99. return API(args=args).app
  100. # Each worker process created by Uvicorn has its own memory space,
  101. # meaning that models and variables are not shared between processes.
  102. # Therefore, any variables (like `llama_queue` or `decoder_model`)
  103. # will not be shared across workers.
  104. # Multi-threading for deep learning can cause issues, such as inconsistent
  105. # outputs if multiple threads access the same buffers simultaneously.
  106. # Instead, it's better to use multiprocessing or independent models per thread.
  107. if __name__ == "__main__":
  108. multiprocessing.set_start_method("spawn", force=True)
  109. args = parse_args()
  110. os.environ[ENV_ARGS_KEY] = json.dumps(vars(args))
  111. # IPv6 address format is [xxxx:xxxx::xxxx]:port
  112. match = re.search(r"\[([^\]]+)\]:(\d+)$", args.listen)
  113. if match:
  114. host, port = match.groups() # IPv6
  115. else:
  116. host, port = args.listen.split(":") # IPv4
  117. uvicorn.run(
  118. "tools.api_server:create_app",
  119. host=host,
  120. port=int(port),
  121. workers=args.workers,
  122. log_level="info",
  123. factory=True,
  124. )