api.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import base64
  2. import io
  3. import queue
  4. import threading
  5. import traceback
  6. from argparse import ArgumentParser
  7. from http import HTTPStatus
  8. from threading import Lock
  9. from typing import Annotated, Literal, Optional
  10. import librosa
  11. import pyrootutils
  12. import soundfile as sf
  13. import torch
  14. from kui.wsgi import (
  15. Body,
  16. HTTPException,
  17. HttpView,
  18. JSONResponse,
  19. Kui,
  20. OpenAPI,
  21. StreamResponse,
  22. )
  23. from kui.wsgi.routing import MultimethodRoutes
  24. from loguru import logger
  25. from pydantic import BaseModel
  26. from transformers import AutoTokenizer
  27. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  28. from tools.llama.generate import launch_thread_safe_queue
  29. from tools.vqgan.inference import load_model as load_vqgan_model
  30. from tools.webui import inference
  31. lock = Lock()
  32. # Define utils for web server
  33. def http_execption_handler(exc: HTTPException):
  34. return JSONResponse(
  35. dict(
  36. statusCode=exc.status_code,
  37. message=exc.content,
  38. error=HTTPStatus(exc.status_code).phrase,
  39. ),
  40. exc.status_code,
  41. exc.headers,
  42. )
  43. def other_exception_handler(exc: "Exception"):
  44. traceback.print_exc()
  45. status = HTTPStatus.INTERNAL_SERVER_ERROR
  46. return JSONResponse(
  47. dict(statusCode=status, message=str(exc), error=status.phrase),
  48. status,
  49. )
  50. routes = MultimethodRoutes(base_class=HttpView)
  51. class InvokeRequest(BaseModel):
  52. text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
  53. reference_text: Optional[str] = None
  54. reference_audio: Optional[str] = None
  55. max_new_tokens: int = 0
  56. chunk_length: int = 30
  57. top_p: float = 0.7
  58. repetition_penalty: float = 1.5
  59. temperature: float = 0.7
  60. speaker: Optional[str] = None
  61. format: Literal["wav", "mp3", "flac"] = "wav"
  62. @torch.inference_mode()
  63. def inference(req: InvokeRequest):
  64. # Parse reference audio aka prompt
  65. prompt_tokens = None
  66. if req.reference_audio is not None:
  67. buffer = io.BytesIO(base64.b64decode(req.reference_audio))
  68. reference_audio_content, _ = librosa.load(
  69. buffer, sr=vqgan_model.sampling_rate, mono=True
  70. )
  71. audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
  72. None, None, :
  73. ]
  74. logger.info(
  75. f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
  76. )
  77. # VQ Encoder
  78. audio_lengths = torch.tensor(
  79. [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
  80. )
  81. prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
  82. # LLAMA Inference
  83. request = dict(
  84. tokenizer=llama_tokenizer,
  85. device=vqgan_model.device,
  86. max_new_tokens=req.max_new_tokens,
  87. text=req.text,
  88. top_p=req.top_p,
  89. repetition_penalty=req.repetition_penalty,
  90. temperature=req.temperature,
  91. compile=args.compile,
  92. iterative_prompt=req.chunk_length > 0,
  93. chunk_length=req.chunk_length,
  94. max_length=args.max_length,
  95. speaker=req.speaker,
  96. prompt_tokens=prompt_tokens,
  97. prompt_text=req.reference_text,
  98. )
  99. payload = dict(
  100. response_queue=queue.Queue(),
  101. request=request,
  102. )
  103. llama_queue.put(payload)
  104. codes = []
  105. while True:
  106. result = payload["response_queue"].get()
  107. if result == "next":
  108. # TODO: handle next sentence
  109. continue
  110. if result == "done":
  111. if payload["success"] is False:
  112. raise payload["response"]
  113. break
  114. codes.append(result)
  115. codes = torch.cat(codes, dim=1)
  116. # VQGAN Inference
  117. feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
  118. fake_audios = vqgan_model.decode(
  119. indices=codes[None], feature_lengths=feature_lengths, return_audios=True
  120. )[0, 0]
  121. fake_audios = fake_audios.float().cpu().numpy()
  122. return fake_audios
  123. @routes.http.post("/v1/invoke")
  124. def api_invoke_model(
  125. req: Annotated[InvokeRequest, Body(exclusive=True)],
  126. ):
  127. """
  128. Invoke model and generate audio
  129. """
  130. if args.max_text_length > 0 and len(req.text) > args.max_text_length:
  131. raise HTTPException(
  132. HTTPStatus.BAD_REQUEST,
  133. content=f"Text is too long, max length is {args.max_text_length}",
  134. )
  135. try:
  136. # Lock, avoid interrupting the inference process
  137. lock.acquire()
  138. fake_audios = inference(req)
  139. except Exception as e:
  140. import traceback
  141. traceback.print_exc()
  142. raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, content=str(e))
  143. finally:
  144. # Release lock
  145. lock.release()
  146. buffer = io.BytesIO()
  147. sf.write(buffer, fake_audios, vqgan_model.sampling_rate, format=req.format)
  148. return StreamResponse(
  149. iterable=[buffer.getvalue()],
  150. headers={
  151. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  152. },
  153. # Make swagger-ui happy
  154. # content_type=f"audio/{req.format}",
  155. content_type="application/octet-stream",
  156. )
  157. @routes.http.post("/v1/health")
  158. def api_health():
  159. """
  160. Health check
  161. """
  162. return JSONResponse({"status": "ok"})
  163. def parse_args():
  164. parser = ArgumentParser()
  165. parser.add_argument(
  166. "--llama-checkpoint-path",
  167. type=str,
  168. default="checkpoints/text2semantic-sft-large-v1-4k.pth",
  169. )
  170. parser.add_argument(
  171. "--llama-config-name", type=str, default="dual_ar_2_codebook_large"
  172. )
  173. parser.add_argument(
  174. "--vqgan-checkpoint-path",
  175. type=str,
  176. default="checkpoints/vq-gan-group-fsq-2x1024.pth",
  177. )
  178. parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
  179. parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
  180. parser.add_argument("--device", type=str, default="cuda")
  181. parser.add_argument("--half", action="store_true")
  182. parser.add_argument("--max-length", type=int, default=2048)
  183. parser.add_argument("--compile", action="store_true")
  184. parser.add_argument("--max-text-length", type=int, default=0)
  185. parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
  186. return parser.parse_args()
  187. # Define Kui app
  188. openapi = OpenAPI(
  189. {
  190. "title": "Fish Speech API",
  191. },
  192. ).routes
  193. app = Kui(
  194. routes=routes + openapi[1:], # Remove the default route
  195. exception_handlers={
  196. HTTPException: http_execption_handler,
  197. Exception: other_exception_handler,
  198. },
  199. cors_config={},
  200. )
  201. if __name__ == "__main__":
  202. import threading
  203. from zibai import create_bind_socket, serve
  204. args = parse_args()
  205. args.precision = torch.half if args.half else torch.bfloat16
  206. logger.info("Loading Llama model...")
  207. llama_queue = launch_thread_safe_queue(
  208. config_name=args.llama_config_name,
  209. checkpoint_path=args.llama_checkpoint_path,
  210. device=args.device,
  211. precision=args.precision,
  212. max_length=args.max_length,
  213. compile=args.compile,
  214. )
  215. llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
  216. logger.info("Llama model loaded, loading VQ-GAN model...")
  217. vqgan_model = load_vqgan_model(
  218. config_name=args.vqgan_config_name,
  219. checkpoint_path=args.vqgan_checkpoint_path,
  220. device=args.device,
  221. )
  222. logger.info("VQ-GAN model loaded, warming up...")
  223. # Dry run to check if the model is loaded correctly and avoid the first-time latency
  224. inference(
  225. InvokeRequest(
  226. text="A warm-up sentence.",
  227. reference_text=None,
  228. reference_audio=None,
  229. max_new_tokens=0,
  230. chunk_length=30,
  231. top_p=0.7,
  232. repetition_penalty=1.5,
  233. temperature=0.7,
  234. speaker=None,
  235. format="wav",
  236. )
  237. )
  238. logger.info(f"Warming up done, starting server at http://{args.listen}")
  239. sock = create_bind_socket(args.listen)
  240. sock.listen()
  241. # Start server
  242. serve(
  243. app=app,
  244. bind_sockets=[sock],
  245. max_workers=10,
  246. graceful_exit=threading.Event(),
  247. )