api_utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from argparse import ArgumentParser
  2. from http import HTTPStatus
  3. from typing import Annotated, Any
  4. import ormsgpack
  5. from baize.datastructures import ContentType
  6. from kui.asgi import (
  7. HTTPException,
  8. HttpRequest,
  9. JSONResponse,
  10. request,
  11. )
  12. from loguru import logger
  13. from pydantic import BaseModel
  14. from fish_speech.inference_engine import TTSInferenceEngine
  15. from fish_speech.utils.schema import ServeTTSRequest
  16. from tools.server.inference import inference_wrapper as inference
  17. def parse_args():
  18. parser = ArgumentParser()
  19. parser.add_argument("--mode", type=str, choices=["tts"], default="tts")
  20. parser.add_argument(
  21. "--llama-checkpoint-path",
  22. type=str,
  23. default="checkpoints/openaudio-s1-mini",
  24. )
  25. parser.add_argument(
  26. "--decoder-checkpoint-path",
  27. type=str,
  28. default="checkpoints/openaudio-s1-mini/codec.pth",
  29. )
  30. parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
  31. parser.add_argument("--device", type=str, default="cuda")
  32. parser.add_argument("--half", action="store_true")
  33. parser.add_argument("--compile", action="store_true")
  34. parser.add_argument("--max-text-length", type=int, default=0)
  35. parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
  36. parser.add_argument("--workers", type=int, default=1)
  37. parser.add_argument("--api-key", type=str, default=None)
  38. return parser.parse_args()
  39. class MsgPackRequest(HttpRequest):
  40. async def data(
  41. self,
  42. ) -> Annotated[
  43. Any,
  44. ContentType("application/msgpack"),
  45. ContentType("application/json"),
  46. ContentType("multipart/form-data"),
  47. ]:
  48. if self.content_type == "application/msgpack":
  49. return ormsgpack.unpackb(await self.body)
  50. elif self.content_type == "application/json":
  51. return await self.json
  52. elif self.content_type == "multipart/form-data":
  53. return await self.form
  54. raise HTTPException(
  55. HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
  56. headers={
  57. "Accept": "application/msgpack, application/json, multipart/form-data"
  58. },
  59. )
  60. async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
  61. for chunk in inference(req, engine):
  62. print("Got chunk")
  63. if isinstance(chunk, bytes):
  64. yield chunk
  65. async def buffer_to_async_generator(buffer):
  66. yield buffer
  67. def get_content_type(audio_format):
  68. if audio_format == "wav":
  69. return "audio/wav"
  70. elif audio_format == "flac":
  71. return "audio/flac"
  72. elif audio_format == "mp3":
  73. return "audio/mpeg"
  74. else:
  75. return "application/octet-stream"
  76. def wants_json(req):
  77. """Helper method to determine if the client wants a JSON response
  78. Parameters
  79. ----------
  80. req : Request
  81. The request object
  82. Returns
  83. -------
  84. bool
  85. True if the client wants a JSON response, False otherwise
  86. """
  87. q = req.query_params.get("format", "").strip().lower()
  88. if q in {"json", "application/json", "msgpack", "application/msgpack"}:
  89. return q == "json"
  90. accept = req.headers.get("Accept", "").strip().lower()
  91. return "application/json" in accept and "application/msgpack" not in accept
  92. def format_response(response: BaseModel, status_code=200):
  93. """
  94. Helper function to format responses consistently based on client preference.
  95. Parameters
  96. ----------
  97. response : BaseModel
  98. The response object to format
  99. status_code : int
  100. HTTP status code (default: 200)
  101. Returns
  102. -------
  103. Response
  104. Formatted response in the client's preferred format
  105. """
  106. try:
  107. if wants_json(request):
  108. return JSONResponse(
  109. response.model_dump(mode="json"), status_code=status_code
  110. )
  111. return (
  112. ormsgpack.packb(
  113. response,
  114. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  115. ),
  116. status_code,
  117. {"Content-Type": "application/msgpack"},
  118. )
  119. except Exception as e:
  120. logger.error(f"Error formatting response: {e}", exc_info=True)
  121. # Fallback to JSON response if formatting fails
  122. return JSONResponse(
  123. {"error": "Response formatting failed", "details": str(e)}, status_code=500
  124. )