inference.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import html
  2. from functools import partial
  3. from typing import Any, Callable
  4. from fish_speech.i18n import i18n
  5. from tools.schema import ServeReferenceAudio, ServeTTSRequest
  6. def inference_wrapper(
  7. text,
  8. normalize,
  9. reference_id,
  10. reference_audio,
  11. reference_text,
  12. max_new_tokens,
  13. chunk_length,
  14. top_p,
  15. repetition_penalty,
  16. temperature,
  17. seed,
  18. use_memory_cache,
  19. engine,
  20. ):
  21. """
  22. Wrapper for the inference function.
  23. Used in the Gradio interface.
  24. """
  25. if reference_audio:
  26. references = get_reference_audio(reference_audio, reference_text)
  27. else:
  28. references = []
  29. req = ServeTTSRequest(
  30. text=text,
  31. normalize=normalize,
  32. reference_id=reference_id if reference_id else None,
  33. references=references,
  34. max_new_tokens=max_new_tokens,
  35. chunk_length=chunk_length,
  36. top_p=top_p,
  37. repetition_penalty=repetition_penalty,
  38. temperature=temperature,
  39. seed=int(seed) if seed else None,
  40. use_memory_cache=use_memory_cache,
  41. )
  42. for result in engine.inference(req):
  43. match result.code:
  44. case "final":
  45. return result.audio, None
  46. case "error":
  47. return None, build_html_error_message(i18n(result.error))
  48. case _:
  49. pass
  50. return None, i18n("No audio generated")
  51. def get_reference_audio(reference_audio: str, reference_text: str) -> list:
  52. """
  53. Get the reference audio bytes.
  54. """
  55. with open(reference_audio, "rb") as audio_file:
  56. audio_bytes = audio_file.read()
  57. return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)]
  58. def build_html_error_message(error: Any) -> str:
  59. error = error if isinstance(error, Exception) else Exception("Unknown error")
  60. return f"""
  61. <div style="color: red;
  62. font-weight: bold;">
  63. {html.escape(str(error))}
  64. </div>
  65. """
  66. def get_inference_wrapper(engine) -> Callable:
  67. """
  68. Get the inference function with the immutable arguments.
  69. """
  70. return partial(
  71. inference_wrapper,
  72. engine=engine,
  73. )