api_utils.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from argparse import ArgumentParser
  2. from http import HTTPStatus
  3. from typing import Annotated, Any
  4. import ormsgpack
  5. from baize.datastructures import ContentType
  6. from kui.asgi import (
  7. HTTPException,
  8. HttpRequest,
  9. JSONResponse,
  10. request,
  11. )
  12. from loguru import logger
  13. from pydantic import BaseModel
  14. from fish_speech.inference_engine import TTSInferenceEngine
  15. from fish_speech.utils.schema import ServeTTSRequest
  16. from tools.server.inference import inference_wrapper as inference
  17. def parse_args():
  18. parser = ArgumentParser()
  19. parser.add_argument("--mode", type=str, choices=["tts"], default="tts")
  20. parser.add_argument(
  21. "--llama-checkpoint-path",
  22. type=str,
  23. default="checkpoints/s2-pro",
  24. )
  25. parser.add_argument(
  26. "--decoder-checkpoint-path",
  27. type=str,
  28. default="checkpoints/s2-pro/codec.pth",
  29. )
  30. parser.add_argument("--decoder-config-name", type=str, default="modded_dac_vq")
  31. parser.add_argument("--device", type=str, default="cuda")
  32. parser.add_argument("--half", action="store_true")
  33. parser.add_argument("--compile", action="store_true")
  34. parser.add_argument("--max-text-length", type=int, default=0)
  35. parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
  36. parser.add_argument("--workers", type=int, default=1)
  37. parser.add_argument("--num-workers", type=int, default=1, help="Number of model worker threads for parallel inference")
  38. parser.add_argument("--api-key", type=str, default=None)
  39. return parser.parse_args()
  40. class MsgPackRequest(HttpRequest):
  41. async def data(
  42. self,
  43. ) -> Annotated[
  44. Any,
  45. ContentType("application/msgpack"),
  46. ContentType("application/json"),
  47. ContentType("multipart/form-data"),
  48. ]:
  49. if self.content_type == "application/msgpack":
  50. return ormsgpack.unpackb(await self.body)
  51. elif self.content_type == "application/json":
  52. return await self.json
  53. elif self.content_type == "multipart/form-data":
  54. return await self.form
  55. raise HTTPException(
  56. HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
  57. headers={
  58. "Accept": "application/msgpack, application/json, multipart/form-data"
  59. },
  60. )
  61. async def inference_async(req: ServeTTSRequest, engine: TTSInferenceEngine):
  62. for chunk in inference(req, engine):
  63. print("Got chunk")
  64. if isinstance(chunk, bytes):
  65. yield chunk
  66. async def buffer_to_async_generator(buffer):
  67. yield buffer
  68. def get_content_type(audio_format):
  69. if audio_format == "wav":
  70. return "audio/wav"
  71. elif audio_format == "flac":
  72. return "audio/flac"
  73. elif audio_format == "mp3":
  74. return "audio/mpeg"
  75. elif audio_format == "opus":
  76. return "audio/ogg"
  77. else:
  78. return "application/octet-stream"
  79. def wants_json(req):
  80. """Helper method to determine if the client wants a JSON response
  81. Parameters
  82. ----------
  83. req : Request
  84. The request object
  85. Returns
  86. -------
  87. bool
  88. True if the client wants a JSON response, False otherwise
  89. """
  90. q = req.query_params.get("format", "").strip().lower()
  91. if q in {"json", "application/json", "msgpack", "application/msgpack"}:
  92. return q in ("json", "application/json")
  93. accept = req.headers.get("Accept", "").strip().lower()
  94. return "application/json" in accept and "application/msgpack" not in accept
  95. def format_response(response: BaseModel, status_code=200):
  96. """
  97. Helper function to format responses consistently based on client preference.
  98. Parameters
  99. ----------
  100. response : BaseModel
  101. The response object to format
  102. status_code : int
  103. HTTP status code (default: 200)
  104. Returns
  105. -------
  106. Response
  107. Formatted response in the client's preferred format
  108. """
  109. try:
  110. if wants_json(request):
  111. return JSONResponse(
  112. response.model_dump(mode="json"), status_code=status_code
  113. )
  114. return (
  115. ormsgpack.packb(
  116. response,
  117. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  118. ),
  119. status_code,
  120. {"Content-Type": "application/msgpack"},
  121. )
  122. except Exception as e:
  123. logger.error(f"Error formatting response: {e}", exc_info=True)
  124. # Fallback to JSON response if formatting fails
  125. return JSONResponse(
  126. {"error": "Response formatting failed", "details": str(e)}, status_code=500
  127. )