views.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import io
  2. import os
  3. import time
  4. from http import HTTPStatus
  5. import numpy as np
  6. import ormsgpack
  7. import soundfile as sf
  8. import torch
  9. from kui.asgi import (
  10. Body,
  11. HTTPException,
  12. HttpView,
  13. JSONResponse,
  14. Routes,
  15. StreamResponse,
  16. request,
  17. )
  18. from loguru import logger
  19. from typing_extensions import Annotated
  20. from fish_speech.utils.schema import (
  21. ServeASRRequest,
  22. ServeASRResponse,
  23. ServeTTSRequest,
  24. ServeVQGANDecodeRequest,
  25. ServeVQGANDecodeResponse,
  26. ServeVQGANEncodeRequest,
  27. ServeVQGANEncodeResponse,
  28. )
  29. from tools.server.api_utils import (
  30. buffer_to_async_generator,
  31. get_content_type,
  32. inference_async,
  33. )
  34. from tools.server.inference import inference_wrapper as inference
  35. from tools.server.model_manager import ModelManager
  36. from tools.server.model_utils import (
  37. batch_asr,
  38. batch_vqgan_decode,
  39. cached_vqgan_batch_encode,
  40. )
  41. MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
  42. routes = Routes()
  43. @routes.http("/v1/health")
  44. class Health(HttpView):
  45. @classmethod
  46. async def get(cls):
  47. return JSONResponse({"status": "ok"})
  48. @classmethod
  49. async def post(cls):
  50. return JSONResponse({"status": "ok"})
  51. @routes.http.post("/v1/vqgan/encode")
  52. async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
  53. # Get the model from the app
  54. model_manager: ModelManager = request.app.state.model_manager
  55. decoder_model = model_manager.decoder_model
  56. # Encode the audio
  57. start_time = time.time()
  58. tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
  59. logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
  60. # Return the response
  61. return ormsgpack.packb(
  62. ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
  63. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  64. )
  65. @routes.http.post("/v1/vqgan/decode")
  66. async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
  67. # Get the model from the app
  68. model_manager: ModelManager = request.app.state.model_manager
  69. decoder_model = model_manager.decoder_model
  70. # Decode the audio
  71. tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
  72. start_time = time.time()
  73. audios = batch_vqgan_decode(decoder_model, tokens)
  74. logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
  75. audios = [audio.astype(np.float16).tobytes() for audio in audios]
  76. # Return the response
  77. return ormsgpack.packb(
  78. ServeVQGANDecodeResponse(audios=audios),
  79. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  80. )
  81. @routes.http.post("/v1/asr")
  82. async def asr(req: Annotated[ServeASRRequest, Body(exclusive=True)]):
  83. # Get the model from the app
  84. model_manager: ModelManager = request.app.state.model_manager
  85. asr_model = model_manager.asr_model
  86. lock = request.app.state.lock
  87. # Perform ASR
  88. start_time = time.time()
  89. audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
  90. audios = [torch.from_numpy(audio).float() for audio in audios]
  91. if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
  92. raise HTTPException(status_code=400, content="Audio length is too long")
  93. transcriptions = batch_asr(
  94. asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
  95. )
  96. logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
  97. # Return the response
  98. return ormsgpack.packb(
  99. ServeASRResponse(transcriptions=transcriptions),
  100. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  101. )
  102. @routes.http.post("/v1/tts")
  103. async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
  104. # Get the model from the app
  105. app_state = request.app.state
  106. model_manager: ModelManager = app_state.model_manager
  107. engine = model_manager.tts_inference_engine
  108. sample_rate = engine.decoder_model.sample_rate
  109. # Check if the text is too long
  110. if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
  111. raise HTTPException(
  112. HTTPStatus.BAD_REQUEST,
  113. content=f"Text is too long, max length is {app_state.max_text_length}",
  114. )
  115. # Check if streaming is enabled
  116. if req.streaming and req.format != "wav":
  117. raise HTTPException(
  118. HTTPStatus.BAD_REQUEST,
  119. content="Streaming only supports WAV format",
  120. )
  121. # Perform TTS
  122. if req.streaming:
  123. return StreamResponse(
  124. iterable=inference_async(req, engine),
  125. headers={
  126. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  127. },
  128. content_type=get_content_type(req.format),
  129. )
  130. else:
  131. fake_audios = next(inference(req, engine))
  132. buffer = io.BytesIO()
  133. sf.write(
  134. buffer,
  135. fake_audios,
  136. sample_rate,
  137. format=req.format,
  138. )
  139. return StreamResponse(
  140. iterable=buffer_to_async_generator(buffer.getvalue()),
  141. headers={
  142. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  143. },
  144. content_type=get_content_type(req.format),
  145. )