Selaa lähdekoodia

Replace zibai with uvicorn

Lengyue 1 vuosi sitten
vanhempi
commit
46dae9bd1c
2 muutettua tiedostoa jossa 24 lisäystä ja 22 poistoa
  1. 1 1
      pyproject.toml
  2. 23 21
      tools/api.py

+ 1 - 1
pyproject.toml

@@ -26,7 +26,7 @@ dependencies = [
     "wandb>=0.15.11",
     "grpcio>=1.58.0",
     "kui>=1.6.0",
-    "zibai-server>=0.9.0",
+    "uvicorn>=0.30.0",
     "loguru>=0.6.0",
     "loralib>=0.1.2",
     "natsort>=8.4.0",

+ 23 - 21
tools/api.py

@@ -16,8 +16,9 @@ import numpy as np
 import pyrootutils
 import soundfile as sf
 import torch
-from kui.wsgi import (
+from kui.asgi import (
     Body,
+    FileResponse,
     HTTPException,
     HttpView,
     JSONResponse,
@@ -25,7 +26,7 @@ from kui.wsgi import (
     OpenAPI,
     StreamResponse,
 )
-from kui.wsgi.routing import MultimethodRoutes
+from kui.asgi.routing import MultimethodRoutes
 from loguru import logger
 from pydantic import BaseModel, Field
 from transformers import AutoTokenizer
@@ -57,7 +58,7 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
 
 
 # Define utils for web server
-def http_execption_handler(exc: HTTPException):
+async def http_execption_handler(exc: HTTPException):
     return JSONResponse(
         dict(
             statusCode=exc.status_code,
@@ -69,7 +70,7 @@ def http_execption_handler(exc: HTTPException):
     )
 
 
-def other_exception_handler(exc: "Exception"):
+async def other_exception_handler(exc: "Exception"):
     traceback.print_exc()
 
     status = HTTPStatus.INTERNAL_SERVER_ERROR
@@ -334,8 +335,17 @@ def inference(req: InvokeRequest):
     yield fake_audios
 
 
+async def inference_async(req: InvokeRequest):
+    for chunk in inference(req):
+        yield chunk
+
+
+async def buffer_to_async_generator(buffer):
+    yield buffer
+
+
 @routes.http.post("/v1/invoke")
-def api_invoke_model(
+async def api_invoke_model(
     req: Annotated[InvokeRequest, Body(exclusive=True)],
 ):
     """
@@ -354,22 +364,21 @@ def api_invoke_model(
             content="Streaming only supports WAV format",
         )
 
-    generator = inference(req)
     if req.streaming:
         return StreamResponse(
-            iterable=generator,
+            iterable=inference_async(req),
             headers={
                 "Content-Disposition": f"attachment; filename=audio.{req.format}",
             },
             content_type=get_content_type(req.format),
         )
     else:
-        fake_audios = next(generator)
+        fake_audios = next(inference(req))
         buffer = io.BytesIO()
         sf.write(buffer, fake_audios, decoder_model.sampling_rate, format=req.format)
 
         return StreamResponse(
-            iterable=[buffer.getvalue()],
+            iterable=buffer_to_async_generator(buffer.getvalue()),
             headers={
                 "Content-Disposition": f"attachment; filename=audio.{req.format}",
             },
@@ -378,7 +387,7 @@ def api_invoke_model(
 
 
 @routes.http.post("/v1/health")
-def api_health():
+async def api_health():
     """
     Health check
     """
@@ -409,6 +418,7 @@ def parse_args():
     parser.add_argument("--compile", action="store_true")
     parser.add_argument("--max-text-length", type=int, default=0)
     parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
+    parser.add_argument("--workers", type=int, default=1)
 
     return parser.parse_args()
 
@@ -433,7 +443,7 @@ app = Kui(
 if __name__ == "__main__":
     import threading
 
-    from zibai import create_bind_socket, serve
+    import uvicorn
 
     args = parse_args()
     args.precision = torch.half if args.half else torch.bfloat16
@@ -480,13 +490,5 @@ if __name__ == "__main__":
     )
 
     logger.info(f"Warming up done, starting server at http://{args.listen}")
-    sock = create_bind_socket(args.listen)
-    sock.listen()
-
-    # Start server
-    serve(
-        app=app,
-        bind_sockets=[sock],
-        max_workers=10,
-        graceful_exit=threading.Event(),
-    )
+    host, port = args.listen.split(":")
+    uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")