inference.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from http import HTTPStatus
  2. import numpy as np
  3. from kui.asgi import HTTPException
  4. from fish_speech.inference_engine import TTSInferenceEngine
  5. from fish_speech.utils.schema import ServeTTSRequest
  6. AMPLITUDE = 32768 # Needs an explaination
  7. def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
  8. """
  9. Wrapper for the inference function.
  10. Used in the API server.
  11. """
  12. count = 0
  13. for result in engine.inference(req):
  14. match result.code:
  15. case "header":
  16. if isinstance(result.audio, tuple):
  17. yield result.audio[1]
  18. case "error":
  19. raise HTTPException(
  20. HTTPStatus.INTERNAL_SERVER_ERROR,
  21. content=str(result.error),
  22. )
  23. case "segment":
  24. count += 1
  25. if isinstance(result.audio, tuple):
  26. yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
  27. case "final":
  28. count += 1
  29. if isinstance(result.audio, tuple):
  30. yield result.audio[1]
  31. return None # Stop the generator
  32. if count == 0:
  33. raise HTTPException(
  34. HTTPStatus.INTERNAL_SERVER_ERROR,
  35. content="No audio generated, please check the input text.",
  36. )