views.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import io
  2. import os
  3. import time
  4. from http import HTTPStatus
  5. import numpy as np
  6. import ormsgpack
  7. import soundfile as sf
  8. import torch
  9. from kui.asgi import HTTPException, HttpView, JSONResponse, StreamResponse, request
  10. from loguru import logger
  11. from tools.schema import (
  12. ServeASRRequest,
  13. ServeASRResponse,
  14. ServeChatRequest,
  15. ServeTTSRequest,
  16. ServeVQGANDecodeRequest,
  17. ServeVQGANDecodeResponse,
  18. ServeVQGANEncodeRequest,
  19. ServeVQGANEncodeResponse,
  20. )
  21. from tools.server.agent import get_response_generator
  22. from tools.server.api_utils import (
  23. buffer_to_async_generator,
  24. get_content_type,
  25. inference_async,
  26. )
  27. from tools.server.inference import inference_wrapper as inference
  28. from tools.server.model_manager import ModelManager
  29. from tools.server.model_utils import batch_asr, cached_vqgan_batch_encode, vqgan_decode
  30. MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
  31. class HealthView(HttpView):
  32. """
  33. Return the health status of the server.
  34. """
  35. @classmethod
  36. async def post(cls):
  37. return JSONResponse({"status": "ok"})
  38. class VQGANEncodeView(HttpView):
  39. """
  40. Encode the audio into symbolic tokens.
  41. """
  42. @classmethod
  43. async def post(cls):
  44. # Decode the request
  45. payload = await request.data()
  46. req = ServeVQGANEncodeRequest(**payload)
  47. # Get the model from the app
  48. model_manager: ModelManager = request.app.state.model_manager
  49. decoder_model = model_manager.decoder_model
  50. # Encode the audio
  51. start_time = time.time()
  52. tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
  53. logger.info(
  54. f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms"
  55. )
  56. # Return the response
  57. return ormsgpack.packb(
  58. ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
  59. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  60. )
  61. class VQGANDecodeView(HttpView):
  62. """
  63. Decode the symbolic tokens into audio.
  64. """
  65. @classmethod
  66. async def post(cls):
  67. # Decode the request
  68. payload = await request.data()
  69. req = ServeVQGANDecodeRequest(**payload)
  70. # Get the model from the app
  71. model_manager: ModelManager = request.app.state.model_manager
  72. decoder_model = model_manager.decoder_model
  73. # Decode the audio
  74. tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
  75. start_time = time.time()
  76. audios = vqgan_decode(decoder_model, tokens)
  77. logger.info(
  78. f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms"
  79. )
  80. audios = [audio.astype(np.float16).tobytes() for audio in audios]
  81. # Return the response
  82. return ormsgpack.packb(
  83. ServeVQGANDecodeResponse(audios=audios),
  84. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  85. )
  86. class ASRView(HttpView):
  87. """
  88. Perform automatic speech recognition on the audio.
  89. """
  90. @classmethod
  91. async def post(cls):
  92. # Decode the request
  93. payload = await request.data()
  94. req = ServeASRRequest(**payload)
  95. # Get the model from the app
  96. model_manager: ModelManager = request.app.state.model_manager
  97. asr_model = model_manager.asr_model
  98. lock = request.app.state.lock
  99. # Perform ASR
  100. start_time = time.time()
  101. audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
  102. audios = [torch.from_numpy(audio).float() for audio in audios]
  103. if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
  104. raise HTTPException(status_code=400, content="Audio length is too long")
  105. transcriptions = batch_asr(
  106. asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
  107. )
  108. logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
  109. # Return the response
  110. return ormsgpack.packb(
  111. ServeASRResponse(transcriptions=transcriptions),
  112. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  113. )
  114. class TTSView(HttpView):
  115. """
  116. Perform text-to-speech on the input text.
  117. """
  118. @classmethod
  119. async def post(cls):
  120. # Decode the request
  121. payload = await request.data()
  122. req = ServeTTSRequest(**payload)
  123. # Get the model from the app
  124. app_state = request.app.state
  125. model_manager: ModelManager = app_state.model_manager
  126. engine = model_manager.tts_inference_engine
  127. sample_rate = engine.decoder_model.spec_transform.sample_rate
  128. # Check if the text is too long
  129. if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
  130. raise HTTPException(
  131. HTTPStatus.BAD_REQUEST,
  132. content=f"Text is too long, max length is {app_state.max_text_length}",
  133. )
  134. # Check if streaming is enabled
  135. if req.streaming and req.format != "wav":
  136. raise HTTPException(
  137. HTTPStatus.BAD_REQUEST,
  138. content="Streaming only supports WAV format",
  139. )
  140. # Perform TTS
  141. if req.streaming:
  142. return StreamResponse(
  143. iterable=inference_async(req, engine),
  144. headers={
  145. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  146. },
  147. content_type=get_content_type(req.format),
  148. )
  149. else:
  150. fake_audios = next(inference(req, engine))
  151. buffer = io.BytesIO()
  152. sf.write(
  153. buffer,
  154. fake_audios,
  155. sample_rate,
  156. format=req.format,
  157. )
  158. return StreamResponse(
  159. iterable=buffer_to_async_generator(buffer.getvalue()),
  160. headers={
  161. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  162. },
  163. content_type=get_content_type(req.format),
  164. )
  165. class ChatView(HttpView):
  166. """
  167. Perform chatbot inference on the input text.
  168. """
  169. @classmethod
  170. async def post(cls):
  171. # Decode the request
  172. payload = await request.data()
  173. req = ServeChatRequest(**payload)
  174. # Check that the number of samples requested is correct
  175. if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
  176. raise HTTPException(
  177. HTTPStatus.BAD_REQUEST,
  178. content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
  179. )
  180. # Get the type of content provided
  181. content_type = request.headers.get("Content-Type", "application/json")
  182. json_mode = "application/json" in content_type
  183. # Get the models from the app
  184. model_manager: ModelManager = request.app.state.model_manager
  185. llama_queue = model_manager.llama_queue
  186. tokenizer = model_manager.tokenizer
  187. config = model_manager.config
  188. device = request.app.state.device
  189. # Get the response generators
  190. response_generator = get_response_generator(
  191. llama_queue, tokenizer, config, req, device, json_mode
  192. )
  193. # Return the response in the correct format
  194. if req.streaming is False:
  195. result = response_generator()
  196. if json_mode:
  197. return JSONResponse(result.model_dump())
  198. else:
  199. return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
  200. return StreamResponse(
  201. iterable=response_generator(), content_type="text/event-stream"
  202. )