|
|
@@ -1,7 +1,8 @@
|
|
|
import io
|
|
|
import os
|
|
|
import queue
|
|
|
-import sys
|
|
|
+import re
|
|
|
+import time
|
|
|
import traceback
|
|
|
import wave
|
|
|
from argparse import ArgumentParser
|
|
|
@@ -9,6 +10,7 @@ from http import HTTPStatus
|
|
|
from pathlib import Path
|
|
|
from typing import Annotated, Any
|
|
|
|
|
|
+import librosa
|
|
|
import numpy as np
|
|
|
import ormsgpack
|
|
|
import pyrootutils
|
|
|
@@ -26,26 +28,67 @@ from kui.asgi import (
|
|
|
Kui,
|
|
|
OpenAPI,
|
|
|
StreamResponse,
|
|
|
+ request,
|
|
|
)
|
|
|
from kui.asgi.routing import MultimethodRoutes
|
|
|
from loguru import logger
|
|
|
+from transformers import AutoTokenizer
|
|
|
|
|
|
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
+import struct
|
|
|
+from threading import Lock
|
|
|
+
|
|
|
+import httpx
|
|
|
+from cachetools import LRUCache, cached
|
|
|
+from funasr import AutoModel
|
|
|
+from silero_vad import get_speech_timestamps, load_silero_vad
|
|
|
+
|
|
|
+from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
|
|
|
+from fish_speech.models.text2semantic.llama import BaseModelArgs
|
|
|
|
|
|
# from fish_speech.models.vqgan.lit_module import VQGAN
|
|
|
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
|
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
|
from fish_speech.utils import autocast_exclude_mps, set_seed
|
|
|
-from tools.commons import ServeTTSRequest
|
|
|
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
|
|
from tools.llama.generate import (
|
|
|
GenerateRequest,
|
|
|
GenerateResponse,
|
|
|
WrappedGenerateResponse,
|
|
|
launch_thread_safe_queue,
|
|
|
+ launch_thread_safe_queue_agent,
|
|
|
+)
|
|
|
+from tools.schema import (
|
|
|
+ GLOBAL_NUM_SAMPLES,
|
|
|
+ ASRPackRequest,
|
|
|
+ ServeASRRequest,
|
|
|
+ ServeASRResponse,
|
|
|
+ ServeASRSegment,
|
|
|
+ ServeAudioPart,
|
|
|
+ ServeForwardMessage,
|
|
|
+ ServeMessage,
|
|
|
+ ServeRequest,
|
|
|
+ ServeResponse,
|
|
|
+ ServeStreamDelta,
|
|
|
+ ServeStreamResponse,
|
|
|
+ ServeTextPart,
|
|
|
+ ServeTimedASRResponse,
|
|
|
+ ServeTTSRequest,
|
|
|
+ ServeVQGANDecodeRequest,
|
|
|
+ ServeVQGANDecodeResponse,
|
|
|
+ ServeVQGANEncodeRequest,
|
|
|
+ ServeVQGANEncodeResponse,
|
|
|
+ ServeVQPart,
|
|
|
)
|
|
|
from tools.vqgan.inference import load_model as load_decoder_model
|
|
|
|
|
|
+global_lock = Lock()
|
|
|
+
|
|
|
+# Whether to disable keepalive (which is helpful if the server is in the same cluster)
|
|
|
+DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
|
|
|
+async_client = httpx.AsyncClient(
|
|
|
+ timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
|
|
|
+)
|
|
|
backends = torchaudio.list_audio_backends()
|
|
|
|
|
|
if "ffmpeg" in backends:
|
|
|
@@ -169,6 +212,385 @@ def get_content_type(audio_format):
|
|
|
return "application/octet-stream"
|
|
|
|
|
|
|
|
|
+@torch.no_grad()
|
|
|
+@torch.autocast(device_type="cuda", dtype=torch.half)
|
|
|
+def batch_encode(model, audios: list[bytes | torch.Tensor]):
|
|
|
+ audios = [
|
|
|
+ (
|
|
|
+ torch.from_numpy(
|
|
|
+ librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
|
|
|
+ )[None]
|
|
|
+ if isinstance(audio, bytes)
|
|
|
+ else audio
|
|
|
+ )
|
|
|
+ for audio in audios
|
|
|
+ ]
|
|
|
+
|
|
|
+ # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
|
|
|
+ # raise ValueError("Single audio length is too long (>120s)")
|
|
|
+
|
|
|
+ max_length = max(audio.shape[-1] for audio in audios)
|
|
|
+ print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
|
|
|
+
|
|
|
+ lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
|
|
|
+ max_length = lengths.max().item()
|
|
|
+ padded = torch.stack(
|
|
|
+ [
|
|
|
+ torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
|
|
|
+ for audio in audios
|
|
|
+ ]
|
|
|
+ ).to(model.device)
|
|
|
+
|
|
|
+ features, feature_lengths = model.encode(padded, audio_lengths=lengths)
|
|
|
+ features, feature_lengths = features.cpu(), feature_lengths.cpu()
|
|
|
+
|
|
|
+ return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
|
|
|
+
|
|
|
+
|
|
|
+@cached(
|
|
|
+ cache=LRUCache(maxsize=10000),
|
|
|
+ key=lambda model, audios: (model.device, tuple(audios)),
|
|
|
+)
|
|
|
+def cached_vqgan_batch_encode(model, audios: list[bytes]):
|
|
|
+ return batch_encode(model, audios)
|
|
|
+
|
|
|
+
|
|
|
+@routes.http.post("/v1/vqgan/encode")
|
|
|
+def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
|
|
|
+
|
|
|
+ start_time = time.time()
|
|
|
+ tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
|
|
|
+ logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
|
|
|
+
|
|
|
+ return ormsgpack.packb(
|
|
|
+ ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
|
|
|
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@torch.no_grad()
|
|
|
+@torch.autocast(device_type="cuda", dtype=torch.half)
|
|
|
+def vqgan_decode(model, features):
|
|
|
+ lengths = torch.tensor(
|
|
|
+ [feature.shape[-1] for feature in features], device=model.device
|
|
|
+ )
|
|
|
+ max_length = lengths.max().item()
|
|
|
+ padded = torch.stack(
|
|
|
+ [
|
|
|
+ torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
|
|
|
+ for feature in features
|
|
|
+ ]
|
|
|
+ ).to(model.device)
|
|
|
+
|
|
|
+ # If bs too large, we do micro batch decode
|
|
|
+ audios, audio_lengths = [], []
|
|
|
+ for i in range(0, padded.shape[0], 8):
|
|
|
+ audio, audio_length = model.decode(
|
|
|
+ padded[i : i + 8], feature_lengths=lengths[i : i + 8]
|
|
|
+ )
|
|
|
+ audios.append(audio)
|
|
|
+ audio_lengths.append(audio_length)
|
|
|
+ audios = torch.cat(audios, dim=0)
|
|
|
+ audio_lengths = torch.cat(audio_lengths, dim=0)
|
|
|
+ audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
|
|
|
+
|
|
|
+ return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
|
|
|
+
|
|
|
+
|
|
|
+@routes.http.post("/v1/vqgan/decode")
|
|
|
+def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
|
|
|
+ tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
|
|
|
+ start_time = time.time()
|
|
|
+ audios = vqgan_decode(decoder_model, tokens)
|
|
|
+ logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
|
|
|
+ audios = [audio.astype(np.float16).tobytes() for audio in audios]
|
|
|
+ return ormsgpack.packb(
|
|
|
+ ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@torch.no_grad()
|
|
|
+def batch_asr(model, audios, sr, language="auto"):
|
|
|
+ resampled_audios = []
|
|
|
+ for audio in audios:
|
|
|
+ audio = torchaudio.functional.resample(audio, sr, 16000)
|
|
|
+ assert audio.ndim == 1
|
|
|
+ resampled_audios.append(audio)
|
|
|
+
|
|
|
+ with global_lock:
|
|
|
+ res = model.generate(
|
|
|
+ input=resampled_audios,
|
|
|
+ batch_size=len(resampled_audios),
|
|
|
+ language=language,
|
|
|
+ use_itn=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for r, audio in zip(res, audios):
|
|
|
+ text = r["text"]
|
|
|
+ text = re.sub(r"<\|.*?\|>", "", text)
|
|
|
+ duration = len(audio) / sr * 1000
|
|
|
+ huge_gap = False
|
|
|
+
|
|
|
+ if "timestamp" in r and len(r["timestamp"]) > 2:
|
|
|
+ for timestamp_a, timestamp_b in zip(
|
|
|
+ r["timestamp"][:-1], r["timestamp"][1:]
|
|
|
+ ):
|
|
|
+ # If there is a gap of more than 5 seconds, we consider it as a huge gap
|
|
|
+ if timestamp_b[0] - timestamp_a[1] > 5000:
|
|
|
+ huge_gap = True
|
|
|
+ break
|
|
|
+
|
|
|
+ # Doesn't make sense to have a huge gap at the end
|
|
|
+ if duration - r["timestamp"][-1][1] > 3000:
|
|
|
+ huge_gap = True
|
|
|
+
|
|
|
+ results.append(
|
|
|
+ {
|
|
|
+ "text": text,
|
|
|
+ "duration": duration,
|
|
|
+ "huge_gap": huge_gap,
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+
|
|
|
+@routes.http.post("/v1/asr")
|
|
|
+def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
|
|
|
+ start_time = time.time()
|
|
|
+ audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
|
|
|
+ audios = [torch.from_numpy(audio).float() for audio in audios]
|
|
|
+
|
|
|
+ if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
|
|
|
+ raise HTTPException(status_code=400, detail="Audio length is too long")
|
|
|
+
|
|
|
+ transcriptions = batch_asr(
|
|
|
+ asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
|
|
|
+ )
|
|
|
+ logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
|
|
|
+
|
|
|
+ return ormsgpack.packb(
|
|
|
+ ServeASRResponse(transcriptions=transcriptions),
|
|
|
+ option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+from fish_speech.conversation import Conversation, Message
|
|
|
+
|
|
|
+
|
|
|
+def execute_request(
|
|
|
+ input_queue: queue.Queue,
|
|
|
+ tokenizer: AutoTokenizer,
|
|
|
+ config: BaseModelArgs,
|
|
|
+ request: ServeRequest,
|
|
|
+ device: str = "cuda:0",
|
|
|
+):
|
|
|
+ semantic_id, im_end_id = tokenizer.convert_tokens_to_ids(
|
|
|
+ [SEMANTIC_TOKEN, IM_END_TOKEN]
|
|
|
+ )
|
|
|
+ messages = []
|
|
|
+ for message in request.messages:
|
|
|
+ messages.append(message.to_conversation_message())
|
|
|
+
|
|
|
+ assert len(messages) >= 1, "At least one message is required"
|
|
|
+ # assert messages[-1].role == "user", "The last message must be from the user"
|
|
|
+
|
|
|
+ if messages[-1].role == "user":
|
|
|
+ messages.append(Message(role="assistant", parts=[], add_im_end=False))
|
|
|
+ else:
|
|
|
+ assert (
|
|
|
+ messages[-1].role == "assistant"
|
|
|
+ ), "The last message must be from the assistant"
|
|
|
+ messages[-1].add_im_end = False
|
|
|
+
|
|
|
+ conv = Conversation(messages=messages)
|
|
|
+ prompt = conv.encode_for_inference(
|
|
|
+ tokenizer=tokenizer, num_codebooks=config.num_codebooks
|
|
|
+ ).to(device)
|
|
|
+
|
|
|
+ if request.streaming:
|
|
|
+ for i in range(request.num_samples):
|
|
|
+ yield ServeStreamResponse(
|
|
|
+ sample_id=i,
|
|
|
+ delta=ServeStreamDelta(
|
|
|
+ role="assistant",
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ req = {
|
|
|
+ "prompt": prompt,
|
|
|
+ "max_new_tokens": request.max_new_tokens,
|
|
|
+ "im_end_id": im_end_id,
|
|
|
+ "semantic_id": semantic_id,
|
|
|
+ "temperature": request.temperature,
|
|
|
+ "top_p": request.top_p,
|
|
|
+ "repetition_penalty": request.repetition_penalty,
|
|
|
+ "num_samples": request.num_samples,
|
|
|
+ "early_stop_threshold": request.early_stop_threshold,
|
|
|
+ }
|
|
|
+
|
|
|
+ start = time.time()
|
|
|
+ response_queue = queue.Queue()
|
|
|
+ input_queue.put(GenerateRequest(req, response_queue))
|
|
|
+
|
|
|
+ # Decoding
|
|
|
+ decode_buffer = [[] for _ in range(request.num_samples)]
|
|
|
+ parts = [[] for _ in range(request.num_samples)]
|
|
|
+
|
|
|
+ def send_reset_buffer(sample_id):
|
|
|
+ nonlocal decode_buffer
|
|
|
+ if len(decode_buffer[sample_id]) == 0:
|
|
|
+ return
|
|
|
+
|
|
|
+ decoded = tokenizer.decode(decode_buffer[sample_id])
|
|
|
+ part = ServeTextPart(text=decoded)
|
|
|
+
|
|
|
+ if request.streaming:
|
|
|
+ yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
|
|
|
+ else:
|
|
|
+ parts[sample_id].append(part)
|
|
|
+
|
|
|
+ decode_buffer[sample_id] = []
|
|
|
+
|
|
|
+ # Decode process
|
|
|
+ finished = [False for _ in range(request.num_samples)]
|
|
|
+ stats = {}
|
|
|
+ idx = 0
|
|
|
+ while True:
|
|
|
+ response = response_queue.get()
|
|
|
+
|
|
|
+ if response in ["stop", "error"]:
|
|
|
+ break
|
|
|
+
|
|
|
+ for sample_id, tokens in enumerate(response):
|
|
|
+ if finished[sample_id]:
|
|
|
+ continue
|
|
|
+
|
|
|
+ if tokens[0] == im_end_id:
|
|
|
+ finished[sample_id] = True
|
|
|
+ if request.streaming:
|
|
|
+ yield from send_reset_buffer(sample_id)
|
|
|
+ yield ServeStreamResponse(
|
|
|
+ sample_id=sample_id,
|
|
|
+ finish_reason="stop",
|
|
|
+ stats=stats,
|
|
|
+ )
|
|
|
+ continue
|
|
|
+
|
|
|
+ if tokens[0] == semantic_id and request.streaming:
|
|
|
+ yield from send_reset_buffer(sample_id)
|
|
|
+ # Streaming vq
|
|
|
+ _tokens = tokens[1:].clone() - 1
|
|
|
+
|
|
|
+ if config.share_codebook_embeddings is False:
|
|
|
+ for i in range(len(_tokens)):
|
|
|
+ _tokens[i] -= config.codebook_size * i
|
|
|
+
|
|
|
+ yield ServeStreamResponse(
|
|
|
+ sample_id=sample_id,
|
|
|
+ delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
|
|
|
+ )
|
|
|
+ continue
|
|
|
+
|
|
|
+ # Not streaming vq
|
|
|
+ if tokens[0] == semantic_id:
|
|
|
+ yield from send_reset_buffer(sample_id)
|
|
|
+ # None streaming vq
|
|
|
+ if len(parts[sample_id]) == 0 or not isinstance(
|
|
|
+ parts[sample_id][-1], ServeVQPart
|
|
|
+ ):
|
|
|
+ _tokens = tokens[1:].clone() - 1
|
|
|
+
|
|
|
+ if config.share_codebook_embeddings is False:
|
|
|
+ for i in range(len(_tokens)):
|
|
|
+ _tokens[i] -= config.codebook_size * i
|
|
|
+
|
|
|
+ parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
|
|
|
+ else:
|
|
|
+ for codebook_id, value in enumerate(tokens[1:, :]):
|
|
|
+ val = value.item() - 1
|
|
|
+ if config.share_codebook_embeddings is False:
|
|
|
+ val -= config.codebook_size * codebook_id
|
|
|
+
|
|
|
+ parts[sample_id][-1].codes[codebook_id].append(val)
|
|
|
+ continue
|
|
|
+
|
|
|
+ if tokens[0] != semantic_id:
|
|
|
+ # Stream text decode is not supported now
|
|
|
+ decode_buffer[sample_id].append(tokens[0, 0])
|
|
|
+
|
|
|
+ if idx == 0:
|
|
|
+ stats["time_to_first_token"] = (time.time() - start) * 1000
|
|
|
+
|
|
|
+ idx += 1
|
|
|
+
|
|
|
+ for sample_id in range(request.num_samples):
|
|
|
+ yield from send_reset_buffer(sample_id)
|
|
|
+
|
|
|
+ stats["total_time"] = (time.time() - start) * 1000
|
|
|
+ stats["total_tokens"] = idx
|
|
|
+
|
|
|
+ if request.streaming:
|
|
|
+ for sample_id in range(request.num_samples):
|
|
|
+ if finished[sample_id]:
|
|
|
+ continue
|
|
|
+ yield ServeStreamResponse(
|
|
|
+ finish_reason=response, stats=stats, sample_id=sample_id
|
|
|
+ )
|
|
|
+ return
|
|
|
+
|
|
|
+ yield ServeResponse(
|
|
|
+ messages=[
|
|
|
+ ServeMessage(role="assistant", parts=parts[i])
|
|
|
+ for i in range(request.num_samples)
|
|
|
+ ],
|
|
|
+ finish_reason=response,
|
|
|
+ stats=stats,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@routes.http.post("/v1/chat")
|
|
|
+def api_invoke_chat(
|
|
|
+ req: Annotated[ServeRequest, Body(exclusive=True)],
|
|
|
+):
|
|
|
+ """
|
|
|
+ Invoke model and generate audio
|
|
|
+ """
|
|
|
+
|
|
|
+ # This makes torch compile happy
|
|
|
+ assert (
|
|
|
+ req.num_samples == GLOBAL_NUM_SAMPLES
|
|
|
+ ), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
|
|
|
+
|
|
|
+ content_type = request.headers.get("Content-Type", "application/json")
|
|
|
+ json_mode = "application/json" in content_type
|
|
|
+
|
|
|
+ async def wrapped_generator():
|
|
|
+ generator = execute_request(llama_queue, tokenizer, config, req, args.device)
|
|
|
+
|
|
|
+ for i in generator:
|
|
|
+ if json_mode:
|
|
|
+ body = i.model_dump_json().encode("utf-8")
|
|
|
+ yield b"data: " + body + b"\n\n"
|
|
|
+ else:
|
|
|
+ body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
|
|
+ yield struct.pack("I", len(body)) + body
|
|
|
+
|
|
|
+ # Naive mode
|
|
|
+ if req.streaming is False:
|
|
|
+ result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
|
|
|
+
|
|
|
+ if json_mode:
|
|
|
+ return JSONResponse(result.model_dump())
|
|
|
+ else:
|
|
|
+ return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
|
|
+
|
|
|
+ return StreamResponse(
|
|
|
+ iterable=wrapped_generator(), content_type="text/event-stream"
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
@torch.inference_mode()
|
|
|
def inference(req: ServeTTSRequest):
|
|
|
|
|
|
@@ -360,6 +782,8 @@ async def api_health():
|
|
|
|
|
|
def parse_args():
|
|
|
parser = ArgumentParser()
|
|
|
+ parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
|
|
|
+ parser.add_argument("--load-asr-model", action="store_true")
|
|
|
parser.add_argument(
|
|
|
"--llama-checkpoint-path",
|
|
|
type=str,
|
|
|
@@ -419,6 +843,15 @@ app = Kui(
|
|
|
)
|
|
|
|
|
|
|
|
|
+def load_asr_model(*, device="cuda", hub="ms"):
|
|
|
+ return AutoModel(
|
|
|
+ model="iic/SenseVoiceSmall",
|
|
|
+ device=device,
|
|
|
+ disable_pbar=True,
|
|
|
+ hub=hub,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
# Each worker process created by Uvicorn has its own memory space,
|
|
|
# meaning that models and variables are not shared between processes.
|
|
|
# Therefore, any global variables (like `llama_queue` or `decoder_model`)
|
|
|
@@ -431,20 +864,33 @@ app = Kui(
|
|
|
@app.on_startup
|
|
|
def initialize_app(app: Kui):
|
|
|
|
|
|
- global args, llama_queue, decoder_model, prompt_tokens, prompt_texts
|
|
|
+ global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
|
|
|
|
|
|
prompt_tokens, prompt_texts = [], []
|
|
|
|
|
|
args = parse_args() # args same as ones in other processes
|
|
|
args.precision = torch.half if args.half else torch.bfloat16
|
|
|
|
|
|
+ if args.load_asr_model:
|
|
|
+ logger.info(f"Loading ASR model...")
|
|
|
+ asr_model = load_asr_model(device=args.device)
|
|
|
+
|
|
|
logger.info("Loading Llama model...")
|
|
|
- llama_queue = launch_thread_safe_queue(
|
|
|
- checkpoint_path=args.llama_checkpoint_path,
|
|
|
- device=args.device,
|
|
|
- precision=args.precision,
|
|
|
- compile=args.compile,
|
|
|
- )
|
|
|
+
|
|
|
+ if args.mode == "tts":
|
|
|
+ llama_queue = launch_thread_safe_queue(
|
|
|
+ checkpoint_path=args.llama_checkpoint_path,
|
|
|
+ device=args.device,
|
|
|
+ precision=args.precision,
|
|
|
+ compile=args.compile,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
|
|
|
+ checkpoint_path=args.llama_checkpoint_path,
|
|
|
+ device=args.device,
|
|
|
+ precision=args.precision,
|
|
|
+ compile=args.compile,
|
|
|
+ )
|
|
|
|
|
|
logger.info("Llama model loaded, loading VQ-GAN model...")
|
|
|
|
|
|
@@ -456,23 +902,28 @@ def initialize_app(app: Kui):
|
|
|
|
|
|
logger.info("VQ-GAN model loaded, warming up...")
|
|
|
|
|
|
- # Dry run to ensure models work and avoid first-time latency
|
|
|
- list(
|
|
|
- inference(
|
|
|
- ServeTTSRequest(
|
|
|
- text="Hello world.",
|
|
|
- references=[],
|
|
|
- reference_id=None,
|
|
|
- max_new_tokens=0,
|
|
|
- chunk_length=200,
|
|
|
- top_p=0.7,
|
|
|
- repetition_penalty=1.2,
|
|
|
- temperature=0.7,
|
|
|
- emotion=None,
|
|
|
- format="wav",
|
|
|
+ vad_model = load_silero_vad()
|
|
|
+
|
|
|
+ logger.info("VAD model loaded, warming up...")
|
|
|
+
|
|
|
+ if args.mode == "tts":
|
|
|
+ # Dry run to ensure models work and avoid first-time latency
|
|
|
+ list(
|
|
|
+ inference(
|
|
|
+ ServeTTSRequest(
|
|
|
+ text="Hello world.",
|
|
|
+ references=[],
|
|
|
+ reference_id=None,
|
|
|
+ max_new_tokens=0,
|
|
|
+ chunk_length=200,
|
|
|
+ top_p=0.7,
|
|
|
+ repetition_penalty=1.2,
|
|
|
+ temperature=0.7,
|
|
|
+ emotion=None,
|
|
|
+ format="wav",
|
|
|
+ )
|
|
|
)
|
|
|
)
|
|
|
- )
|
|
|
|
|
|
logger.info(f"Warming up done, starting server at http://{args.listen}")
|
|
|
|