Browse Source

Implement thread safe compile

Lengyue 1 year ago
parent
commit
a095c8f6c7
4 changed files with 95 additions and 28 deletions
  1. 1 1
      pyproject.toml
  2. 34 16
      tools/api.py
  3. 44 4
      tools/llama/generate.py
  4. 16 7
      tools/webui.py

+ 1 - 1
pyproject.toml

@@ -35,7 +35,7 @@ dependencies = [
     "samplerate>=0.2.1",
     "samplerate>=0.2.1",
     "resampy>=0.4.3",
     "resampy>=0.4.3",
     "spaces>=0.26.1",
     "spaces>=0.26.1",
-    "einx[torch]==0.2.0"
+    "einx[torch]==0.2.2"
 ]
 ]
 
 
 [project.optional-dependencies]
 [project.optional-dependencies]

+ 34 - 16
tools/api.py

@@ -1,5 +1,6 @@
 import base64
 import base64
 import io
 import io
+import threading
 import traceback
 import traceback
 from argparse import ArgumentParser
 from argparse import ArgumentParser
 from http import HTTPStatus
 from http import HTTPStatus
@@ -17,15 +18,13 @@ from kui.wsgi import (
     Kui,
     Kui,
     OpenAPI,
     OpenAPI,
     StreamResponse,
     StreamResponse,
-    allow_cors,
 )
 )
 from kui.wsgi.routing import MultimethodRoutes
 from kui.wsgi.routing import MultimethodRoutes
 from loguru import logger
 from loguru import logger
 from pydantic import BaseModel
 from pydantic import BaseModel
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 
 
-from tools.llama.generate import generate_long
-from tools.llama.generate import load_model as load_llama_model
+from tools.llama.generate import launch_thread_safe_queue
 from tools.vqgan.inference import load_model as load_vqgan_model
 from tools.vqgan.inference import load_model as load_vqgan_model
 from tools.webui import inference
 from tools.webui import inference
 
 
@@ -95,11 +94,9 @@ def inference(req: InvokeRequest):
         prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
         prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
 
 
     # LLAMA Inference
     # LLAMA Inference
-    result = generate_long(
-        model=llama_model,
+    request = dict(
         tokenizer=llama_tokenizer,
         tokenizer=llama_tokenizer,
         device=vqgan_model.device,
         device=vqgan_model.device,
-        decode_one_token=decode_one_token,
         max_new_tokens=req.max_new_tokens,
         max_new_tokens=req.max_new_tokens,
         text=req.text,
         text=req.text,
         top_k=int(req.top_k) if req.top_k > 0 else None,
         top_k=int(req.top_k) if req.top_k > 0 else None,
@@ -115,7 +112,18 @@ def inference(req: InvokeRequest):
         prompt_text=req.reference_text,
         prompt_text=req.reference_text,
     )
     )
 
 
-    codes = next(result)
+    payload = dict(
+        event=threading.Event(),
+        request=request,
+    )
+    llama_queue.put(payload)
+
+    # Wait for the result
+    payload["event"].wait()
+    if payload["success"] is False:
+        raise payload["response"]
+
+    codes = payload["response"][0]
 
 
     # VQGAN Inference
     # VQGAN Inference
     feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
     feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
@@ -128,7 +136,7 @@ def inference(req: InvokeRequest):
     return fake_audios
     return fake_audios
 
 
 
 
-@routes.http.post("/invoke")
+@routes.http.post("/v1/invoke")
 def api_invoke_model(
 def api_invoke_model(
     req: Annotated[InvokeRequest, Body(exclusive=True)],
     req: Annotated[InvokeRequest, Body(exclusive=True)],
 ):
 ):
@@ -139,7 +147,7 @@ def api_invoke_model(
     if args.max_gradio_length > 0 and len(req.text) > args.max_gradio_length:
     if args.max_gradio_length > 0 and len(req.text) > args.max_gradio_length:
         raise HTTPException(
         raise HTTPException(
             HTTPStatus.BAD_REQUEST,
             HTTPStatus.BAD_REQUEST,
-            f"Text is too long, max length is {args.max_gradio_length}",
+            content=f"Text is too long, max length is {args.max_gradio_length}",
         )
         )
 
 
     try:
     try:
@@ -147,7 +155,11 @@ def api_invoke_model(
         lock.acquire()
         lock.acquire()
         fake_audios = inference(req)
         fake_audios = inference(req)
     except Exception as e:
     except Exception as e:
-        raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
+        import traceback
+
+        traceback.print_exc()
+
+        raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, content=str(e))
     finally:
     finally:
         # Release lock
         # Release lock
         lock.release()
         lock.release()
@@ -159,12 +171,14 @@ def api_invoke_model(
         iterable=[buffer.getvalue()],
         iterable=[buffer.getvalue()],
         headers={
         headers={
             "Content-Disposition": f"attachment; filename=audio.{req.format}",
             "Content-Disposition": f"attachment; filename=audio.{req.format}",
-            "Content-Type": "application/octet-stream",
         },
         },
+        # Make swagger-ui happy
+        # content_type=f"audio/{req.format}",
+        content_type="application/octet-stream",
     )
     )
 
 
 
 
-@routes.http.post("/health")
+@routes.http.post("/v1/health")
 def api_health():
 def api_health():
     """
     """
     Health check
     Health check
@@ -201,7 +215,14 @@ def parse_args():
 
 
 
 
 # Define Kui app
 # Define Kui app
