api.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. import base64
  2. import io
  3. import json
  4. import queue
  5. import random
  6. import sys
  7. import traceback
  8. import wave
  9. from argparse import ArgumentParser
  10. from http import HTTPStatus
  11. from pathlib import Path
  12. from typing import Annotated, Literal, Optional
  13. import numpy as np
  14. import pyrootutils
  15. import soundfile as sf
  16. import torch
  17. import torchaudio
  18. from kui.asgi import (
  19. Body,
  20. HTTPException,
  21. HttpView,
  22. JSONResponse,
  23. Kui,
  24. OpenAPI,
  25. StreamResponse,
  26. )
  27. from kui.asgi.routing import MultimethodRoutes
  28. from loguru import logger
  29. from pydantic import BaseModel, Field
  30. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  31. # from fish_speech.models.vqgan.lit_module import VQGAN
  32. from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
  33. from fish_speech.utils import autocast_exclude_mps
  34. from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
  35. from tools.llama.generate import (
  36. GenerateRequest,
  37. GenerateResponse,
  38. WrappedGenerateResponse,
  39. launch_thread_safe_queue,
  40. )
  41. from tools.vqgan.inference import load_model as load_decoder_model
  42. def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
  43. buffer = io.BytesIO()
  44. with wave.open(buffer, "wb") as wav_file:
  45. wav_file.setnchannels(channels)
  46. wav_file.setsampwidth(bit_depth // 8)
  47. wav_file.setframerate(sample_rate)
  48. wav_header_bytes = buffer.getvalue()
  49. buffer.close()
  50. return wav_header_bytes
  51. # Define utils for web server
  52. async def http_execption_handler(exc: HTTPException):
  53. return JSONResponse(
  54. dict(
  55. statusCode=exc.status_code,
  56. message=exc.content,
  57. error=HTTPStatus(exc.status_code).phrase,
  58. ),
  59. exc.status_code,
  60. exc.headers,
  61. )
  62. async def other_exception_handler(exc: "Exception"):
  63. traceback.print_exc()
  64. status = HTTPStatus.INTERNAL_SERVER_ERROR
  65. return JSONResponse(
  66. dict(statusCode=status, message=str(exc), error=status.phrase),
  67. status,
  68. )
  69. def load_audio(reference_audio, sr):
  70. if len(reference_audio) > 255 or not Path(reference_audio).exists():
  71. try:
  72. audio_data = base64.b64decode(reference_audio)
  73. reference_audio = io.BytesIO(audio_data)
  74. except base64.binascii.Error:
  75. raise ValueError("Invalid path or base64 string")
  76. waveform, original_sr = torchaudio.load(
  77. reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
  78. )
  79. if waveform.shape[0] > 1:
  80. waveform = torch.mean(waveform, dim=0, keepdim=True)
  81. if original_sr != sr:
  82. resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
  83. waveform = resampler(waveform)
  84. audio = waveform.squeeze().numpy()
  85. return audio
  86. def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
  87. if enable_reference_audio and reference_audio is not None:
  88. # Load audios, and prepare basic info here
  89. reference_audio_content = load_audio(
  90. reference_audio, decoder_model.spec_transform.sample_rate
  91. )
  92. audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
  93. None, None, :
  94. ]
  95. audio_lengths = torch.tensor(
  96. [audios.shape[2]], device=decoder_model.device, dtype=torch.long
  97. )
  98. logger.info(
  99. f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
  100. )
  101. # VQ Encoder
  102. if isinstance(decoder_model, FireflyArchitecture):
  103. prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
  104. logger.info(f"Encoded prompt: {prompt_tokens.shape}")
  105. else:
  106. prompt_tokens = None
  107. logger.info("No reference audio provided")
  108. return prompt_tokens
  109. def decode_vq_tokens(
  110. *,
  111. decoder_model,
  112. codes,
  113. ):
  114. feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
  115. logger.info(f"VQ features: {codes.shape}")
  116. if isinstance(decoder_model, FireflyArchitecture):
  117. # VQGAN Inference
  118. return decoder_model.decode(
  119. indices=codes[None],
  120. feature_lengths=feature_lengths,
  121. ).squeeze()
  122. raise ValueError(f"Unknown model type: {type(decoder_model)}")
  123. routes = MultimethodRoutes(base_class=HttpView)
  124. def get_random_paths(base_path, data, speaker, emotion):
  125. if base_path and data and speaker and emotion and (Path(base_path).exists()):
  126. if speaker in data and emotion in data[speaker]:
  127. files = data[speaker][emotion]
  128. lab_files = [f for f in files if f.endswith(".lab")]
  129. wav_files = [f for f in files if f.endswith(".wav")]
  130. if lab_files and wav_files:
  131. selected_lab = random.choice(lab_files)
  132. selected_wav = random.choice(wav_files)
  133. lab_path = Path(base_path) / speaker / emotion / selected_lab
  134. wav_path = Path(base_path) / speaker / emotion / selected_wav
  135. if lab_path.exists() and wav_path.exists():
  136. return lab_path, wav_path
  137. return None, None
  138. def load_json(json_file):
  139. if not json_file:
  140. logger.info("Not using a json file")
  141. return None
  142. try:
  143. with open(json_file, "r", encoding="utf-8") as file:
  144. data = json.load(file)
  145. except FileNotFoundError:
  146. logger.warning(f"ref json not found: {json_file}")
  147. data = None
  148. except Exception as e:
  149. logger.warning(f"Loading json failed: {e}")
  150. data = None
  151. return data
  152. class InvokeRequest(BaseModel):
  153. text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
  154. reference_text: Optional[str] = None
  155. reference_audio: Optional[str] = None
  156. max_new_tokens: int = 1024
  157. chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
  158. top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  159. repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
  160. temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
  161. emotion: Optional[str] = None
  162. format: Literal["wav", "mp3", "flac"] = "wav"
  163. streaming: bool = False
  164. ref_json: Optional[str] = "ref_data.json"
  165. ref_base: Optional[str] = "ref_data"
  166. speaker: Optional[str] = None
  167. def get_content_type(audio_format):
  168. if audio_format == "wav":
  169. return "audio/wav"
  170. elif audio_format == "flac":
  171. return "audio/flac"
  172. elif audio_format == "mp3":
  173. return "audio/mpeg"
  174. else:
  175. return "application/octet-stream"
  176. @torch.inference_mode()
  177. def inference(req: InvokeRequest):
  178. # Parse reference audio aka prompt
  179. prompt_tokens = None
  180. ref_data = load_json(req.ref_json)
  181. ref_base = req.ref_base
  182. lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
  183. if lab_path and wav_path:
  184. with open(lab_path, "r", encoding="utf-8") as lab_file:
  185. ref_text = lab_file.read()
  186. req.reference_audio = wav_path
  187. req.reference_text = ref_text
  188. logger.info("ref_path: " + str(wav_path))
  189. logger.info("ref_text: " + ref_text)
  190. # Parse reference audio aka prompt
  191. prompt_tokens = encode_reference(
  192. decoder_model=decoder_model,
  193. reference_audio=req.reference_audio,
  194. enable_reference_audio=req.reference_audio is not None,
  195. )
  196. logger.info(f"ref_text: {req.reference_text}")
  197. # LLAMA Inference
  198. request = dict(
  199. device=decoder_model.device,
  200. max_new_tokens=req.max_new_tokens,
  201. text=req.text,
  202. top_p=req.top_p,
  203. repetition_penalty=req.repetition_penalty,
  204. temperature=req.temperature,
  205. compile=args.compile,
  206. iterative_prompt=req.chunk_length > 0,
  207. chunk_length=req.chunk_length,
  208. max_length=2048,
  209. prompt_tokens=prompt_tokens,
  210. prompt_text=req.reference_text,
  211. )
  212. response_queue = queue.Queue()
  213. llama_queue.put(
  214. GenerateRequest(
  215. request=request,
  216. response_queue=response_queue,
  217. )
  218. )
  219. if req.streaming:
  220. yield wav_chunk_header()
  221. segments = []
  222. while True:
  223. result: WrappedGenerateResponse = response_queue.get()
  224. if result.status == "error":
  225. raise result.response
  226. break
  227. result: GenerateResponse = result.response
  228. if result.action == "next":
  229. break
  230. with autocast_exclude_mps(
  231. device_type=decoder_model.device.type, dtype=args.precision
  232. ):
  233. fake_audios = decode_vq_tokens(
  234. decoder_model=decoder_model,
  235. codes=result.codes,
  236. )
  237. fake_audios = fake_audios.float().cpu().numpy()
  238. if req.streaming:
  239. yield (fake_audios * 32768).astype(np.int16).tobytes()
  240. else:
  241. segments.append(fake_audios)
  242. if req.streaming:
  243. return
  244. if len(segments) == 0:
  245. raise HTTPException(
  246. HTTPStatus.INTERNAL_SERVER_ERROR,
  247. content="No audio generated, please check the input text.",
  248. )
  249. fake_audios = np.concatenate(segments, axis=0)
  250. yield fake_audios
  251. def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
  252. if not use_auto_rerank:
  253. # 如果不使用 auto_rerank,直接调用原始的 inference 函数
  254. return inference(req)
  255. zh_model, en_model = load_model()
  256. max_attempts = 5
  257. best_wer = float("inf")
  258. best_audio = None
  259. for attempt in range(max_attempts):
  260. # 调用原始的 inference 函数
  261. audio_generator = inference(req)
  262. fake_audios = next(audio_generator)
  263. asr_result = batch_asr(
  264. zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
  265. )[0]
  266. wer = calculate_wer(req.text, asr_result["text"])
  267. if wer <= 0.1 and not asr_result["huge_gap"]:
  268. return fake_audios
  269. if wer < best_wer:
  270. best_wer = wer
  271. best_audio = fake_audios
  272. if attempt == max_attempts - 1:
  273. break
  274. return best_audio
  275. async def inference_async(req: InvokeRequest):
  276. for chunk in inference(req):
  277. yield chunk
  278. async def buffer_to_async_generator(buffer):
  279. yield buffer
  280. @routes.http.post("/v1/invoke")
  281. async def api_invoke_model(
  282. req: Annotated[InvokeRequest, Body(exclusive=True)],
  283. ):
  284. """
  285. Invoke model and generate audio
  286. """
  287. if args.max_text_length > 0 and len(req.text) > args.max_text_length:
  288. raise HTTPException(
  289. HTTPStatus.BAD_REQUEST,
  290. content=f"Text is too long, max length is {args.max_text_length}",
  291. )
  292. if req.streaming and req.format != "wav":
  293. raise HTTPException(
  294. HTTPStatus.BAD_REQUEST,
  295. content="Streaming only supports WAV format",
  296. )
  297. if req.streaming:
  298. return StreamResponse(
  299. iterable=inference_async(req),
  300. headers={
  301. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  302. },
  303. content_type=get_content_type(req.format),
  304. )
  305. else:
  306. fake_audios = next(inference(req))
  307. buffer = io.BytesIO()
  308. sf.write(
  309. buffer,
  310. fake_audios,
  311. decoder_model.spec_transform.sample_rate,
  312. format=req.format,
  313. )
  314. return StreamResponse(
  315. iterable=buffer_to_async_generator(buffer.getvalue()),
  316. headers={
  317. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  318. },
  319. content_type=get_content_type(req.format),
  320. )
  321. @routes.http.post("/v1/health")
  322. async def api_health():
  323. """
  324. Health check
  325. """
  326. return JSONResponse({"status": "ok"})
  327. def parse_args():
  328. parser = ArgumentParser()
  329. parser.add_argument(
  330. "--llama-checkpoint-path",
  331. type=str,
  332. default="checkpoints/fish-speech-1.2-sft",
  333. )
  334. parser.add_argument(
  335. "--decoder-checkpoint-path",
  336. type=str,
  337. default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
  338. )
  339. parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
  340. parser.add_argument("--device", type=str, default="cuda")
  341. parser.add_argument("--half", action="store_true")
  342. parser.add_argument("--compile", action="store_true")
  343. parser.add_argument("--max-text-length", type=int, default=0)
  344. parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
  345. parser.add_argument("--workers", type=int, default=1)
  346. parser.add_argument("--use-auto-rerank", type=bool, default=True)
  347. return parser.parse_args()
  348. # Define Kui app
  349. openapi = OpenAPI(
  350. {
  351. "title": "Fish Speech API",
  352. },
  353. ).routes
  354. app = Kui(
  355. routes=routes + openapi[1:], # Remove the default route
  356. exception_handlers={
  357. HTTPException: http_execption_handler,
  358. Exception: other_exception_handler,
  359. },
  360. cors_config={},
  361. )
  362. if __name__ == "__main__":
  363. import threading
  364. import uvicorn
  365. args = parse_args()
  366. args.precision = torch.half if args.half else torch.bfloat16
  367. logger.info("Loading Llama model...")
  368. llama_queue = launch_thread_safe_queue(
  369. checkpoint_path=args.llama_checkpoint_path,
  370. device=args.device,
  371. precision=args.precision,
  372. compile=args.compile,
  373. )
  374. logger.info("Llama model loaded, loading VQ-GAN model...")
  375. decoder_model = load_decoder_model(
  376. config_name=args.decoder_config_name,
  377. checkpoint_path=args.decoder_checkpoint_path,
  378. device=args.device,
  379. )
  380. logger.info("VQ-GAN model loaded, warming up...")
  381. # Dry run to check if the model is loaded correctly and avoid the first-time latency
  382. list(
  383. inference(
  384. InvokeRequest(
  385. text="Hello world.",
  386. reference_text=None,
  387. reference_audio=None,
  388. max_new_tokens=0,
  389. top_p=0.7,
  390. repetition_penalty=1.2,
  391. temperature=0.7,
  392. emotion=None,
  393. format="wav",
  394. ref_base=None,
  395. ref_json=None,
  396. )
  397. )
  398. )
  399. logger.info(f"Warming up done, starting server at http://{args.listen}")
  400. host, port = args.listen.split(":")
  401. uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")