views.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import io
  2. import os
  3. import time
  4. from http import HTTPStatus
  5. import numpy as np
  6. import ormsgpack
  7. import soundfile as sf
  8. import torch
  9. from kui.asgi import (
  10. Body,
  11. HTTPException,
  12. HttpView,
  13. JSONResponse,
  14. Routes,
  15. StreamResponse,
  16. request,
  17. )
  18. from loguru import logger
  19. from typing_extensions import Annotated
  20. from fish_speech.utils.schema import (
  21. ServeTTSRequest,
  22. ServeVQGANDecodeRequest,
  23. ServeVQGANDecodeResponse,
  24. ServeVQGANEncodeRequest,
  25. ServeVQGANEncodeResponse,
  26. )
  27. from tools.server.api_utils import (
  28. buffer_to_async_generator,
  29. get_content_type,
  30. inference_async,
  31. )
  32. from tools.server.inference import inference_wrapper as inference
  33. from tools.server.model_manager import ModelManager
  34. from tools.server.model_utils import (
  35. batch_asr,
  36. batch_vqgan_decode,
  37. cached_vqgan_batch_encode,
  38. )
  39. MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
  40. routes = Routes()
  41. @routes.http("/v1/health")
  42. class Health(HttpView):
  43. @classmethod
  44. async def get(cls):
  45. return JSONResponse({"status": "ok"})
  46. @classmethod
  47. async def post(cls):
  48. return JSONResponse({"status": "ok"})
  49. @routes.http.post("/v1/vqgan/encode")
  50. async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
  51. # Get the model from the app
  52. model_manager: ModelManager = request.app.state.model_manager
  53. decoder_model = model_manager.decoder_model
  54. # Encode the audio
  55. start_time = time.time()
  56. tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
  57. logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
  58. # Return the response
  59. return ormsgpack.packb(
  60. ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
  61. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  62. )
  63. @routes.http.post("/v1/vqgan/decode")
  64. async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
  65. # Get the model from the app
  66. model_manager: ModelManager = request.app.state.model_manager
  67. decoder_model = model_manager.decoder_model
  68. # Decode the audio
  69. tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
  70. start_time = time.time()
  71. audios = batch_vqgan_decode(decoder_model, tokens)
  72. logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
  73. audios = [audio.astype(np.float16).tobytes() for audio in audios]
  74. # Return the response
  75. return ormsgpack.packb(
  76. ServeVQGANDecodeResponse(audios=audios),
  77. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  78. )
  79. @routes.http.post("/v1/tts")
  80. async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
  81. # Get the model from the app
  82. app_state = request.app.state
  83. model_manager: ModelManager = app_state.model_manager
  84. engine = model_manager.tts_inference_engine
  85. sample_rate = engine.decoder_model.sample_rate
  86. # Check if the text is too long
  87. if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
  88. raise HTTPException(
  89. HTTPStatus.BAD_REQUEST,
  90. content=f"Text is too long, max length is {app_state.max_text_length}",
  91. )
  92. # Check if streaming is enabled
  93. if req.streaming and req.format != "wav":
  94. raise HTTPException(
  95. HTTPStatus.BAD_REQUEST,
  96. content="Streaming only supports WAV format",
  97. )
  98. # Perform TTS
  99. if req.streaming:
  100. return StreamResponse(
  101. iterable=inference_async(req, engine),
  102. headers={
  103. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  104. },
  105. content_type=get_content_type(req.format),
  106. )
  107. else:
  108. fake_audios = next(inference(req, engine))
  109. buffer = io.BytesIO()
  110. sf.write(
  111. buffer,
  112. fake_audios,
  113. sample_rate,
  114. format=req.format,
  115. )
  116. return StreamResponse(
  117. iterable=buffer_to_async_generator(buffer.getvalue()),
  118. headers={
  119. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  120. },
  121. content_type=get_content_type(req.format),
  122. )