inference.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import html
  2. from functools import partial
  3. from typing import Any, Callable
  4. from fish_speech.i18n import i18n
  5. from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
  6. def inference_wrapper(
  7. text,
  8. reference_id,
  9. reference_audio,
  10. reference_text,
  11. max_new_tokens,
  12. chunk_length,
  13. top_p,
  14. repetition_penalty,
  15. temperature,
  16. seed,
  17. use_memory_cache,
  18. engine,
  19. ):
  20. """
  21. Wrapper for the inference function.
  22. Used in the Gradio interface.
  23. """
  24. if reference_audio:
  25. references = get_reference_audio(reference_audio, reference_text)
  26. else:
  27. references = []
  28. req = ServeTTSRequest(
  29. text=text,
  30. reference_id=reference_id if reference_id else None,
  31. references=references,
  32. max_new_tokens=max_new_tokens,
  33. chunk_length=chunk_length,
  34. top_p=top_p,
  35. repetition_penalty=repetition_penalty,
  36. temperature=temperature,
  37. seed=int(seed) if seed else None,
  38. use_memory_cache=use_memory_cache,
  39. )
  40. for result in engine.inference(req):
  41. match result.code:
  42. case "final":
  43. return result.audio, None
  44. case "error":
  45. return None, build_html_error_message(i18n(result.error))
  46. case _:
  47. pass
  48. return None, i18n("No audio generated")
  49. def get_reference_audio(reference_audio: str, reference_text: str) -> list:
  50. """
  51. Get the reference audio bytes.
  52. """
  53. with open(reference_audio, "rb") as audio_file:
  54. audio_bytes = audio_file.read()
  55. return [ServeReferenceAudio(audio=audio_bytes, text=reference_text)]
  56. def build_html_error_message(error: Any) -> str:
  57. error_str = str(error) if error is not None else "Unknown error"
  58. return f"""
  59. <div style="color: red;
  60. font-weight: bold;">
  61. {html.escape(error_str)}
  62. </div>
  63. """
  64. def get_inference_wrapper(engine) -> Callable:
  65. """
  66. Get the inference function with the immutable arguments.
  67. """
  68. return partial(
  69. inference_wrapper,
  70. engine=engine,
  71. )