| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576 |
- from argparse import ArgumentParser
- from http import HTTPStatus
- from typing import Annotated, Any
- import ormsgpack
- from baize.datastructures import ContentType
- from kui.asgi import HTTPException, HttpRequest
- from fish_speech.inference_engine import TTSInferenceEngine
- from fish_speech.utils.schema import ServeTTSRequest
- from tools.server.inference import inference_wrapper as inference
- def parse_args():
- parser = ArgumentParser()
- parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
- parser.add_argument("--load-asr-model", action="store_true")
- parser.add_argument(
- "--llama-checkpoint-path",
- type=str,
- default="checkpoints/fish-speech-1.5",
- )
- parser.add_argument(
- "--decoder-checkpoint-path",
- type=str,
- default="checkpoints/fish-speech-1.5/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
- )
- parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
- parser.add_argument("--device", type=str, default="cuda")
- parser.add_argument("--half", action="store_true")
- parser.add_argument("--compile", action="store_true")
- parser.add_argument("--max-text-length", type=int, default=0)
- parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
- parser.add_argument("--workers", type=int, default=1)
- parser.add_argument("--api-key", type=str, default=None)
- return parser.parse_args()
- class MsgPackRequest(HttpRequest):
- async def data(
- self,
- ) -> Annotated[
- Any, ContentType("application/msgpack"), ContentType("application/json")
- ]:
- if self.content_type == "application/msgpack":
- return ormsgpack.unpackb(await self.body)
- elif self.content_type == "application/json":
- return await self.json
- raise HTTPException(
- HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
- headers={"Accept": "application/msgpack, application/json"},
- )
- async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
- for chunk in inference(req, engine):
- if isinstance(chunk, bytes):
- yield chunk
- async def buffer_to_async_generator(buffer):
- yield buffer
- def get_content_type(audio_format):
- if audio_format == "wav":
- return "audio/wav"
- elif audio_format == "flac":
- return "audio/flac"
- elif audio_format == "mp3":
- return "audio/mpeg"
- else:
- return "application/octet-stream"
|