views.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import io
  2. import os
  3. import time
  4. from http import HTTPStatus
  5. import numpy as np
  6. import ormsgpack
  7. import soundfile as sf
  8. import torch
  9. from kui.asgi import (
  10. Body,
  11. HTTPException,
  12. HttpView,
  13. JSONResponse,
  14. Routes,
  15. StreamResponse,
  16. request,
  17. )
  18. from loguru import logger
  19. from typing_extensions import Annotated
  20. from fish_speech.utils.schema import (
  21. ServeASRRequest,
  22. ServeASRResponse,
  23. ServeChatRequest,
  24. ServeTTSRequest,
  25. ServeVQGANDecodeRequest,
  26. ServeVQGANDecodeResponse,
  27. ServeVQGANEncodeRequest,
  28. ServeVQGANEncodeResponse,
  29. )
  30. from tools.server.agent import get_response_generator
  31. from tools.server.api_utils import (
  32. buffer_to_async_generator,
  33. get_content_type,
  34. inference_async,
  35. )
  36. from tools.server.inference import inference_wrapper as inference
  37. from tools.server.model_manager import ModelManager
  38. from tools.server.model_utils import (
  39. batch_asr,
  40. batch_vqgan_decode,
  41. cached_vqgan_batch_encode,
  42. )
  43. MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
  44. routes = Routes()
  45. @routes.http("/v1/health")
  46. class Health(HttpView):
  47. @classmethod
  48. async def get(cls):
  49. return JSONResponse({"status": "ok"})
  50. @classmethod
  51. async def post(cls):
  52. return JSONResponse({"status": "ok"})
  53. @routes.http.post("/v1/vqgan/encode")
  54. async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
  55. # Get the model from the app
  56. model_manager: ModelManager = request.app.state.model_manager
  57. decoder_model = model_manager.decoder_model
  58. # Encode the audio
  59. start_time = time.time()
  60. tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
  61. logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
  62. # Return the response
  63. return ormsgpack.packb(
  64. ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
  65. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  66. )
  67. @routes.http.post("/v1/vqgan/decode")
  68. async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
  69. # Get the model from the app
  70. model_manager: ModelManager = request.app.state.model_manager
  71. decoder_model = model_manager.decoder_model
  72. # Decode the audio
  73. tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
  74. start_time = time.time()
  75. audios = batch_vqgan_decode(decoder_model, tokens)
  76. logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
  77. audios = [audio.astype(np.float16).tobytes() for audio in audios]
  78. # Return the response
  79. return ormsgpack.packb(
  80. ServeVQGANDecodeResponse(audios=audios),
  81. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  82. )
  83. @routes.http.post("/v1/asr")
  84. async def asr(req: Annotated[ServeASRRequest, Body(exclusive=True)]):
  85. # Get the model from the app
  86. model_manager: ModelManager = request.app.state.model_manager
  87. asr_model = model_manager.asr_model
  88. lock = request.app.state.lock
  89. # Perform ASR
  90. start_time = time.time()
  91. audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
  92. audios = [torch.from_numpy(audio).float() for audio in audios]
  93. if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
  94. raise HTTPException(status_code=400, content="Audio length is too long")
  95. transcriptions = batch_asr(
  96. asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
  97. )
  98. logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
  99. # Return the response
  100. return ormsgpack.packb(
  101. ServeASRResponse(transcriptions=transcriptions),
  102. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  103. )
  104. @routes.http.post("/v1/tts")
  105. async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
  106. # Get the model from the app
  107. app_state = request.app.state
  108. model_manager: ModelManager = app_state.model_manager
  109. engine = model_manager.tts_inference_engine
  110. sample_rate = engine.decoder_model.spec_transform.sample_rate
  111. # Check if the text is too long
  112. if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
  113. raise HTTPException(
  114. HTTPStatus.BAD_REQUEST,
  115. content=f"Text is too long, max length is {app_state.max_text_length}",
  116. )
  117. # Check if streaming is enabled
  118. if req.streaming and req.format != "wav":
  119. raise HTTPException(
  120. HTTPStatus.BAD_REQUEST,
  121. content="Streaming only supports WAV format",
  122. )
  123. # Perform TTS
  124. if req.streaming:
  125. return StreamResponse(
  126. iterable=inference_async(req, engine),
  127. headers={
  128. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  129. },
  130. content_type=get_content_type(req.format),
  131. )
  132. else:
  133. fake_audios = next(inference(req, engine))
  134. buffer = io.BytesIO()
  135. sf.write(
  136. buffer,
  137. fake_audios,
  138. sample_rate,
  139. format=req.format,
  140. )
  141. return StreamResponse(
  142. iterable=buffer_to_async_generator(buffer.getvalue()),
  143. headers={
  144. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  145. },
  146. content_type=get_content_type(req.format),
  147. )
  148. @routes.http.post("/v1/chat")
  149. async def chat(req: Annotated[ServeChatRequest, Body(exclusive=True)]):
  150. # Check that the number of samples requested is correct
  151. if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
  152. raise HTTPException(
  153. HTTPStatus.BAD_REQUEST,
  154. content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
  155. )
  156. # Get the type of content provided
  157. content_type = request.headers.get("Content-Type", "application/json")
  158. json_mode = "application/json" in content_type
  159. # Get the models from the app
  160. model_manager: ModelManager = request.app.state.model_manager
  161. llama_queue = model_manager.llama_queue
  162. tokenizer = model_manager.tokenizer
  163. config = model_manager.config
  164. device = request.app.state.device
  165. # Get the response generators
  166. response_generator = get_response_generator(
  167. llama_queue, tokenizer, config, req, device, json_mode
  168. )
  169. # Return the response in the correct format
  170. if req.streaming is False:
  171. result = response_generator()
  172. if json_mode:
  173. return JSONResponse(result.model_dump())
  174. else:
  175. return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
  176. return StreamResponse(
  177. iterable=response_generator(), content_type="text/event-stream"
  178. )