فهرست منبع

Implement thread safe compile

Lengyue 1 سال پیش
والد
کامیت
a095c8f6c7
4فایلهای تغییر یافته به همراه95 افزوده شده و 28 حذف شده
  1. 1 1
      pyproject.toml
  2. 34 16
      tools/api.py
  3. 44 4
      tools/llama/generate.py
  4. 16 7
      tools/webui.py

+ 1 - 1
pyproject.toml

@@ -35,7 +35,7 @@ dependencies = [
     "samplerate>=0.2.1",
     "resampy>=0.4.3",
     "spaces>=0.26.1",
-    "einx[torch]==0.2.0"
+    "einx[torch]==0.2.2"
 ]
 
 [project.optional-dependencies]

+ 34 - 16
tools/api.py

@@ -1,5 +1,6 @@
 import base64
 import io
+import threading
 import traceback
 from argparse import ArgumentParser
 from http import HTTPStatus
@@ -17,15 +18,13 @@ from kui.wsgi import (
     Kui,
     OpenAPI,
     StreamResponse,
-    allow_cors,
 )
 from kui.wsgi.routing import MultimethodRoutes
 from loguru import logger
 from pydantic import BaseModel
 from transformers import AutoTokenizer
 
-from tools.llama.generate import generate_long
-from tools.llama.generate import load_model as load_llama_model
+from tools.llama.generate import launch_thread_safe_queue
 from tools.vqgan.inference import load_model as load_vqgan_model
 from tools.webui import inference
 
@@ -95,11 +94,9 @@ def inference(req: InvokeRequest):
         prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
 
     # LLAMA Inference
