views.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import io
  2. import os
  3. import time
  4. from http import HTTPStatus
  5. import numpy as np
  6. import ormsgpack
  7. import soundfile as sf
  8. import torch
  9. from kui.asgi import Body, HTTPException, JSONResponse, Routes, StreamResponse, request
  10. from loguru import logger
  11. from typing_extensions import Annotated
  12. from tools.schema import (
  13. ServeASRRequest,
  14. ServeASRResponse,
  15. ServeChatRequest,
  16. ServeTTSRequest,
  17. ServeVQGANDecodeRequest,
  18. ServeVQGANDecodeResponse,
  19. ServeVQGANEncodeRequest,
  20. ServeVQGANEncodeResponse,
  21. )
  22. from tools.server.agent import get_response_generator
  23. from tools.server.api_utils import (
  24. buffer_to_async_generator,
  25. get_content_type,
  26. inference_async,
  27. )
  28. from tools.server.inference import inference_wrapper as inference
  29. from tools.server.model_manager import ModelManager
  30. from tools.server.model_utils import batch_asr, cached_vqgan_batch_encode, vqgan_decode
  31. MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
  32. routes = Routes()
  33. @routes.http.post("/v1/health")
  34. async def health():
  35. return JSONResponse({"status": "ok"})
  36. @routes.http.post("/v1/vqgan/encode")
  37. async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
  38. # Get the model from the app
  39. model_manager: ModelManager = request.app.state.model_manager
  40. decoder_model = model_manager.decoder_model
  41. # Encode the audio
  42. start_time = time.time()
  43. tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
  44. logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
  45. # Return the response
  46. return ormsgpack.packb(
  47. ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
  48. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  49. )
  50. @routes.http.post("/v1/vqgan/decode")
  51. async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
  52. # Get the model from the app
  53. model_manager: ModelManager = request.app.state.model_manager
  54. decoder_model = model_manager.decoder_model
  55. # Decode the audio
  56. tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
  57. start_time = time.time()
  58. audios = vqgan_decode(decoder_model, tokens)
  59. logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
  60. audios = [audio.astype(np.float16).tobytes() for audio in audios]
  61. # Return the response
  62. return ormsgpack.packb(
  63. ServeVQGANDecodeResponse(audios=audios),
  64. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  65. )
  66. @routes.http.post("/v1/asr")
  67. async def asr(req: Annotated[ServeASRRequest, Body(exclusive=True)]):
  68. # Get the model from the app
  69. model_manager: ModelManager = request.app.state.model_manager
  70. asr_model = model_manager.asr_model
  71. lock = request.app.state.lock
  72. # Perform ASR
  73. start_time = time.time()
  74. audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
  75. audios = [torch.from_numpy(audio).float() for audio in audios]
  76. if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
  77. raise HTTPException(status_code=400, content="Audio length is too long")
  78. transcriptions = batch_asr(
  79. asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
  80. )
  81. logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
  82. # Return the response
  83. return ormsgpack.packb(
  84. ServeASRResponse(transcriptions=transcriptions),
  85. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  86. )
  87. @routes.http.post("/v1/tts")
  88. async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
  89. # Get the model from the app
  90. app_state = request.app.state
  91. model_manager: ModelManager = app_state.model_manager
  92. engine = model_manager.tts_inference_engine
  93. sample_rate = engine.decoder_model.spec_transform.sample_rate
  94. # Check if the text is too long
  95. if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
  96. raise HTTPException(
  97. HTTPStatus.BAD_REQUEST,
  98. content=f"Text is too long, max length is {app_state.max_text_length}",
  99. )
  100. # Check if streaming is enabled
  101. if req.streaming and req.format != "wav":
  102. raise HTTPException(
  103. HTTPStatus.BAD_REQUEST,
  104. content="Streaming only supports WAV format",
  105. )
  106. # Perform TTS
  107. if req.streaming:
  108. return StreamResponse(
  109. iterable=inference_async(req, engine),
  110. headers={
  111. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  112. },
  113. content_type=get_content_type(req.format),
  114. )
  115. else:
  116. fake_audios = next(inference(req, engine))
  117. buffer = io.BytesIO()
  118. sf.write(
  119. buffer,
  120. fake_audios,
  121. sample_rate,
  122. format=req.format,
  123. )
  124. return StreamResponse(
  125. iterable=buffer_to_async_generator(buffer.getvalue()),
  126. headers={
  127. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  128. },
  129. content_type=get_content_type(req.format),
  130. )
  131. @routes.http.post("/v1/chat")
  132. async def chat(req: Annotated[ServeChatRequest, Body(exclusive=True)]):
  133. # Check that the number of samples requested is correct
  134. if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
  135. raise HTTPException(
  136. HTTPStatus.BAD_REQUEST,
  137. content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
  138. )
  139. # Get the type of content provided
  140. content_type = request.headers.get("Content-Type", "application/json")
  141. json_mode = "application/json" in content_type
  142. # Get the models from the app
  143. model_manager: ModelManager = request.app.state.model_manager
  144. llama_queue = model_manager.llama_queue
  145. tokenizer = model_manager.tokenizer
  146. config = model_manager.config
  147. device = request.app.state.device
  148. # Get the response generators
  149. response_generator = get_response_generator(
  150. llama_queue, tokenizer, config, req, device, json_mode
  151. )
  152. # Return the response in the correct format
  153. if req.streaming is False:
  154. result = response_generator()
  155. if json_mode:
  156. return JSONResponse(result.model_dump())
  157. else:
  158. return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
  159. return StreamResponse(
  160. iterable=response_generator(), content_type="text/event-stream"
  161. )