api.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import base64
  2. import io
  3. import threading
  4. import traceback
  5. from argparse import ArgumentParser
  6. from http import HTTPStatus
  7. from threading import Lock
  8. from typing import Annotated, Literal, Optional
  9. import librosa
  10. import soundfile as sf
  11. import torch
  12. from kui.wsgi import (
  13. Body,
  14. HTTPException,
  15. HttpView,
  16. JSONResponse,
  17. Kui,
  18. OpenAPI,
  19. StreamResponse,
  20. )
  21. from kui.wsgi.routing import MultimethodRoutes
  22. from loguru import logger
  23. from pydantic import BaseModel
  24. from transformers import AutoTokenizer
  25. from tools.llama.generate import launch_thread_safe_queue
  26. from tools.vqgan.inference import load_model as load_vqgan_model
  27. from tools.webui import inference
  28. lock = Lock()
  29. # Define utils for web server
  30. def http_execption_handler(exc: HTTPException):
  31. return JSONResponse(
  32. dict(
  33. statusCode=exc.status_code,
  34. message=exc.content,
  35. error=HTTPStatus(exc.status_code).phrase,
  36. ),
  37. exc.status_code,
  38. exc.headers,
  39. )
  40. def other_exception_handler(exc: "Exception"):
  41. traceback.print_exc()
  42. status = HTTPStatus.INTERNAL_SERVER_ERROR
  43. return JSONResponse(
  44. dict(statusCode=status, message=str(exc), error=status.phrase),
  45. status,
  46. )
  47. routes = MultimethodRoutes(base_class=HttpView)
  48. class InvokeRequest(BaseModel):
  49. text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
  50. reference_text: Optional[str] = None
  51. reference_audio: Optional[str] = None
  52. max_new_tokens: int = 0
  53. chunk_length: int = 30
  54. top_k: int = 0
  55. top_p: float = 0.7
  56. repetition_penalty: float = 1.5
  57. temperature: float = 0.7
  58. speaker: Optional[str] = None
  59. format: Literal["wav", "mp3", "flac"] = "wav"
  60. @torch.inference_mode()
  61. def inference(req: InvokeRequest):
  62. # Parse reference audio aka prompt
  63. prompt_tokens = None
  64. if req.reference_audio is not None:
  65. buffer = io.BytesIO(base64.b64decode(req.reference_audio))
  66. reference_audio_content, _ = librosa.load(
  67. buffer, sr=vqgan_model.sampling_rate, mono=True
  68. )
  69. audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
  70. None, None, :
  71. ]
  72. logger.info(
  73. f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
  74. )
  75. # VQ Encoder
  76. audio_lengths = torch.tensor(
  77. [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
  78. )
  79. prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
  80. # LLAMA Inference
  81. request = dict(
  82. tokenizer=llama_tokenizer,
  83. device=vqgan_model.device,
  84. max_new_tokens=req.max_new_tokens,
  85. text=req.text,
  86. top_k=int(req.top_k) if req.top_k > 0 else None,
  87. top_p=req.top_p,
  88. repetition_penalty=req.repetition_penalty,
  89. temperature=req.temperature,
  90. compile=args.compile,
  91. iterative_prompt=req.chunk_length > 0,
  92. chunk_length=req.chunk_length,
  93. max_length=args.max_length,
  94. speaker=req.speaker,
  95. prompt_tokens=prompt_tokens,
  96. prompt_text=req.reference_text,
  97. )
  98. payload = dict(
  99. event=threading.Event(),
  100. request=request,
  101. )
  102. llama_queue.put(payload)
  103. # Wait for the result
  104. payload["event"].wait()
  105. if payload["success"] is False:
  106. raise payload["response"]
  107. codes = payload["response"][0]
  108. # VQGAN Inference
  109. feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
  110. fake_audios = vqgan_model.decode(
  111. indices=codes[None], feature_lengths=feature_lengths, return_audios=True
  112. )[0, 0]
  113. fake_audios = fake_audios.float().cpu().numpy()
  114. return fake_audios
  115. @routes.http.post("/v1/invoke")
  116. def api_invoke_model(
  117. req: Annotated[InvokeRequest, Body(exclusive=True)],
  118. ):
  119. """
  120. Invoke model and generate audio
  121. """
  122. if args.max_text_length > 0 and len(req.text) > args.max_text_length:
  123. raise HTTPException(
  124. HTTPStatus.BAD_REQUEST,
  125. content=f"Text is too long, max length is {args.max_text_length}",
  126. )
  127. try:
  128. # Lock, avoid interrupting the inference process
  129. lock.acquire()
  130. fake_audios = inference(req)
  131. except Exception as e:
  132. import traceback
  133. traceback.print_exc()
  134. raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, content=str(e))
  135. finally:
  136. # Release lock
  137. lock.release()
  138. buffer = io.BytesIO()
  139. sf.write(buffer, fake_audios, vqgan_model.sampling_rate, format=req.format)
  140. return StreamResponse(
  141. iterable=[buffer.getvalue()],
  142. headers={
  143. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  144. },
  145. # Make swagger-ui happy
  146. # content_type=f"audio/{req.format}",
  147. content_type="application/octet-stream",
  148. )
  149. @routes.http.post("/v1/health")
  150. def api_health():
  151. """
  152. Health check
  153. """
  154. return JSONResponse({"status": "ok"})
  155. def parse_args():
  156. parser = ArgumentParser()
  157. parser.add_argument(
  158. "--llama-checkpoint-path",
  159. type=str,
  160. default="checkpoints/text2semantic-sft-large-v1-4k.pth",
  161. )
  162. parser.add_argument(
  163. "--llama-config-name", type=str, default="dual_ar_2_codebook_large"
  164. )
  165. parser.add_argument(
  166. "--vqgan-checkpoint-path",
  167. type=str,
  168. default="checkpoints/vq-gan-group-fsq-2x1024.pth",
  169. )
  170. parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
  171. parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
  172. parser.add_argument("--device", type=str, default="cuda")
  173. parser.add_argument("--half", action="store_true")
  174. parser.add_argument("--max-length", type=int, default=2048)
  175. parser.add_argument("--compile", action="store_true")
  176. parser.add_argument("--max-text-length", type=int, default=0)
  177. parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
  178. return parser.parse_args()
  179. # Define Kui app
  180. openapi = OpenAPI(
  181. {
  182. "title": "Fish Speech API",
  183. },
  184. ).routes
  185. app = Kui(
  186. routes=routes + openapi[1:], # Remove the default route
  187. exception_handlers={
  188. HTTPException: http_execption_handler,
  189. Exception: other_exception_handler,
  190. },
  191. cors_config={},
  192. )
  193. if __name__ == "__main__":
  194. import threading
  195. from zibai import create_bind_socket, serve
  196. args = parse_args()
  197. args.precision = torch.half if args.half else torch.bfloat16
  198. logger.info("Loading Llama model...")
  199. llama_queue = launch_thread_safe_queue(
  200. config_name=args.llama_config_name,
  201. checkpoint_path=args.llama_checkpoint_path,
  202. device=args.device,
  203. precision=args.precision,
  204. max_length=args.max_length,
  205. compile=args.compile,
  206. )
  207. llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
  208. logger.info("Llama model loaded, loading VQ-GAN model...")
  209. vqgan_model = load_vqgan_model(
  210. config_name=args.vqgan_config_name,
  211. checkpoint_path=args.vqgan_checkpoint_path,
  212. device=args.device,
  213. )
  214. logger.info("VQ-GAN model loaded, warming up...")
  215. # Dry run to check if the model is loaded correctly and avoid the first-time latency
  216. inference(
  217. InvokeRequest(
  218. text="A warm-up sentence.",
  219. reference_text=None,
  220. reference_audio=None,
  221. max_new_tokens=0,
  222. chunk_length=30,
  223. top_k=0,
  224. top_p=0.7,
  225. repetition_penalty=1.5,
  226. temperature=0.7,
  227. speaker=None,
  228. format="wav",
  229. )
  230. )
  231. logger.info(f"Warming up done, starting server at http://{args.listen}")
  232. sock = create_bind_socket(args.listen)
  233. sock.listen()
  234. # Start server
  235. serve(
  236. app=app,
  237. bind_sockets=[sock],
  238. max_workers=10,
  239. graceful_exit=threading.Event(),
  240. )