Ver Fonte

refactor: openapi doc (#770)

* refactor: openapi doc

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: spicysama <a2983352531@outlook.com>
JaysonAlbert há 1 ano atrás
pai
commit
0b48e781ec
3 ficheiros alterados com 159 adições e 224 exclusões
  1. 11 26
      tools/api_server.py
  2. 2 1
      tools/schema.py
  3. 146 197
      tools/server/views.py

+ 11 - 26
tools/api_server.py

@@ -11,6 +11,8 @@ from kui.asgi import (
     OpenAPI,
     Routes,
 )
+from kui.cors import CORSConfig
+from kui.openapi.specification import Info
 from kui.security import bearer_auth
 from loguru import logger
 from typing_extensions import Annotated
@@ -20,27 +22,13 @@ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
 from tools.server.api_utils import MsgPackRequest, parse_args
 from tools.server.exception_handler import ExceptionHandler
 from tools.server.model_manager import ModelManager
-from tools.server.views import (
-    ASRView,
-    ChatView,
-    HealthView,
-    TTSView,
-    VQGANDecodeView,
-    VQGANEncodeView,
-)
+from tools.server.views import routes
 
 
 class API(ExceptionHandler):
     def __init__(self):
         self.args = parse_args()
-        self.routes = [
-            ("/v1/health", HealthView),
-            ("/v1/vqgan/encode", VQGANEncodeView),
-            ("/v1/vqgan/decode", VQGANDecodeView),
-            ("/v1/asr", ASRView),
-            ("/v1/tts", TTSView),
-            ("/v1/chat", ChatView),
-        ]
+        self.routes = routes
 
         def api_auth(endpoint):
             async def verify(token: Annotated[str, Depends(bearer_auth)]):
@@ -56,16 +44,13 @@ class API(ExceptionHandler):
             else:
                 return passthrough
 
-        self.routes = Routes(
-            [HttpRoute(path, view) for path, view in self.routes],
-            http_middlewares=[api_auth],
-        )
-
         self.openapi = OpenAPI(
-            {
-                "title": "Fish Speech API",
-                "version": "1.5.0",
-            },
+            Info(
+                {
+                    "title": "Fish Speech API",
+                    "version": "1.5.0",
+                }
+            ),
         ).routes
 
         # Initialize the app
@@ -76,7 +61,7 @@ class API(ExceptionHandler):
                 Exception: self.other_exception_handler,
             },
             factory_class=FactoryClass(http=MsgPackRequest),
-            cors_config={},
+            cors_config=CORSConfig(),
         )
 
         # Add the state variables

+ 2 - 1
tools/schema.py

@@ -1,11 +1,12 @@
 import os
 import queue
 from dataclasses import dataclass
-from typing import Annotated, Literal
+from typing import Literal
 
 import torch
 from pydantic import BaseModel, Field, conint, conlist
 from pydantic.functional_validators import SkipValidation
+from typing_extensions import Annotated
 
 from fish_speech.conversation import Message, TextPart, VQPart
 

+ 146 - 197
tools/server/views.py

@@ -7,8 +7,9 @@ import numpy as np
 import ormsgpack
 import soundfile as sf
 import torch
-from kui.asgi import HTTPException, HttpView, JSONResponse, StreamResponse, request
+from kui.asgi import Body, HTTPException, JSONResponse, Routes, StreamResponse, request
 from loguru import logger
