Explorar el Código

Add queue to support streaming

Lengyue hace 1 año
padre
commit
dcbe986fc9
Se han modificado 3 ficheros con 52 adiciones y 34 borrados
  1. 16 6
      tools/api.py
  2. 20 22
      tools/llama/generate.py
  3. 16 6
      tools/webui.py

+ 16 - 6
tools/api.py

@@ -1,5 +1,6 @@
 import base64
 import io
+import queue
 import threading
 import traceback
 from argparse import ArgumentParser
@@ -114,17 +115,26 @@ def inference(req: InvokeRequest):
     )
 
     payload = dict(
-        event=threading.Event(),
+        response_queue=queue.Queue(),
         request=request,
     )
     llama_queue.put(payload)
 
-    # Wait for the result
-    payload["event"].wait()
-    if payload["success"] is False:
-        raise payload["response"]
+    codes = []
+    while True:
+        result = payload["response_queue"].get()
+        if result == "next":
+            # TODO: handle next sentence
+            continue
 
-    codes = payload["response"][0]
+        if result == "done":
+            if payload["success"] is False:
+                raise payload["response"]
+            break
+
+        codes.append(result)
+
+    codes = torch.cat(codes, dim=1)
 
     # VQGAN Inference
     feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)

+ 20 - 22
tools/llama/generate.py

@@ -470,16 +470,14 @@ def generate_long(
     texts = split_text(text, chunk_length) if iterative_prompt else [text]
 
     if use_prompt:
-        encoded.append(
-            encode_tokens(
-                tokenizer,
-                prompt_text,
-                prompt_tokens=prompt_tokens,
-                bos=True,
-                device=device,
-                speaker=speaker,
-                num_codebooks=model.config.num_codebooks,
-            )
+        encoded_prompts = encode_tokens(
+            tokenizer,
+            prompt_text,
+            prompt_tokens=prompt_tokens,
+            bos=True,
+            device=device,
+            speaker=speaker,
+            num_codebooks=model.config.num_codebooks,
         )
 
     for idx, text in enumerate(texts):
@@ -501,10 +499,6 @@ def generate_long(
         all_codes = []
         seg_idx = 0
 
-        if use_prompt:
-            seg_idx = 1
-            global_encoded.append(encoded[0])
-
         while seg_idx < len(encoded):
             logger.info(
                 f"Generating sentence {seg_idx + 1}/{len(encoded)} of sample {sample_idx + 1}/{num_samples}"
@@ -531,6 +525,9 @@ def generate_long(
             else:
                 partial_encoded = global_encoded
 
+            if use_prompt:
+                partial_encoded = [encoded_prompts] + partial_encoded
+
             cat_encoded = torch.cat(partial_encoded, dim=1)
             prompt_length = cat_encoded.size(1)
 
@@ -593,7 +590,7 @@ def generate_long(
 
         if is_streaming:
             # This indicates the end of the current sample
-            yield None
+            yield "next"
         else:
             all_codes = torch.cat(all_codes, dim=1)
             assert (all_codes >= 0).all(), f"Negative code found: {codes}"
@@ -623,20 +620,21 @@ def launch_thread_safe_queue(
                 break
 
             kwargs = item["request"]
-            event = item["event"]
+            response_queue = item["response_queue"]
 
             try:
                 item["success"] = True
-                item["response"] = list(
-                    generate_long(
-                        model=model, decode_one_token=decode_one_token, **kwargs
-                    )
-                )
+                for chunk in generate_long(
+                    model=model, decode_one_token=decode_one_token, **kwargs
+                ):
+                    response_queue.put(chunk)
+
+                response_queue.put("done")
             except Exception as e:
                 item["success"] = False
                 item["response"] = e
 
-            event.set()
+                response_queue.put("done")
 
     threading.Thread(target=worker, daemon=True).start()
     init_event.wait()

+ 16 - 6
tools/webui.py

@@ -1,6 +1,7 @@
 import gc
 import html
 import os
+import queue
 import threading
 from argparse import ArgumentParser
 from pathlib import Path
@@ -119,17 +120,26 @@ def inference(
     )
 
     payload = dict(
-        event=threading.Event(),
+        response_queue=queue.Queue(),
         request=request,
     )
     llama_queue.put(payload)
 
-    # Wait for the result
-    payload["event"].wait()
-    if payload["success"] is False:
-        raise payload["response"]
+    codes = []
+    while True:
+        result = payload["response_queue"].get()
+        if result == "next":
+            # TODO: handle next sentence
+            continue
 
-    codes = payload["response"][0]
+        if result == "done":
+            if payload["success"] is False:
+                raise payload["response"]
+            break
+
+        codes.append(result)
+
+    codes = torch.cat(codes, dim=1)
 
     # VQGAN Inference
     feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)