__init__.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from typing import Callable
  2. import gradio as gr
  3. from fish_speech.i18n import i18n
  4. from fish_speech.inference_engine.utils import normalize_text
  5. from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER
  6. def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
  7. with gr.Blocks(theme=gr.themes.Base()) as app:
  8. gr.Markdown(HEADER_MD)
  9. # Use light theme by default
  10. app.load(
  11. None,
  12. None,
  13. js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
  14. % theme,
  15. )
  16. # Inference
  17. with gr.Row():
  18. with gr.Column(scale=3):
  19. text = gr.Textbox(
  20. label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
  21. )
  22. refined_text = gr.Textbox(
  23. label=i18n("Realtime Transform Text"),
  24. placeholder=i18n(
  25. "Normalization Result Preview (Currently Only Chinese)"
  26. ),
  27. lines=5,
  28. interactive=False,
  29. )
  30. with gr.Row():
  31. normalize = gr.Checkbox(
  32. label=i18n("Text Normalization"),
  33. value=False,
  34. )
  35. with gr.Row():
  36. with gr.Column():
  37. with gr.Tab(label=i18n("Advanced Config")):
  38. with gr.Row():
  39. chunk_length = gr.Slider(
  40. label=i18n("Iterative Prompt Length, 0 means off"),
  41. minimum=0,
  42. maximum=300,
  43. value=200,
  44. step=8,
  45. )
  46. max_new_tokens = gr.Slider(
  47. label=i18n(
  48. "Maximum tokens per batch, 0 means no limit"
  49. ),
  50. minimum=0,
  51. maximum=2048,
  52. value=0,
  53. step=8,
  54. )
  55. with gr.Row():
  56. top_p = gr.Slider(
  57. label="Top-P",
  58. minimum=0.6,
  59. maximum=0.9,
  60. value=0.7,
  61. step=0.01,
  62. )
  63. repetition_penalty = gr.Slider(
  64. label=i18n("Repetition Penalty"),
  65. minimum=1,
  66. maximum=1.5,
  67. value=1.2,
  68. step=0.01,
  69. )
  70. with gr.Row():
  71. temperature = gr.Slider(
  72. label="Temperature",
  73. minimum=0.6,
  74. maximum=0.9,
  75. value=0.7,
  76. step=0.01,
  77. )
  78. seed = gr.Number(
  79. label="Seed",
  80. info="0 means randomized inference, otherwise deterministic",
  81. value=0,
  82. )
  83. with gr.Tab(label=i18n("Reference Audio")):
  84. with gr.Row():
  85. gr.Markdown(
  86. i18n(
  87. "5 to 10 seconds of reference audio, useful for specifying speaker."
  88. )
  89. )
  90. with gr.Row():
  91. reference_id = gr.Textbox(
  92. label=i18n("Reference ID"),
  93. placeholder="Leave empty to use uploaded references",
  94. )
  95. with gr.Row():
  96. use_memory_cache = gr.Radio(
  97. label=i18n("Use Memory Cache"),
  98. choices=["on", "off"],
  99. value="on",
  100. )
  101. with gr.Row():
  102. reference_audio = gr.Audio(
  103. label=i18n("Reference Audio"),
  104. type="filepath",
  105. )
  106. with gr.Row():
  107. reference_text = gr.Textbox(
  108. label=i18n("Reference Text"),
  109. lines=1,
  110. placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
  111. value="",
  112. )
  113. with gr.Column(scale=3):
  114. with gr.Row():
  115. error = gr.HTML(
  116. label=i18n("Error Message"),
  117. visible=True,
  118. )
  119. with gr.Row():
  120. audio = gr.Audio(
  121. label=i18n("Generated Audio"),
  122. type="numpy",
  123. interactive=False,
  124. visible=True,
  125. )
  126. with gr.Row():
  127. with gr.Column(scale=3):
  128. generate = gr.Button(
  129. value="\U0001f3a7 " + i18n("Generate"),
  130. variant="primary",
  131. )
  132. text.input(fn=normalize_text, inputs=[text, normalize], outputs=[refined_text])
  133. # Submit
  134. generate.click(
  135. inference_fct,
  136. [
  137. refined_text,
  138. normalize,
  139. reference_id,
  140. reference_audio,
  141. reference_text,
  142. max_new_tokens,
  143. chunk_length,
  144. top_p,
  145. repetition_penalty,
  146. temperature,
  147. seed,
  148. use_memory_cache,
  149. ],
  150. [audio, error],
  151. concurrency_limit=1,
  152. )
  153. return app