Просмотр исходного кода

fix(server): make uvicorn workers>1 work (factory + spawn for CUDA) (#1141)

* fix: pass cli args to uv workers via env vars

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: multi-worker mode with spawns instead of default forks for CUDA

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Vladyslav Tkachenko 2 недель назад
Родитель
Сommit
46806ae6e9
1 измененных файлов с 31 добавлено и 7 удалено
  1. 31 7
      tools/api_server.py

+ 31 - 7
tools/api_server.py

@@ -1,4 +1,8 @@
+import json
+import multiprocessing
+import os
 import re
+from argparse import Namespace
 from threading import Lock
 
 import pyrootutils
@@ -25,10 +29,12 @@ from tools.server.exception_handler import ExceptionHandler
 from tools.server.model_manager import ModelManager
 from tools.server.views import routes
 
+ENV_ARGS_KEY = "FISH_API_SERVER_ARGS"
+
 
 class API(ExceptionHandler):
-    def __init__(self):
-        self.args = parse_args()
+    def __init__(self, args: Namespace | None = None):
+        self.args = args or parse_args()
 
         def api_auth(endpoint):
             async def verify(token: Annotated[str, Depends(bearer_auth)]):
@@ -93,6 +99,19 @@ class API(ExceptionHandler):
         logger.info(f"Startup done, listening server at http://{self.args.listen}")
 
 
+def create_app():
+    args_env = os.environ.get(ENV_ARGS_KEY)
+    args = None
+
+    if args_env:
+        try:
+            args = Namespace(**json.loads(args_env))
+        except Exception as exc:
+            logger.warning(f"Failed to load args from {ENV_ARGS_KEY}: {exc}")
+
+    return API(args=args).app
+
+
 # Each worker process created by Uvicorn has its own memory space,
 # meaning that models and variables are not shared between processes.
 # Therefore, any variables (like `llama_queue` or `decoder_model`)
@@ -103,19 +122,24 @@ class API(ExceptionHandler):
 # Instead, it's better to use multiprocessing or independent models per thread.
 
 if __name__ == "__main__":
-    api = API()
+
+    multiprocessing.set_start_method("spawn", force=True)
+
+    args = parse_args()
+    os.environ[ENV_ARGS_KEY] = json.dumps(vars(args))
 
     # IPv6 address format is [xxxx:xxxx::xxxx]:port
-    match = re.search(r"\[([^\]]+)\]:(\d+)$", api.args.listen)
+    match = re.search(r"\[([^\]]+)\]:(\d+)$", args.listen)
     if match:
         host, port = match.groups()  # IPv6
     else:
-        host, port = api.args.listen.split(":")  # IPv4
+        host, port = args.listen.split(":")  # IPv4
 
     uvicorn.run(
-        api.app,
+        "tools.api_server:create_app",
         host=host,
         port=int(port),
-        workers=api.args.workers,
+        workers=args.workers,
         log_level="info",
+        factory=True,
     )