api_utils.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from argparse import ArgumentParser
  2. from http import HTTPStatus
  3. from typing import Annotated, Any
  4. import ormsgpack
  5. from baize.datastructures import ContentType
  6. from kui.asgi import HTTPException, HttpRequest
  7. from fish_speech.inference_engine import TTSInferenceEngine
  8. from fish_speech.utils.schema import ServeTTSRequest
  9. from tools.server.inference import inference_wrapper as inference
  10. def parse_args():
  11. parser = ArgumentParser()
  12. parser.add_argument("--mode", type=str, choices=["tts"], default="tts")
  13. parser.add_argument(
  14. "--llama-checkpoint-path",
  15. type=str,
  16. default="checkpoints/openaudio-s1-mini",
  17. )
  18. parser.add_argument(
  19. "--decoder-checkpoint-path",
  20. type=str,
  21. default="checkpoints/openaudio-s1-mini/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
  22. )
  23. parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
  24. parser.add_argument("--device", type=str, default="cuda")
  25. parser.add_argument("--half", action="store_true")
  26. parser.add_argument("--compile", action="store_true")
  27. parser.add_argument("--max-text-length", type=int, default=0)
  28. parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
  29. parser.add_argument("--workers", type=int, default=1)
  30. parser.add_argument("--api-key", type=str, default=None)
  31. return parser.parse_args()
  32. class MsgPackRequest(HttpRequest):
  33. async def data(
  34. self,
  35. ) -> Annotated[
  36. Any, ContentType("application/msgpack"), ContentType("application/json")
  37. ]:
  38. if self.content_type == "application/msgpack":
  39. return ormsgpack.unpackb(await self.body)
  40. elif self.content_type == "application/json":
  41. return await self.json
  42. raise HTTPException(
  43. HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
  44. headers={"Accept": "application/msgpack, application/json"},
  45. )
  46. async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
  47. for chunk in inference(req, engine):
  48. if isinstance(chunk, bytes):
  49. yield chunk
  50. async def buffer_to_async_generator(buffer):
  51. yield buffer
  52. def get_content_type(audio_format):
  53. if audio_format == "wav":
  54. return "audio/wav"
  55. elif audio_format == "flac":
  56. return "audio/flac"
  57. elif audio_format == "mp3":
  58. return "audio/mpeg"
  59. else:
  60. return "application/octet-stream"