|
|
@@ -1,4 +1,8 @@
|
|
|
+import json
|
|
|
+import multiprocessing
|
|
|
+import os
|
|
|
import re
|
|
|
+from argparse import Namespace
|
|
|
from threading import Lock
|
|
|
|
|
|
import pyrootutils
|
|
|
@@ -25,10 +29,12 @@ from tools.server.exception_handler import ExceptionHandler
|
|
|
from tools.server.model_manager import ModelManager
|
|
|
from tools.server.views import routes
|
|
|
|
|
|
+ENV_ARGS_KEY = "FISH_API_SERVER_ARGS"
|
|
|
+
|
|
|
|
|
|
class API(ExceptionHandler):
|
|
|
- def __init__(self):
|
|
|
- self.args = parse_args()
|
|
|
+ def __init__(self, args: Namespace | None = None):
|
|
|
+ self.args = args or parse_args()
|
|
|
|
|
|
def api_auth(endpoint):
|
|
|
async def verify(token: Annotated[str, Depends(bearer_auth)]):
|
|
|
@@ -93,6 +99,19 @@ class API(ExceptionHandler):
|
|
|
logger.info(f"Startup done, listening server at http://{self.args.listen}")
|
|
|
|
|
|
|
|
|
+def create_app():
|
|
|
+ args_env = os.environ.get(ENV_ARGS_KEY)
|
|
|
+ args = None
|
|
|
+
|
|
|
+ if args_env:
|
|
|
+ try:
|
|
|
+ args = Namespace(**json.loads(args_env))
|
|
|
+ except Exception as exc:
|
|
|
+ logger.warning(f"Failed to load args from {ENV_ARGS_KEY}: {exc}")
|
|
|
+
|
|
|
+ return API(args=args).app
|
|
|
+
|
|
|
+
|
|
|
# Each worker process created by Uvicorn has its own memory space,
|
|
|
# meaning that models and variables are not shared between processes.
|
|
|
# Therefore, any variables (like `llama_queue` or `decoder_model`)
|
|
|
@@ -103,19 +122,24 @@ class API(ExceptionHandler):
|
|
|
# Instead, it's better to use multiprocessing or independent models per thread.
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- api = API()
|
|
|
+
|
|
|
+ multiprocessing.set_start_method("spawn", force=True)
|
|
|
+
|
|
|
+ args = parse_args()
|
|
|
+ os.environ[ENV_ARGS_KEY] = json.dumps(vars(args))
|
|
|
|
|
|
# IPv6 address format is [xxxx:xxxx::xxxx]:port
|
|
|
- match = re.search(r"\[([^\]]+)\]:(\d+)$", api.args.listen)
|
|
|
+ match = re.search(r"\[([^\]]+)\]:(\d+)$", args.listen)
|
|
|
if match:
|
|
|
host, port = match.groups() # IPv6
|
|
|
else:
|
|
|
- host, port = api.args.listen.split(":") # IPv4
|
|
|
+ host, port = args.listen.split(":") # IPv4
|
|
|
|
|
|
uvicorn.run(
|
|
|
- api.app,
|
|
|
+ "tools.api_server:create_app",
|
|
|
host=host,
|
|
|
port=int(port),
|
|
|
- workers=api.args.workers,
|
|
|
+ workers=args.workers,
|
|
|
log_level="info",
|
|
|
+ factory=True,
|
|
|
)
|