inference.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from http import HTTPStatus
  2. import numpy as np
  3. from kui.asgi import HTTPException
  4. from tools.inference_engine import TTSInferenceEngine
  5. from tools.schema import ServeTTSRequest
  6. AMPLITUDE = 32768 # Needs an explaination
  7. def inference_wrapper(req: ServeTTSRequest, engine: TTSInferenceEngine):
  8. """
  9. Wrapper for the inference function.
  10. Used in the API server.
  11. """
  12. for result in engine.inference(req):
  13. match result.code:
  14. case "header":
  15. if isinstance(result.audio, tuple):
  16. yield result.audio[1]
  17. case "error":
  18. raise HTTPException(
  19. HTTPStatus.INTERNAL_SERVER_ERROR,
  20. content=str(result.error),
  21. )
  22. case "segment":
  23. if isinstance(result.audio, tuple):
  24. yield (result.audio[1] * AMPLITUDE).astype(np.int16).tobytes()
  25. case "final":
  26. if isinstance(result.audio, tuple):
  27. yield result.audio[1]
  28. return None # Stop the generator
  29. raise HTTPException(
  30. HTTPStatus.INTERNAL_SERVER_ERROR,
  31. content="No audio generated, please check the input text.",
  32. )