Просмотр исходного кода

Api json format (#588)

* api json support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 год назад
Родитель
Сommit
977365535e
1 измененных файлов с 12 добавлено и 8 удалено
  1. 12 8
      tools/api.py

+ 12 - 8
tools/api.py

@@ -1,15 +1,12 @@
-import base64
 import io
 import io
-import json
 import queue
 import queue
-import random
 import sys
 import sys
 import traceback
 import traceback
 import wave
 import wave
 from argparse import ArgumentParser
 from argparse import ArgumentParser
 from http import HTTPStatus
 from http import HTTPStatus
 from pathlib import Path
 from pathlib import Path
-from typing import Annotated, Any, Literal, Optional
+from typing import Annotated, Any
 
 
 import numpy as np
 import numpy as np
 import ormsgpack
 import ormsgpack
@@ -31,7 +28,6 @@ from kui.asgi import (
 )
 )
 from kui.asgi.routing import MultimethodRoutes
 from kui.asgi.routing import MultimethodRoutes
 from loguru import logger
 from loguru import logger
-from pydantic import BaseModel, Field, conint
 
 
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 
 
@@ -39,7 +35,7 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
 from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
 from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
 from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
 from fish_speech.utils import autocast_exclude_mps
 from fish_speech.utils import autocast_exclude_mps
-from tools.commons import ServeReferenceAudio, ServeTTSRequest
+from tools.commons import ServeTTSRequest
 from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
 from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
 from tools.llama.generate import (
 from tools.llama.generate import (
     GenerateRequest,
     GenerateRequest,
@@ -367,18 +363,26 @@ def parse_args():
 openapi = OpenAPI(
 openapi = OpenAPI(
     {
     {
         "title": "Fish Speech API",
         "title": "Fish Speech API",
+        "version": "1.4.2",
     },
     },
 ).routes
 ).routes
 
 
 
 
 class MsgPackRequest(HttpRequest):
 class MsgPackRequest(HttpRequest):
-    async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
+    async def data(
+        self,
+    ) -> Annotated[
+        Any, ContentType("application/msgpack"), ContentType("application/json")
+    ]:
         if self.content_type == "application/msgpack":
         if self.content_type == "application/msgpack":
             return ormsgpack.unpackb(await self.body)
             return ormsgpack.unpackb(await self.body)
 
 
+        elif self.content_type == "application/json":
+            return await self.json
+
         raise HTTPException(
         raise HTTPException(
             HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
             HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
-            headers={"Accept": "application/msgpack"},
+            headers={"Accept": "application/msgpack, application/json"},
         )
         )