Explorar el Código

fully support ormsgpack (#518)

* fully support ormsgpack

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* dependency

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama hace 1 año
padre
commit
237f4fdcb9
Se han modificado 5 ficheros con 64 adiciones y 77 borrados
  1. 1 0
      pyproject.toml
  2. 1 33
      tools/api.py
  3. 35 0
      tools/commons.py
  4. 1 35
      tools/msgpack_api.py
  5. 26 9
      tools/post_api.py

+ 1 - 0
pyproject.toml

@@ -42,6 +42,7 @@ dependencies = [
     "funasr==1.1.5",
     "opencc-python-reimplemented==0.1.7",
     "silero-vad",
+    "ormsgpack",
 ]
 
 [project.optional-dependencies]

+ 1 - 33
tools/api.py

@@ -39,7 +39,7 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
 from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
 from fish_speech.utils import autocast_exclude_mps
-from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
+from tools.commons import ServeReferenceAudio, ServeTTSRequest
 from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
 from tools.llama.generate import (
     GenerateRequest,
@@ -156,38 +156,6 @@ def decode_vq_tokens(
 routes = MultimethodRoutes(base_class=HttpView)
 
 
-class ServeReferenceAudio(BaseModel):
-    audio: bytes
-    text: str
-
-
-class ServeTTSRequest(BaseModel):
-    text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
-    chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
-    # Audio format
-    format: Literal["wav", "pcm", "mp3"] = "wav"
-    mp3_bitrate: Literal[64, 128, 192] = 128
-    # References audios for in-context learning
-    references: list[ServeReferenceAudio] = []
-    # Reference id
-    # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
-    # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
-    reference_id: str | None = None
-    # Normalize text for en & zh, this increase stability for numbers
-    normalize: bool = True
-    mp3_bitrate: Optional[int] = 64
-    opus_bitrate: Optional[int] = -1000
-    # Balance mode will reduce latency to 300ms, but may decrease stability
-    latency: Literal["normal", "balanced"] = "normal"
-    # not usually used below
-    streaming: bool = False
-    emotion: Optional[str] = None
-    max_new_tokens: int = 1024
-    top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
-    repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
-    temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
-
-
 def get_content_type(audio_format):
     if audio_format == "wav":
         return "audio/wav"

+ 35 - 0
tools/commons.py

@@ -0,0 +1,35 @@
+from typing import Annotated, Literal, Optional
+
+from pydantic import BaseModel, Field, conint
+
+
+class ServeReferenceAudio(BaseModel):
+    audio: bytes
+    text: str
+
+
+class ServeTTSRequest(BaseModel):
+    text: str
+    chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
+    # Audio format
+    format: Literal["wav", "pcm", "mp3"] = "wav"
+    mp3_bitrate: Literal[64, 128, 192] = 128
+    # References audios for in-context learning
+    references: list[ServeReferenceAudio] = []
+    # Reference id
+    # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
+    # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
+    reference_id: str | None = None
+    # Normalize text for en & zh, this increase stability for numbers
+    normalize: bool = True
+    mp3_bitrate: Optional[int] = 64
+    opus_bitrate: Optional[int] = -1000
+    # Balance mode will reduce latency to 300ms, but may decrease stability
+    latency: Literal["normal", "balanced"] = "normal"
+    # not usually used below
+    streaming: bool = False
+    emotion: Optional[str] = None
+    max_new_tokens: int = 1024
+    top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
+    repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
+    temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7

+ 1 - 35
tools/msgpack_api.py

@@ -1,41 +1,7 @@
-from typing import Annotated, AsyncGenerator, Literal, Optional
-
 import httpx
 import ormsgpack
-from pydantic import AfterValidator, BaseModel, Field, conint
-
-
-class ServeReferenceAudio(BaseModel):
-    audio: bytes
-    text: str
-
-
-class ServeTTSRequest(BaseModel):
-    text: str
-    chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
-    # Audio format
-    format: Literal["wav", "pcm", "mp3"] = "wav"
-    mp3_bitrate: Literal[64, 128, 192] = 128
-    # References audios for in-context learning
-    references: list[ServeReferenceAudio] = []
-    # Reference id
-    # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
-    # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
-    reference_id: str | None = None
-    # Normalize text for en & zh, this increase stability for numbers
-    normalize: bool = True
-    mp3_bitrate: Optional[int] = 64
-    opus_bitrate: Optional[int] = -1000
-    # Balance mode will reduce latency to 300ms, but may decrease stability
-    latency: Literal["normal", "balanced"] = "normal"
-    # not usually used below
-    streaming: bool = False
-    emotion: Optional[str] = None
-    max_new_tokens: int = 1024
-    top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
-    repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
-    temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
 
+from tools.commons import ServeReferenceAudio, ServeTTSRequest
 
 # priority: ref_id > references
 request = ServeTTSRequest(

+ 26 - 9
tools/post_api.py

@@ -1,13 +1,14 @@
 import argparse
 import base64
 import wave
-from pathlib import Path
 
+import ormsgpack
 import pyaudio
 import requests
 from pydub import AudioSegment
 from pydub.playback import play
 
+from tools.commons import ServeReferenceAudio, ServeTTSRequest
 from tools.file import audio_to_bytes, read_ref_text
 
 
@@ -113,20 +114,26 @@ if __name__ == "__main__":
     idstr: str | None = args.reference_id
     # priority: ref_id > [{text, audio},...]
     if idstr is None:
-        base64_audios = [
-            audio_to_bytes(ref_audio) for ref_audio in args.reference_audio
-        ]
-        ref_texts = [read_ref_text(ref_text) for ref_text in args.reference_text]
+        ref_audios = args.reference_audio
+        ref_texts = args.reference_text
+        if ref_audios is None:
+            byte_audios = []
+        else:
+            byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
+        if ref_texts is None:
+            ref_texts = []
+        else:
+            ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
     else:
-        base64_audios = []
+        byte_audios = []
         ref_texts = []
         pass  # in api.py
 
     data = {
         "text": args.text,
         "references": [
-            dict(text=ref_text, audio=ref_audio)
-            for ref_text, ref_audio in zip(ref_texts, base64_audios)
+            ServeReferenceAudio(audio=ref_audio, text=ref_text)
+            for ref_text, ref_audio in zip(ref_texts, byte_audios)
         ],
         "reference_id": idstr,
         "normalize": args.normalize,
@@ -143,7 +150,17 @@ if __name__ == "__main__":
         "streaming": args.streaming,
     }
 
-    response = requests.post(args.url, json=data, stream=args.streaming)
+    pydantic_data = ServeTTSRequest(**data)
+
+    response = requests.post(
+        args.url,
+        data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
+        stream=args.streaming,
+        headers={
+            "authorization": "Bearer YOUR_API_KEY",
+            "content-type": "application/msgpack",
+        },
+    )
 
     if response.status_code == 200:
         if args.streaming: