api.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943
  1. import io
  2. import os
  3. import queue
  4. import re
  5. import time
  6. import traceback
  7. import wave
  8. from argparse import ArgumentParser
  9. from http import HTTPStatus
  10. from pathlib import Path
  11. from typing import Annotated, Any
  12. import librosa
  13. import numpy as np
  14. import ormsgpack
  15. import pyrootutils
  16. import soundfile as sf
  17. import torch
  18. import torchaudio
  19. from baize.datastructures import ContentType
  20. from kui.asgi import (
  21. Body,
  22. FactoryClass,
  23. HTTPException,
  24. HttpRequest,
  25. HttpView,
  26. JSONResponse,
  27. Kui,
  28. OpenAPI,
  29. StreamResponse,
  30. request,
  31. )
  32. from kui.asgi.routing import MultimethodRoutes
  33. from loguru import logger
  34. from transformers import AutoTokenizer
  35. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  36. import struct
  37. from threading import Lock
  38. import httpx
  39. from cachetools import LRUCache, cached
  40. from funasr import AutoModel
  41. from silero_vad import get_speech_timestamps, load_silero_vad
  42. from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
  43. from fish_speech.models.text2semantic.llama import BaseModelArgs
  44. # from fish_speech.models.vqgan.lit_module import VQGAN
  45. from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
  46. from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
  47. from fish_speech.utils import autocast_exclude_mps, set_seed
  48. from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
  49. from tools.llama.generate import (
  50. GenerateRequest,
  51. GenerateResponse,
  52. WrappedGenerateResponse,
  53. launch_thread_safe_queue,
  54. launch_thread_safe_queue_agent,
  55. )
  56. from tools.schema import (
  57. GLOBAL_NUM_SAMPLES,
  58. ASRPackRequest,
  59. ServeASRRequest,
  60. ServeASRResponse,
  61. ServeASRSegment,
  62. ServeAudioPart,
  63. ServeForwardMessage,
  64. ServeMessage,
  65. ServeRequest,
  66. ServeResponse,
  67. ServeStreamDelta,
  68. ServeStreamResponse,
  69. ServeTextPart,
  70. ServeTimedASRResponse,
  71. ServeTTSRequest,
  72. ServeVQGANDecodeRequest,
  73. ServeVQGANDecodeResponse,
  74. ServeVQGANEncodeRequest,
  75. ServeVQGANEncodeResponse,
  76. ServeVQPart,
  77. )
  78. from tools.vqgan.inference import load_model as load_decoder_model
  79. global_lock = Lock()
  80. # Whether to disable keepalive (which is helpful if the server is in the same cluster)
  81. DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
  82. async_client = httpx.AsyncClient(
  83. timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
  84. )
  85. backends = torchaudio.list_audio_backends()
  86. if "ffmpeg" in backends:
  87. backend = "ffmpeg"
  88. else:
  89. backend = "soundfile"
  90. def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
  91. buffer = io.BytesIO()
  92. with wave.open(buffer, "wb") as wav_file:
  93. wav_file.setnchannels(channels)
  94. wav_file.setsampwidth(bit_depth // 8)
  95. wav_file.setframerate(sample_rate)
  96. wav_header_bytes = buffer.getvalue()
  97. buffer.close()
  98. return wav_header_bytes
  99. # Define utils for web server
  100. async def http_execption_handler(exc: HTTPException):
  101. return JSONResponse(
  102. dict(
  103. statusCode=exc.status_code,
  104. message=exc.content,
  105. error=HTTPStatus(exc.status_code).phrase,
  106. ),
  107. exc.status_code,
  108. exc.headers,
  109. )
  110. async def other_exception_handler(exc: "Exception"):
  111. traceback.print_exc()
  112. status = HTTPStatus.INTERNAL_SERVER_ERROR
  113. return JSONResponse(
  114. dict(statusCode=status, message=str(exc), error=status.phrase),
  115. status,
  116. )
  117. def load_audio(reference_audio, sr):
  118. if len(reference_audio) > 255 or not Path(reference_audio).exists():
  119. audio_data = reference_audio
  120. reference_audio = io.BytesIO(audio_data)
  121. waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
  122. if waveform.shape[0] > 1:
  123. waveform = torch.mean(waveform, dim=0, keepdim=True)
  124. if original_sr != sr:
  125. resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
  126. waveform = resampler(waveform)
  127. audio = waveform.squeeze().numpy()
  128. return audio
  129. def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
  130. if enable_reference_audio and reference_audio is not None:
  131. # Load audios, and prepare basic info here
  132. reference_audio_content = load_audio(
  133. reference_audio, decoder_model.spec_transform.sample_rate
  134. )
  135. audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
  136. None, None, :
  137. ]
  138. audio_lengths = torch.tensor(
  139. [audios.shape[2]], device=decoder_model.device, dtype=torch.long
  140. )
  141. logger.info(
  142. f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
  143. )
  144. # VQ Encoder
  145. if isinstance(decoder_model, FireflyArchitecture):
  146. prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
  147. logger.info(f"Encoded prompt: {prompt_tokens.shape}")
  148. else:
  149. prompt_tokens = None
  150. logger.info("No reference audio provided")
  151. return prompt_tokens
  152. def decode_vq_tokens(
  153. *,
  154. decoder_model,
  155. codes,
  156. ):
  157. feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
  158. logger.info(f"VQ features: {codes.shape}")
  159. if isinstance(decoder_model, FireflyArchitecture):
  160. # VQGAN Inference
  161. return decoder_model.decode(
  162. indices=codes[None],
  163. feature_lengths=feature_lengths,
  164. )[0].squeeze()
  165. raise ValueError(f"Unknown model type: {type(decoder_model)}")
  166. routes = MultimethodRoutes(base_class=HttpView)
  167. def get_content_type(audio_format):
  168. if audio_format == "wav":
  169. return "audio/wav"
  170. elif audio_format == "flac":
  171. return "audio/flac"
  172. elif audio_format == "mp3":
  173. return "audio/mpeg"
  174. else:
  175. return "application/octet-stream"
  176. @torch.no_grad()
  177. @torch.autocast(device_type="cuda", dtype=torch.half)
  178. def batch_encode(model, audios: list[bytes | torch.Tensor]):
  179. audios = [
  180. (
  181. torch.from_numpy(
  182. librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
  183. )[None]
  184. if isinstance(audio, bytes)
  185. else audio
  186. )
  187. for audio in audios
  188. ]
  189. # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
  190. # raise ValueError("Single audio length is too long (>120s)")
  191. max_length = max(audio.shape[-1] for audio in audios)
  192. print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
  193. lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
  194. max_length = lengths.max().item()
  195. padded = torch.stack(
  196. [
  197. torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
  198. for audio in audios
  199. ]
  200. ).to(model.device)
  201. features, feature_lengths = model.encode(padded, audio_lengths=lengths)
  202. features, feature_lengths = features.cpu(), feature_lengths.cpu()
  203. return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
  204. @cached(
  205. cache=LRUCache(maxsize=10000),
  206. key=lambda model, audios: (model.device, tuple(audios)),
  207. )
  208. def cached_vqgan_batch_encode(model, audios: list[bytes]):
  209. return batch_encode(model, audios)
  210. @routes.http.post("/v1/vqgan/encode")
  211. def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
  212. start_time = time.time()
  213. tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
  214. logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
  215. return ormsgpack.packb(
  216. ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
  217. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  218. )
  219. @torch.no_grad()
  220. @torch.autocast(device_type="cuda", dtype=torch.half)
  221. def vqgan_decode(model, features):
  222. lengths = torch.tensor(
  223. [feature.shape[-1] for feature in features], device=model.device
  224. )
  225. max_length = lengths.max().item()
  226. padded = torch.stack(
  227. [
  228. torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
  229. for feature in features
  230. ]
  231. ).to(model.device)
  232. # If bs too large, we do micro batch decode
  233. audios, audio_lengths = [], []
  234. for i in range(0, padded.shape[0], 8):
  235. audio, audio_length = model.decode(
  236. padded[i : i + 8], feature_lengths=lengths[i : i + 8]
  237. )
  238. audios.append(audio)
  239. audio_lengths.append(audio_length)
  240. audios = torch.cat(audios, dim=0)
  241. audio_lengths = torch.cat(audio_lengths, dim=0)
  242. audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
  243. return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
  244. @routes.http.post("/v1/vqgan/decode")
  245. def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
  246. tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
  247. start_time = time.time()
  248. audios = vqgan_decode(decoder_model, tokens)
  249. logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
  250. audios = [audio.astype(np.float16).tobytes() for audio in audios]
  251. return ormsgpack.packb(
  252. ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
  253. )
  254. @torch.no_grad()
  255. def batch_asr(model, audios, sr, language="auto"):
  256. resampled_audios = []
  257. for audio in audios:
  258. audio = torchaudio.functional.resample(audio, sr, 16000)
  259. assert audio.ndim == 1
  260. resampled_audios.append(audio)
  261. with global_lock:
  262. res = model.generate(
  263. input=resampled_audios,
  264. batch_size=len(resampled_audios),
  265. language=language,
  266. use_itn=True,
  267. )
  268. results = []
  269. for r, audio in zip(res, audios):
  270. text = r["text"]
  271. text = re.sub(r"<\|.*?\|>", "", text)
  272. duration = len(audio) / sr * 1000
  273. huge_gap = False
  274. if "timestamp" in r and len(r["timestamp"]) > 2:
  275. for timestamp_a, timestamp_b in zip(
  276. r["timestamp"][:-1], r["timestamp"][1:]
  277. ):
  278. # If there is a gap of more than 5 seconds, we consider it as a huge gap
  279. if timestamp_b[0] - timestamp_a[1] > 5000:
  280. huge_gap = True
  281. break
  282. # Doesn't make sense to have a huge gap at the end
  283. if duration - r["timestamp"][-1][1] > 3000:
  284. huge_gap = True
  285. results.append(
  286. {
  287. "text": text,
  288. "duration": duration,
  289. "huge_gap": huge_gap,
  290. }
  291. )
  292. return results
  293. @routes.http.post("/v1/asr")
  294. def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
  295. start_time = time.time()
  296. audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
  297. audios = [torch.from_numpy(audio).float() for audio in audios]
  298. if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
  299. raise HTTPException(status_code=400, detail="Audio length is too long")
  300. transcriptions = batch_asr(
  301. asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
  302. )
  303. logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
  304. return ormsgpack.packb(
  305. ServeASRResponse(transcriptions=transcriptions),
  306. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  307. )
  308. from fish_speech.conversation import Conversation, Message
  309. def execute_request(
  310. input_queue: queue.Queue,
  311. tokenizer: AutoTokenizer,
  312. config: BaseModelArgs,
  313. request: ServeRequest,
  314. device: str = "cuda:0",
  315. ):
  316. semantic_id, im_end_id = tokenizer.convert_tokens_to_ids(
  317. [SEMANTIC_TOKEN, IM_END_TOKEN]
  318. )
  319. messages = []
  320. for message in request.messages:
  321. messages.append(message.to_conversation_message())
  322. assert len(messages) >= 1, "At least one message is required"
  323. # assert messages[-1].role == "user", "The last message must be from the user"
  324. if messages[-1].role == "user":
  325. messages.append(Message(role="assistant", parts=[], add_im_end=False))
  326. else:
  327. assert (
  328. messages[-1].role == "assistant"
  329. ), "The last message must be from the assistant"
  330. messages[-1].add_im_end = False
  331. conv = Conversation(messages=messages)
  332. prompt = conv.encode_for_inference(
  333. tokenizer=tokenizer, num_codebooks=config.num_codebooks
  334. ).to(device)
  335. if request.streaming:
  336. for i in range(request.num_samples):
  337. yield ServeStreamResponse(
  338. sample_id=i,
  339. delta=ServeStreamDelta(
  340. role="assistant",
  341. ),
  342. )
  343. req = {
  344. "prompt": prompt,
  345. "max_new_tokens": request.max_new_tokens,
  346. "im_end_id": im_end_id,
  347. "semantic_id": semantic_id,
  348. "temperature": request.temperature,
  349. "top_p": request.top_p,
  350. "repetition_penalty": request.repetition_penalty,
  351. "num_samples": request.num_samples,
  352. "early_stop_threshold": request.early_stop_threshold,
  353. }
  354. start = time.time()
  355. response_queue = queue.Queue()
  356. input_queue.put(GenerateRequest(req, response_queue))
  357. # Decoding
  358. decode_buffer = [[] for _ in range(request.num_samples)]
  359. parts = [[] for _ in range(request.num_samples)]
  360. def send_reset_buffer(sample_id):
  361. nonlocal decode_buffer
  362. if len(decode_buffer[sample_id]) == 0:
  363. return
  364. decoded = tokenizer.decode(decode_buffer[sample_id])
  365. part = ServeTextPart(text=decoded)
  366. if request.streaming:
  367. yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
  368. else:
  369. parts[sample_id].append(part)
  370. decode_buffer[sample_id] = []
  371. # Decode process
  372. finished = [False for _ in range(request.num_samples)]
  373. stats = {}
  374. idx = 0
  375. while True:
  376. response = response_queue.get()
  377. if response in ["stop", "error"]:
  378. break
  379. for sample_id, tokens in enumerate(response):
  380. if finished[sample_id]:
  381. continue
  382. if tokens[0] == im_end_id:
  383. finished[sample_id] = True
  384. if request.streaming:
  385. yield from send_reset_buffer(sample_id)
  386. yield ServeStreamResponse(
  387. sample_id=sample_id,
  388. finish_reason="stop",
  389. stats=stats,
  390. )
  391. continue
  392. if tokens[0] == semantic_id and request.streaming:
  393. yield from send_reset_buffer(sample_id)
  394. # Streaming vq
  395. _tokens = tokens[1:].clone() - 1
  396. if config.share_codebook_embeddings is False:
  397. for i in range(len(_tokens)):
  398. _tokens[i] -= config.codebook_size * i
  399. yield ServeStreamResponse(
  400. sample_id=sample_id,
  401. delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
  402. )
  403. continue
  404. # Not streaming vq
  405. if tokens[0] == semantic_id:
  406. yield from send_reset_buffer(sample_id)
  407. # None streaming vq
  408. if len(parts[sample_id]) == 0 or not isinstance(
  409. parts[sample_id][-1], ServeVQPart
  410. ):
  411. _tokens = tokens[1:].clone() - 1
  412. if config.share_codebook_embeddings is False:
  413. for i in range(len(_tokens)):
  414. _tokens[i] -= config.codebook_size * i
  415. parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
  416. else:
  417. for codebook_id, value in enumerate(tokens[1:, :]):
  418. val = value.item() - 1
  419. if config.share_codebook_embeddings is False:
  420. val -= config.codebook_size * codebook_id
  421. parts[sample_id][-1].codes[codebook_id].append(val)
  422. continue
  423. if tokens[0] != semantic_id:
  424. # Stream text decode is not supported now
  425. decode_buffer[sample_id].append(tokens[0, 0])
  426. if idx == 0:
  427. stats["time_to_first_token"] = (time.time() - start) * 1000
  428. idx += 1
  429. for sample_id in range(request.num_samples):
  430. yield from send_reset_buffer(sample_id)
  431. stats["total_time"] = (time.time() - start) * 1000
  432. stats["total_tokens"] = idx
  433. if request.streaming:
  434. for sample_id in range(request.num_samples):
  435. if finished[sample_id]:
  436. continue
  437. yield ServeStreamResponse(
  438. finish_reason=response, stats=stats, sample_id=sample_id
  439. )
  440. return
  441. yield ServeResponse(
  442. messages=[
  443. ServeMessage(role="assistant", parts=parts[i])
  444. for i in range(request.num_samples)
  445. ],
  446. finish_reason=response,
  447. stats=stats,
  448. )
  449. @routes.http.post("/v1/chat")
  450. def api_invoke_chat(
  451. req: Annotated[ServeRequest, Body(exclusive=True)],
  452. ):
  453. """
  454. Invoke model and generate audio
  455. """
  456. # This makes torch compile happy
  457. assert (
  458. req.num_samples == GLOBAL_NUM_SAMPLES
  459. ), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
  460. content_type = request.headers.get("Content-Type", "application/json")
  461. json_mode = "application/json" in content_type
  462. async def wrapped_generator():
  463. generator = execute_request(llama_queue, tokenizer, config, req, args.device)
  464. for i in generator:
  465. if json_mode:
  466. body = i.model_dump_json().encode("utf-8")
  467. yield b"data: " + body + b"\n\n"
  468. else:
  469. body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
  470. yield struct.pack("I", len(body)) + body
  471. # Naive mode
  472. if req.streaming is False:
  473. result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
  474. if json_mode:
  475. return JSONResponse(result.model_dump())
  476. else:
  477. return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
  478. return StreamResponse(
  479. iterable=wrapped_generator(), content_type="text/event-stream"
  480. )
  481. @torch.inference_mode()
  482. def inference(req: ServeTTSRequest):
  483. global prompt_tokens, prompt_texts
  484. idstr: str | None = req.reference_id
  485. if idstr is not None:
  486. ref_folder = Path("references") / idstr
  487. ref_folder.mkdir(parents=True, exist_ok=True)
  488. ref_audios = list_files(
  489. ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
  490. )
  491. if req.use_memory_cache == "never" or (
  492. req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
  493. ):
  494. prompt_tokens = [
  495. encode_reference(
  496. decoder_model=decoder_model,
  497. reference_audio=audio_to_bytes(str(ref_audio)),
  498. enable_reference_audio=True,
  499. )
  500. for ref_audio in ref_audios
  501. ]
  502. prompt_texts = [
  503. read_ref_text(str(ref_audio.with_suffix(".lab")))
  504. for ref_audio in ref_audios
  505. ]
  506. else:
  507. logger.info("Use same references")
  508. else:
  509. # Parse reference audio aka prompt
  510. refs = req.references
  511. if req.use_memory_cache == "never" or (
  512. req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
  513. ):
  514. prompt_tokens = [
  515. encode_reference(
  516. decoder_model=decoder_model,
  517. reference_audio=ref.audio,
  518. enable_reference_audio=True,
  519. )
  520. for ref in refs
  521. ]
  522. prompt_texts = [ref.text for ref in refs]
  523. else:
  524. logger.info("Use same references")
  525. if req.seed is not None:
  526. set_seed(req.seed)
  527. logger.warning(f"set seed: {req.seed}")
  528. # LLAMA Inference
  529. request = dict(
  530. device=decoder_model.device,
  531. max_new_tokens=req.max_new_tokens,
  532. text=(
  533. req.text
  534. if not req.normalize
  535. else ChnNormedText(raw_text=req.text).normalize()
  536. ),
  537. top_p=req.top_p,
  538. repetition_penalty=req.repetition_penalty,
  539. temperature=req.temperature,
  540. compile=args.compile,
  541. iterative_prompt=req.chunk_length > 0,
  542. chunk_length=req.chunk_length,
  543. max_length=4096,
  544. prompt_tokens=prompt_tokens,
  545. prompt_text=prompt_texts,
  546. )
  547. response_queue = queue.Queue()
  548. llama_queue.put(
  549. GenerateRequest(
  550. request=request,
  551. response_queue=response_queue,
  552. )
  553. )
  554. if req.streaming:
  555. yield wav_chunk_header()
  556. segments = []
  557. while True:
  558. result: WrappedGenerateResponse = response_queue.get()
  559. if result.status == "error":
  560. raise result.response
  561. break
  562. result: GenerateResponse = result.response
  563. if result.action == "next":
  564. break
  565. with autocast_exclude_mps(
  566. device_type=decoder_model.device.type, dtype=args.precision
  567. ):
  568. fake_audios = decode_vq_tokens(
  569. decoder_model=decoder_model,
  570. codes=result.codes,
  571. )
  572. fake_audios = fake_audios.float().cpu().numpy()
  573. if req.streaming:
  574. yield (fake_audios * 32768).astype(np.int16).tobytes()
  575. else:
  576. segments.append(fake_audios)
  577. if req.streaming:
  578. return
  579. if len(segments) == 0:
  580. raise HTTPException(
  581. HTTPStatus.INTERNAL_SERVER_ERROR,
  582. content="No audio generated, please check the input text.",
  583. )
  584. fake_audios = np.concatenate(segments, axis=0)
  585. yield fake_audios
  586. async def inference_async(req: ServeTTSRequest):
  587. for chunk in inference(req):
  588. yield chunk
  589. async def buffer_to_async_generator(buffer):
  590. yield buffer
  591. @routes.http.post("/v1/tts")
  592. async def api_invoke_model(
  593. req: Annotated[ServeTTSRequest, Body(exclusive=True)],
  594. ):
  595. """
  596. Invoke model and generate audio
  597. """
  598. if args.max_text_length > 0 and len(req.text) > args.max_text_length:
  599. raise HTTPException(
  600. HTTPStatus.BAD_REQUEST,
  601. content=f"Text is too long, max length is {args.max_text_length}",
  602. )
  603. if req.streaming and req.format != "wav":
  604. raise HTTPException(
  605. HTTPStatus.BAD_REQUEST,
  606. content="Streaming only supports WAV format",
  607. )
  608. if req.streaming:
  609. return StreamResponse(
  610. iterable=inference_async(req),
  611. headers={
  612. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  613. },
  614. content_type=get_content_type(req.format),
  615. )
  616. else:
  617. fake_audios = next(inference(req))
  618. buffer = io.BytesIO()
  619. sf.write(
  620. buffer,
  621. fake_audios,
  622. decoder_model.spec_transform.sample_rate,
  623. format=req.format,
  624. )
  625. return StreamResponse(
  626. iterable=buffer_to_async_generator(buffer.getvalue()),
  627. headers={
  628. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  629. },
  630. content_type=get_content_type(req.format),
  631. )
  632. @routes.http.post("/v1/health")
  633. async def api_health():
  634. """
  635. Health check
  636. """
  637. return JSONResponse({"status": "ok"})
  638. def parse_args():
  639. parser = ArgumentParser()
  640. parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
  641. parser.add_argument("--load-asr-model", action="store_true")
  642. parser.add_argument(
  643. "--llama-checkpoint-path",
  644. type=str,
  645. default="checkpoints/fish-speech-1.4",
  646. )
  647. parser.add_argument(
  648. "--decoder-checkpoint-path",
  649. type=str,
  650. default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
  651. )
  652. parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
  653. parser.add_argument("--device", type=str, default="cuda")
  654. parser.add_argument("--half", action="store_true")
  655. parser.add_argument("--compile", action="store_true")
  656. parser.add_argument("--max-text-length", type=int, default=0)
  657. parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
  658. parser.add_argument("--workers", type=int, default=1)
  659. return parser.parse_args()
  660. # Define Kui app
  661. openapi = OpenAPI(
  662. {
  663. "title": "Fish Speech API",
  664. "version": "1.4.2",
  665. },
  666. ).routes
  667. class MsgPackRequest(HttpRequest):
  668. async def data(
  669. self,
  670. ) -> Annotated[
  671. Any, ContentType("application/msgpack"), ContentType("application/json")
  672. ]:
  673. if self.content_type == "application/msgpack":
  674. return ormsgpack.unpackb(await self.body)
  675. elif self.content_type == "application/json":
  676. return await self.json
  677. raise HTTPException(
  678. HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
  679. headers={"Accept": "application/msgpack, application/json"},
  680. )
  681. app = Kui(
  682. routes=routes + openapi[1:], # Remove the default route
  683. exception_handlers={
  684. HTTPException: http_execption_handler,
  685. Exception: other_exception_handler,
  686. },
  687. factory_class=FactoryClass(http=MsgPackRequest),
  688. cors_config={},
  689. )
  690. def load_asr_model(*, device="cuda", hub="ms"):
  691. return AutoModel(
  692. model="iic/SenseVoiceSmall",
  693. device=device,
  694. disable_pbar=True,
  695. hub=hub,
  696. )
  697. # Each worker process created by Uvicorn has its own memory space,
  698. # meaning that models and variables are not shared between processes.
  699. # Therefore, any global variables (like `llama_queue` or `decoder_model`)
  700. # will not be shared across workers.
  701. # Multi-threading for deep learning can cause issues, such as inconsistent
  702. # outputs if multiple threads access the same buffers simultaneously.
  703. # Instead, it's better to use multiprocessing or independent models per thread.
  704. @app.on_startup
  705. def initialize_app(app: Kui):
  706. global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
  707. prompt_tokens, prompt_texts = [], []
  708. args = parse_args() # args same as ones in other processes
  709. args.precision = torch.half if args.half else torch.bfloat16
  710. if args.load_asr_model:
  711. logger.info(f"Loading ASR model...")
  712. asr_model = load_asr_model(device=args.device)
  713. logger.info("Loading Llama model...")
  714. if args.mode == "tts":
  715. llama_queue = launch_thread_safe_queue(
  716. checkpoint_path=args.llama_checkpoint_path,
  717. device=args.device,
  718. precision=args.precision,
  719. compile=args.compile,
  720. )
  721. else:
  722. llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
  723. checkpoint_path=args.llama_checkpoint_path,
  724. device=args.device,
  725. precision=args.precision,
  726. compile=args.compile,
  727. )
  728. logger.info("Llama model loaded, loading VQ-GAN model...")
  729. decoder_model = load_decoder_model(
  730. config_name=args.decoder_config_name,
  731. checkpoint_path=args.decoder_checkpoint_path,
  732. device=args.device,
  733. )
  734. logger.info("VQ-GAN model loaded, warming up...")
  735. vad_model = load_silero_vad()
  736. logger.info("VAD model loaded, warming up...")
  737. if args.mode == "tts":
  738. # Dry run to ensure models work and avoid first-time latency
  739. list(
  740. inference(
  741. ServeTTSRequest(
  742. text="Hello world.",
  743. references=[],
  744. reference_id=None,
  745. max_new_tokens=0,
  746. chunk_length=200,
  747. top_p=0.7,
  748. repetition_penalty=1.2,
  749. temperature=0.7,
  750. emotion=None,
  751. format="wav",
  752. )
  753. )
  754. )
  755. logger.info(f"Warming up done, starting server at http://{args.listen}")
  756. if __name__ == "__main__":
  757. import uvicorn
  758. args = parse_args()
  759. host, port = args.listen.split(":")
  760. uvicorn.run(
  761. "tools.api:app",
  762. host=host,
  763. port=int(port),
  764. workers=args.workers,
  765. log_level="info",
  766. )