|
|
@@ -22,14 +22,12 @@ from typing_extensions import Annotated
|
|
|
from fish_speech.utils.schema import (
|
|
|
ServeASRRequest,
|
|
|
ServeASRResponse,
|
|
|
- ServeChatRequest,
|
|
|
ServeTTSRequest,
|
|
|
ServeVQGANDecodeRequest,
|
|
|
ServeVQGANDecodeResponse,
|
|
|
ServeVQGANEncodeRequest,
|
|
|
ServeVQGANEncodeResponse,
|
|
|
)
|
|
|
-from tools.server.agent import get_response_generator
|
|
|
from tools.server.api_utils import (
|
|
|
buffer_to_async_generator,
|
|
|
get_content_type,
|
|
|
@@ -130,7 +128,7 @@ async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
|
|
|
app_state = request.app.state
|
|
|
model_manager: ModelManager = app_state.model_manager
|
|
|
engine = model_manager.tts_inference_engine
|
|
|
- sample_rate = engine.decoder_model.spec_transform.sample_rate
|
|
|
+ sample_rate = engine.decoder_model.sample_rate
|
|
|
|
|
|
# Check if the text is too long
|
|
|
if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
|
|
|
@@ -172,42 +170,3 @@ async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
|
|
|
},
|
|
|
content_type=get_content_type(req.format),
|
|
|
)
|
|
|
-
|
|
|
-
|
|
|
-@routes.http.post("/v1/chat")
|
|
|
-async def chat(req: Annotated[ServeChatRequest, Body(exclusive=True)]):
|
|
|
- # Check that the number of samples requested is correct
|
|
|
- if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
|
|
|
- raise HTTPException(
|
|
|
- HTTPStatus.BAD_REQUEST,
|
|
|
- content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
|
|
|
- )
|
|
|
-
|
|
|
- # Get the type of content provided
|
|
|
- content_type = request.headers.get("Content-Type", "application/json")
|
|
|
- json_mode = "application/json" in content_type
|
|
|
-
|
|
|
- # Get the models from the app
|
|
|
- model_manager: ModelManager = request.app.state.model_manager
|
|
|
- llama_queue = model_manager.llama_queue
|
|
|
- tokenizer = model_manager.tokenizer
|
|
|
- config = model_manager.config
|
|
|
-
|
|
|
- device = request.app.state.device
|
|
|
-
|
|
|
- # Get the response generators
|
|
|
- response_generator = get_response_generator(
|
|
|
- llama_queue, tokenizer, config, req, device, json_mode
|
|
|
- )
|
|
|
-
|
|
|
- # Return the response in the correct format
|
|
|
- if req.streaming is False:
|
|
|
- result = response_generator()
|
|
|
- if json_mode:
|
|
|
- return JSONResponse(result.model_dump())
|
|
|
- else:
|
|
|
- return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
|
|
|
-
|
|
|
- return StreamResponse(
|
|
|
- iterable=response_generator(), content_type="text/event-stream"
|
|
|
- )
|