+from typing_extensions import Annotated
 
 from tools.schema import (
     ServeASRRequest,
@@ -32,215 +33,163 @@ from tools.server.model_utils import batch_asr, cached_vqgan_batch_encode, vqgan
 
 MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
 
-
-class HealthView(HttpView):
-    """
-    Return the health status of the server.
-    """
-
-    @classmethod
-    async def post(cls):
-        return JSONResponse({"status": "ok"})
-
-
-class VQGANEncodeView(HttpView):
-    """
-    Encode the audio into symbolic tokens.
-    """
-
-    @classmethod
-    async def post(cls):
-        # Decode the request
-        payload = await request.data()
-        req = ServeVQGANEncodeRequest(**payload)
-
-        # Get the model from the app
-        model_manager: ModelManager = request.app.state.model_manager
-        decoder_model = model_manager.decoder_model
-
-        # Encode the audio
-        start_time = time.time()
-        tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
-        logger.info(
-            f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
+routes = Routes()
+
+
+@routes.http.post("/v1/health")
+async def health():
+    return JSONResponse({"status": "ok"})
+
+
+@routes.http.post("/v1/vqgan/encode")
+async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
+    # Get the model from the app
+    model_manager: ModelManager = request.app.state.model_manager
+    decoder_model = model_manager.decoder_model
+
+    # Encode the audio
+    start_time = time.time()
+    tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
+    logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
+
+    # Return the response
+    return ormsgpack.packb(
+        ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
+        option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+    )
+
+
+@routes.http.post("/v1/vqgan/decode")
+async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
+    # Get the model from the app
+    model_manager: ModelManager = request.app.state.model_manager
+    decoder_model = model_manager.decoder_model
+
+    # Decode the audio
+    tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
+    start_time = time.time()
+    audios = vqgan_decode(decoder_model, tokens)
+    logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
+    audios = [audio.astype(np.float16).tobytes() for audio in audios]
+
+    # Return the response
+    return ormsgpack.packb(
+        ServeVQGANDecodeResponse(audios=audios),
+        option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+    )
+
+
+@routes.http.post("/v1/asr")
+async def asr(req: Annotated[ServeASRRequest, Body(exclusive=True)]):
+    # Get the model from the app
+    model_manager: ModelManager = request.app.state.model_manager
+    asr_model = model_manager.asr_model
+    lock = request.app.state.lock
+
+    # Perform ASR
+    start_time = time.time()
+    audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
+    audios = [torch.from_numpy(audio).float() for audio in audios]
+
+    if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
+        raise HTTPException(status_code=400, content="Audio length is too long")
+
+    transcriptions = batch_asr(
+        asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
+    )
+    logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
+
+    # Return the response
+    return ormsgpack.packb(
+        ServeASRResponse(transcriptions=transcriptions),
+        option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+    )
+
+
+@routes.http.post("/v1/tts")
+async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
+    # Get the model from the app
+    app_state = request.app.state
+    model_manager: ModelManager = app_state.model_manager
+    engine = model_manager.tts_inference_engine
+    sample_rate = engine.decoder_model.spec_transform.sample_rate
+
+    # Check if the text is too long
+    if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
+        raise HTTPException(
+            HTTPStatus.BAD_REQUEST,
+            content=f"Text is too long, max length is {app_state.max_text_length}",
         )
 
-        # Return the response
-        return ormsgpack.packb(
-            ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
-            option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+    # Check if streaming is enabled
+    if req.streaming and req.format != "wav":
+        raise HTTPException(
+            HTTPStatus.BAD_REQUEST,
+            content="Streaming only supports WAV format",
         )
 
-
-class VQGANDecodeView(HttpView):
-    """
-    Decode the symbolic tokens into audio.
-    """
-
-    @classmethod
-    async def post(cls):
-        # Decode the request
-        payload = await request.data()
-        req = ServeVQGANDecodeRequest(**payload)
-
-        # Get the model from the app
-        model_manager: ModelManager = request.app.state.model_manager
-        decoder_model = model_manager.decoder_model
-
-        # Decode the audio
-        tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
-        start_time = time.time()
-        audios = vqgan_decode(decoder_model, tokens)
-        logger.info(
-            f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
+    # Perform TTS
+    if req.streaming:
+        return StreamResponse(
+            iterable=inference_async(req, engine),
+            headers={
+                "Content-Disposition": f"attachment; filename=audio.{req.format}",
+            },
+            content_type=get_content_type(req.format),
         )
-        audios = [audio.astype(np.float16).tobytes() for audio in audios]
-
-        # Return the response
-        return ormsgpack.packb(
-            ServeVQGANDecodeResponse(audios=audios),
-            option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
+    else:
+        fake_audios = next(inference(req, engine))
+        buffer = io.BytesIO()
+        sf.write(
+            buffer,
+            fake_audios,
+            sample_rate,
+            format=req.format,
         )
 
+        return StreamResponse(
+            iterable=buffer_to_async_generator(buffer.getvalue()),
+            headers={
+                "Content-Disposition": f"attachment; filename=audio.{req.format}",
+            },
+            content_type=get_content_type(req.format),
+        )
 
-class ASRView(HttpView):
-    """
-    Perform automatic speech recognition on the audio.
-    """
-
-    @classmethod
-    async def post(cls):
-        # Decode the request
-        payload = await request.data()
-        req = ServeASRRequest(**payload)
-
-        # Get the model from the app
-        model_manager: ModelManager = request.app.state.model_manager
-        asr_model = model_manager.asr_model
-        lock = request.app.state.lock
 
-        # Perform ASR
-        start_time = time.time()
-        audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
-        audios = [torch.from_numpy(audio).float() for audio in audios]
+@routes.http.post("/v1/chat")
+async def chat(req: Annotated[ServeChatRequest, Body(exclusive=True)]):
+    # Check that the number of samples requested is correct
+    if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
+        raise HTTPException(
+            HTTPStatus.BAD_REQUEST,
+            content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
+        )
 
-        if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
-            raise HTTPException(status_code=400, content="Audio length is too long")
+    # Get the type of content provided
+    content_type = request.headers.get("Content-Type", "application/json")
+    json_mode = "application/json" in content_type
 
-        transcriptions = batch_asr(
-            asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
-        )
-        logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
+    # Get the models from the app
+    model_manager: ModelManager = request.app.state.model_manager
+    llama_queue = model_manager.llama_queue
+    tokenizer = model_manager.tokenizer
+    config = model_manager.config
 
-        # Return the response
-        return ormsgpack.packb(
-            ServeASRResponse(transcriptions=transcriptions),
-            option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
-        )
+    device = request.app.state.device
 
+    # Get the response generators
+    response_generator = get_response_generator(
+        llama_queue, tokenizer, config, req, device, json_mode
+    )
 
-class TTSView(HttpView):
-    """
-    Perform text-to-speech on the input text.
-    """
-
-    @classmethod
-    async def post(cls):
-        # Decode the request
-        payload = await request.data()
-        req = ServeTTSRequest(**payload)
-
-        # Get the model from the app
-        app_state = request.app.state
-        model_manager: ModelManager = app_state.model_manager
-        engine = model_manager.tts_inference_engine
-        sample_rate = engine.decoder_model.spec_transform.sample_rate
-
-        # Check if the text is too long
-        if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
-            raise HTTPException(
-                HTTPStatus.BAD_REQUEST,
-                content=f"Text is too long, max length is {app_state.max_text_length}",
-            )
-
-        # Check if streaming is enabled
-        if req.streaming and req.format != "wav":
-            raise HTTPException(
-                HTTPStatus.BAD_REQUEST,
-                content="Streaming only supports WAV format",
-            )
-
-        # Perform TTS
-        if req.streaming:
-            return StreamResponse(
-                iterable=inference_async(req, engine),
-                headers={
-                    "Content-Disposition": f"attachment; filename=audio.{req.format}",
-                },
-                content_type=get_content_type(req.format),
-            )
+    # Return the response in the correct format
+    if req.streaming is False:
+        result = response_generator()
+        if json_mode:
+            return JSONResponse(result.model_dump())
         else:
-            fake_audios = next(inference(req, engine))
-            buffer = io.BytesIO()
-            sf.write(
-                buffer,
-                fake_audios,
-                sample_rate,
-                format=req.format,
-            )
-
-            return StreamResponse(
-                iterable=buffer_to_async_generator(buffer.getvalue()),
-                headers={
-                    "Content-Disposition": f"attachment; filename=audio.{req.format}",
-                },
-                content_type=get_content_type(req.format),
-            )
-
-
-class ChatView(HttpView):
-    """
-    Perform chatbot inference on the input text.
-    """
-
-    @classmethod
-    async def post(cls):
-        # Decode the request
-        payload = await request.data()
-        req = ServeChatRequest(**payload)
-
-        # Check that the number of samples requested is correct
-        if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
-            raise HTTPException(
-                HTTPStatus.BAD_REQUEST,
-                content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
-            )
-
-        # Get the type of content provided
-        content_type = request.headers.get("Content-Type", "application/json")
-        json_mode = "application/json" in content_type
-
-        # Get the models from the app
-        model_manager: ModelManager = request.app.state.model_manager
-        llama_queue = model_manager.llama_queue
-        tokenizer = model_manager.tokenizer
-        config = model_manager.config
-
-        device = request.app.state.device
-
-        # Get the response generators
-        response_generator = get_response_generator(
-            llama_queue, tokenizer, config, req, device, json_mode
-        )
+            return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
 
-        # Return the response in the correct format
-        if req.streaming is False:
-            result = response_generator()
-            if json_mode:
-                return JSONResponse(result.model_dump())
-            else:
-                return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
-
-        return StreamResponse(
-            iterable=response_generator(), content_type="text/event-stream"
-        )
+    return StreamResponse(
+        iterable=response_generator(), content_type="text/event-stream"
+    )