-    result = generate_long(
-        model=llama_model,
+    request = dict(
         tokenizer=llama_tokenizer,
         device=vqgan_model.device,
-        decode_one_token=decode_one_token,
         max_new_tokens=req.max_new_tokens,
         text=req.text,
         top_k=int(req.top_k) if req.top_k > 0 else None,
@@ -115,7 +112,18 @@ def inference(req: InvokeRequest):
         prompt_text=req.reference_text,
     )
 
-    codes = next(result)
+    payload = dict(
+        event=threading.Event(),
+        request=request,
+    )
+    llama_queue.put(payload)
+
+    # Wait for the result
+    payload["event"].wait()
+    if payload["success"] is False:
+        raise payload["response"]
+
+    codes = payload["response"][0]
 
     # VQGAN Inference
     feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
@@ -128,7 +136,7 @@ def inference(req: InvokeRequest):
     return fake_audios
 
 
-@routes.http.post("/invoke")
+@routes.http.post("/v1/invoke")
 def api_invoke_model(
     req: Annotated[InvokeRequest, Body(exclusive=True)],
 ):
@@ -139,7 +147,7 @@ def api_invoke_model(
     if args.max_gradio_length > 0 and len(req.text) > args.max_gradio_length:
         raise HTTPException(
             HTTPStatus.BAD_REQUEST,
-            f"Text is too long, max length is {args.max_gradio_length}",
+            content=f"Text is too long, max length is {args.max_gradio_length}",
         )
 
     try:
@@ -147,7 +155,11 @@ def api_invoke_model(
         lock.acquire()
         fake_audios = inference(req)
     except Exception as e:
-        raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
+        import traceback
+
+        traceback.print_exc()
+
+        raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, content=str(e))
     finally:
         # Release lock
         lock.release()
@@ -159,12 +171,14 @@ def api_invoke_model(
         iterable=[buffer.getvalue()],
         headers={
             "Content-Disposition": f"attachment; filename=audio.{req.format}",
-            "Content-Type": "application/octet-stream",
         },
+        # Make swagger-ui happy
+        # content_type=f"audio/{req.format}",
+        content_type="application/octet-stream",
     )
 
 
-@routes.http.post("/health")
+@routes.http.post("/v1/health")
 def api_health():
     """
     Health check
@@ -201,7 +215,14 @@ def parse_args():
 
 
 # Define Kui app
+openapi = OpenAPI(
+    {
+        "title": "Fish Speech API",
+    },
+).routes
+
 app = Kui(
+    routes=routes + openapi[1:],  # Remove the default route
     exception_handlers={
         HTTPException: http_execption_handler,
         Exception: other_exception_handler,
@@ -209,9 +230,6 @@ app = Kui(
     cors_config={},
 )
 
-# Swagger UI & routes
-app.router << ("/v1" // routes) << ("/docs" // OpenAPI().routes)
-
 
 if __name__ == "__main__":
     import threading
@@ -222,7 +240,7 @@ if __name__ == "__main__":
     args.precision = torch.half if args.half else torch.bfloat16
 
     logger.info("Loading Llama model...")
-    llama_model, decode_one_token = load_llama_model(
+    llama_queue = launch_thread_safe_queue(
         config_name=args.llama_config_name,
         checkpoint_path=args.llama_checkpoint_path,
         device=args.device,

+ 44 - 4
tools/llama/generate.py

@@ -1,4 +1,6 @@
 import os
+import queue
+import threading
 import time
 from pathlib import Path
 from typing import Optional, Tuple, Union
@@ -567,10 +569,7 @@ def generate_long(
             codes = y[1:, prompt_length:-2].clone()
 
             codes = codes - 2
-            if not (codes >= 0).all():
-                global_encoded.pop()
-                logger.warning(f"Negative code found: {codes}, retrying ...")
-                continue
+            assert (codes >= 0).all(), f"Negative code found"
 
             decoded = y[:, prompt_length:-1].clone()
             if decoded[0, -1] != im_end_id:  # <im_end>
@@ -599,6 +598,47 @@ def generate_long(
             yield all_codes
 
 
+def launch_thread_safe_queue(
+    config_name,
+    checkpoint_path,
+    device,
+    precision,
+    max_length,
+    compile=False,
+):
+    input_queue = queue.Queue()
+
+    def worker():
+        model, decode_one_token = load_model(
+            config_name, checkpoint_path, device, precision, max_length, compile=compile
+        )
+
+        while True:
+            item = input_queue.get()
+            if item is None:
+                break
+
+            kwargs = item["request"]
+            event = item["event"]
+
+            try:
+                item["success"] = True
+                item["response"] = list(
+                    generate_long(
+                        model=model, decode_one_token=decode_one_token, **kwargs
+                    )
+                )
+            except Exception as e:
+                item["success"] = False
+                item["response"] = e
+
+            event.set()
+
+    threading.Thread(target=worker, daemon=True).start()
+
+    return input_queue
+
+
 @click.command()
 @click.option(
     "--text",

+ 16 - 7
tools/webui.py

@@ -1,5 +1,6 @@
 import html
 import os
+import threading
 from argparse import ArgumentParser
 from io import BytesIO
 from pathlib import Path
@@ -12,8 +13,7 @@ from loguru import logger
 from torchaudio import functional as AF
 from transformers import AutoTokenizer
 
-from tools.llama.generate import generate_long
-from tools.llama.generate import load_model as load_llama_model
+from tools.llama.generate import launch_thread_safe_queue
 from tools.vqgan.inference import load_model as load_vqgan_model
 
 # Make einx happy
@@ -85,11 +85,9 @@ def inference(
         prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
 
     # LLAMA Inference
-    result = generate_long(
-        model=llama_model,
+    request = dict(
         tokenizer=llama_tokenizer,
         device=vqgan_model.device,
-        decode_one_token=decode_one_token,
         max_new_tokens=max_new_tokens,
         text=text,
         top_k=int(top_k) if top_k > 0 else None,
@@ -105,7 +103,18 @@ def inference(
         prompt_text=reference_text if enable_reference_audio else None,
     )
 
-    codes = next(result)
+    payload = dict(
+        event=threading.Event(),
+        request=request,
+    )
+    llama_queue.put(payload)
+
+    # Wait for the result
+    payload["event"].wait()
+    if payload["success"] is False:
+        raise payload["response"]
+
+    codes = payload["response"][0]
 
     # VQGAN Inference
     feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
@@ -270,7 +279,7 @@ if __name__ == "__main__":
     args.precision = torch.half if args.half else torch.bfloat16
 
     logger.info("Loading Llama model...")
-    llama_model, decode_one_token = load_llama_model(
+    llama_queue = launch_thread_safe_queue(
         config_name=args.llama_config_name,
         checkpoint_path=args.llama_checkpoint_path,
         device=args.device,