| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import io
- import os
- import time
- from http import HTTPStatus
- import numpy as np
- import ormsgpack
- import soundfile as sf
- import torch
- from kui.asgi import (
- Body,
- HTTPException,
- HttpView,
- JSONResponse,
- Routes,
- StreamResponse,
- request,
- )
- from loguru import logger
- from typing_extensions import Annotated
- from fish_speech.utils.schema import (
- ServeTTSRequest,
- ServeVQGANDecodeRequest,
- ServeVQGANDecodeResponse,
- ServeVQGANEncodeRequest,
- ServeVQGANEncodeResponse,
- )
- from tools.server.api_utils import (
- buffer_to_async_generator,
- get_content_type,
- inference_async,
- )
- from tools.server.inference import inference_wrapper as inference
- from tools.server.model_manager import ModelManager
- from tools.server.model_utils import (
- batch_asr,
- batch_vqgan_decode,
- cached_vqgan_batch_encode,
- )
- MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
- routes = Routes()
- @routes.http("/v1/health")
- class Health(HttpView):
- @classmethod
- async def get(cls):
- return JSONResponse({"status": "ok"})
- @classmethod
- async def post(cls):
- return JSONResponse({"status": "ok"})
- @routes.http.post("/v1/vqgan/encode")
- async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
- # Get the model from the app
- model_manager: ModelManager = request.app.state.model_manager
- decoder_model = model_manager.decoder_model
- # Encode the audio
- start_time = time.time()
- tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
- logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
- # Return the response
- return ormsgpack.packb(
- ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
- option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
- )
- @routes.http.post("/v1/vqgan/decode")
- async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
- # Get the model from the app
- model_manager: ModelManager = request.app.state.model_manager
- decoder_model = model_manager.decoder_model
- # Decode the audio
- tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
- start_time = time.time()
- audios = batch_vqgan_decode(decoder_model, tokens)
- logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
- audios = [audio.astype(np.float16).tobytes() for audio in audios]
- # Return the response
- return ormsgpack.packb(
- ServeVQGANDecodeResponse(audios=audios),
- option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
- )
- @routes.http.post("/v1/tts")
- async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
- # Get the model from the app
- app_state = request.app.state
- model_manager: ModelManager = app_state.model_manager
- engine = model_manager.tts_inference_engine
- sample_rate = engine.decoder_model.sample_rate
- # Check if the text is too long
- if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
- raise HTTPException(
- HTTPStatus.BAD_REQUEST,
- content=f"Text is too long, max length is {app_state.max_text_length}",
- )
- # Check if streaming is enabled
- if req.streaming and req.format != "wav":
- raise HTTPException(
- HTTPStatus.BAD_REQUEST,
- content="Streaming only supports WAV format",
- )
- # Perform TTS
- if req.streaming:
- return StreamResponse(
- iterable=inference_async(req, engine),
- headers={
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
- },
- content_type=get_content_type(req.format),
- )
- else:
- fake_audios = next(inference(req, engine))
- buffer = io.BytesIO()
- sf.write(
- buffer,
- fake_audios,
- sample_rate,
- format=req.format,
- )
- return StreamResponse(
- iterable=buffer_to_async_generator(buffer.getvalue()),
- headers={
- "Content-Disposition": f"attachment; filename=audio.{req.format}",
- },
- content_type=get_content_type(req.format),
- )
|