Przeglądaj źródła

Api json format (#588)

* api json support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 rok temu
rodzic
commit
977365535e
1 zmienionych plików z 12 dodań i 8 usunięć
  1. 12 8
      tools/api.py

+ 12 - 8
tools/api.py

@@ -1,15 +1,12 @@
-import base64
 import io
-import json
 import queue
-import random
 import sys
 import traceback
 import wave
 from argparse import ArgumentParser
 from http import HTTPStatus
 from pathlib import Path
-from typing import Annotated, Any, Literal, Optional
+from typing import Annotated, Any
 
 import numpy as np
 import ormsgpack
@@ -31,7 +28,6 @@ from kui.asgi import (
 )
 from kui.asgi.routing import MultimethodRoutes
 from loguru import logger
-from pydantic import BaseModel, Field, conint
 
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
@@ -39,7 +35,7 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
 from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
 from fish_speech.utils import autocast_exclude_mps
-from tools.commons import ServeReferenceAudio, ServeTTSRequest
+from tools.commons import ServeTTSRequest
 from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
 from tools.llama.generate import (
     GenerateRequest,
@@ -367,18 +363,26 @@ def parse_args():
 openapi = OpenAPI(
     {
         "title": "Fish Speech API",
+        "version": "1.4.2",
     },
 ).routes
 
 
 class MsgPackRequest(HttpRequest):
-    async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
+    async def data(
+        self,
+    ) -> Annotated[
+        Any, ContentType("application/msgpack"), ContentType("application/json")
+    ]:
         if self.content_type == "application/msgpack":
             return ormsgpack.unpackb(await self.body)
 
+        elif self.content_type == "application/json":
+            return await self.json
+
         raise HTTPException(
             HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
-            headers={"Accept": "application/msgpack"},
+            headers={"Accept": "application/msgpack, application/json"},
         )