|
|
@@ -3,12 +3,13 @@ import io
|
|
|
import queue
|
|
|
import threading
|
|
|
import traceback
|
|
|
+import wave
|
|
|
from argparse import ArgumentParser
|
|
|
from http import HTTPStatus
|
|
|
-from threading import Lock
|
|
|
from typing import Annotated, Literal, Optional
|
|
|
|
|
|
import librosa
|
|
|
+import numpy as np
|
|
|
import pyrootutils
|
|
|
import soundfile as sf
|
|
|
import torch
|
|
|
@@ -23,7 +24,7 @@ from kui.wsgi import (
|
|
|
)
|
|
|
from kui.wsgi.routing import MultimethodRoutes
|
|
|
from loguru import logger
|
|
|
-from pydantic import BaseModel
|
|
|
+from pydantic import BaseModel, Field
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
@@ -32,7 +33,18 @@ from tools.llama.generate import launch_thread_safe_queue
|
|
|
from tools.vqgan.inference import load_model as load_vqgan_model
|
|
|
from tools.webui import inference
|
|
|
|
|
|
-lock = Lock()
|
|
|
+
|
|
|
+def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
|
+ buffer = io.BytesIO()
|
|
|
+
|
|
|
+ with wave.open(buffer, "wb") as wav_file:
|
|
|
+ wav_file.setnchannels(channels)
|
|
|
+ wav_file.setsampwidth(bit_depth // 8)
|
|
|
+ wav_file.setframerate(sample_rate)
|
|
|
+
|
|
|
+ wav_header_bytes = buffer.getvalue()
|
|
|
+ buffer.close()
|
|
|
+ return wav_header_bytes
|
|
|
|
|
|
|
|
|
# Define utils for web server
|
|
|
@@ -66,12 +78,13 @@ class InvokeRequest(BaseModel):
|
|
|
reference_text: Optional[str] = None
|
|
|
reference_audio: Optional[str] = None
|
|
|
max_new_tokens: int = 0
|
|
|
- chunk_length: int = 30
|
|
|
- top_p: float = 0.7
|
|
|
- repetition_penalty: float = 1.5
|
|
|
- temperature: float = 0.7
|
|
|
+ chunk_length: Annotated[int, Field(ge=0, le=200, strict=True)] = 30
|
|
|
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
|
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.5
|
|
|
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
|
speaker: Optional[str] = None
|
|
|
format: Literal["wav", "mp3", "flac"] = "wav"
|
|
|
+ streaming: bool = False
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
@@ -113,6 +126,7 @@ def inference(req: InvokeRequest):
|
|
|
speaker=req.speaker,
|
|
|
prompt_tokens=prompt_tokens,
|
|
|
prompt_text=req.reference_text,
|
|
|
+ is_streaming=True,
|
|
|
)
|
|
|
|
|
|
payload = dict(
|
|
|
@@ -121,7 +135,10 @@ def inference(req: InvokeRequest):
|
|
|
)
|
|
|
llama_queue.put(payload)
|
|
|
|
|
|
- codes = []
|
|
|
+ if req.streaming:
|
|
|
+ yield wav_chunk_header()
|
|
|
+
|
|
|
+ segments = []
|
|
|
while True:
|
|
|
result = payload["response_queue"].get()
|
|
|
if result == "next":
|
|
|
@@ -133,19 +150,22 @@ def inference(req: InvokeRequest):
|
|
|
raise payload["response"]
|
|
|
break
|
|
|
|
|
|
- codes.append(result)
|
|
|
-
|
|
|
- codes = torch.cat(codes, dim=1)
|
|
|
+ # VQGAN Inference
|
|
|
+ feature_lengths = torch.tensor([result.shape[1]], device=vqgan_model.device)
|
|
|
+ fake_audios = vqgan_model.decode(
|
|
|
+ indices=result[None], feature_lengths=feature_lengths, return_audios=True
|
|
|
+ )[0, 0]
|
|
|
+ fake_audios = fake_audios.float().cpu().numpy()
|
|
|
+ fake_audios = np.concatenate([fake_audios, np.zeros((11025,))], axis=0)
|
|
|
|
|
|
- # VQGAN Inference
|
|
|
- feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
|
|
|
- fake_audios = vqgan_model.decode(
|
|
|
- indices=codes[None], feature_lengths=feature_lengths, return_audios=True
|
|
|
- )[0, 0]
|
|
|
+ if req.streaming:
|
|
|
+ yield (fake_audios * 32768).astype(np.int16).tobytes()
|
|
|
+ else:
|
|
|
+ segments.append(fake_audios)
|
|
|
|
|
|
- fake_audios = fake_audios.float().cpu().numpy()
|
|
|
-
|
|
|
- return fake_audios
|
|
|
+ if req.streaming is False:
|
|
|
+ fake_audios = np.concatenate(segments, axis=0)
|
|
|
+ yield fake_audios
|
|
|
|
|
|
|
|
|
@routes.http.post("/v1/invoke")
|
|
|
@@ -162,32 +182,33 @@ def api_invoke_model(
|
|
|
content=f"Text is too long, max length is {args.max_text_length}",
|
|
|
)
|
|
|
|
|
|
- try:
|
|
|
- # Lock, avoid interrupting the inference process
|
|
|
- lock.acquire()
|
|
|
- fake_audios = inference(req)
|
|
|
- except Exception as e:
|
|
|
- import traceback
|
|
|
-
|
|
|
- traceback.print_exc()
|
|
|
-
|
|
|
- raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, content=str(e))
|
|
|
- finally:
|
|
|
- # Release lock
|
|
|
- lock.release()
|
|
|
+ if req.streaming and req.format != "wav":
|
|
|
+ raise HTTPException(
|
|
|
+ HTTPStatus.BAD_REQUEST,
|
|
|
+ content="Streaming only supports WAV format",
|
|
|
+ )
|
|
|
|
|
|
- buffer = io.BytesIO()
|
|
|
- sf.write(buffer, fake_audios, vqgan_model.sampling_rate, format=req.format)
|
|
|
-
|
|
|
- return StreamResponse(
|
|
|
- iterable=[buffer.getvalue()],
|
|
|
- headers={
|
|
|
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
- },
|
|
|
- # Make swagger-ui happy
|
|
|
- # content_type=f"audio/{req.format}",
|
|
|
- content_type="application/octet-stream",
|
|
|
- )
|
|
|
+ generator = inference(req)
|
|
|
+ if req.streaming:
|
|
|
+ return StreamResponse(
|
|
|
+ iterable=generator,
|
|
|
+ headers={
|
|
|
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
+ },
|
|
|
+ content_type="application/octet-stream",
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ fake_audios = next(generator)
|
|
|
+ buffer = io.BytesIO()
|
|
|
+ sf.write(buffer, fake_audios, vqgan_model.sampling_rate, format=req.format)
|
|
|
+
|
|
|
+ return StreamResponse(
|
|
|
+ iterable=[buffer.getvalue()],
|
|
|
+ headers={
|
|
|
+ "Content-Disposition": f"attachment; filename=audio.{req.format}",
|
|
|
+ },
|
|
|
+ content_type="application/octet-stream",
|
|
|
+ )
|
|
|
|
|
|
|
|
|
@routes.http.post("/v1/health")
|
|
|
@@ -272,18 +293,20 @@ if __name__ == "__main__":
|
|
|
logger.info("VQ-GAN model loaded, warming up...")
|
|
|
|
|
|
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
|
|
- inference(
|
|
|
- InvokeRequest(
|
|
|
- text="A warm-up sentence.",
|
|
|
- reference_text=None,
|
|
|
- reference_audio=None,
|
|
|
- max_new_tokens=0,
|
|
|
- chunk_length=30,
|
|
|
- top_p=0.7,
|
|
|
- repetition_penalty=1.5,
|
|
|
- temperature=0.7,
|
|
|
- speaker=None,
|
|
|
- format="wav",
|
|
|
+ list(
|
|
|
+ inference(
|
|
|
+ InvokeRequest(
|
|
|
+ text="A warm-up sentence.",
|
|
|
+ reference_text=None,
|
|
|
+ reference_audio=None,
|
|
|
+ max_new_tokens=0,
|
|
|
+ chunk_length=30,
|
|
|
+ top_p=0.7,
|
|
|
+ repetition_penalty=1.5,
|
|
|
+ temperature=0.7,
|
|
|
+ speaker=None,
|
|
|
+ format="wav",
|
|
|
+ )
|
|
|
)
|
|
|
)
|
|
|
|