Jelajahi Sumber

keep up with official close-source api (#513)

* keep up with official close-source api

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* curl support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* avoid empty ref

* remove unused files

* api CHN normalize

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* ormsgpack support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 tahun lalu
induk
melakukan
e9394c71f0
8 mengubah file dengan 251 tambahan dan 255 penghapusan
  1. 2 1
      .gitignore
  2. 96 119
      tools/api.py
  3. 17 0
      tools/file.py
  4. 0 36
      tools/gen_ref.py
  5. 0 55
      tools/merge_asr_files.py
  6. 68 0
      tools/msgpack_api.py
  7. 67 43
      tools/post_api.py
  8. 1 1
      tools/sensevoice/fun_asr.py

+ 2 - 1
.gitignore

@@ -15,6 +15,7 @@ filelists
 /*.npy
 /*.npy
 /*.wav
 /*.wav
 /*.mp3
 /*.mp3
+/*.lab
 /results
 /results
 /data
 /data
 /.idea
 /.idea
@@ -25,6 +26,6 @@ asr-label*
 /fishenv
 /fishenv
 /.locale
 /.locale
 /demo-audios
 /demo-audios
-ref_data*
+/references
 /example
 /example
 /faster_whisper
 /faster_whisper

+ 96 - 119
tools/api.py

@@ -9,16 +9,20 @@ import wave
 from argparse import ArgumentParser
 from argparse import ArgumentParser
 from http import HTTPStatus
 from http import HTTPStatus
 from pathlib import Path
 from pathlib import Path
-from typing import Annotated, Literal, Optional
+from typing import Annotated, Any, Literal, Optional
 
 
 import numpy as np
 import numpy as np
+import ormsgpack
 import pyrootutils
 import pyrootutils
 import soundfile as sf
 import soundfile as sf
 import torch
 import torch
 import torchaudio
 import torchaudio
+from baize.datastructures import ContentType
 from kui.asgi import (
 from kui.asgi import (
     Body,
     Body,
+    FactoryClass,
     HTTPException,
     HTTPException,
+    HttpRequest,
     HttpView,
     HttpView,
     JSONResponse,
     JSONResponse,
     Kui,
     Kui,
@@ -27,14 +31,16 @@ from kui.asgi import (
 )
 )
 from kui.asgi.routing import MultimethodRoutes
 from kui.asgi.routing import MultimethodRoutes
 from loguru import logger
 from loguru import logger
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, Field, conint
 
 
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
 
 # from fish_speech.models.vqgan.lit_module import VQGAN
 # from fish_speech.models.vqgan.lit_module import VQGAN
 from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
 from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
+from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
 from fish_speech.utils import autocast_exclude_mps
 from fish_speech.utils import autocast_exclude_mps
 from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
 from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
+from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
 from tools.llama.generate import (
 from tools.llama.generate import (
     GenerateRequest,
     GenerateRequest,
     GenerateResponse,
     GenerateResponse,
@@ -82,11 +88,8 @@ async def other_exception_handler(exc: "Exception"):
 
 
 def load_audio(reference_audio, sr):
 def load_audio(reference_audio, sr):
     if len(reference_audio) > 255 or not Path(reference_audio).exists():
     if len(reference_audio) > 255 or not Path(reference_audio).exists():
-        try:
-            audio_data = base64.b64decode(reference_audio)
-            reference_audio = io.BytesIO(audio_data)
-        except base64.binascii.Error:
-            raise ValueError("Invalid path or base64 string")
+        audio_data = reference_audio
+        reference_audio = io.BytesIO(audio_data)
 
 
     waveform, original_sr = torchaudio.load(
     waveform, original_sr = torchaudio.load(
         reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
         reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
@@ -153,56 +156,36 @@ def decode_vq_tokens(
 routes = MultimethodRoutes(base_class=HttpView)
 routes = MultimethodRoutes(base_class=HttpView)
 
 
 
 
-def get_random_paths(base_path, data, speaker, emotion):
-    if base_path and data and speaker and emotion and (Path(base_path).exists()):
-        if speaker in data and emotion in data[speaker]:
-            files = data[speaker][emotion]
-            lab_files = [f for f in files if f.endswith(".lab")]
-            wav_files = [f for f in files if f.endswith(".wav")]
+class ServeReferenceAudio(BaseModel):
+    audio: bytes
+    text: str
 
 
-            if lab_files and wav_files:
-                selected_lab = random.choice(lab_files)
-                selected_wav = random.choice(wav_files)
 
 
-                lab_path = Path(base_path) / speaker / emotion / selected_lab
-                wav_path = Path(base_path) / speaker / emotion / selected_wav
-                if lab_path.exists() and wav_path.exists():
-                    return lab_path, wav_path
-
-    return None, None
-
-
-def load_json(json_file):
-    if not json_file:
-        logger.info("Not using a json file")
-        return None
-    try:
-        with open(json_file, "r", encoding="utf-8") as file:
-            data = json.load(file)
-    except FileNotFoundError:
-        logger.warning(f"ref json not found: {json_file}")
-        data = None
-    except Exception as e:
-        logger.warning(f"Loading json failed: {e}")
-        data = None
-    return data
-
-
-class InvokeRequest(BaseModel):
+class ServeTTSRequest(BaseModel):
     text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
     text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
-    reference_text: Optional[str] = None
-    reference_audio: Optional[str] = None
+    chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
+    # Audio format
+    format: Literal["wav", "pcm", "mp3"] = "wav"
+    mp3_bitrate: Literal[64, 128, 192] = 128
+    # References audios for in-context learning
+    references: list[ServeReferenceAudio] = []
+    # Reference id
+    # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
+    # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
+    reference_id: str | None = None
+    # Normalize text for en & zh, this increase stability for numbers
+    normalize: bool = True
+    mp3_bitrate: Optional[int] = 64
+    opus_bitrate: Optional[int] = -1000
+    # Balance mode will reduce latency to 300ms, but may decrease stability
+    latency: Literal["normal", "balanced"] = "normal"
+    # not usually used below
+    streaming: bool = False
+    emotion: Optional[str] = None
     max_new_tokens: int = 1024
     max_new_tokens: int = 1024
-    chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
     top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
     top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
     repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
     repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
     temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
     temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
-    emotion: Optional[str] = None
-    format: Literal["wav", "mp3", "flac"] = "wav"
-    streaming: bool = False
-    ref_json: Optional[str] = "ref_data.json"
-    ref_base: Optional[str] = "ref_data"
-    speaker: Optional[str] = None
 
 
 
 
 def get_content_type(audio_format):
 def get_content_type(audio_format):
@@ -217,35 +200,52 @@ def get_content_type(audio_format):
 
 
 
 
 @torch.inference_mode()
 @torch.inference_mode()
-def inference(req: InvokeRequest):
-    # Parse reference audio aka prompt
-    prompt_tokens = None
-
-    ref_data = load_json(req.ref_json)
-    ref_base = req.ref_base
-
-    lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
-
-    if lab_path and wav_path:
-        with open(lab_path, "r", encoding="utf-8") as lab_file:
-            ref_text = lab_file.read()
-        req.reference_audio = wav_path
-        req.reference_text = ref_text
-        logger.info("ref_path: " + str(wav_path))
-        logger.info("ref_text: " + ref_text)
-
-    # Parse reference audio aka prompt
-    prompt_tokens = encode_reference(
-        decoder_model=decoder_model,
-        reference_audio=req.reference_audio,
-        enable_reference_audio=req.reference_audio is not None,
-    )
-    logger.info(f"ref_text: {req.reference_text}")
+def inference(req: ServeTTSRequest):
+
+    idstr: str | None = req.reference_id
+    if idstr is not None:
+        ref_folder = Path("references") / idstr
+        ref_folder.mkdir(parents=True, exist_ok=True)
+        ref_audios = list_files(
+            ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
+        )
+        prompt_tokens = [
+            encode_reference(
+                decoder_model=decoder_model,
+                reference_audio=audio_to_bytes(str(ref_audio)),
+                enable_reference_audio=True,
+            )
+            for ref_audio in ref_audios
+        ]
+        prompt_texts = [
+            read_ref_text(str(ref_audio.with_suffix(".lab")))
+            for ref_audio in ref_audios
+        ]
+
+    else:
+        # Parse reference audio aka prompt
+        refs = req.references
+        if refs is None:
+            refs = []
+        prompt_tokens = [
+            encode_reference(
+                decoder_model=decoder_model,
+                reference_audio=ref.audio,
+                enable_reference_audio=True,
+            )
+            for ref in refs
+        ]
+        prompt_texts = [ref.text for ref in refs]
+
     # LLAMA Inference
     # LLAMA Inference
     request = dict(
     request = dict(
         device=decoder_model.device,
         device=decoder_model.device,
         max_new_tokens=req.max_new_tokens,
         max_new_tokens=req.max_new_tokens,
-        text=req.text,
+        text=(
+            req.text
+            if not req.normalize
+            else ChnNormedText(raw_text=req.text).normalize()
+        ),
         top_p=req.top_p,
         top_p=req.top_p,
         repetition_penalty=req.repetition_penalty,
         repetition_penalty=req.repetition_penalty,
         temperature=req.temperature,
         temperature=req.temperature,
@@ -254,7 +254,7 @@ def inference(req: InvokeRequest):
         chunk_length=req.chunk_length,
         chunk_length=req.chunk_length,
         max_length=2048,
         max_length=2048,
         prompt_tokens=prompt_tokens,
         prompt_tokens=prompt_tokens,
-        prompt_text=req.reference_text,
+        prompt_text=prompt_texts,
     )
     )
 
 
     response_queue = queue.Queue()
     response_queue = queue.Queue()
@@ -307,40 +307,7 @@ def inference(req: InvokeRequest):
     yield fake_audios
     yield fake_audios
 
 
 
 
-def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
-    if not use_auto_rerank:
-        # 如果不使用 auto_rerank,直接调用原始的 inference 函数
-        return inference(req)
-
-    zh_model, en_model = load_model()
-    max_attempts = 5
-    best_wer = float("inf")
-    best_audio = None
-
-    for attempt in range(max_attempts):
-        # 调用原始的 inference 函数
-        audio_generator = inference(req)
-        fake_audios = next(audio_generator)
-
-        asr_result = batch_asr(
-            zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
-        )[0]
-        wer = calculate_wer(req.text, asr_result["text"])
-
-        if wer <= 0.1 and not asr_result["huge_gap"]:
-            return fake_audios
-
-        if wer < best_wer:
-            best_wer = wer
-            best_audio = fake_audios
-
-        if attempt == max_attempts - 1:
-            break
-
-    return best_audio
-
-
-async def inference_async(req: InvokeRequest):
+async def inference_async(req: ServeTTSRequest):
     for chunk in inference(req):
     for chunk in inference(req):
         yield chunk
         yield chunk
 
 
@@ -349,9 +316,9 @@ async def buffer_to_async_generator(buffer):
     yield buffer
     yield buffer
 
 
 
 
-@routes.http.post("/v1/invoke")
+@routes.http.post("/v1/tts")
 async def api_invoke_model(
 async def api_invoke_model(
-    req: Annotated[InvokeRequest, Body(exclusive=True)],
+    req: Annotated[ServeTTSRequest, Body(exclusive=True)],
 ):
 ):
     """
     """
     Invoke model and generate audio
     Invoke model and generate audio
@@ -422,7 +389,7 @@ def parse_args():
     parser.add_argument("--half", action="store_true")
     parser.add_argument("--half", action="store_true")
     parser.add_argument("--compile", action="store_true")
     parser.add_argument("--compile", action="store_true")
     parser.add_argument("--max-text-length", type=int, default=0)
     parser.add_argument("--max-text-length", type=int, default=0)
-    parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
+    parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
     parser.add_argument("--workers", type=int, default=1)
     parser.add_argument("--workers", type=int, default=1)
     parser.add_argument("--use-auto-rerank", type=bool, default=True)
     parser.add_argument("--use-auto-rerank", type=bool, default=True)
 
 
@@ -436,18 +403,30 @@ openapi = OpenAPI(
     },
     },
 ).routes
 ).routes
 
 
+
+class MsgPackRequest(HttpRequest):
+    async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
+        if self.content_type == "application/msgpack":
+            return ormsgpack.unpackb(await self.body)
+
+        raise HTTPException(
+            HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
+            headers={"Accept": "application/msgpack"},
+        )
+
+
 app = Kui(
 app = Kui(
     routes=routes + openapi[1:],  # Remove the default route
     routes=routes + openapi[1:],  # Remove the default route
     exception_handlers={
     exception_handlers={
         HTTPException: http_execption_handler,
         HTTPException: http_execption_handler,
         Exception: other_exception_handler,
         Exception: other_exception_handler,
     },
     },
+    factory_class=FactoryClass(http=MsgPackRequest),
     cors_config={},
     cors_config={},
 )
 )
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    import threading
 
 
     import uvicorn
     import uvicorn
 
 
@@ -474,18 +453,16 @@ if __name__ == "__main__":
     # Dry run to check if the model is loaded correctly and avoid the first-time latency
     # Dry run to check if the model is loaded correctly and avoid the first-time latency
     list(
     list(
         inference(
         inference(
-            InvokeRequest(
+            ServeTTSRequest(
                 text="Hello world.",
                 text="Hello world.",
-                reference_text=None,
-                reference_audio=None,
+                references=[],
+                reference_id=None,
                 max_new_tokens=0,
                 max_new_tokens=0,
                 top_p=0.7,
                 top_p=0.7,
                 repetition_penalty=1.2,
                 repetition_penalty=1.2,
                 temperature=0.7,
                 temperature=0.7,
                 emotion=None,
                 emotion=None,
                 format="wav",
                 format="wav",
-                ref_base=None,
-                ref_json=None,
             )
             )
         )
         )
     )
     )

+ 17 - 0
tools/file.py

@@ -1,3 +1,4 @@
+import base64
 from pathlib import Path
 from pathlib import Path
 from typing import Union
 from typing import Union
 
 
@@ -23,6 +24,22 @@ VIDEO_EXTENSIONS = {
 }
 }
 
 
 
 
+def audio_to_bytes(file_path):
+    if not file_path or not Path(file_path).exists():
+        return None
+    with open(file_path, "rb") as wav_file:
+        wav = wav_file.read()
+    return wav
+
+
+def read_ref_text(ref_text):
+    path = Path(ref_text)
+    if path.exists() and path.is_file():
+        with path.open("r", encoding="utf-8") as file:
+            return file.read()
+    return ref_text
+
+
 def list_files(
 def list_files(
     path: Union[Path, str],
     path: Union[Path, str],
     extensions: set[str] = None,
     extensions: set[str] = None,

+ 0 - 36
tools/gen_ref.py

@@ -1,36 +0,0 @@
-import json
-from pathlib import Path
-
-
-def scan_folder(base_path):
-    wav_lab_pairs = {}
-
-    base = Path(base_path)
-    for suf in ["wav", "lab"]:
-        for f in base.rglob(f"*.{suf}"):
-            relative_path = f.relative_to(base)
-            parts = relative_path.parts
-            print(parts)
-            if len(parts) >= 3:
-                character = parts[0]
-                emotion = parts[1]
-
-                if character not in wav_lab_pairs:
-                    wav_lab_pairs[character] = {}
-                if emotion not in wav_lab_pairs[character]:
-                    wav_lab_pairs[character][emotion] = []
-                wav_lab_pairs[character][emotion].append(str(f.name))
-
-    return wav_lab_pairs
-
-
-def save_to_json(data, output_file):
-    with open(output_file, "w", encoding="utf-8") as file:
-        json.dump(data, file, ensure_ascii=False, indent=2)
-
-
-base_path = "ref_data"
-out_ref_file = "ref_data.json"
-
-wav_lab_pairs = scan_folder(base_path)
-save_to_json(wav_lab_pairs, out_ref_file)

+ 0 - 55
tools/merge_asr_files.py

@@ -1,55 +0,0 @@
-import os
-from pathlib import Path
-
-from pydub import AudioSegment
-from tqdm import tqdm
-
-from tools.file import AUDIO_EXTENSIONS, list_files
-
-
-def merge_and_delete_files(save_dir, original_files):
-    save_path = Path(save_dir)
-    audio_slice_files = list_files(
-        path=save_dir, extensions=AUDIO_EXTENSIONS.union([".lab"]), recursive=True
-    )
-    audio_files = {}
-    label_files = {}
-    for file_path in tqdm(audio_slice_files, desc="Merging audio files"):
-        rel_path = Path(file_path).relative_to(save_path)
-        (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
-        if file_path.suffix == ".wav":
-            prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
-            if prefix == rel_path.parent / file_path.stem:
-                continue
-            audio = AudioSegment.from_wav(file_path)
-            if prefix in audio_files.keys():
-                audio_files[prefix] = audio_files[prefix] + audio
-            else:
-                audio_files[prefix] = audio
-
-        elif file_path.suffix == ".lab":
-            prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0]
-            if prefix == rel_path.parent / file_path.stem:
-                continue
-            with open(file_path, "r", encoding="utf-8") as f:
-                label = f.read()
-            if prefix in label_files.keys():
-                label_files[prefix] = label_files[prefix] + ", " + label
-            else:
-                label_files[prefix] = label
-
-    for prefix, audio in audio_files.items():
-        output_audio_path = save_path / f"{prefix}.wav"
-        audio.export(output_audio_path, format="wav")
-
-    for prefix, label in label_files.items():
-        output_label_path = save_path / f"{prefix}.lab"
-        with open(output_label_path, "w", encoding="utf-8") as f:
-            f.write(label)
-
-    for file_path in original_files:
-        os.remove(file_path)
-
-
-if __name__ == "__main__":
-    merge_and_delete_files("/made/by/spicysama/laziman", [__file__])

+ 68 - 0
tools/msgpack_api.py

@@ -0,0 +1,68 @@
+from typing import Annotated, AsyncGenerator, Literal, Optional
+
+import httpx
+import ormsgpack
+from pydantic import AfterValidator, BaseModel, Field, conint
+
+
+class ServeReferenceAudio(BaseModel):
+    audio: bytes
+    text: str
+
+
+class ServeTTSRequest(BaseModel):
+    text: str
+    chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
+    # Audio format
+    format: Literal["wav", "pcm", "mp3"] = "wav"
+    mp3_bitrate: Literal[64, 128, 192] = 128
+    # References audios for in-context learning
+    references: list[ServeReferenceAudio] = []
+    # Reference id
+    # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
+    # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
+    reference_id: str | None = None
+    # Normalize text for en & zh, this increase stability for numbers
+    normalize: bool = True
+    mp3_bitrate: Optional[int] = 64
+    opus_bitrate: Optional[int] = -1000
+    # Balance mode will reduce latency to 300ms, but may decrease stability
+    latency: Literal["normal", "balanced"] = "normal"
+    # not usually used below
+    streaming: bool = False
+    emotion: Optional[str] = None
+    max_new_tokens: int = 1024
+    top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
+    repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
+    temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
+
+
+# priority: ref_id > references
+request = ServeTTSRequest(
+    text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
+    # reference_id="114514",
+    references=[
+        ServeReferenceAudio(
+            audio=open("lengyue.wav", "rb").read(),
+            text=open("lengyue.lab", "r", encoding="utf-8").read(),
+        )
+    ],
+    streaming=True,
+)
+
+with (
+    httpx.Client() as client,
+    open("hello.wav", "wb") as f,
+):
+    with client.stream(
+        "POST",
+        "http://127.0.0.1:8080/v1/tts",
+        content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
+        headers={
+            "authorization": "Bearer YOUR_API_KEY",
+            "content-type": "application/msgpack",
+        },
+        timeout=None,
+    ) as response:
+        for chunk in response.iter_bytes():
+            f.write(chunk)

+ 67 - 43
tools/post_api.py

@@ -1,40 +1,18 @@
 import argparse
 import argparse
 import base64
 import base64
-import json
 import wave
 import wave
 from pathlib import Path
 from pathlib import Path
 
 
 import pyaudio
 import pyaudio
 import requests
 import requests
+from pydub import AudioSegment
+from pydub.playback import play
 
 
+from tools.file import audio_to_bytes, read_ref_text
 
 
-def wav_to_base64(file_path):
-    if not file_path or not Path(file_path).exists():
-        return None
-    with open(file_path, "rb") as wav_file:
-        wav_content = wav_file.read()
-        base64_encoded = base64.b64encode(wav_content)
-        return base64_encoded.decode("utf-8")
 
 
+def parse_args():
 
 
-def read_ref_text(ref_text):
-    path = Path(ref_text)
-    if path.exists() and path.is_file():
-        with path.open("r", encoding="utf-8") as file:
-            return file.read()
-    return ref_text
-
-
-def play_audio(audio_content, format, channels, rate):
-    p = pyaudio.PyAudio()
-    stream = p.open(format=format, channels=channels, rate=rate, output=True)
-    stream.write(audio_content)
-    stream.stop_stream()
-    stream.close()
-    p.terminate()
-
-
-if __name__ == "__main__":
     parser = argparse.ArgumentParser(
     parser = argparse.ArgumentParser(
         description="Send a WAV file and text to a server and receive synthesized audio."
         description="Send a WAV file and text to a server and receive synthesized audio."
     )
     )
@@ -43,16 +21,24 @@ if __name__ == "__main__":
         "--url",
         "--url",
         "-u",
         "-u",
         type=str,
         type=str,
-        default="http://127.0.0.1:8080/v1/invoke",
+        default="http://127.0.0.1:8080/v1/tts",
         help="URL of the server",
         help="URL of the server",
     )
     )
     parser.add_argument(
     parser.add_argument(
         "--text", "-t", type=str, required=True, help="Text to be synthesized"
         "--text", "-t", type=str, required=True, help="Text to be synthesized"
     )
     )
+    parser.add_argument(
+        "--reference_id",
+        "-id",
+        type=str,
+        default=None,
+        help="ID of the reference model o be used for the speech",
+    )
     parser.add_argument(
     parser.add_argument(
         "--reference_audio",
         "--reference_audio",
         "-ra",
         "-ra",
         type=str,
         type=str,
+        nargs="+",
         default=None,
         default=None,
         help="Path to the WAV file",
         help="Path to the WAV file",
     )
     )
@@ -60,9 +46,30 @@ if __name__ == "__main__":
         "--reference_text",
         "--reference_text",
         "-rt",
         "-rt",
         type=str,
         type=str,
+        nargs="+",
         default=None,
         default=None,
         help="Reference text for voice synthesis",
         help="Reference text for voice synthesis",
     )
     )
+    parser.add_argument(
+        "--output",
+        "-o",
+        type=str,
+        default="generated_audio",
+        help="Output audio file name",
+    )
+    parser.add_argument(
+        "--play",
+        type=bool,
+        default=True,
+        help="Whether to play audio after receiving data",
+    )
+    parser.add_argument("--normalize", type=bool, default=True)
+    parser.add_argument(
+        "--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
+    )
+    parser.add_argument("--mp3_bitrate", type=int, default=64)
+    parser.add_argument("--opus_bitrate", type=int, default=-1000)
+    parser.add_argument("--latency", type=str, default="normal", help="延迟选项")
     parser.add_argument(
     parser.add_argument(
         "--max_new_tokens",
         "--max_new_tokens",
         type=int,
         type=int,
@@ -88,7 +95,6 @@ if __name__ == "__main__":
         "--speaker", type=str, default=None, help="Speaker ID for voice synthesis"
         "--speaker", type=str, default=None, help="Speaker ID for voice synthesis"
     )
     )
     parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion")
     parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion")
-    parser.add_argument("--format", type=str, default="wav", help="Audio format")
     parser.add_argument(
     parser.add_argument(
         "--streaming", type=bool, default=False, help="Enable streaming response"
         "--streaming", type=bool, default=False, help="Enable streaming response"
     )
     )
@@ -97,18 +103,36 @@ if __name__ == "__main__":
     )
     )
     parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
     parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
 
 
-    args = parser.parse_args()
+    return parser.parse_args()
+
+
+if __name__ == "__main__":
 
 
-    base64_audio = wav_to_base64(args.reference_audio)
+    args = parse_args()
 
 
-    ref_text = args.reference_text
-    if ref_text:
-        ref_text = read_ref_text(ref_text)
+    idstr: str | None = args.reference_id
+    # priority: ref_id > [{text, audio},...]
+    if idstr is None:
+        base64_audios = [
+            audio_to_bytes(ref_audio) for ref_audio in args.reference_audio
+        ]
+        ref_texts = [read_ref_text(ref_text) for ref_text in args.reference_text]
+    else:
+        base64_audios = []
+        ref_texts = []
+        pass  # in api.py
 
 
     data = {
     data = {
         "text": args.text,
         "text": args.text,
-        "reference_text": ref_text,
-        "reference_audio": base64_audio,
+        "references": [
+            dict(text=ref_text, audio=ref_audio)
+            for ref_text, ref_audio in zip(ref_texts, base64_audios)
+        ],
+        "reference_id": idstr,
+        "normalize": args.normalize,
+        "format": args.format,
+        "mp3_bitrate": args.mp3_bitrate,
+        "opus_bitrate": args.opus_bitrate,
         "max_new_tokens": args.max_new_tokens,
         "max_new_tokens": args.max_new_tokens,
         "chunk_length": args.chunk_length,
         "chunk_length": args.chunk_length,
         "top_p": args.top_p,
         "top_p": args.top_p,
@@ -116,22 +140,20 @@ if __name__ == "__main__":
         "temperature": args.temperature,
         "temperature": args.temperature,
         "speaker": args.speaker,
         "speaker": args.speaker,
         "emotion": args.emotion,
         "emotion": args.emotion,
-        "format": args.format,
         "streaming": args.streaming,
         "streaming": args.streaming,
     }
     }
 
 
     response = requests.post(args.url, json=data, stream=args.streaming)
     response = requests.post(args.url, json=data, stream=args.streaming)
 
 
-    audio_format = pyaudio.paInt16  # Assuming 16-bit PCM format
-
     if response.status_code == 200:
     if response.status_code == 200:
         if args.streaming:
         if args.streaming:
             p = pyaudio.PyAudio()
             p = pyaudio.PyAudio()
+            audio_format = pyaudio.paInt16  # Assuming 16-bit PCM format
             stream = p.open(
             stream = p.open(
                 format=audio_format, channels=args.channels, rate=args.rate, output=True
                 format=audio_format, channels=args.channels, rate=args.rate, output=True
             )
             )
 
 
-            wf = wave.open("generated_audio.wav", "wb")
+            wf = wave.open(f"{args.output}.wav", "wb")
             wf.setnchannels(args.channels)
             wf.setnchannels(args.channels)
             wf.setsampwidth(p.get_sample_size(audio_format))
             wf.setsampwidth(p.get_sample_size(audio_format))
             wf.setframerate(args.rate)
             wf.setframerate(args.rate)
@@ -153,12 +175,14 @@ if __name__ == "__main__":
                 wf.close()
                 wf.close()
         else:
         else:
             audio_content = response.content
             audio_content = response.content
-
-            with open("generated_audio.wav", "wb") as audio_file:
+            audio_path = f"{args.output}.{args.format}"
+            with open(audio_path, "wb") as audio_file:
                 audio_file.write(audio_content)
                 audio_file.write(audio_content)
 
 
-            play_audio(audio_content, audio_format, args.channels, args.rate)
-            print("Audio has been saved to 'generated_audio.wav'.")
+            audio = AudioSegment.from_file(audio_path, format=args.format)
+            if args.play:
+                play(audio)
+            print(f"Audio has been saved to '{audio_path}'.")
     else:
     else:
         print(f"Request failed with status code {response.status_code}")
         print(f"Request failed with status code {response.status_code}")
         print(response.json())
         print(response.json())

+ 1 - 1
tools/sensevoice/fun_asr.py

@@ -26,7 +26,7 @@ def uvr5_cli(
     output_folder: Path,
     output_folder: Path,
     audio_files: list[Path] | None = None,
     audio_files: list[Path] | None = None,
     output_format: str = "flac",
     output_format: str = "flac",
-    model: str = "BS-Roformer-Viperx-1296.ckpt",
+    model: str = "BS-Roformer-Viperx-1297.ckpt",
 ):
 ):
     # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
     # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
     sepr = Separator(
     sepr = Separator(