api.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. import base64
  2. import io
  3. import json
  4. import queue
  5. import random
  6. import traceback
  7. import wave
  8. from argparse import ArgumentParser
  9. from http import HTTPStatus
  10. from pathlib import Path
  11. from typing import Annotated, Literal, Optional
  12. import librosa
  13. import numpy as np
  14. import pyrootutils
  15. import soundfile as sf
  16. import torch
  17. from kui.asgi import (
  18. Body,
  19. HTTPException,
  20. HttpView,
  21. JSONResponse,
  22. Kui,
  23. OpenAPI,
  24. StreamResponse,
  25. )
  26. from kui.asgi.routing import MultimethodRoutes
  27. from loguru import logger
  28. from pydantic import BaseModel, Field
  29. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  30. # from fish_speech.models.vqgan.lit_module import VQGAN
  31. from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
  32. from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
  33. from tools.llama.generate import (
  34. GenerateRequest,
  35. GenerateResponse,
  36. WrappedGenerateResponse,
  37. launch_thread_safe_queue,
  38. )
  39. from tools.vqgan.inference import load_model as load_decoder_model
  40. def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
  41. buffer = io.BytesIO()
  42. with wave.open(buffer, "wb") as wav_file:
  43. wav_file.setnchannels(channels)
  44. wav_file.setsampwidth(bit_depth // 8)
  45. wav_file.setframerate(sample_rate)
  46. wav_header_bytes = buffer.getvalue()
  47. buffer.close()
  48. return wav_header_bytes
  49. # Define utils for web server
  50. async def http_execption_handler(exc: HTTPException):
  51. return JSONResponse(
  52. dict(
  53. statusCode=exc.status_code,
  54. message=exc.content,
  55. error=HTTPStatus(exc.status_code).phrase,
  56. ),
  57. exc.status_code,
  58. exc.headers,
  59. )
  60. async def other_exception_handler(exc: "Exception"):
  61. traceback.print_exc()
  62. status = HTTPStatus.INTERNAL_SERVER_ERROR
  63. return JSONResponse(
  64. dict(statusCode=status, message=str(exc), error=status.phrase),
  65. status,
  66. )
  67. def load_audio(reference_audio, sr):
  68. if len(reference_audio) > 255 or not Path(reference_audio).exists():
  69. try:
  70. audio_data = base64.b64decode(reference_audio)
  71. reference_audio = io.BytesIO(audio_data)
  72. except base64.binascii.Error:
  73. raise ValueError("Invalid path or base64 string")
  74. audio, _ = librosa.load(reference_audio, sr=sr, mono=True)
  75. return audio
  76. def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
  77. if enable_reference_audio and reference_audio is not None:
  78. # Load audios, and prepare basic info here
  79. reference_audio_content = load_audio(
  80. reference_audio, decoder_model.spec_transform.sample_rate
  81. )
  82. audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
  83. None, None, :
  84. ]
  85. audio_lengths = torch.tensor(
  86. [audios.shape[2]], device=decoder_model.device, dtype=torch.long
  87. )
  88. logger.info(
  89. f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
  90. )
  91. # VQ Encoder
  92. if isinstance(decoder_model, FireflyArchitecture):
  93. prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
  94. logger.info(f"Encoded prompt: {prompt_tokens.shape}")
  95. else:
  96. prompt_tokens = None
  97. logger.info("No reference audio provided")
  98. return prompt_tokens
  99. def decode_vq_tokens(
  100. *,
  101. decoder_model,
  102. codes,
  103. ):
  104. feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
  105. logger.info(f"VQ features: {codes.shape}")
  106. if isinstance(decoder_model, FireflyArchitecture):
  107. # VQGAN Inference
  108. return decoder_model.decode(
  109. indices=codes[None],
  110. feature_lengths=feature_lengths,
  111. ).squeeze()
  112. raise ValueError(f"Unknown model type: {type(decoder_model)}")
  113. routes = MultimethodRoutes(base_class=HttpView)
  114. def get_random_paths(base_path, data, speaker, emotion):
  115. if base_path and data and speaker and emotion and (Path(base_path).exists()):
  116. if speaker in data and emotion in data[speaker]:
  117. files = data[speaker][emotion]
  118. lab_files = [f for f in files if f.endswith(".lab")]
  119. wav_files = [f for f in files if f.endswith(".wav")]
  120. if lab_files and wav_files:
  121. selected_lab = random.choice(lab_files)
  122. selected_wav = random.choice(wav_files)
  123. lab_path = Path(base_path) / speaker / emotion / selected_lab
  124. wav_path = Path(base_path) / speaker / emotion / selected_wav
  125. if lab_path.exists() and wav_path.exists():
  126. return lab_path, wav_path
  127. return None, None
  128. def load_json(json_file):
  129. if not json_file:
  130. logger.info("Not using a json file")
  131. return None
  132. try:
  133. with open(json_file, "r", encoding="utf-8") as file:
  134. data = json.load(file)
  135. except FileNotFoundError:
  136. logger.warning(f"ref json not found: {json_file}")
  137. data = None
  138. except Exception as e:
  139. logger.warning(f"Loading json failed: {e}")
  140. data = None
  141. return data
  142. class InvokeRequest(BaseModel):
  143. text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
  144. reference_text: Optional[str] = None
  145. reference_audio: Optional[str] = None
  146. max_new_tokens: int = 1024
  147. chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
  148. top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  149. repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
  150. temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  151. emotion: Optional[str] = None
  152. format: Literal["wav", "mp3", "flac"] = "wav"
  153. streaming: bool = False
  154. ref_json: Optional[str] = "ref_data.json"
  155. ref_base: Optional[str] = "ref_data"
  156. speaker: Optional[str] = None
  157. def get_content_type(audio_format):
  158. if audio_format == "wav":
  159. return "audio/wav"
  160. elif audio_format == "flac":
  161. return "audio/flac"
  162. elif audio_format == "mp3":
  163. return "audio/mpeg"
  164. else:
  165. return "application/octet-stream"
  166. @torch.inference_mode()
  167. def inference(req: InvokeRequest):
  168. # Parse reference audio aka prompt
  169. prompt_tokens = None
  170. ref_data = load_json(req.ref_json)
  171. ref_base = req.ref_base
  172. lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
  173. if lab_path and wav_path:
  174. with open(lab_path, "r", encoding="utf-8") as lab_file:
  175. ref_text = lab_file.read()
  176. req.reference_audio = wav_path
  177. req.reference_text = ref_text
  178. logger.info("ref_path: " + str(wav_path))
  179. logger.info("ref_text: " + ref_text)
  180. # Parse reference audio aka prompt
  181. prompt_tokens = encode_reference(
  182. decoder_model=decoder_model,
  183. reference_audio=req.reference_audio,
  184. enable_reference_audio=req.reference_audio is not None,
  185. )
  186. logger.info(f"ref_text: {req.reference_text}")
  187. # LLAMA Inference
  188. request = dict(
  189. device=decoder_model.device,
  190. max_new_tokens=req.max_new_tokens,
  191. text=req.text,
  192. top_p=req.top_p,
  193. repetition_penalty=req.repetition_penalty,
  194. temperature=req.temperature,
  195. compile=args.compile,
  196. iterative_prompt=req.chunk_length > 0,
  197. chunk_length=req.chunk_length,
  198. max_length=2048,
  199. prompt_tokens=prompt_tokens,
  200. prompt_text=req.reference_text,
  201. )
  202. response_queue = queue.Queue()
  203. llama_queue.put(
  204. GenerateRequest(
  205. request=request,
  206. response_queue=response_queue,
  207. )
  208. )
  209. if req.streaming:
  210. yield wav_chunk_header()
  211. segments = []
  212. while True:
  213. result: WrappedGenerateResponse = response_queue.get()
  214. if result.status == "error":
  215. raise result.response
  216. break
  217. result: GenerateResponse = result.response
  218. if result.action == "next":
  219. break
  220. with torch.autocast(
  221. device_type=decoder_model.device.type, dtype=args.precision
  222. ):
  223. fake_audios = decode_vq_tokens(
  224. decoder_model=decoder_model,
  225. codes=result.codes,
  226. )
  227. fake_audios = fake_audios.float().cpu().numpy()
  228. if req.streaming:
  229. yield (fake_audios * 32768).astype(np.int16).tobytes()
  230. else:
  231. segments.append(fake_audios)
  232. if req.streaming:
  233. return
  234. if len(segments) == 0:
  235. raise HTTPException(
  236. HTTPStatus.INTERNAL_SERVER_ERROR,
  237. content="No audio generated, please check the input text.",
  238. )
  239. fake_audios = np.concatenate(segments, axis=0)
  240. yield fake_audios
  241. def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
  242. if not use_auto_rerank:
  243. # 如果不使用 auto_rerank,直接调用原始的 inference 函数
  244. return inference(req)
  245. zh_model, en_model = load_model()
  246. max_attempts = 5
  247. best_wer = float("inf")
  248. best_audio = None
  249. for attempt in range(max_attempts):
  250. # 调用原始的 inference 函数
  251. audio_generator = inference(req)
  252. fake_audios = next(audio_generator)
  253. asr_result = batch_asr(
  254. zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
  255. )[0]
  256. wer = calculate_wer(req.text, asr_result["text"])
  257. if wer <= 0.1 and not asr_result["huge_gap"]:
  258. return fake_audios
  259. if wer < best_wer:
  260. best_wer = wer
  261. best_audio = fake_audios
  262. if attempt == max_attempts - 1:
  263. break
  264. return best_audio
  265. async def inference_async(req: InvokeRequest):
  266. for chunk in inference(req):
  267. yield chunk
  268. async def buffer_to_async_generator(buffer):
  269. yield buffer
  270. @routes.http.post("/v1/invoke")
  271. async def api_invoke_model(
  272. req: Annotated[InvokeRequest, Body(exclusive=True)],
  273. ):
  274. """
  275. Invoke model and generate audio
  276. """
  277. if args.max_text_length > 0 and len(req.text) > args.max_text_length:
  278. raise HTTPException(
  279. HTTPStatus.BAD_REQUEST,
  280. content=f"Text is too long, max length is {args.max_text_length}",
  281. )
  282. if req.streaming and req.format != "wav":
  283. raise HTTPException(
  284. HTTPStatus.BAD_REQUEST,
  285. content="Streaming only supports WAV format",
  286. )
  287. if req.streaming:
  288. return StreamResponse(
  289. iterable=inference_async(req),
  290. headers={
  291. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  292. },
  293. content_type=get_content_type(req.format),
  294. )
  295. else:
  296. fake_audios = next(inference(req))
  297. buffer = io.BytesIO()
  298. sf.write(
  299. buffer,
  300. fake_audios,
  301. decoder_model.spec_transform.sample_rate,
  302. format=req.format,
  303. )
  304. return StreamResponse(
  305. iterable=buffer_to_async_generator(buffer.getvalue()),
  306. headers={
  307. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  308. },
  309. content_type=get_content_type(req.format),
  310. )
  311. @routes.http.post("/v1/health")
  312. async def api_health():
  313. """
  314. Health check
  315. """
  316. return JSONResponse({"status": "ok"})
  317. def parse_args():
  318. parser = ArgumentParser()
  319. parser.add_argument(
  320. "--llama-checkpoint-path",
  321. type=str,
  322. default="checkpoints/fish-speech-1.2-sft",
  323. )
  324. parser.add_argument(
  325. "--decoder-checkpoint-path",
  326. type=str,
  327. default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
  328. )
  329. parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
  330. parser.add_argument("--device", type=str, default="cuda")
  331. parser.add_argument("--half", action="store_true")
  332. parser.add_argument("--compile", action="store_true")
  333. parser.add_argument("--max-text-length", type=int, default=0)
  334. parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
  335. parser.add_argument("--workers", type=int, default=1)
  336. parser.add_argument("--use-auto-rerank", type=bool, default=True)
  337. return parser.parse_args()
  338. # Define Kui app
  339. openapi = OpenAPI(
  340. {
  341. "title": "Fish Speech API",
  342. },
  343. ).routes
  344. app = Kui(
  345. routes=routes + openapi[1:], # Remove the default route
  346. exception_handlers={
  347. HTTPException: http_execption_handler,
  348. Exception: other_exception_handler,
  349. },
  350. cors_config={},
  351. )
  352. if __name__ == "__main__":
  353. import threading
  354. import uvicorn
  355. args = parse_args()
  356. args.precision = torch.half if args.half else torch.bfloat16
  357. logger.info("Loading Llama model...")
  358. llama_queue = launch_thread_safe_queue(
  359. checkpoint_path=args.llama_checkpoint_path,
  360. device=args.device,
  361. precision=args.precision,
  362. compile=args.compile,
  363. )
  364. logger.info("Llama model loaded, loading VQ-GAN model...")
  365. decoder_model = load_decoder_model(
  366. config_name=args.decoder_config_name,
  367. checkpoint_path=args.decoder_checkpoint_path,
  368. device=args.device,
  369. )
  370. logger.info("VQ-GAN model loaded, warming up...")
  371. # Dry run to check if the model is loaded correctly and avoid the first-time latency
  372. list(
  373. inference(
  374. InvokeRequest(
  375. text="Hello world.",
  376. reference_text=None,
  377. reference_audio=None,
  378. max_new_tokens=0,
  379. top_p=0.7,
  380. repetition_penalty=1.2,
  381. temperature=0.7,
  382. emotion=None,
  383. format="wav",
  384. ref_base=None,
  385. ref_json=None,
  386. )
  387. )
  388. )
  389. logger.info(f"Warming up done, starting server at http://{args.listen}")
  390. host, port = args.listen.split(":")
  391. uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")