Explorar o código

Fix tools.api_server fails to run problem (#1011)

* [fix]: Fix tools.api_server fails to run

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
cronrpc hai 10 meses
pai
achega
352d076db7
Modificáronse 3 ficheiros con 3 adicións e 56 borrados
  1. 1 1
      tools/server/api_utils.py
  2. 1 13
      tools/server/model_manager.py
  3. 1 42
      tools/server/views.py

+ 1 - 1
tools/server/api_utils.py

@@ -13,7 +13,7 @@ from tools.server.inference import inference_wrapper as inference
 
 def parse_args():
     parser = ArgumentParser()
-    parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
+    parser.add_argument("--mode", type=str, choices=["tts"], default="tts")
     parser.add_argument("--load-asr-model", action="store_true")
     parser.add_argument(
         "--llama-checkpoint-path",

+ 1 - 13
tools/server/model_manager.py

@@ -4,10 +4,7 @@ from loguru import logger
 
 from fish_speech.inference_engine import TTSInferenceEngine
 from fish_speech.models.dac.inference import load_model as load_decoder_model
-from fish_speech.models.text2semantic.inference import (
-    launch_thread_safe_queue,
-    launch_thread_safe_queue_agent,
-)
+from fish_speech.models.text2semantic.inference import launch_thread_safe_queue
 from fish_speech.utils.schema import ServeTTSRequest
 from tools.server.inference import inference_wrapper as inference
 
@@ -84,15 +81,6 @@ class ModelManager:
                 precision=precision,
                 compile=compile,
             )
-        elif mode == "agent":
-            self.llama_queue, self.tokenizer, self.config = (
-                launch_thread_safe_queue_agent(
-                    checkpoint_path=checkpoint_path,
-                    device=device,
-                    precision=precision,
-                    compile=compile,
-                )
-            )
         else:
             raise ValueError(f"Invalid mode: {mode}")
 

+ 1 - 42
tools/server/views.py

@@ -22,14 +22,12 @@ from typing_extensions import Annotated
 from fish_speech.utils.schema import (
     ServeASRRequest,
     ServeASRResponse,
-    ServeChatRequest,
     ServeTTSRequest,
     ServeVQGANDecodeRequest,
     ServeVQGANDecodeResponse,
     ServeVQGANEncodeRequest,
     ServeVQGANEncodeResponse,
 )
-from tools.server.agent import get_response_generator
 from tools.server.api_utils import (
     buffer_to_async_generator,
     get_content_type,
@@ -130,7 +128,7 @@ async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
     app_state = request.app.state
     model_manager: ModelManager = app_state.model_manager
     engine = model_manager.tts_inference_engine
-    sample_rate = engine.decoder_model.spec_transform.sample_rate
+    sample_rate = engine.decoder_model.sample_rate
 
     # Check if the text is too long
     if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
@@ -172,42 +170,3 @@ async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
             },
             content_type=get_content_type(req.format),
         )
-
-
-@routes.http.post("/v1/chat")
-async def chat(req: Annotated[ServeChatRequest, Body(exclusive=True)]):
-    # Check that the number of samples requested is correct
-    if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
-        raise HTTPException(
-            HTTPStatus.BAD_REQUEST,
-            content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
-        )
-
-    # Get the type of content provided
-    content_type = request.headers.get("Content-Type", "application/json")
-    json_mode = "application/json" in content_type
-
-    # Get the models from the app
-    model_manager: ModelManager = request.app.state.model_manager
-    llama_queue = model_manager.llama_queue
-    tokenizer = model_manager.tokenizer
-    config = model_manager.config
-
-    device = request.app.state.device
-
-    # Get the response generators
-    response_generator = get_response_generator(
-        llama_queue, tokenizer, config, req, device, json_mode
-    )
-
-    # Return the response in the correct format
-    if req.streaming is False:
-        result = response_generator()
-        if json_mode:
-            return JSONResponse(result.model_dump())
-        else:
-            return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
-
-    return StreamResponse(
-        iterable=response_generator(), content_type="text/event-stream"
-    )