|
@@ -16,8 +16,9 @@ import numpy as np
|
|
|
import pyrootutils
|
|
import pyrootutils
|
|
|
import soundfile as sf
|
|
import soundfile as sf
|
|
|
import torch
|
|
import torch
|
|
|
-from kui.wsgi import (
|
|
|
|
|
|
|
+from kui.asgi import (
|
|
|
Body,
|
|
Body,
|
|
|
|
|
+ FileResponse,
|
|
|
HTTPException,
|
|
HTTPException,
|
|
|
HttpView,
|
|
HttpView,
|
|
|
JSONResponse,
|
|
JSONResponse,
|
|
@@ -25,7 +26,7 @@ from kui.wsgi import (
|
|
|
OpenAPI,
|
|
OpenAPI,
|
|
|
StreamResponse,
|
|
StreamResponse,
|
|
|
)
|
|
)
|
|
|
-from kui.wsgi.routing import MultimethodRoutes
|
|
|
|
|
|
|
+from kui.asgi.routing import MultimethodRoutes
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
from pydantic import BaseModel, Field
|
|
from pydantic import BaseModel, Field
|
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoTokenizer
|
|
@@ -57,7 +58,7 @@ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
# Define utils for web server
|
|
# Define utils for web server
|
|
|
-def http_execption_handler(exc: HTTPException):
|
|
|
|
|
|
|
+async def http_execption_handler(exc: HTTPException):
|
|
|
return JSONResponse(
|
|
return JSONResponse(
|
|
|
dict(
|
|
dict(
|
|
|
statusCode=exc.status_code,
|
|
statusCode=exc.status_code,
|
|
@@ -69,7 +70,7 @@ def http_execption_handler(exc: HTTPException):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
-def other_exception_handler(exc: "Exception"):
|
|
|
|
|
|
|
+async def other_exception_handler(exc: "Exception"):
|
|
|
traceback.print_exc()
|
|
traceback.print_exc()
|
|
|
|
|
|
|
|
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
|
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
|
@@ -334,8 +335,17 @@ def inference(req: InvokeRequest):
|
|
|
yield fake_audios
|
|
yield fake_audios
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+async def inference_async(req: InvokeRequest):
|
|
|
|
|
+ for chunk in inference(req):
|
|
|
|
|
+ yield chunk
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def buffer_to_async_generator(buffer):
|
|
|
|
|
+ yield buffer
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
@routes.http.post("/v1/invoke")
|
|
@routes.http.post("/v1/invoke")
|
|
|
-def api_invoke_model(
|
|
|
|
|
|
|
+async def api_invoke_model(
|
|
|
req: Annotated[InvokeRequest, Body(exclusive=True)],
|
|
req: Annotated[InvokeRequest, Body(exclusive=True)],
|
|
|
):
|
|
):
|
|
|
"""
|
|
"""
|
|
@@ -354,22 +364,21 @@ def api_invoke_model(
|
|
|
content="Streaming only supports WAV format",
|
|
content="Streaming only supports WAV format",
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- generator = inference(req)
|
|
|
|
|
if req.streaming:
|
|
if req.streaming:
|
|
|
return StreamResponse(
|
|
return StreamResponse(
|
|
|
- iterable=generator,
|
|
|
|
|
|
|
+ iterable=inference_async(req),
|
|
|
headers={
|
|
headers={
|
|
|
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
},
|
|
},
|
|
|
content_type=get_content_type(req.format),
|
|
content_type=get_content_type(req.format),
|
|
|
)
|
|
)
|
|
|
else:
|
|
else:
|
|
|
- fake_audios = next(generator)
|
|
|
|
|
|
|
+ fake_audios = next(inference(req))
|
|
|
buffer = io.BytesIO()
|
|
buffer = io.BytesIO()
|
|
|
sf.write(buffer, fake_audios, decoder_model.sampling_rate, format=req.format)
|
|
sf.write(buffer, fake_audios, decoder_model.sampling_rate, format=req.format)
|
|
|
|
|
|
|
|
return StreamResponse(
|
|
return StreamResponse(
|
|
|
- iterable=[buffer.getvalue()],
|
|
|
|
|
|
|
+ iterable=buffer_to_async_generator(buffer.getvalue()),
|
|
|
headers={
|
|
headers={
|
|
|
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
},
|
|
},
|
|
@@ -378,7 +387,7 @@ def api_invoke_model(
|
|
|
|
|
|
|
|
|
|
|
|
|
@routes.http.post("/v1/health")
|
|
@routes.http.post("/v1/health")
|
|
|
-def api_health():
|
|
|
|
|
|
|
+async def api_health():
|
|
|
"""
|
|
"""
|
|
|
Health check
|
|
Health check
|
|
|
"""
|
|
"""
|
|
@@ -409,6 +418,7 @@ def parse_args():
|
|
|
parser.add_argument("--compile", action="store_true")
|
|
parser.add_argument("--compile", action="store_true")
|
|
|
parser.add_argument("--max-text-length", type=int, default=0)
|
|
parser.add_argument("--max-text-length", type=int, default=0)
|
|
|
parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
|
|
parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
|
|
|
|
|
+ parser.add_argument("--workers", type=int, default=1)
|
|
|
|
|
|
|
|
return parser.parse_args()
|
|
return parser.parse_args()
|
|
|
|
|
|
|
@@ -433,7 +443,7 @@ app = Kui(
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
import threading
|
|
import threading
|
|
|
|
|
|
|
|
- from zibai import create_bind_socket, serve
|
|
|
|
|
|
|
+ import uvicorn
|
|
|
|
|
|
|
|
args = parse_args()
|
|
args = parse_args()
|
|
|
args.precision = torch.half if args.half else torch.bfloat16
|
|
args.precision = torch.half if args.half else torch.bfloat16
|
|
@@ -480,13 +490,5 @@ if __name__ == "__main__":
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
logger.info(f"Warming up done, starting server at http://{args.listen}")
|
|
logger.info(f"Warming up done, starting server at http://{args.listen}")
|
|
|
- sock = create_bind_socket(args.listen)
|
|
|
|
|
- sock.listen()
|
|
|
|
|
-
|
|
|
|
|
- # Start server
|
|
|
|
|
- serve(
|
|
|
|
|
- app=app,
|
|
|
|
|
- bind_sockets=[sock],
|
|
|
|
|
- max_workers=10,
|
|
|
|
|
- graceful_exit=threading.Event(),
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ host, port = args.listen.split(":")
|
|
|
|
|
+ uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
|