|
@@ -1,5 +1,6 @@
|
|
|
import base64
|
|
import base64
|
|
|
import io
|
|
import io
|
|
|
|
|
+import threading
|
|
|
import traceback
|
|
import traceback
|
|
|
from argparse import ArgumentParser
|
|
from argparse import ArgumentParser
|
|
|
from http import HTTPStatus
|
|
from http import HTTPStatus
|
|
@@ -17,15 +18,13 @@ from kui.wsgi import (
|
|
|
Kui,
|
|
Kui,
|
|
|
OpenAPI,
|
|
OpenAPI,
|
|
|
StreamResponse,
|
|
StreamResponse,
|
|
|
- allow_cors,
|
|
|
|
|
)
|
|
)
|
|
|
from kui.wsgi.routing import MultimethodRoutes
|
|
from kui.wsgi.routing import MultimethodRoutes
|
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
|
from transformers import AutoTokenizer
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
-from tools.llama.generate import generate_long
|
|
|
|
|
-from tools.llama.generate import load_model as load_llama_model
|
|
|
|
|
|
|
+from tools.llama.generate import launch_thread_safe_queue
|
|
|
from tools.vqgan.inference import load_model as load_vqgan_model
|
|
from tools.vqgan.inference import load_model as load_vqgan_model
|
|
|
from tools.webui import inference
|
|
from tools.webui import inference
|
|
|
|
|
|
|
@@ -95,11 +94,9 @@ def inference(req: InvokeRequest):
|
|
|
prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
|
|
prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
|
|
|
|
|
|
|
|
# LLAMA Inference
|
|
# LLAMA Inference
|
|
|
- result = generate_long(
|
|
|
|
|
- model=llama_model,
|
|
|
|
|
|
|
+ request = dict(
|
|
|
tokenizer=llama_tokenizer,
|
|
tokenizer=llama_tokenizer,
|
|
|
device=vqgan_model.device,
|
|
device=vqgan_model.device,
|
|
|
- decode_one_token=decode_one_token,
|
|
|
|
|
max_new_tokens=req.max_new_tokens,
|
|
max_new_tokens=req.max_new_tokens,
|
|
|
text=req.text,
|
|
text=req.text,
|
|
|
top_k=int(req.top_k) if req.top_k > 0 else None,
|
|
top_k=int(req.top_k) if req.top_k > 0 else None,
|
|
@@ -115,7 +112,18 @@ def inference(req: InvokeRequest):
|
|
|
prompt_text=req.reference_text,
|
|
prompt_text=req.reference_text,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- codes = next(result)
|
|
|
|
|
|
|
+ payload = dict(
|
|
|
|
|
+ event=threading.Event(),
|
|
|
|
|
+ request=request,
|
|
|
|
|
+ )
|
|
|
|
|
+ llama_queue.put(payload)
|
|
|
|
|
+
|
|
|
|
|
+ # Wait for the result
|
|
|
|
|
+ payload["event"].wait()
|
|
|
|
|
+ if payload["success"] is False:
|
|
|
|
|
+ raise payload["response"]
|
|
|
|
|
+
|
|
|
|
|
+ codes = payload["response"][0]
|
|
|
|
|
|
|
|
# VQGAN Inference
|
|
# VQGAN Inference
|
|
|
feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
|
|
feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
|
|
@@ -128,7 +136,7 @@ def inference(req: InvokeRequest):
|
|
|
return fake_audios
|
|
return fake_audios
|
|
|
|
|
|
|
|
|
|
|
|
|
-@routes.http.post("/invoke")
|
|
|
|
|
|
|
+@routes.http.post("/v1/invoke")
|
|
|
def api_invoke_model(
|
|
def api_invoke_model(
|
|
|
req: Annotated[InvokeRequest, Body(exclusive=True)],
|
|
req: Annotated[InvokeRequest, Body(exclusive=True)],
|
|
|
):
|
|
):
|
|
@@ -139,7 +147,7 @@ def api_invoke_model(
|
|
|
if args.max_gradio_length > 0 and len(req.text) > args.max_gradio_length:
|
|
if args.max_gradio_length > 0 and len(req.text) > args.max_gradio_length:
|
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
|
HTTPStatus.BAD_REQUEST,
|
|
HTTPStatus.BAD_REQUEST,
|
|
|
- f"Text is too long, max length is {args.max_gradio_length}",
|
|
|
|
|
|
|
+ content=f"Text is too long, max length is {args.max_gradio_length}",
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
@@ -147,7 +155,11 @@ def api_invoke_model(
|
|
|
lock.acquire()
|
|
lock.acquire()
|
|
|
fake_audios = inference(req)
|
|
fake_audios = inference(req)
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
- raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
|
|
|
|
|
|
|
+ import traceback
|
|
|
|
|
+
|
|
|
|
|
+ traceback.print_exc()
|
|
|
|
|
+
|
|
|
|
|
+ raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, content=str(e))
|
|
|
finally:
|
|
finally:
|
|
|
# Release lock
|
|
# Release lock
|
|
|
lock.release()
|
|
lock.release()
|
|
@@ -159,12 +171,14 @@ def api_invoke_model(
|
|
|
iterable=[buffer.getvalue()],
|
|
iterable=[buffer.getvalue()],
|
|
|
headers={
|
|
headers={
|
|
|
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
- "Content-Type": "application/octet-stream",
|
|
|
|
|
},
|
|
},
|
|
|
|
|
+ # Make swagger-ui happy
|
|
|
|
|
+ # content_type=f"audio/{req.format}",
|
|
|
|
|
+ content_type="application/octet-stream",
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
-@routes.http.post("/health")
|
|
|
|
|
|
|
+@routes.http.post("/v1/health")
|
|
|
def api_health():
|
|
def api_health():
|
|
|
"""
|
|
"""
|
|
|
Health check
|
|
Health check
|
|
@@ -201,7 +215,14 @@ def parse_args():
|
|
|
|
|
|
|
|
|
|
|
|
|
# Define Kui app
|
|
# Define Kui app
|
|
|
|
|
+openapi = OpenAPI(
|
|
|
|
|
+ {
|
|
|
|
|
+ "title": "Fish Speech API",
|
|
|
|
|
+ },
|
|
|
|
|
+).routes
|
|
|
|
|
+
|
|
|
app = Kui(
|
|
app = Kui(
|
|
|
|
|
+ routes=routes + openapi[1:], # Remove the default route
|
|
|
exception_handlers={
|
|
exception_handlers={
|
|
|
HTTPException: http_execption_handler,
|
|
HTTPException: http_execption_handler,
|
|
|
Exception: other_exception_handler,
|
|
Exception: other_exception_handler,
|
|
@@ -209,9 +230,6 @@ app = Kui(
|
|
|
cors_config={},
|
|
cors_config={},
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
-# Swagger UI & routes
|
|
|
|
|
-app.router << ("/v1" // routes) << ("/docs" // OpenAPI().routes)
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
import threading
|
|
import threading
|
|
@@ -222,7 +240,7 @@ if __name__ == "__main__":
|
|
|
args.precision = torch.half if args.half else torch.bfloat16
|
|
args.precision = torch.half if args.half else torch.bfloat16
|
|
|
|
|
|
|
|
logger.info("Loading Llama model...")
|
|
logger.info("Loading Llama model...")
|
|
|
- llama_model, decode_one_token = load_llama_model(
|
|
|
|
|
|
|
+ llama_queue = launch_thread_safe_queue(
|
|
|
config_name=args.llama_config_name,
|
|
config_name=args.llama_config_name,
|
|
|
checkpoint_path=args.llama_checkpoint_path,
|
|
checkpoint_path=args.llama_checkpoint_path,
|
|
|
device=args.device,
|
|
device=args.device,
|