api.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import base64
  2. import io
  3. import traceback
  4. from argparse import ArgumentParser
  5. from http import HTTPStatus
  6. from threading import Lock
  7. from typing import Annotated, Literal, Optional
  8. import librosa
  9. import numpy as np
  10. import soundfile as sf
  11. import torch
  12. from kui.wsgi import (
  13. Body,
  14. HTTPException,
  15. HttpView,
  16. JSONResponse,
  17. Kui,
  18. OpenAPI,
  19. StreamResponse,
  20. allow_cors,
  21. )
  22. from kui.wsgi.routing import MultimethodRoutes
  23. from loguru import logger
  24. from pydantic import BaseModel
  25. from transformers import AutoTokenizer
  26. from tools.llama.generate import generate_long
  27. from tools.llama.generate import load_model as load_llama_model
  28. from tools.vqgan.inference import load_model as load_vqgan_model
  29. from tools.webui import inference
  30. lock = Lock()
  31. # Define utils for web server
  32. def http_execption_handler(exc: HTTPException):
  33. return JSONResponse(
  34. dict(
  35. statusCode=exc.status_code,
  36. message=exc.content,
  37. error=HTTPStatus(exc.status_code).phrase,
  38. ),
  39. exc.status_code,
  40. exc.headers,
  41. )
  42. def other_exception_handler(exc: "Exception"):
  43. traceback.print_exc()
  44. status = HTTPStatus.INTERNAL_SERVER_ERROR
  45. return JSONResponse(
  46. dict(statusCode=status, message=str(exc), error=status.phrase),
  47. status,
  48. )
  49. routes = MultimethodRoutes(base_class=HttpView)
  50. class InvokeRequest(BaseModel):
  51. text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
  52. reference_text: Optional[str] = None
  53. reference_audio: Optional[str] = None
  54. max_new_tokens: int = 0
  55. chunk_length: int = 30
  56. top_k: int = 0
  57. top_p: float = 0.7
  58. repetition_penalty: float = 1.5
  59. temperature: float = 0.7
  60. speaker: Optional[str] = None
  61. format: Literal["wav", "mp3", "flac"] = "wav"
  62. def inference(req: InvokeRequest):
  63. # Parse reference audio aka prompt
  64. prompt_tokens = None
  65. if req.reference_audio is not None:
  66. buffer = io.BytesIO(base64.b64decode(req.reference_audio))
  67. reference_audio_content, _ = librosa.load(
  68. buffer, sr=vqgan_model.sampling_rate, mono=True
  69. )
  70. audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
  71. None, None, :
  72. ]
  73. logger.info(
  74. f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
  75. )
  76. # VQ Encoder
  77. audio_lengths = torch.tensor(
  78. [audios.shape[2]], device=vqgan_model.device, dtype=torch.long
  79. )
  80. prompt_tokens = vqgan_model.encode(audios, audio_lengths)[0][0]
  81. # LLAMA Inference
  82. result = generate_long(
  83. model=llama_model,
  84. tokenizer=llama_tokenizer,
  85. device=vqgan_model.device,
  86. decode_one_token=decode_one_token,
  87. max_new_tokens=req.max_new_tokens,
  88. text=req.text,
  89. top_k=int(req.top_k) if req.top_k > 0 else None,
  90. top_p=req.top_p,
  91. repetition_penalty=req.repetition_penalty,
  92. temperature=req.temperature,
  93. compile=args.compile,
  94. iterative_prompt=req.chunk_length > 0,
  95. chunk_length=req.chunk_length,
  96. max_length=args.max_length,
  97. speaker=req.speaker,
  98. prompt_tokens=prompt_tokens,
  99. prompt_text=req.reference_text,
  100. )
  101. codes = next(result)
  102. # VQGAN Inference
  103. feature_lengths = torch.tensor([codes.shape[1]], device=vqgan_model.device)
  104. fake_audios = vqgan_model.decode(
  105. indices=codes[None], feature_lengths=feature_lengths, return_audios=True
  106. )[0, 0]
  107. fake_audios = fake_audios.float().cpu().numpy()
  108. return fake_audios
  109. @routes.http.post("/invoke")
  110. def api_invoke_model(
  111. req: Annotated[InvokeRequest, Body(exclusive=True)],
  112. ):
  113. """
  114. Invoke model and generate audio
  115. """
  116. if args.max_gradio_length > 0 and len(req.text) > args.max_gradio_length:
  117. raise HTTPException(
  118. HTTPStatus.BAD_REQUEST,
  119. f"Text is too long, max length is {args.max_gradio_length}",
  120. )
  121. try:
  122. # Lock, avoid interrupting the inference process
  123. lock.acquire()
  124. fake_audios = inference(req)
  125. except Exception as e:
  126. raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
  127. finally:
  128. # Release lock
  129. lock.release()
  130. buffer = io.BytesIO()
  131. sf.write(buffer, fake_audios, vqgan_model.sampling_rate, format=req.format)
  132. return StreamResponse(
  133. iterable=[buffer.getvalue()],
  134. headers={
  135. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  136. "Content-Type": "application/octet-stream",
  137. },
  138. )
  139. @routes.http.post("/health")
  140. def api_health():
  141. """
  142. Health check
  143. """
  144. return JSONResponse({"status": "ok"})
  145. def parse_args():
  146. parser = ArgumentParser()
  147. parser.add_argument(
  148. "--llama-checkpoint-path",
  149. type=str,
  150. default="checkpoints/text2semantic-medium-v1-2k.pth",
  151. )
  152. parser.add_argument(
  153. "--llama-config-name", type=str, default="dual_ar_2_codebook_medium"
  154. )
  155. parser.add_argument(
  156. "--vqgan-checkpoint-path",
  157. type=str,
  158. default="checkpoints/vq-gan-group-fsq-2x1024.pth",
  159. )
  160. parser.add_argument("--vqgan-config-name", type=str, default="vqgan_pretrain")
  161. parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
  162. parser.add_argument("--device", type=str, default="cuda")
  163. parser.add_argument("--half", action="store_true")
  164. parser.add_argument("--max-length", type=int, default=2048)
  165. parser.add_argument("--compile", action="store_true")
  166. parser.add_argument("--max-gradio-length", type=int, default=0)
  167. parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
  168. return parser.parse_args()
  169. # Define Kui app
  170. app = Kui(
  171. exception_handlers={
  172. HTTPException: http_execption_handler,
  173. Exception: other_exception_handler,
  174. },
  175. cors_config={},
  176. )
  177. # Swagger UI & routes
  178. app.router << ("/v1" // routes) << ("/docs" // OpenAPI().routes)
  179. args = parse_args()
  180. if __name__ == "__main__":
  181. from zibai import Options, main
  182. options = Options(
  183. app="tools.api:app",
  184. listen=[args.listen],
  185. )
  186. main(options)
  187. else:
  188. args.precision = torch.half if args.half else torch.bfloat16
  189. logger.info("Loading Llama model...")
  190. llama_model, decode_one_token = load_llama_model(
  191. config_name=args.llama_config_name,
  192. checkpoint_path=args.llama_checkpoint_path,
  193. device=args.device,
  194. precision=args.precision,
  195. max_length=args.max_length,
  196. compile=args.compile,
  197. )
  198. llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
  199. logger.info("Llama model loaded, loading VQ-GAN model...")
  200. vqgan_model = load_vqgan_model(
  201. config_name=args.vqgan_config_name,
  202. checkpoint_path=args.vqgan_checkpoint_path,
  203. device=args.device,
  204. )
  205. logger.info("VQ-GAN model loaded, warming up...")
  206. # Dry run to check if the model is loaded correctly and avoid the first-time latency
  207. inference(
  208. InvokeRequest(
  209. text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
  210. reference_text=None,
  211. reference_audio=None,
  212. max_new_tokens=0,
  213. chunk_length=30,
  214. top_k=0,
  215. top_p=0.7,
  216. repetition_penalty=1.5,
  217. temperature=0.7,
  218. speaker=None,
  219. format="wav",
  220. )
  221. )
  222. logger.info("Warming up done.")