Просмотр исходного кода

Enable API stream generation & reduce WebUI GPU memory usage

Lengyue 1 год назад
Родитель
Сommit
231b8ed6ad
2 измененных файлов с 103 добавлено и 72 удалено
  1. 79 56
      tools/api.py
  2. 24 16
      tools/webui.py

+ 79 - 56
tools/api.py

@@ -3,12 +3,13 @@ import io
 import queue
 import threading
 import traceback
+import wave
 from argparse import ArgumentParser
 from http import HTTPStatus
-from threading import Lock
 from typing import Annotated, Literal, Optional
 
 import librosa
+import numpy as np
 import pyrootutils
 import soundfile as sf
 import torch
@@ -23,7 +24,7 @@ from kui.wsgi import (
 )
 from kui.wsgi.routing import MultimethodRoutes
 from loguru import logger
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
 from transformers import AutoTokenizer
 
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
@@ -32,7 +33,18 @@ from tools.llama.generate import launch_thread_safe_queue
 from tools.vqgan.inference import load_model as load_vqgan_model
 from tools.webui import inference
 
-lock = Lock()
+
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
+    buffer = io.BytesIO()
+
+    with wave.open(buffer, "wb") as wav_file:
+        wav_file.setnchannels(channels)
+        wav_file.setsampwidth(bit_depth // 8)
+        wav_file.setframerate(sample_rate)
+
+    wav_header_bytes = buffer.getvalue()
+    buffer.close()
+    return wav_header_bytes
 
 
 # Define utils for web server
@@ -66,12 +78,13 @@ class InvokeRequest(BaseModel):
     reference_text: Optional[str] = None
     reference_audio: Optional[str] = None
     max_new_tokens: int = 0
-    chunk_length: int = 30
-    top_p: float = 0.7
-    repetition_penalty: float = 1.5
-    temperature: float = 0.7
+    chunk_length: Annotated[int, Field(ge=0, le=200, strict=True)] = 30
+    top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
+    repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.5
+    temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
     speaker: Optional[str] = None
     format: Literal["wav", "mp3", "flac"] = "wav"
+    streaming: bool = False
 
 
 @torch.inference_mode()
@@ -113,6 +126,7 @@ def inference(req: InvokeRequest):
         speaker=req.speaker,
         prompt_tokens=prompt_tokens,
         prompt_text=req.reference_text,
+        is_streaming=True,
     )
 
     payload = dict(
@@ -121,7 +135,10 @@ def inference(req: InvokeRequest):
     )
     llama_queue.put(payload)
 
-    codes = []
+    if req.streaming:
+        yield wav_chunk_header()
+
+    segments = []
     while True:
         result = payload["response_queue"].get()
         if result == "next":
@@ -133,19 +150,22 @@ def inference(req: InvokeRequest):
                 raise payload["response"]
             break
 
-        codes.append(result)
-
-    codes = torch.cat(codes, dim=1)
+        # VQGAN Inference
+        feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
+        fake_audios = vqgan_model.decode(
+            indices=result[None], feature_lengths=feature_lengths, return_audios=True
+        )[0, 0]
+        fake_audios = fake_audios.float().cpu().numpy()
+        fake_audios = np.concatenate([fake_audios, np.zeros((11025,))], axis=0)
 
-    # VQGAN Inference
-    feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
-    fake_audios = vqgan_model.decode(
-        indices=codes[None], feature_lengths=feature_lengths, return_audios=True
-    )[0, 0]
+        if req.streaming:
+            yield (fake_audios * 32768).astype(np.int16).tobytes()
+        else:
+            segments.append(fake_audios)
 
-    fake_audios = fake_audios.float().cpu().numpy()
-
-    return fake_audios
+    if req.streaming is False:
+        fake_audios = np.concatenate(segments, axis=0)
+        yield fake_audios
 
 
 @routes.http.post("/v1/invoke")
@@ -162,32 +182,33 @@ def api_invoke_model(
             content=f"Text is too long, max length is {args.max_text_length}",
         )
 
-    try:
-        # Lock, avoid interrupting the inference process
-        lock.acquire()
-        fake_audios = inference(req)
-    except Exception as e:
-        import traceback
-
-        traceback.print_exc()
-
-        raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, content=str(e))
-    finally:
-        # Release lock
-        lock.release()
+    if req.streaming and req.format != "wav":
+        raise HTTPException(
+            HTTPStatus.BAD_REQUEST,
+            content="Streaming only supports WAV format",
+        )
 