+openapi = OpenAPI(
+    {
+        "title": "Fish Speech API",
+    },
+).routes
+
 app = Kui(
 app = Kui(
+    routes=routes + openapi[1:],  # Remove the default route
     exception_handlers={
     exception_handlers={
         HTTPException: http_execption_handler,
         HTTPException: http_execption_handler,
         Exception: other_exception_handler,
         Exception: other_exception_handler,
@@ -209,9 +230,6 @@ app = Kui(
     cors_config={},
     cors_config={},
 )
 )
 
 
-# Swagger UI & routes
-app.router << ("/v1" // routes) << ("/docs" // OpenAPI().routes)
-
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     import threading
     import threading
@@ -222,7 +240,7 @@ if __name__ == "__main__":
     args.precision = torch.half if args.half else torch.bfloat16
     args.precision = torch.half if args.half else torch.bfloat16
 
 
     logger.info("Loading Llama model...")
     logger.info("Loading Llama model...")
-    llama_model, decode_one_token = load_llama_model(
+    llama_queue = launch_thread_safe_queue(
         config_name=args.llama_config_name,
         config_name=args.llama_config_name,
         checkpoint_path=args.llama_checkpoint_path,
         checkpoint_path=args.llama_checkpoint_path,
         device=args.device,
         device=args.device,

+ 44 - 4
tools/llama/generate.py

@@ -1,4 +1,6 @@
 import os
 import os
+import queue
+import threading
 import time
 import time
 from pathlib import Path
 from pathlib import Path
 from typing import Optional, Tuple, Union
 from typing import Optional, Tuple, Union
@@ -567,10 +569,7 @@ def generate_long(
             codes = y[1:, prompt_length:-2].clone()
             codes = y[1:, prompt_length:-2].clone()
 
 
             codes = codes - 2
             codes = codes - 2
-            if not (codes >= 0).all():
-                global_encoded.pop()
-                logger.warning(f"Negative code found: {codes}, retrying ...")
-                continue
+            assert (codes >= 0).all(), f"Negative code found"
 
 
             decoded = y[:, prompt_length:-1].clone()
             decoded = y[:, prompt_length:-1].clone()
             if decoded[0, -1] != im_end_id:  # <im_end>
             if decoded[0, -1] != im_end_id:  # <im_end>
@@ -599,6 +598,47 @@ def generate_long(
             yield all_codes
             yield all_codes
 
 
 
 
+def launch_thread_safe_queue(
+    config_name,
+    checkpoint_path,
+    device,
+    precision,
+    max_length,
+    compile=False,
+):
+    input_queue = queue.Queue()
+
+    def worker():
+        model, decode_one_token = load_model(
+            config_name, checkpoint_path, device, precision, max_length, compile=compile
+        )
+
+        while True:
+            item = input_queue.get()
+            if item is None:
+                break
+
+            kwargs = item["request"]
+            event = item["event"]
+
+            try:
+                item["success"] = True
+                item["response"] = list(
+                    generate_long(
+                        model=model, decode_one_token=decode_one_token, **kwargs
+                    )
+                )
+            except Exception as e:
+                item["success"] = False
+                item["response"] = e
+
+            event.set()
+
+    threading.Thread(target=worker, daemon=True).start()
+
+    return input_queue
+
+
 @click.command()
 @click.command()
 @click.option(
 @click.option(
     "--text",
     "--text",

+ 16 - 7
tools/webui.py

@@ -1,5 +1,6 @@
 import html
 import html
 import os
 import os
+import threading
 from argparse import ArgumentParser
 from argparse import ArgumentParser
 from io import BytesIO
 from io import BytesIO
 from pathlib import Path
 from pathlib import Path
@@ -12,8 +13,7 @@ from loguru import logger
 from torchaudio import functional as AF
 from torchaudio import functional as AF
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 
 
-from tools.llama.generate import generate_long
-from tools.llama.generate import load_model as load_llama_model
+from tools.llama.generate import launch_thread_safe_queue
 from tools.vqgan.inference import load_model as load_vqgan_model
 from tools.vqgan.inference import load_model as load_vqgan_model
 
 
 # Make einx happy
 # Make einx happy
@@ -85,11 +85,9 @@ def inference(
         prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
         prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
 
 
     # LLAMA Inference
     # LLAMA Inference
-    result = generate_long(
-        model=llama_model,
+    request = dict(
         tokenizer=llama_tokenizer,
         tokenizer=llama_tokenizer,
         device=vqgan_model.device,
         device=vqgan_model.device,
-        decode_one_token=decode_one_token,
         max_new_tokens=max_new_tokens,
         max_new_tokens=max_new_tokens,
         text=text,
         text=text,
         top_k=int(top_k) if top_k > 0 else None,
         top_k=int(top_k) if top_k > 0 else None,
@@ -105,7 +103,18 @@ def inference(
         prompt_text=reference_text if enable_reference_audio else None,
         prompt_text=reference_text if enable_reference_audio else None,
     )
     )
 
 
-    codes = next(result)
+    payload = dict(
+        event=threading.Event(),
+        request=request,
+    )
+    llama_queue.put(payload)
+
+    # Wait for the result
+    payload["event"].wait()
+    if payload["success"] is False:
+        raise payload["response"]
+
+    codes = payload["response"][0]
 
 
     # VQGAN Inference
     # VQGAN Inference
     feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
     feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
@@ -270,7 +279,7 @@ if __name__ == "__main__":
     args.precision = torch.half if args.half else torch.bfloat16
     args.precision = torch.half if args.half else torch.bfloat16
 
 
     logger.info("Loading Llama model...")
     logger.info("Loading Llama model...")
-    llama_model, decode_one_token = load_llama_model(
+    llama_queue = launch_thread_safe_queue(
         config_name=args.llama_config_name,
         config_name=args.llama_config_name,
         checkpoint_path=args.llama_checkpoint_path,
         checkpoint_path=args.llama_checkpoint_path,
         device=args.device,
         device=args.device,