views.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. import io
  2. import os
  3. import re
  4. import tempfile
  5. import time
  6. from http import HTTPStatus
  7. from pathlib import Path
  8. import numpy as np
  9. import ormsgpack
  10. import soundfile as sf
  11. import torch
  12. from kui.asgi import (
  13. Body,
  14. HTTPException,
  15. HttpView,
  16. JSONResponse,
  17. Routes,
  18. StreamResponse,
  19. UploadFile,
  20. request,
  21. )
  22. from loguru import logger
  23. from typing_extensions import Annotated
  24. from fish_speech.utils.schema import (
  25. AddReferenceResponse,
  26. DeleteReferenceResponse,
  27. ListReferencesResponse,
  28. ServeTTSRequest,
  29. ServeVQGANDecodeRequest,
  30. ServeVQGANDecodeResponse,
  31. ServeVQGANEncodeRequest,
  32. ServeVQGANEncodeResponse,
  33. UpdateReferenceResponse,
  34. )
  35. from tools.server.api_utils import (
  36. buffer_to_async_generator,
  37. format_response,
  38. get_content_type,
  39. inference_async,
  40. )
  41. from tools.server.inference import inference_wrapper as inference
  42. from tools.server.model_manager import ModelManager
  43. from tools.server.model_utils import (
  44. batch_vqgan_decode,
  45. cached_vqgan_batch_encode,
  46. )
  47. MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
  48. _WEBUI_HTML = (
  49. Path(__file__).parent.parent.parent / "awesome_webui" / "dist" / "index.html"
  50. )
  51. routes = Routes()
  52. @routes.http("/ui")
  53. class WebUI(HttpView):
  54. @classmethod
  55. async def get(cls):
  56. from kui.asgi import HTMLResponse
  57. if _WEBUI_HTML.exists():
  58. return HTMLResponse(_WEBUI_HTML.read_text(encoding="utf-8"))
  59. return JSONResponse(
  60. {"error": "WebUI not built. Run: cd awesome_webui && npm run build"},
  61. status_code=404,
  62. )
  63. @routes.http("/v1/health")
  64. class Health(HttpView):
  65. @classmethod
  66. async def get(cls):
  67. return JSONResponse({"status": "ok"})
  68. @classmethod
  69. async def post(cls):
  70. return JSONResponse({"status": "ok"})
  71. @routes.http.post("/v1/vqgan/encode")
  72. async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
  73. """
  74. Encode audio using VQGAN model.
  75. """
  76. try:
  77. # Get the model from the app
  78. model_manager: ModelManager = request.app.state.model_manager
  79. decoder_model = model_manager.decoder_model
  80. # Encode the audio
  81. start_time = time.time()
  82. tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
  83. logger.info(
  84. f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
  85. )
  86. # Return the response
  87. return ormsgpack.packb(
  88. ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
  89. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  90. )
  91. except Exception as e:
  92. logger.error(f"Error in VQGAN encode: {e}", exc_info=True)
  93. raise HTTPException(
  94. HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to encode audio"
  95. )
  96. @routes.http.post("/v1/vqgan/decode")
  97. async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
  98. """
  99. Decode tokens to audio using VQGAN model.
  100. """
  101. try:
  102. # Get the model from the app
  103. model_manager: ModelManager = request.app.state.model_manager
  104. decoder_model = model_manager.decoder_model
  105. # Decode the audio
  106. tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
  107. start_time = time.time()
  108. audios = batch_vqgan_decode(decoder_model, tokens)
  109. logger.info(
  110. f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
  111. )
  112. audios = [audio.astype(np.float16).tobytes() for audio in audios]
  113. # Return the response
  114. return ormsgpack.packb(
  115. ServeVQGANDecodeResponse(audios=audios),
  116. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  117. )
  118. except Exception as e:
  119. logger.error(f"Error in VQGAN decode: {e}", exc_info=True)
  120. raise HTTPException(
  121. HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to decode tokens to audio"
  122. )
  123. @routes.http.post("/v1/tts")
  124. async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
  125. """
  126. Generate speech from text using TTS model.
  127. """
  128. logger.info(f"/v1/tts param: {req}")
  129. try:
  130. # Get the model from the app
  131. app_state = request.app.state
  132. model_manager: ModelManager = app_state.model_manager
  133. engine = model_manager.tts_inference_engine
  134. sample_rate = engine.decoder_model.sample_rate
  135. # Check if the text is too long
  136. if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
  137. raise HTTPException(
  138. HTTPStatus.BAD_REQUEST,
  139. content=f"Text is too long, max length is {app_state.max_text_length}",
  140. )
  141. # Check if streaming is enabled
  142. if req.streaming and req.format != "wav":
  143. raise HTTPException(
  144. HTTPStatus.BAD_REQUEST,
  145. content="Streaming only supports WAV format",
  146. )
  147. # Check Reference id is existed
  148. if req.reference_id:
  149. ref_dir = Path("references") / req.reference_id
  150. if not ref_dir.exists():
  151. raise HTTPException(
  152. HTTPStatus.BAD_REQUEST,
  153. content="Reference id is not existed",
  154. )
  155. # Perform TTS
  156. if req.streaming:
  157. return StreamResponse(
  158. iterable=inference_async(req, engine),
  159. headers={
  160. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  161. },
  162. content_type=get_content_type(req.format),
  163. )
  164. else:
  165. fake_audios = next(inference(req, engine))
  166. buffer = io.BytesIO()
  167. sf.write(
  168. buffer,
  169. fake_audios,
  170. sample_rate,
  171. format=req.format,
  172. )
  173. return StreamResponse(
  174. iterable=buffer_to_async_generator(buffer.getvalue()),
  175. headers={
  176. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  177. },
  178. content_type=get_content_type(req.format),
  179. )
  180. except HTTPException:
  181. # Re-raise HTTP exceptions as they are already properly formatted
  182. raise
  183. except Exception as e:
  184. logger.error(f"Error in TTS generation: {e}", exc_info=True)
  185. raise HTTPException(
  186. HTTPStatus.INTERNAL_SERVER_ERROR, content="Failed to generate speech"
  187. )
  188. @routes.http.post("/v1/references/add")
  189. async def add_reference(
  190. id: str = Body(...), audio: UploadFile = Body(...), text: str = Body(...)
  191. ):
  192. """
  193. Add a new reference voice with audio file and text.
  194. """
  195. temp_file_path = None
  196. try:
  197. # Validate input parameters
  198. if not id or not id.strip():
  199. raise ValueError("Reference ID cannot be empty")
  200. if not text or not text.strip():
  201. raise ValueError("Reference text cannot be empty")
  202. # Get the model manager to access the reference loader
  203. app_state = request.app.state
  204. model_manager: ModelManager = app_state.model_manager
  205. engine = model_manager.tts_inference_engine
  206. # Read the uploaded audio file
  207. audio_content = audio.read()
  208. if not audio_content:
  209. raise ValueError("Audio file is empty or could not be read")
  210. # Create a temporary file for the audio data
  211. with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
  212. temp_file.write(audio_content)
  213. temp_file_path = temp_file.name
  214. # Add the reference using the engine's reference loader
  215. engine.add_reference(id, temp_file_path, text)
  216. response = AddReferenceResponse(
  217. success=True,
  218. message=f"Reference voice '{id}' added successfully",
  219. reference_id=id,
  220. )
  221. return format_response(response)
  222. except FileExistsError as e:
  223. logger.warning(f"Reference ID '{id}' already exists: {e}")
  224. response = AddReferenceResponse(
  225. success=False,
  226. message=f"Reference ID '{id}' already exists",
  227. reference_id=id,
  228. )
  229. return format_response(response, status_code=409) # Conflict
  230. except ValueError as e:
  231. logger.warning(f"Invalid input for reference '{id}': {e}")
  232. response = AddReferenceResponse(success=False, message=str(e), reference_id=id)
  233. return format_response(response, status_code=400)
  234. except (FileNotFoundError, OSError) as e:
  235. logger.error(f"File system error for reference '{id}': {e}")
  236. response = AddReferenceResponse(
  237. success=False, message="File system error occurred", reference_id=id
  238. )
  239. return format_response(response, status_code=500)
  240. except Exception as e:
  241. logger.error(f"Unexpected error adding reference '{id}': {e}", exc_info=True)
  242. response = AddReferenceResponse(
  243. success=False, message="Internal server error occurred", reference_id=id
  244. )
  245. return format_response(response, status_code=500)
  246. finally:
  247. # Clean up temporary file
  248. if temp_file_path and os.path.exists(temp_file_path):
  249. try:
  250. os.unlink(temp_file_path)
  251. except OSError as e:
  252. logger.warning(
  253. f"Failed to clean up temporary file {temp_file_path}: {e}"
  254. )
  255. @routes.http.get("/v1/references/list")
  256. async def list_references():
  257. """
  258. Get a list of all available reference voice IDs.
  259. """
  260. try:
  261. # Get the model manager to access the reference loader
  262. app_state = request.app.state
  263. model_manager: ModelManager = app_state.model_manager
  264. engine = model_manager.tts_inference_engine
  265. # Get the list of reference IDs
  266. reference_ids = engine.list_reference_ids()
  267. response = ListReferencesResponse(
  268. success=True,
  269. reference_ids=reference_ids,
  270. message=f"Found {len(reference_ids)} reference voices",
  271. )
  272. return format_response(response)
  273. except Exception as e:
  274. logger.error(f"Unexpected error listing references: {e}", exc_info=True)
  275. response = ListReferencesResponse(
  276. success=False, reference_ids=[], message="Internal server error occurred"
  277. )
  278. return format_response(response, status_code=500)
  279. @routes.http.delete("/v1/references/delete")
  280. async def delete_reference(reference_id: str = Body(...)):
  281. """
  282. Delete a reference voice by ID.
  283. """
  284. try:
  285. # Validate input parameters
  286. if not reference_id or not reference_id.strip():
  287. raise ValueError("Reference ID cannot be empty")
  288. id_pattern = r"^[a-zA-Z0-9\-_ ]+$"
  289. if not re.match(id_pattern, reference_id) or len(reference_id) > 255:
  290. raise ValueError("Reference ID contains invalid characters or is too long")
  291. # Get the model manager to access the reference loader
  292. app_state = request.app.state
  293. model_manager: ModelManager = app_state.model_manager
  294. engine = model_manager.tts_inference_engine
  295. # Delete the reference using the engine's reference loader
  296. engine.delete_reference(reference_id)
  297. response = DeleteReferenceResponse(
  298. success=True,
  299. message=f"Reference voice '{reference_id}' deleted successfully",
  300. reference_id=reference_id,
  301. )
  302. return format_response(response)
  303. except FileNotFoundError as e:
  304. logger.warning(f"Reference ID '{reference_id}' not found: {e}")
  305. response = DeleteReferenceResponse(
  306. success=False,
  307. message=f"Reference ID '{reference_id}' not found",
  308. reference_id=reference_id,
  309. )
  310. return format_response(response, status_code=404) # Not Found
  311. except ValueError as e:
  312. logger.warning(f"Invalid input for reference '{reference_id}': {e}")
  313. response = DeleteReferenceResponse(
  314. success=False, message=str(e), reference_id=reference_id
  315. )
  316. return format_response(response, status_code=400)
  317. except OSError as e:
  318. logger.error(f"File system error deleting reference '{reference_id}': {e}")
  319. response = DeleteReferenceResponse(
  320. success=False,
  321. message="File system error occurred",
  322. reference_id=reference_id,
  323. )
  324. return format_response(response, status_code=500)
  325. except Exception as e:
  326. logger.error(
  327. f"Unexpected error deleting reference '{reference_id}': {e}", exc_info=True
  328. )
  329. response = DeleteReferenceResponse(
  330. success=False,
  331. message="Internal server error occurred",
  332. reference_id=reference_id,
  333. )
  334. return format_response(response, status_code=500)
  335. @routes.http.post("/v1/references/update")
  336. async def update_reference(
  337. old_reference_id: str = Body(...), new_reference_id: str = Body(...)
  338. ):
  339. """
  340. Rename a reference voice directory from old_reference_id to new_reference_id.
  341. """
  342. try:
  343. # Validate input parameters
  344. if not old_reference_id or not old_reference_id.strip():
  345. raise ValueError("Old reference ID cannot be empty")
  346. if not new_reference_id or not new_reference_id.strip():
  347. raise ValueError("New reference ID cannot be empty")
  348. if old_reference_id == new_reference_id:
  349. raise ValueError("New reference ID must be different from old reference ID")
  350. # Validate ID format per ReferenceLoader rules
  351. id_pattern = r"^[a-zA-Z0-9\-_ ]+$"
  352. if not re.match(id_pattern, old_reference_id) or len(old_reference_id) > 255:
  353. raise ValueError(
  354. "Old reference ID contains invalid characters or is too long"
  355. )
  356. if not re.match(id_pattern, new_reference_id) or len(new_reference_id) > 255:
  357. raise ValueError(
  358. "New reference ID contains invalid characters or is too long"
  359. )
  360. # Access engine to update caches after renaming
  361. app_state = request.app.state
  362. model_manager: ModelManager = app_state.model_manager
  363. engine = model_manager.tts_inference_engine
  364. refs_base = Path("references")
  365. old_dir = refs_base / old_reference_id
  366. new_dir = refs_base / new_reference_id
  367. # Existence checks
  368. if not old_dir.exists() or not old_dir.is_dir():
  369. raise FileNotFoundError(f"Reference ID '{old_reference_id}' not found")
  370. if new_dir.exists():
  371. # Conflict: destination already exists
  372. response = UpdateReferenceResponse(
  373. success=False,
  374. message=f"Reference ID '{new_reference_id}' already exists",
  375. old_reference_id=old_reference_id,
  376. new_reference_id=new_reference_id,
  377. )
  378. return format_response(response, status_code=409)
  379. # Perform rename
  380. old_dir.rename(new_dir)
  381. # Update in-memory cache key if present
  382. if old_reference_id in engine.ref_by_id:
  383. engine.ref_by_id[new_reference_id] = engine.ref_by_id.pop(old_reference_id)
  384. response = UpdateReferenceResponse(
  385. success=True,
  386. message=(
  387. f"Reference voice renamed from '{old_reference_id}' to '{new_reference_id}' successfully"
  388. ),
  389. old_reference_id=old_reference_id,
  390. new_reference_id=new_reference_id,
  391. )
  392. return format_response(response)
  393. except FileNotFoundError as e:
  394. logger.warning(str(e))
  395. response = UpdateReferenceResponse(
  396. success=False,
  397. message=str(e),
  398. old_reference_id=old_reference_id,
  399. new_reference_id=new_reference_id,
  400. )
  401. return format_response(response, status_code=404)
  402. except ValueError as e:
  403. logger.warning(f"Invalid input for update reference: {e}")
  404. response = UpdateReferenceResponse(
  405. success=False,
  406. message=str(e),
  407. old_reference_id=old_reference_id if "old_reference_id" in locals() else "",
  408. new_reference_id=new_reference_id if "new_reference_id" in locals() else "",
  409. )
  410. return format_response(response, status_code=400)
  411. except OSError as e:
  412. logger.error(f"File system error renaming reference: {e}")
  413. response = UpdateReferenceResponse(
  414. success=False,
  415. message="File system error occurred",
  416. old_reference_id=old_reference_id,
  417. new_reference_id=new_reference_id,
  418. )
  419. return format_response(response, status_code=500)
  420. except Exception as e:
  421. logger.error(f"Unexpected error updating reference: {e}", exc_info=True)
  422. response = UpdateReferenceResponse(
  423. success=False,
  424. message="Internal server error occurred",
  425. old_reference_id=old_reference_id if "old_reference_id" in locals() else "",
  426. new_reference_id=new_reference_id if "new_reference_id" in locals() else "",
  427. )
  428. return format_response(response, status_code=500)