-    buffer = io.BytesIO()
-    sf.write(buffer, fake_audios, vqgan_model.sampling_rate, format=req.format)
-
-    return StreamResponse(
-        iterable=[buffer.getvalue()],
-        headers={
-            "Content-Disposition": f"attachment; filename=audio.{req.format}",
-        },
-        # Make swagger-ui happy
-        # content_type=f"audio/{req.format}",
-        content_type="application/octet-stream",
-    )
+    generator = inference(req)
+    if req.streaming:
+        return StreamResponse(
+            iterable=generator,
+            headers={
+                "Content-Disposition": f"attachment; filename=audio.{req.format}",
+            },
+            content_type="application/octet-stream",
+        )
+    else:
+        fake_audios = next(generator)
+        buffer = io.BytesIO()
+        sf.write(buffer, fake_audios, vqgan_model.sampling_rate, format=req.format)
+
+        return StreamResponse(
+            iterable=[buffer.getvalue()],
+            headers={
+                "Content-Disposition": f"attachment; filename=audio.{req.format}",
+            },
+            content_type="application/octet-stream",
+        )
 
 
 @routes.http.post("/v1/health")
@@ -272,18 +293,20 @@ if __name__ == "__main__":
     logger.info("VQ-GAN model loaded, warming up...")
 
     # Dry run to check if the model is loaded correctly and avoid the first-time latency
-    inference(
-        InvokeRequest(
-            text="A warm-up sentence.",
-            reference_text=None,
-            reference_audio=None,
-            max_new_tokens=0,
-            chunk_length=30,
-            top_p=0.7,
-            repetition_penalty=1.5,
-            temperature=0.7,
-            speaker=None,
-            format="wav",
+    list(
+        inference(
+            InvokeRequest(
+                text="A warm-up sentence.",
+                reference_text=None,
+                reference_audio=None,
+                max_new_tokens=0,
+                chunk_length=30,
+                top_p=0.7,
+                repetition_penalty=1.5,
+                temperature=0.7,
+                speaker=None,
+                format="wav",
+            )
         )
     )
 

+ 24 - 16
tools/webui.py

@@ -121,7 +121,7 @@ def inference(
         speaker=speaker if speaker else None,
         prompt_tokens=prompt_tokens if enable_reference_audio else None,
         prompt_text=reference_text if enable_reference_audio else None,
-        is_streaming=streaming,
+        is_streaming=True,  # Always streaming
     )
 
     payload = dict(
@@ -133,6 +133,7 @@ def inference(
     if streaming:
         yield wav_chunk_header(), None
 
+    segments = []
     while True:
         result = payload["response_queue"].get()
         if result == "next":
@@ -150,13 +151,16 @@ def inference(
             indices=result[None], feature_lengths=feature_lengths, return_audios=True
         )[0, 0]
         fake_audios = fake_audios.float().cpu().numpy()
+        fake_audios = np.concatenate([fake_audios, np.zeros((11025,))], axis=0)
 
         if streaming:
-            yield (
-                np.concatenate([fake_audios, np.zeros((11025,))], axis=0) * 32768
-            ).astype(np.int16).tobytes(), None
+            yield (fake_audios * 32768).astype(np.int16).tobytes()
         else:
-            yield (vqgan_model.sampling_rate, fake_audios), None
+            segments.append(fake_audios)
+
+    if streaming is False:
+        audio = np.concatenate(segments, axis=0)
+        yield (vqgan_model.sampling_rate, audio), None
 
     if torch.cuda.is_available():
         torch.cuda.empty_cache()
@@ -168,10 +172,12 @@ inference_stream = partial(inference, streaming=True)
 
 def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
     buffer = io.BytesIO()
+
     with wave.open(buffer, "wb") as wav_file:
         wav_file.setnchannels(channels)
         wav_file.setsampwidth(bit_depth // 8)
         wav_file.setframerate(sample_rate)
+
     wav_header_bytes = buffer.getvalue()
     buffer.close()
     return wav_header_bytes
@@ -374,17 +380,19 @@ if __name__ == "__main__":
     logger.info("VQ-GAN model loaded, warming up...")
 
     # Dry run to check if the model is loaded correctly and avoid the first-time latency
-    inference(
-        text="Hello, world!",
-        enable_reference_audio=False,
-        reference_audio=None,
-        reference_text="",
-        max_new_tokens=0,
-        chunk_length=0,
-        top_p=0.7,
-        repetition_penalty=1.5,
-        temperature=0.7,
-        speaker=None,
+    list(
+        inference(
+            text="Hello, world!",
+            enable_reference_audio=False,
+            reference_audio=None,
+            reference_text="",
+            max_new_tokens=0,
+            chunk_length=0,
+            top_p=0.7,
+            repetition_penalty=1.5,
+            temperature=0.7,
+            speaker=None,
+        )
     )
 
     logger.info("Warming up done, launching the web UI...")