|
|
@@ -9,16 +9,20 @@ import wave
|
|
|
from argparse import ArgumentParser
|
|
|
from http import HTTPStatus
|
|
|
from pathlib import Path
|
|
|
-from typing import Annotated, Literal, Optional
|
|
|
+from typing import Annotated, Any, Literal, Optional
|
|
|
|
|
|
import numpy as np
|
|
|
+import ormsgpack
|
|
|
import pyrootutils
|
|
|
import soundfile as sf
|
|
|
import torch
|
|
|
import torchaudio
|
|
|
+from baize.datastructures import ContentType
|
|
|
from kui.asgi import (
|
|
|
Body,
|
|
|
+ FactoryClass,
|
|
|
HTTPException,
|
|
|
+ HttpRequest,
|
|
|
HttpView,
|
|
|
JSONResponse,
|
|
|
Kui,
|
|
|
@@ -27,14 +31,16 @@ from kui.asgi import (
|
|
|
)
|
|
|
from kui.asgi.routing import MultimethodRoutes
|
|
|
from loguru import logger
|
|
|
-from pydantic import BaseModel, Field
|
|
|
+from pydantic import BaseModel, Field, conint
|
|
|
|
|
|
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
|
|
|
|
# from fish_speech.models.vqgan.lit_module import VQGAN
|
|
|
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
|
|
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
|
|
from fish_speech.utils import autocast_exclude_mps
|
|
|
from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
|
|
|
+from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
|
|
|
from tools.llama.generate import (
|
|
|
GenerateRequest,
|
|
|
GenerateResponse,
|
|
|
@@ -82,11 +88,8 @@ async def other_exception_handler(exc: "Exception"):
|
|
|
|
|
|
def load_audio(reference_audio, sr):
|
|
|
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
|
|
- try:
|
|
|
- audio_data = base64.b64decode(reference_audio)
|
|
|
- reference_audio = io.BytesIO(audio_data)
|
|
|
- except base64.binascii.Error:
|
|
|
- raise ValueError("Invalid path or base64 string")
|
|
|
+ audio_data = reference_audio
|
|
|
+ reference_audio = io.BytesIO(audio_data)
|
|
|
|
|
|
waveform, original_sr = torchaudio.load(
|
|
|
reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
|
|
|
@@ -153,56 +156,36 @@ def decode_vq_tokens(
|
|
|
routes = MultimethodRoutes(base_class=HttpView)
|
|
|
|
|
|
|
|
|
-def get_random_paths(base_path, data, speaker, emotion):
|
|
|
- if base_path and data and speaker and emotion and (Path(base_path).exists()):
|
|
|
- if speaker in data and emotion in data[speaker]:
|
|
|
- files = data[speaker][emotion]
|
|
|
- lab_files = [f for f in files if f.endswith(".lab")]
|
|
|
- wav_files = [f for f in files if f.endswith(".wav")]
|
|
|
+class ServeReferenceAudio(BaseModel):
|
|
|
+ audio: bytes
|
|
|
+ text: str
|
|
|
|
|
|
- if lab_files and wav_files:
|
|
|
- selected_lab = random.choice(lab_files)
|
|
|
- selected_wav = random.choice(wav_files)
|
|
|
|
|
|
- lab_path = Path(base_path) / speaker / emotion / selected_lab
|
|
|
- wav_path = Path(base_path) / speaker / emotion / selected_wav
|
|
|
- if lab_path.exists() and wav_path.exists():
|
|
|
- return lab_path, wav_path
|
|
|
-
|
|
|
- return None, None
|
|
|
-
|
|
|
-
|
|
|
-def load_json(json_file):
|
|
|
- if not json_file:
|
|
|
- logger.info("Not using a json file")
|
|
|
- return None
|
|
|
- try:
|
|
|
- with open(json_file, "r", encoding="utf-8") as file:
|
|
|
- data = json.load(file)
|
|
|
- except FileNotFoundError:
|
|
|
- logger.warning(f"ref json not found: {json_file}")
|
|
|
- data = None
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f"Loading json failed: {e}")
|
|
|
- data = None
|
|
|
- return data
|
|
|
-
|
|
|
-
|
|
|
-class InvokeRequest(BaseModel):
|
|
|
+class ServeTTSRequest(BaseModel):
|
|
|
text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
|
|
|
- reference_text: Optional[str] = None
|
|
|
- reference_audio: Optional[str] = None
|
|
|
+ chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
|
|
|
+ # Audio format
|
|
|
+ format: Literal["wav", "pcm", "mp3"] = "wav"
|
|
|
+ mp3_bitrate: Literal[64, 128, 192] = 128
|
|
|
+ # References audios for in-context learning
|
|
|
+ references: list[ServeReferenceAudio] = []
|
|
|
+ # Reference id
|
|
|
+ # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
|
|
|
+ # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
|
|
|
+ reference_id: str | None = None
|
|
|
+ # Normalize text for en & zh, this increase stability for numbers
|
|
|
+ normalize: bool = True
|
|
|
+ mp3_bitrate: Optional[int] = 64
|
|
|
+ opus_bitrate: Optional[int] = -1000
|
|
|
+ # Balance mode will reduce latency to 300ms, but may decrease stability
|
|
|
+ latency: Literal["normal", "balanced"] = "normal"
|
|
|
+ # not usually used below
|
|
|
+ streaming: bool = False
|
|
|
+ emotion: Optional[str] = None
|
|
|
max_new_tokens: int = 1024
|
|
|
- chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
|
|
|
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
|
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
|
|
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
|
|
- emotion: Optional[str] = None
|
|
|
- format: Literal["wav", "mp3", "flac"] = "wav"
|
|
|
- streaming: bool = False
|
|
|
- ref_json: Optional[str] = "ref_data.json"
|
|
|
- ref_base: Optional[str] = "ref_data"
|
|
|
- speaker: Optional[str] = None
|
|
|
|
|
|
|
|
|
def get_content_type(audio_format):
|
|
|
@@ -217,35 +200,52 @@ def get_content_type(audio_format):
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
-def inference(req: InvokeRequest):
|
|
|
- # Parse reference audio aka prompt
|
|
|
- prompt_tokens = None
|
|
|
-
|
|
|
- ref_data = load_json(req.ref_json)
|
|
|
- ref_base = req.ref_base
|
|
|
-
|
|
|
- lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
|
|
|
-
|
|
|
- if lab_path and wav_path:
|
|
|
- with open(lab_path, "r", encoding="utf-8") as lab_file:
|
|
|
- ref_text = lab_file.read()
|
|
|
- req.reference_audio = wav_path
|
|
|
- req.reference_text = ref_text
|
|
|
- logger.info("ref_path: " + str(wav_path))
|
|
|
- logger.info("ref_text: " + ref_text)
|
|
|
-
|
|
|
- # Parse reference audio aka prompt
|
|
|
- prompt_tokens = encode_reference(
|
|
|
- decoder_model=decoder_model,
|
|
|
- reference_audio=req.reference_audio,
|
|
|
- enable_reference_audio=req.reference_audio is not None,
|
|
|
- )
|
|
|
- logger.info(f"ref_text: {req.reference_text}")
|
|
|
+def inference(req: ServeTTSRequest):
|
|
|
+
|
|
|
+ idstr: str | None = req.reference_id
|
|
|
+ if idstr is not None:
|
|
|
+ ref_folder = Path("references") / idstr
|
|
|
+ ref_folder.mkdir(parents=True, exist_ok=True)
|
|
|
+ ref_audios = list_files(
|
|
|
+ ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
|
|
|
+ )
|
|
|
+ prompt_tokens = [
|
|
|
+ encode_reference(
|
|
|
+ decoder_model=decoder_model,
|
|
|
+ reference_audio=audio_to_bytes(str(ref_audio)),
|
|
|
+ enable_reference_audio=True,
|
|
|
+ )
|
|
|
+ for ref_audio in ref_audios
|
|
|
+ ]
|
|
|
+ prompt_texts = [
|
|
|
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
|
|
|
+ for ref_audio in ref_audios
|
|
|
+ ]
|
|
|
+
|
|
|
+ else:
|
|
|
+ # Parse reference audio aka prompt
|
|
|
+ refs = req.references
|
|
|
+ if refs is None:
|
|
|
+ refs = []
|
|
|
+ prompt_tokens = [
|
|
|
+ encode_reference(
|
|
|
+ decoder_model=decoder_model,
|
|
|
+ reference_audio=ref.audio,
|
|
|
+ enable_reference_audio=True,
|
|
|
+ )
|
|
|
+ for ref in refs
|
|
|
+ ]
|
|
|
+ prompt_texts = [ref.text for ref in refs]
|
|
|
+
|
|
|
# LLAMA Inference
|
|
|
request = dict(
|
|
|
device=decoder_model.device,
|
|
|
max_new_tokens=req.max_new_tokens,
|
|
|
- text=req.text,
|
|
|
+ text=(
|
|
|
+ req.text
|
|
|
+ if not req.normalize
|
|
|
+ else ChnNormedText(raw_text=req.text).normalize()
|
|
|
+ ),
|
|
|
top_p=req.top_p,
|
|
|
repetition_penalty=req.repetition_penalty,
|
|
|
temperature=req.temperature,
|
|
|
@@ -254,7 +254,7 @@ def inference(req: InvokeRequest):
|
|
|
chunk_length=req.chunk_length,
|
|
|
max_length=2048,
|
|
|
prompt_tokens=prompt_tokens,
|
|
|
- prompt_text=req.reference_text,
|
|
|
+ prompt_text=prompt_texts,
|
|
|
)
|
|
|
|
|
|
response_queue = queue.Queue()
|
|
|
@@ -307,40 +307,7 @@ def inference(req: InvokeRequest):
|
|
|
yield fake_audios
|
|
|
|
|
|
|
|
|
-def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
|
|
|
- if not use_auto_rerank:
|
|
|
- # 如果不使用 auto_rerank,直接调用原始的 inference 函数
|
|
|
- return inference(req)
|
|
|
-
|
|
|
- zh_model, en_model = load_model()
|
|
|
- max_attempts = 5
|
|
|
- best_wer = float("inf")
|
|
|
- best_audio = None
|
|
|
-
|
|
|
- for attempt in range(max_attempts):
|
|
|
- # 调用原始的 inference 函数
|
|
|
- audio_generator = inference(req)
|
|
|
- fake_audios = next(audio_generator)
|
|
|
-
|
|
|
- asr_result = batch_asr(
|
|
|
- zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
|
|
|
- )[0]
|
|
|
- wer = calculate_wer(req.text, asr_result["text"])
|
|
|
-
|
|
|
- if wer <= 0.1 and not asr_result["huge_gap"]:
|
|
|
- return fake_audios
|
|
|
-
|
|
|
- if wer < best_wer:
|
|
|
- best_wer = wer
|
|
|
- best_audio = fake_audios
|
|
|
-
|
|
|
- if attempt == max_attempts - 1:
|
|
|
- break
|
|
|
-
|
|
|
- return best_audio
|
|
|
-
|
|
|
-
|
|
|
-async def inference_async(req: InvokeRequest):
|
|
|
+async def inference_async(req: ServeTTSRequest):
|
|
|
for chunk in inference(req):
|
|
|
yield chunk
|
|
|
|
|
|
@@ -349,9 +316,9 @@ async def buffer_to_async_generator(buffer):
|
|
|
yield buffer
|
|
|
|
|
|
|
|
|
-@routes.http.post("/v1/invoke")
|
|
|
+@routes.http.post("/v1/tts")
|
|
|
async def api_invoke_model(
|
|
|
- req: Annotated[InvokeRequest, Body(exclusive=True)],
|
|
|
+ req: Annotated[ServeTTSRequest, Body(exclusive=True)],
|
|
|
):
|
|
|
"""
|
|
|
Invoke model and generate audio
|
|
|
@@ -422,7 +389,7 @@ def parse_args():
|
|
|
parser.add_argument("--half", action="store_true")
|
|
|
parser.add_argument("--compile", action="store_true")
|
|
|
parser.add_argument("--max-text-length", type=int, default=0)
|
|
|
- parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
|
|
|
+ parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
|
|
|
parser.add_argument("--workers", type=int, default=1)
|
|
|
parser.add_argument("--use-auto-rerank", type=bool, default=True)
|
|
|
|
|
|
@@ -436,18 +403,30 @@ openapi = OpenAPI(
|
|
|
},
|
|
|
).routes
|
|
|
|
|
|
+
|
|
|
+class MsgPackRequest(HttpRequest):
|
|
|
+ async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
|
|
|
+ if self.content_type == "application/msgpack":
|
|
|
+ return ormsgpack.unpackb(await self.body)
|
|
|
+
|
|
|
+ raise HTTPException(
|
|
|
+ HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
|
|
+ headers={"Accept": "application/msgpack"},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
app = Kui(
|
|
|
routes=routes + openapi[1:], # Remove the default route
|
|
|
exception_handlers={
|
|
|
HTTPException: http_execption_handler,
|
|
|
Exception: other_exception_handler,
|
|
|
},
|
|
|
+ factory_class=FactoryClass(http=MsgPackRequest),
|
|
|
cors_config={},
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- import threading
|
|
|
|
|
|
import uvicorn
|
|
|
|
|
|
@@ -474,18 +453,16 @@ if __name__ == "__main__":
|
|
|
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
|
|
list(
|
|
|
inference(
|
|
|
- InvokeRequest(
|
|
|
+ ServeTTSRequest(
|
|
|
text="Hello world.",
|
|
|
- reference_text=None,
|
|
|
- reference_audio=None,
|
|
|
+ references=[],
|
|
|
+ reference_id=None,
|
|
|
max_new_tokens=0,
|
|
|
top_p=0.7,
|
|
|
repetition_penalty=1.2,
|
|
|
temperature=0.7,
|
|
|
emotion=None,
|
|
|
format="wav",
|
|
|
- ref_base=None,
|
|
|
- ref_json=None,
|
|
|
)
|
|
|
)
|
|
|
)
|