api_utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from argparse import ArgumentParser
  2. from http import HTTPStatus
  3. from typing import Annotated, Any
  4. import ormsgpack
  5. from baize.datastructures import ContentType
  6. from kui.asgi import (
  7. HTTPException,
  8. HttpRequest,
  9. JSONResponse,
  10. request,
  11. )
  12. from loguru import logger
  13. from pydantic import BaseModel
  14. from fish_speech.inference_engine import TTSInferenceEngine
  15. from fish_speech.utils.schema import ServeTTSRequest
  16. from tools.server.inference import inference_wrapper as inference
  17. def parse_args():
  18. parser = ArgumentParser()
  19. parser.add_argument("--mode", type=str, choices=["tts"], default="tts")
  20. parser.add_argument(
  21. "--llama-checkpoint-path",
  22. type=str,
  23. default="checkpoints/s2-pro",
  24. )
  25. parser.add_argument(
  26. "--decoder-checkpoint-path",
  27. type=str,
  28. default="checkpoints/s2-pro/codec.pth",
  29. )
  30. parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
  31. parser.add_argument("--device", type=str, default="cuda")
  32. parser.add_argument("--half", action="store_true")
  33. parser.add_argument("--compile", action="store_true")
  34. parser.add_argument("--max-text-length", type=int, default=0)
  35. parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
  36. parser.add_argument("--workers", type=int, default=1)
  37. parser.add_argument("--api-key", type=str, default=None)
  38. return parser.parse_args()
  39. class MsgPackRequest(HttpRequest):
  40. async def data(
  41. self,
  42. ) -> Annotated[
  43. Any,
  44. ContentType("application/msgpack"),
  45. ContentType("application/json"),
  46. ContentType("multipart/form-data"),
  47. ]:
  48. if self.content_type == "application/msgpack":
  49. return ormsgpack.unpackb(await self.body)
  50. elif self.content_type == "application/json":
  51. return await self.json
  52. elif self.content_type == "multipart/form-data":
  53. return await self.form
  54. raise HTTPException(
  55. HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
  56. headers={
  57. "Accept": "application/msgpack, application/json, multipart/form-data"
  58. },
  59. )
  60. async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
  61. for chunk in inference(req, engine):
  62. print("Got chunk")
  63. if isinstance(chunk, bytes):
  64. yield chunk
  65. async def buffer_to_async_generator(buffer):
  66. yield buffer
  67. def get_content_type(audio_format):
  68. if audio_format == "wav":
  69. return "audio/wav"
  70. elif audio_format == "flac":
  71. return "audio/flac"
  72. elif audio_format == "mp3":
  73. return "audio/mpeg"
  74. elif audio_format == "opus":
  75. return "audio/ogg"
  76. else:
  77. return "application/octet-stream"
  78. def wants_json(req):
  79. """Helper method to determine if the client wants a JSON response
  80. Parameters
  81. ----------
  82. req : Request
  83. The request object
  84. Returns
  85. -------
  86. bool
  87. True if the client wants a JSON response, False otherwise
  88. """
  89. q = req.query_params.get("format", "").strip().lower()
  90. if q in {"json", "application/json", "msgpack", "application/msgpack"}:
  91. return q in ("json", "application/json")
  92. accept = req.headers.get("Accept", "").strip().lower()
  93. return "application/json" in accept and "application/msgpack" not in accept
  94. def format_response(response: BaseModel, status_code=200):
  95. """
  96. Helper function to format responses consistently based on client preference.
  97. Parameters
  98. ----------
  99. response : BaseModel
  100. The response object to format
  101. status_code : int
  102. HTTP status code (default: 200)
  103. Returns
  104. -------
  105. Response
  106. Formatted response in the client's preferred format
  107. """
  108. try:
  109. if wants_json(request):
  110. return JSONResponse(
  111. response.model_dump(mode="json"), status_code=status_code
  112. )
  113. return (
  114. ormsgpack.packb(
  115. response,
  116. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  117. ),
  118. status_code,
  119. {"Content-Type": "application/msgpack"},
  120. )
  121. except Exception as e:
  122. logger.error(f"Error formatting response: {e}", exc_info=True)
  123. # Fallback to JSON response if formatting fails
  124. return JSONResponse(
  125. {"error": "Response formatting failed", "details": str(e)}, status_code=500
  126. )