views.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import io
  2. import os
  3. import time
  4. from http import HTTPStatus
  5. import numpy as np
  6. import ormsgpack
  7. import soundfile as sf
  8. import torch
  9. from kui.asgi import (
  10. Body,
  11. HTTPException,
  12. HttpView,
  13. JSONResponse,
  14. Routes,
  15. StreamResponse,
  16. request,
  17. )
  18. from loguru import logger
  19. from typing_extensions import Annotated
  20. from fish_speech.utils.schema import (
  21. ServeTTSRequest,
  22. ServeVQGANDecodeRequest,
  23. ServeVQGANDecodeResponse,
  24. ServeVQGANEncodeRequest,
  25. ServeVQGANEncodeResponse,
  26. )
  27. from tools.server.api_utils import (
  28. buffer_to_async_generator,
  29. get_content_type,
  30. inference_async,
  31. )
  32. from tools.server.inference import inference_wrapper as inference
  33. from tools.server.model_manager import ModelManager
  34. from tools.server.model_utils import (
  35. batch_vqgan_decode,
  36. cached_vqgan_batch_encode,
  37. )
  38. MAX_NUM_SAMPLES = int(os.getenv("NUM_SAMPLES", 1))
  39. routes = Routes()
  40. @routes.http("/v1/health")
  41. class Health(HttpView):
  42. @classmethod
  43. async def get(cls):
  44. return JSONResponse({"status": "ok"})
  45. @classmethod
  46. async def post(cls):
  47. return JSONResponse({"status": "ok"})
  48. @routes.http.post("/v1/vqgan/encode")
  49. async def vqgan_encode(req: Annotated[ServeVQGANEncodeRequest, Body(exclusive=True)]):
  50. # Get the model from the app
  51. model_manager: ModelManager = request.app.state.model_manager
  52. decoder_model = model_manager.decoder_model
  53. # Encode the audio
  54. start_time = time.time()
  55. tokens = cached_vqgan_batch_encode(decoder_model, req.audios)
  56. logger.info(f"[EXEC] VQGAN encode time: {(time.time() - start_time) * 1000:.2f}ms")
  57. # Return the response
  58. return ormsgpack.packb(
  59. ServeVQGANEncodeResponse(tokens=[i.tolist() for i in tokens]),
  60. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  61. )
  62. @routes.http.post("/v1/vqgan/decode")
  63. async def vqgan_decode(req: Annotated[ServeVQGANDecodeRequest, Body(exclusive=True)]):
  64. # Get the model from the app
  65. model_manager: ModelManager = request.app.state.model_manager
  66. decoder_model = model_manager.decoder_model
  67. # Decode the audio
  68. tokens = [torch.tensor(token, dtype=torch.int) for token in req.tokens]
  69. start_time = time.time()
  70. audios = batch_vqgan_decode(decoder_model, tokens)
  71. logger.info(f"[EXEC] VQGAN decode time: {(time.time() - start_time) * 1000:.2f}ms")
  72. audios = [audio.astype(np.float16).tobytes() for audio in audios]
  73. # Return the response
  74. return ormsgpack.packb(
  75. ServeVQGANDecodeResponse(audios=audios),
  76. option=ormsgpack.OPT_SERIALIZE_PYDANTIC,
  77. )
  78. @routes.http.post("/v1/tts")
  79. async def tts(req: Annotated[ServeTTSRequest, Body(exclusive=True)]):
  80. # Get the model from the app
  81. app_state = request.app.state
  82. model_manager: ModelManager = app_state.model_manager
  83. engine = model_manager.tts_inference_engine
  84. sample_rate = engine.decoder_model.sample_rate
  85. # Check if the text is too long
  86. if app_state.max_text_length > 0 and len(req.text) > app_state.max_text_length:
  87. raise HTTPException(
  88. HTTPStatus.BAD_REQUEST,
  89. content=f"Text is too long, max length is {app_state.max_text_length}",
  90. )
  91. # Check if streaming is enabled
  92. if req.streaming and req.format != "wav":
  93. raise HTTPException(
  94. HTTPStatus.BAD_REQUEST,
  95. content="Streaming only supports WAV format",
  96. )
  97. # Perform TTS
  98. if req.streaming:
  99. return StreamResponse(
  100. iterable=inference_async(req, engine),
  101. headers={
  102. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  103. },
  104. content_type=get_content_type(req.format),
  105. )
  106. else:
  107. fake_audios = next(inference(req, engine))
  108. buffer = io.BytesIO()
  109. sf.write(
  110. buffer,
  111. fake_audios,
  112. sample_rate,
  113. format=req.format,
  114. )
  115. return StreamResponse(
  116. iterable=buffer_to_async_generator(buffer.getvalue()),
  117. headers={
  118. "Content-Disposition": f"attachment; filename=audio.{req.format}",
  119. },
  120. content_type=get_content_type(req.format),
  121. )