api.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953
  1. import io
  2. import json
  3. import os
  4. import queue
  5. import re
  6. import time
  7. import traceback
  8. import wave
  9. from argparse import ArgumentParser
  10. from http import HTTPStatus
  11. from pathlib import Path
  12. from typing import Annotated, Any
  13. import librosa
  14. import numpy as np
  15. import ormsgpack
  16. import pyrootutils
  17. import soundfile as sf
  18. import torch
  19. import torchaudio
  20. from baize.datastructures import ContentType
  21. from kui.asgi import (
  22. Body,
  23. FactoryClass,
  24. HTTPException,
  25. HttpRequest,
  26. HttpView,
  27. JSONResponse,
  28. Kui,
  29. OpenAPI,
  30. StreamResponse,
  31. request,
  32. )
  33. from kui.asgi.routing import MultimethodRoutes
  34. from loguru import logger
  35. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  36. import struct
  37. from threading import Lock
  38. import httpx
  39. from cachetools import LRUCache, cached
  40. from funasr import AutoModel
  41. from silero_vad import get_speech_timestamps, load_silero_vad
  42. from fish_speech.models.text2semantic.llama import BaseModelArgs
  43. # from fish_speech.models.vqgan.lit_module import VQGAN
  44. from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
  45. from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
  46. # from fish_speech.conversation import IM_END_TOKEN, SEMANTIC_TOKEN
  47. from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
  48. from fish_speech.utils import autocast_exclude_mps, set_seed
  49. from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
  50. from tools.llama.generate import (
  51. GenerateRequest,
  52. GenerateResponse,
  53. WrappedGenerateResponse,
  54. launch_thread_safe_queue,
  55. launch_thread_safe_queue_agent,
  56. )
  57. from tools.schema import (
  58. GLOBAL_NUM_SAMPLES,
  59. ASRPackRequest,
  60. ServeASRRequest,
  61. ServeASRResponse,
  62. ServeASRSegment,
  63. ServeAudioPart,
  64. ServeForwardMessage,
  65. ServeMessage,
  66. ServeRequest,
  67. ServeResponse,
  68. ServeStreamDelta,
  69. ServeStreamResponse,
  70. ServeTextPart,
  71. ServeTimedASRResponse,
  72. ServeTTSRequest,
  73. ServeVQGANDecodeRequest,
  74. ServeVQGANDecodeResponse,
  75. ServeVQGANEncodeRequest,
  76. ServeVQGANEncodeResponse,
  77. ServeVQPart,
  78. )
  79. from tools.vqgan.inference import load_model as load_decoder_model
  80. global_lock = Lock()
  81. # Whether to disable keepalive (which is helpful if the server is in the same cluster)
  82. DISABLE_KEEPALIVE = os.getenv("DISABLE_KEEPALIVE", "false").lower() == "true"
  83. async_client = httpx.AsyncClient(
  84. timeout=120, limits=httpx.Limits(keepalive_expiry=0 if DISABLE_KEEPALIVE else None)
  85. )
  86. backends = torchaudio.list_audio_backends()
  87. if "ffmpeg" in backends:
  88. backend = "ffmpeg"
  89. else:
  90. backend = "soundfile"
  91. def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
  92. buffer = io.BytesIO()
  93. with wave.open(buffer, "wb") as wav_file:
  94. wav_file.setnchannels(channels)
  95. wav_file.setsampwidth(bit_depth // 8)
  96. wav_file.setframerate(sample_rate)
  97. wav_header_bytes = buffer.getvalue()
  98. buffer.close()
  99. return wav_header_bytes
  100. # Define utils for web server
  101. async def http_execption_handler(exc: HTTPException):
  102. return JSONResponse(
  103. dict(
  104. statusCode=exc.status_code,
  105. message=exc.content,
  106. error=HTTPStatus(exc.status_code).phrase,
  107. ),
  108. exc.status_code,
  109. exc.headers,
  110. )
  111. async def other_exception_handler(exc: "Exception"):
  112. traceback.print_exc()
  113. status = HTTPStatus.INTERNAL_SERVER_ERROR
  114. return JSONResponse(
  115. dict(statusCode=status, message=str(exc), error=status.phrase),
  116. status,
  117. )
  118. def load_audio(reference_audio, sr):
  119. if len(reference_audio) > 255 or not Path(reference_audio).exists():
  120. audio_data = reference_audio
  121. reference_audio = io.BytesIO(audio_data)
  122. waveform, original_sr = torchaudio.load(reference_audio, backend=backend)
  123. if waveform.shape[0] > 1:
  124. waveform = torch.mean(waveform, dim=0, keepdim=True)
  125. if original_sr != sr:
  126. resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
  127. waveform = resampler(waveform)
  128. audio = waveform.squeeze().numpy()
  129. return audio
  130. def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
  131. if enable_reference_audio and reference_audio is not None:
  132. # Load audios, and prepare basic info here
  133. reference_audio_content = load_audio(
  134. reference_audio, decoder_model.spec_transform.sample_rate
  135. )
  136. audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
  137. None, None, :
  138. ]
  139. audio_lengths = torch.tensor(
  140. [audios.shape[2]], device=decoder_model.device, dtype=torch.long
  141. )
  142. logger.info(
  143. f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
  144. )
  145. # VQ Encoder
  146. if isinstance(decoder_model, FireflyArchitecture):
  147. prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
  148. logger.info(f"Encoded prompt: {prompt_tokens.shape}")
  149. else:
  150. prompt_tokens = None
  151. logger.info("No reference audio provided")
  152. return prompt_tokens
  153. def decode_vq_tokens(
  154. *,
  155. decoder_model,
  156. codes,
  157. ):
  158. feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
  159. logger.info(f"VQ features: {codes.shape}")
  160. if isinstance(decoder_model, FireflyArchitecture):
  161. # VQGAN Inference
  162. return decoder_model.decode(
  163. indices=codes[None],
  164. feature_lengths=feature_lengths,
  165. )[0].squeeze()
  166. raise ValueError(f"Unknown model type: {type(decoder_model)}")
  167. routes = MultimethodRoutes(base_class=HttpView)
  168. def get_content_type(audio_format):
  169. if audio_format == "wav":
  170. return "audio/wav"
  171. elif audio_format == "flac":
  172. return "audio/flac"
  173. elif audio_format == "mp3":
  174. return "audio/mpeg"
  175. else:
  176. return "application/octet-stream"
  177. @torch.no_grad()
  178. @torch.autocast(device_type="cuda", dtype=torch.half)
  179. def batch_encode(model, audios: list[bytes | torch.Tensor]):
  180. audios = [
  181. (
  182. torch.from_numpy(
  183. librosa.load(io.BytesIO(audio), sr=model.spec_transform.sample_rate)[0]
  184. )[None]
  185. if isinstance(audio, bytes)
  186. else audio
  187. )
  188. for audio in audios
  189. ]
  190. # if any(audio.shape[-1] > model.spec_transform.sample_rate * 120 for audio in audios):
  191. # raise ValueError("Single audio length is too long (>120s)")
  192. max_length = max(audio.shape[-1] for audio in audios)
  193. print(f"Encode max length: {max_length / model.spec_transform.sample_rate:.2f}s")
  194. lengths = torch.tensor([audio.shape[-1] for audio in audios], device=model.device)
  195. max_length = lengths.max().item()
  196. padded = torch.stack(
  197. [
  198. torch.nn.functional.pad(audio, (0, max_length - audio.shape[-1]))
  199. for audio in audios
  200. ]
  201. ).to(model.device)
  202. features, feature_lengths = model.encode(padded, audio_lengths=lengths)
  203. features, feature_lengths = features.cpu(), feature_lengths.cpu()
  204. return [feature[..., :length] for feature, length in zip(features, feature_lengths)]
  205. @cached(
  206. cache=LRUCache(maxsize=10000),
  207. key=lambda model, audios: (model.device, tuple(audios)),
  208. )
  209. def cached_vqgan_batch_encode(model, audios: list[bytes]):
  210. return batch_encode(model, audios)
  211. @routes.http.post("/v1/vqgan/encode")
  212. def api_vqgan_encode(payload: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
  213. start_time = time.time()
  214. tokens = cached_vqgan_batch_encode(decoder_model, payload.audios)
  215. logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
  216. return ormsgpack.packb(
  217. ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
  218. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  219. )
  220. @torch.no_grad()
  221. @torch.autocast(device_type="cuda", dtype=torch.half)
  222. def vqgan_decode(model, features):
  223. lengths = torch.tensor(
  224. [feature.shape[-1] for feature in features], device=model.device
  225. )
  226. max_length = lengths.max().item()
  227. padded = torch.stack(
  228. [
  229. torch.nn.functional.pad(feature, (0, max_length - feature.shape[-1]))
  230. for feature in features
  231. ]
  232. ).to(model.device)
  233. # If bs too large, we do micro batch decode
  234. audios, audio_lengths = [], []
  235. for i in range(0, padded.shape[0], 8):
  236. audio, audio_length = model.decode(
  237. padded[i : i + 8], feature_lengths=lengths[i : i + 8]
  238. )
  239. audios.append(audio)
  240. audio_lengths.append(audio_length)
  241. audios = torch.cat(audios, dim=0)
  242. audio_lengths = torch.cat(audio_lengths, dim=0)
  243. audios, audio_lengths = audios.cpu(), audio_lengths.cpu()
  244. return [audio[..., :length].numpy() for audio, length in zip(audios, audio_lengths)]
  245. @routes.http.post("/v1/vqgan/decode")
  246. def api_vqgan_decode(payload: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
  247. tokens = [torch.tensor(token, dtype=torch.int) for token in payload.tokens]
  248. start_time = time.time()
  249. audios = vqgan_decode(decoder_model, tokens)
  250. logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
  251. audios = [audio.astype(np.float16).tobytes() for audio in audios]
  252. return ormsgpack.packb(
  253. ServeVQGANDecodeResponse(audios=audios), option=ormsgpack.OPT_SERIALIZE_PYDANTIC
  254. )
  255. @torch.no_grad()
  256. def batch_asr(model, audios, sr, language="auto"):
  257. resampled_audios = []
  258. for audio in audios:
  259. audio = torchaudio.functional.resample(audio, sr, 16000)
  260. assert audio.ndim == 1
  261. resampled_audios.append(audio)
  262. with global_lock:
  263. res = model.generate(
  264. input=resampled_audios,
  265. batch_size=len(resampled_audios),
  266. language=language,
  267. use_itn=True,
  268. )
  269. results = []
  270. for r, audio in zip(res, audios):
  271. text = r["text"]
  272. text = re.sub(r"<\|.*?\|>", "", text)
  273. duration = len(audio) / sr * 1000
  274. huge_gap = False
  275. if "timestamp" in r and len(r["timestamp"]) > 2:
  276. for timestamp_a, timestamp_b in zip(
  277. r["timestamp"][:-1], r["timestamp"][1:]
  278. ):
  279. # If there is a gap of more than 5 seconds, we consider it as a huge gap
  280. if timestamp_b[0] - timestamp_a[1] > 5000:
  281. huge_gap = True
  282. break
  283. # Doesn't make sense to have a huge gap at the end
  284. if duration - r["timestamp"][-1][1] > 3000:
  285. huge_gap = True
  286. results.append(
  287. {
  288. "text": text,
  289. "duration": duration,
  290. "huge_gap": huge_gap,
  291. }
  292. )
  293. return results
  294. @routes.http.post("/v1/asr")
  295. def api_invoke_asr(payload: Annotated[ServeASRRequest, Body(exclusive=True)]):
  296. start_time = time.time()
  297. audios = [np.frombuffer(audio, dtype=np.float16) for audio in payload.audios]
  298. audios = [torch.from_numpy(audio).float() for audio in audios]
  299. if any(audios.shape[-1] >= 30 * payload.sample_rate for audios in audios):
  300. raise HTTPException(status_code=400, detail="Audio length is too long")
  301. transcriptions = batch_asr(
  302. asr_model, audios=audios, sr=payload.sample_rate, language=payload.language
  303. )
  304. logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
  305. return ormsgpack.packb(
  306. ServeASRResponse(transcriptions=transcriptions),
  307. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  308. )
  309. from fish_speech.conversation import Conversation, Message
  310. def execute_request(
  311. input_queue: queue.Queue,
  312. tokenizer: FishTokenizer,
  313. config: BaseModelArgs,
  314. request: ServeRequest,
  315. device: str = "cuda:0",
  316. ):
  317. im_end_id = tokenizer.get_token_id(IM_END_TOKEN)
  318. messages = []
  319. for message in request.messages:
  320. messages.append(message.to_conversation_message())
  321. assert len(messages) >= 1, "At least one message is required"
  322. # assert messages[-1].role == "user", "The last message must be from the user"
  323. if messages[-1].role == "user":
  324. messages.append(
  325. Message(role="assistant", parts=[], add_im_end=False, modality="voice")
  326. )
  327. elif messages[-1].role == "raw":
  328. messages[-1].add_im_start = False
  329. messages[-1].add_im_end = False
  330. messages[-1].modality = "voice"
  331. else:
  332. assert (
  333. messages[-1].role == "assistant"
  334. ), "The last message must be from the assistant"
  335. messages[-1].add_im_end = False
  336. conv = Conversation(messages=messages)
  337. # conv.visualize(tokenizer)
  338. prompt = conv.encode_for_inference(
  339. tokenizer=tokenizer, num_codebooks=config.num_codebooks
  340. ).to(device)
  341. if request.streaming:
  342. for i in range(request.num_samples):
  343. yield ServeStreamResponse(
  344. sample_id=i,
  345. delta=ServeStreamDelta(
  346. role="assistant",
  347. ),
  348. )
  349. req = {
  350. "prompt": prompt,
  351. "max_new_tokens": request.max_new_tokens,
  352. "im_end_id": im_end_id,
  353. "temperature": request.temperature,
  354. "top_p": request.top_p,
  355. "repetition_penalty": request.repetition_penalty,
  356. "num_samples": request.num_samples,
  357. "early_stop_threshold": request.early_stop_threshold,
  358. }
  359. start = time.time()
  360. response_queue = queue.Queue()
  361. input_queue.put(GenerateRequest(req, response_queue))
  362. # Decoding
  363. decode_buffer = [[] for _ in range(request.num_samples)]
  364. parts = [[] for _ in range(request.num_samples)]
  365. def send_reset_buffer(sample_id):
  366. nonlocal decode_buffer
  367. if len(decode_buffer[sample_id]) == 0:
  368. return
  369. decoded = tokenizer.decode(decode_buffer[sample_id])
  370. part = ServeTextPart(text=decoded)
  371. if request.streaming:
  372. yield ServeStreamResponse(delta=ServeStreamDelta(part=part))
  373. else:
  374. parts[sample_id].append(part)
  375. decode_buffer[sample_id] = []
  376. # Decode process
  377. finished = [False for _ in range(request.num_samples)]
  378. stats = {}
  379. idx = 0
  380. while True:
  381. response = response_queue.get()
  382. if response in ["stop", "error"]:
  383. break
  384. for sample_id, tokens in enumerate(response):
  385. if finished[sample_id]:
  386. continue
  387. if tokens[0] == im_end_id:
  388. finished[sample_id] = True
  389. if request.streaming:
  390. yield from send_reset_buffer(sample_id)
  391. yield ServeStreamResponse(
  392. sample_id=sample_id,
  393. finish_reason="stop",
  394. stats=stats,
  395. )
  396. continue
  397. is_semantic = (
  398. tokenizer.semantic_begin_id <= tokens[0] <= tokenizer.semantic_end_id
  399. )
  400. if is_semantic and request.streaming:
  401. yield from send_reset_buffer(sample_id)
  402. # Streaming vq
  403. _tokens = tokens[1:].clone()
  404. if config.share_codebook_embeddings is False:
  405. for i in range(len(_tokens)):
  406. _tokens[i] -= config.codebook_size * i
  407. yield ServeStreamResponse(
  408. sample_id=sample_id,
  409. delta=ServeStreamDelta(part=ServeVQPart(codes=_tokens.tolist())),
  410. )
  411. continue
  412. # Not streaming vq
  413. if is_semantic:
  414. yield from send_reset_buffer(sample_id)
  415. # None streaming vq
  416. if len(parts[sample_id]) == 0 or not isinstance(
  417. parts[sample_id][-1], ServeVQPart
  418. ):
  419. _tokens = tokens[1:].clone()
  420. if config.share_codebook_embeddings is False:
  421. for i in range(len(_tokens)):
  422. _tokens[i] -= config.codebook_size * i
  423. parts[sample_id].append(ServeVQPart(codes=_tokens.tolist()))
  424. else:
  425. for codebook_id, value in enumerate(tokens[1:, :]):
  426. val = value.item()
  427. if config.share_codebook_embeddings is False:
  428. val -= config.codebook_size * codebook_id
  429. parts[sample_id][-1].codes[codebook_id].append(val)
  430. continue
  431. if not is_semantic:
  432. # Stream text decode is not supported now
  433. decode_buffer[sample_id].append(tokens[0, 0])
  434. if idx == 0:
  435. stats["time_to_first_token"] = (time.time() - start) * 1000
  436. idx += 1
  437. for sample_id in range(request.num_samples):
  438. yield from send_reset_buffer(sample_id)
  439. stats["total_time"] = (time.time() - start) * 1000
  440. stats["total_tokens"] = idx
  441. if request.streaming:
  442. for sample_id in range(request.num_samples):
  443. if finished[sample_id]:
  444. continue
  445. yield ServeStreamResponse(
  446. finish_reason=response, stats=stats, sample_id=sample_id
  447. )
  448. return
  449. yield ServeResponse(
  450. messages=[
  451. ServeMessage(role="assistant", parts=parts[i])
  452. for i in range(request.num_samples)
  453. ],
  454. finish_reason=response,
  455. stats=stats,
  456. )
  457. @routes.http.post("/v1/chat")
  458. def api_invoke_chat(
  459. req: Annotated[ServeRequest, Body(exclusive=True)],
  460. ):
  461. """
  462. Invoke model and generate audio
  463. """
  464. # This makes torch compile happy
  465. assert (
  466. req.num_samples == GLOBAL_NUM_SAMPLES
  467. ), f"num_samples must be {GLOBAL_NUM_SAMPLES}"
  468. content_type = request.headers.get("Content-Type", "application/json")
  469. json_mode = "application/json" in content_type
  470. async def wrapped_generator():
  471. generator = execute_request(llama_queue, tokenizer, config, req, args.device)
  472. for i in generator:
  473. if json_mode:
  474. body = i.model_dump_json().encode("utf-8")
  475. yield b"data: " + body + b"\n\n"
  476. else:
  477. body = ormsgpack.packb(i, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
  478. yield struct.pack("I", len(body)) + body
  479. # Naive mode
  480. if req.streaming is False:
  481. result = next(execute_request(llama_queue, tokenizer, config, req, args.device))
  482. if json_mode:
  483. return JSONResponse(result.model_dump())
  484. else:
  485. return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
  486. return StreamResponse(
  487. iterable=wrapped_generator(), content_type="text/event-stream"
  488. )
  489. @torch.inference_mode()
  490. def inference(req: ServeTTSRequest):
  491. global prompt_tokens, prompt_texts
  492. idstr: str | None = req.reference_id
  493. if idstr is not None:
  494. ref_folder = Path("references") / idstr
  495. ref_folder.mkdir(parents=True, exist_ok=True)
  496. ref_audios = list_files(
  497. ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
  498. )
  499. if req.use_memory_cache == "never" or (
  500. req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
  501. ):
  502. prompt_tokens = [
  503. encode_reference(
  504. decoder_model=decoder_model,
  505. reference_audio=audio_to_bytes(str(ref_audio)),
  506. enable_reference_audio=True,
  507. )
  508. for ref_audio in ref_audios
  509. ]
  510. prompt_texts = [
  511. read_ref_text(str(ref_audio.with_suffix(".lab")))
  512. for ref_audio in ref_audios
  513. ]
  514. else:
  515. logger.info("Use same references")
  516. else:
  517. # Parse reference audio aka prompt
  518. refs = req.references
  519. if req.use_memory_cache == "never" or (
  520. req.use_memory_cache == "on-demand" and len(prompt_tokens) == 0
  521. ):
  522. prompt_tokens = [
  523. encode_reference(
  524. decoder_model=decoder_model,
  525. reference_audio=ref.audio,
  526. enable_reference_audio=True,
  527. )
  528. for ref in refs
  529. ]
  530. prompt_texts = [ref.text for ref in refs]
  531. else:
  532. logger.info("Use same references")
  533. if req.seed is not None:
  534. set_seed(req.seed)
  535. logger.warning(f"set seed: {req.seed}")
  536. # LLAMA Inference
  537. request = dict(
  538. device=decoder_model.device,
  539. max_new_tokens=req.max_new_tokens,
  540. text=(
  541. req.text
  542. if not req.normalize
  543. else ChnNormedText(raw_text=req.text).normalize()
  544. ),
  545. top_p=req.top_p,
  546. repetition_penalty=req.repetition_penalty,
  547. temperature=req.temperature,
  548. compile=args.compile,
  549. iterative_prompt=req.chunk_length > 0,
  550. chunk_length=req.chunk_length,
  551. max_length=4096,
  552. prompt_tokens=prompt_tokens,
  553. prompt_text=prompt_texts,
  554. )
  555. response_queue = queue.Queue()
  556. llama_queue.put(
  557. GenerateRequest(
  558. request=request,
  559. response_queue=response_queue,
  560. )
  561. )
  562. if req.streaming:
  563. yield wav_chunk_header()
  564. segments = []
  565. while True:
  566. result: WrappedGenerateResponse = response_queue.get()
  567. if result.status == "error":
  568. raise result.response
  569. break
  570. result: GenerateResponse = result.response
  571. if result.action == "next":
  572. break
  573. with autocast_exclude_mps(
  574. device_type=decoder_model.device.type, dtype=args.precision
  575. ):
  576. fake_audios = decode_vq_tokens(
  577. decoder_model=decoder_model,
  578. codes=result.codes,
  579. )
  580. fake_audios = fake_audios.float().cpu().numpy()
  581. if req.streaming:
  582. yield (fake_audios * 32768).astype(np.int16).tobytes()
  583. else:
  584. segments.append(fake_audios)
  585. if req.streaming:
  586. return
  587. if len(segments) == 0:
  588. raise HTTPException(
  589. HTTPStatus.INTERNAL_SERVER_ERROR,
  590. content="No audio generated, please check the input text.",
  591. )
  592. fake_audios = np.concatenate(segments, axis=0)
  593. yield fake_audios
  594. async def inference_async(req: ServeTTSRequest):
  595. for chunk in inference(req):
  596. yield chunk
  597. async def buffer_to_async_generator(buffer):
  598. yield buffer
  599. @routes.http.post("/v1/tts")
  600. async def api_invoke_model(
  601. req: Annotated[ServeTTSRequest, Body(exclusive=True)],
  602. ):
  603. """
  604. Invoke model and generate audio
  605. """
  606. if args.max_text_length > 0 and len(req.text) > args.max_text_length:
  607. raise HTTPException(
  608. HTTPStatus.BAD_REQUEST,
  609. content=f"Text is too long, max length is {args.max_text_length}",
  610. )
  611. if req.streaming and req.format != "wav":
  612. raise HTTPException(
  613. HTTPStatus.BAD_REQUEST,
  614. content="Streaming only supports WAV format",
  615. )
  616. if req.streaming:
  617. return StreamResponse(
  618. iterable=inference_async(req),
  619. headers={
  620. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  621. },
  622. content_type=get_content_type(req.format),
  623. )
  624. else:
  625. fake_audios = next(inference(req))
  626. buffer = io.BytesIO()
  627. sf.write(
  628. buffer,
  629. fake_audios,
  630. decoder_model.spec_transform.sample_rate,
  631. format=req.format,
  632. )
  633. return StreamResponse(
  634. iterable=buffer_to_async_generator(buffer.getvalue()),
  635. headers={
  636. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  637. },
  638. content_type=get_content_type(req.format),
  639. )
  640. @routes.http.post("/v1/health")
  641. async def api_health():
  642. """
  643. Health check
  644. """
  645. return JSONResponse({"status": "ok"})
  646. def parse_args():
  647. parser = ArgumentParser()
  648. parser.add_argument("--mode", type=str, choices=["agent", "tts"], default="tts")
  649. parser.add_argument("--load-asr-model", action="store_true")
  650. parser.add_argument(
  651. "--llama-checkpoint-path",
  652. type=str,
  653. default="checkpoints/fish-speech-1.4",
  654. )
  655. parser.add_argument(
  656. "--decoder-checkpoint-path",
  657. type=str,
  658. default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
  659. )
  660. parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
  661. parser.add_argument("--device", type=str, default="cuda")
  662. parser.add_argument("--half", action="store_true")
  663. parser.add_argument("--compile", action="store_true")
  664. parser.add_argument("--max-text-length", type=int, default=0)
  665. parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
  666. parser.add_argument("--workers", type=int, default=1)
  667. return parser.parse_args()
  668. # Define Kui app
  669. openapi = OpenAPI(
  670. {
  671. "title": "Fish Speech API",
  672. "version": "1.4.2",
  673. },
  674. ).routes
  675. class MsgPackRequest(HttpRequest):
  676. async def data(
  677. self,
  678. ) -> Annotated[
  679. Any, ContentType("application/msgpack"), ContentType("application/json")
  680. ]:
  681. if self.content_type == "application/msgpack":
  682. return ormsgpack.unpackb(await self.body)
  683. elif self.content_type == "application/json":
  684. return await self.json
  685. raise HTTPException(
  686. HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
  687. headers={"Accept": "application/msgpack, application/json"},
  688. )
  689. app = Kui(
  690. routes=routes + openapi[1:], # Remove the default route
  691. exception_handlers={
  692. HTTPException: http_execption_handler,
  693. Exception: other_exception_handler,
  694. },
  695. factory_class=FactoryClass(http=MsgPackRequest),
  696. cors_config={},
  697. )
  698. def load_asr_model(*, device="cuda", hub="ms"):
  699. return AutoModel(
  700. model="iic/SenseVoiceSmall",
  701. device=device,
  702. disable_pbar=True,
  703. hub=hub,
  704. )
  705. # Each worker process created by Uvicorn has its own memory space,
  706. # meaning that models and variables are not shared between processes.
  707. # Therefore, any global variables (like `llama_queue` or `decoder_model`)
  708. # will not be shared across workers.
  709. # Multi-threading for deep learning can cause issues, such as inconsistent
  710. # outputs if multiple threads access the same buffers simultaneously.
  711. # Instead, it's better to use multiprocessing or independent models per thread.
  712. @app.on_startup
  713. def initialize_app(app: Kui):
  714. global args, llama_queue, tokenizer, config, decoder_model, vad_model, asr_model, prompt_tokens, prompt_texts
  715. prompt_tokens, prompt_texts = [], []
  716. args = parse_args() # args same as ones in other processes
  717. args.precision = torch.half if args.half else torch.bfloat16
  718. if args.load_asr_model:
  719. logger.info(f"Loading ASR model...")
  720. asr_model = load_asr_model(device=args.device)
  721. logger.info("Loading Llama model...")
  722. if args.mode == "tts":
  723. llama_queue = launch_thread_safe_queue(
  724. checkpoint_path=args.llama_checkpoint_path,
  725. device=args.device,
  726. precision=args.precision,
  727. compile=args.compile,
  728. )
  729. else:
  730. llama_queue, tokenizer, config = launch_thread_safe_queue_agent(
  731. checkpoint_path=args.llama_checkpoint_path,
  732. device=args.device,
  733. precision=args.precision,
  734. compile=args.compile,
  735. )
  736. logger.info("Llama model loaded, loading VQ-GAN model...")
  737. decoder_model = load_decoder_model(
  738. config_name=args.decoder_config_name,
  739. checkpoint_path=args.decoder_checkpoint_path,
  740. device=args.device,
  741. )
  742. logger.info("VQ-GAN model loaded, warming up...")
  743. vad_model = load_silero_vad()
  744. logger.info("VAD model loaded, warming up...")
  745. if args.mode == "tts":
  746. # Dry run to ensure models work and avoid first-time latency
  747. list(
  748. inference(
  749. ServeTTSRequest(
  750. text="Hello world.",
  751. references=[],
  752. reference_id=None,
  753. max_new_tokens=0,
  754. chunk_length=200,
  755. top_p=0.7,
  756. repetition_penalty=1.5,
  757. temperature=0.7,
  758. emotion=None,
  759. format="wav",
  760. )
  761. )
  762. )
  763. logger.info(f"Warming up done, starting server at http://{args.listen}")
  764. if __name__ == "__main__":
  765. import uvicorn
  766. args = parse_args()
  767. host, port = args.listen.split(":")
  768. uvicorn.run(
  769. "tools.api:app",
  770. host=host,
  771. port=int(port),
  772. workers=args.workers,
  773. log_level="info",
  774. )