__init__.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from typing import Callable
  2. import gradio as gr
  3. from fish_speech.i18n import i18n
  4. from tools.webui.variables import HEADER_MD, TEXTBOX_PLACEHOLDER
  5. def build_app(inference_fct: Callable, theme: str = "light") -> gr.Blocks:
  6. with gr.Blocks(theme=gr.themes.Base()) as app:
  7. gr.Markdown(HEADER_MD)
  8. # Use light theme by default
  9. app.load(
  10. None,
  11. None,
  12. js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
  13. % theme,
  14. )
  15. # Inference
  16. with gr.Row():
  17. with gr.Column(scale=3):
  18. text = gr.Textbox(
  19. label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
  20. )
  21. with gr.Row():
  22. with gr.Column():
  23. with gr.Tab(label=i18n("Advanced Config")):
  24. with gr.Row():
  25. chunk_length = gr.Slider(
  26. label=i18n("Iterative Prompt Length, 0 means off"),
  27. minimum=0,
  28. maximum=300,
  29. value=200,
  30. step=8,
  31. )
  32. max_new_tokens = gr.Slider(
  33. label=i18n(
  34. "Maximum tokens per batch, 0 means no limit"
  35. ),
  36. minimum=0,
  37. maximum=2048,
  38. value=0,
  39. step=8,
  40. )
  41. with gr.Row():
  42. top_p = gr.Slider(
  43. label="Top-P",
  44. minimum=0.6,
  45. maximum=0.9,
  46. value=0.7,
  47. step=0.01,
  48. )
  49. repetition_penalty = gr.Slider(
  50. label=i18n("Repetition Penalty"),
  51. minimum=1,
  52. maximum=1.5,
  53. value=1.2,
  54. step=0.01,
  55. )
  56. with gr.Row():
  57. temperature = gr.Slider(
  58. label="Temperature",
  59. minimum=0.6,
  60. maximum=0.9,
  61. value=0.7,
  62. step=0.01,
  63. )
  64. seed = gr.Number(
  65. label="Seed",
  66. info="0 means randomized inference, otherwise deterministic",
  67. value=0,
  68. )
  69. with gr.Tab(label=i18n("Reference Audio")):
  70. with gr.Row():
  71. gr.Markdown(
  72. i18n(
  73. "5 to 10 seconds of reference audio, useful for specifying speaker."
  74. )
  75. )
  76. with gr.Row():
  77. reference_id = gr.Textbox(
  78. label=i18n("Reference ID"),
  79. placeholder="Leave empty to use uploaded references",
  80. )
  81. with gr.Row():
  82. use_memory_cache = gr.Radio(
  83. label=i18n("Use Memory Cache"),
  84. choices=["on", "off"],
  85. value="on",
  86. )
  87. with gr.Row():
  88. reference_audio = gr.Audio(
  89. label=i18n("Reference Audio"),
  90. type="filepath",
  91. )
  92. with gr.Row():
  93. reference_text = gr.Textbox(
  94. label=i18n("Reference Text"),
  95. lines=1,
  96. placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
  97. value="",
  98. )
  99. with gr.Column(scale=3):
  100. with gr.Row():
  101. error = gr.HTML(
  102. label=i18n("Error Message"),
  103. visible=True,
  104. )
  105. with gr.Row():
  106. audio = gr.Audio(
  107. label=i18n("Generated Audio"),
  108. type="numpy",
  109. interactive=False,
  110. visible=True,
  111. )
  112. with gr.Row():
  113. with gr.Column(scale=3):
  114. generate = gr.Button(
  115. value="\U0001f3a7 " + i18n("Generate"),
  116. variant="primary",
  117. )
  118. # Submit
  119. generate.click(
  120. inference_fct,
  121. [
  122. text,
  123. reference_id,
  124. reference_audio,
  125. reference_text,
  126. max_new_tokens,
  127. chunk_length,
  128. top_p,
  129. repetition_penalty,
  130. temperature,
  131. seed,
  132. use_memory_cache,
  133. ],
  134. [audio, error],
  135. concurrency_limit=1,
  136. )
  137. return app