views.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import io
  2. import os
  3. import time
  4. from http import HTTPStatus
  5. import numpy as np
  6. import ormsgpack
  7. import soundfile as sf
  8. import torch
  9. from kui.asgi import Body, HTTPException, JSONResponse, Routes, StreamResponse, request
  10. from loguru import logger
  11. from typing_extensions import Annotated
  12. from fish_speech.utils.schema import (
  13. ServeASRRequest,
  14. ServeASRResponse,
  15. ServeChatRequest,
  16. ServeTTSRequest,
  17. ServeVQGANDecodeRequest,
  18. ServeVQGANDecodeResponse,
  19. ServeVQGANEncodeRequest,
  20. ServeVQGANEncodeResponse,
  21. )
  22. from tools.server.agent import get_response_generator
  23. from tools.server.api_utils import (
  24. buffer_to_async_generator,
  25. get_content_type,
  26. inference_async,
  27. )
  28. from tools.server.inference import inference_wrapper as inference
  29. from tools.server.model_manager import ModelManager
  30. from tools.server.model_utils import (
  31. batch_asr,
  32. batch_vqgan_decode,
  33. cached_vqgan_batch_encode,
  34. )
  35. MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
  36. routes = Routes()
  37. @routes.http.post("/v1/health")
  38. async def health():
  39. return JSONResponse({"status": "ok"})
  40. @routes.http.post("/v1/vqgan/encode")
  41. async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
  42. # Get the model from the app
  43. model_manager: ModelManager = request.app.state.model_manager
  44. decoder_model = model_manager.decoder_model
  45. # Encode the audio
  46. start_time = time.time()
  47. tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
  48. logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
  49. # Return the response
  50. return ormsgpack.packb(
  51. ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
  52. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  53. )
  54. @routes.http.post("/v1/vqgan/decode")
  55. async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
  56. # Get the model from the app
  57. model_manager: ModelManager = request.app.state.model_manager
  58. decoder_model = model_manager.decoder_model
  59. # Decode the audio
  60. tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
  61. start_time = time.time()
  62. audios = batch_vqgan_decode(decoder_model, tokens)
  63. logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
  64. audios = [audio.astype(np.float16).tobytes() for audio in audios]
  65. # Return the response
  66. return ormsgpack.packb(
  67. ServeVQGANDecodeResponse(audios=audios),
  68. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  69. )
  70. @routes.http.post("/v1/asr")
  71. async def asr(req: Annotated[ServeASRRequest, Body(exclusive=True)]):
  72. # Get the model from the app
  73. model_manager: ModelManager = request.app.state.model_manager
  74. asr_model = model_manager.asr_model
  75. lock = request.app.state.lock
  76. # Perform ASR
  77. start_time = time.time()
  78. audios = [np.frombuffer(audio, dtype=np.float16) for audio in req.audios]
  79. audios = [torch.from_numpy(audio).float() for audio in audios]
  80. if any(audios.shape[-1] >= 30 * req.sample_rate for audios in audios):
  81. raise HTTPException(status_code=400, content="Audio length is too long")
  82. transcriptions = batch_asr(
  83. asr_model, lock, audios=audios, sr=req.sample_rate, language=req.language
  84. )
  85. logger.info(f"[EXEC] ASR time: {(time.time() - start_time) * 1000:.2f}ms")
  86. # Return the response
  87. return ormsgpack.packb(
  88. ServeASRResponse(transcriptions=transcriptions),
  89. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  90. )
  91. @routes.http.post("/v1/tts")
  92. async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
  93. # Get the model from the app
  94. app_state = request.app.state
  95. model_manager: ModelManager = app_state.model_manager
  96. engine = model_manager.tts_inference_engine
  97. sample_rate = engine.decoder_model.spec_transform.sample_rate
  98. # Check if the text is too long
  99. if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
  100. raise HTTPException(
  101. HTTPStatus.BAD_REQUEST,
  102. content=f"Text is too long, max length is {app_state.max_text_length}",
  103. )
  104. # Check if streaming is enabled
  105. if req.streaming and req.format != "wav":
  106. raise HTTPException(
  107. HTTPStatus.BAD_REQUEST,
  108. content="Streaming only supports WAV format",
  109. )
  110. # Perform TTS
  111. if req.streaming:
  112. return StreamResponse(
  113. iterable=inference_async(req, engine),
  114. headers={
  115. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  116. },
  117. content_type=get_content_type(req.format),
  118. )
  119. else:
  120. fake_audios = next(inference(req, engine))
  121. buffer = io.BytesIO()
  122. sf.write(
  123. buffer,
  124. fake_audios,
  125. sample_rate,
  126. format=req.format,
  127. )
  128. return StreamResponse(
  129. iterable=buffer_to_async_generator(buffer.getvalue()),
  130. headers={
  131. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  132. },
  133. content_type=get_content_type(req.format),
  134. )
  135. @routes.http.post("/v1/chat")
  136. async def chat(req: Annotated[ServeChatRequest, Body(exclusive=True)]):
  137. # Check that the number of samples requested is correct
  138. if req.num_samples < 1 or req.num_samples > MAX_NUM_SAMPLES:
  139. raise HTTPException(
  140. HTTPStatus.BAD_REQUEST,
  141. content=f"Number of samples must be between 1 and {MAX_NUM_SAMPLES}",
  142. )
  143. # Get the type of content provided
  144. content_type = request.headers.get("Content-Type", "application/json")
  145. json_mode = "application/json" in content_type
  146. # Get the models from the app
  147. model_manager: ModelManager = request.app.state.model_manager
  148. llama_queue = model_manager.llama_queue
  149. tokenizer = model_manager.tokenizer
  150. config = model_manager.config
  151. device = request.app.state.device
  152. # Get the response generators
  153. response_generator = get_response_generator(
  154. llama_queue, tokenizer, config, req, device, json_mode
  155. )
  156. # Return the response in the correct format
  157. if req.streaming is False:
  158. result = response_generator()
  159. if json_mode:
  160. return JSONResponse(result.model_dump())
  161. else:
  162. return ormsgpack.packb(result, option=ormsgpack.OPT_SERIALIZE_PYDANTIC)
  163. return StreamResponse(
  164. iterable=response_generator(), content_type="text/event-stream"
  165. )