|
|
@@ -167,6 +167,54 @@ def inference(
|
|
|
|
|
|
inference_stream = partial(inference, streaming=True)
|
|
|
|
|
|
+n_audios = 3
|
|
|
+
|
|
|
+global_audio_list = []
|
|
|
+
|
|
|
+
|
|
|
+def inference_wrapper(
|
|
|
+ text,
|
|
|
+ enable_reference_audio,
|
|
|
+ reference_audio,
|
|
|
+ reference_text,
|
|
|
+ max_new_tokens,
|
|
|
+ chunk_length,
|
|
|
+ top_p,
|
|
|
+ repetition_penalty,
|
|
|
+ temperature,
|
|
|
+ speaker,
|
|
|
+ batch_infer_num,
|
|
|
+):
|
|
|
+ audios = []
|
|
|
+ for _ in range(batch_infer_num):
|
|
|
+ items = inference(
|
|
|
+ text,
|
|
|
+ enable_reference_audio,
|
|
|
+ reference_audio,
|
|
|
+ reference_text,
|
|
|
+ max_new_tokens,
|
|
|
+ chunk_length,
|
|
|
+ top_p,
|
|
|
+ repetition_penalty,
|
|
|
+ temperature,
|
|
|
+ speaker,
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ item = next(items)
|
|
|
+ if item and item[1]:
|
|
|
+ audios.append(
|
|
|
+ gr.Audio(value=item[1], visible=True),
|
|
|
+ )
|
|
|
+ except StopIteration:
|
|
|
+ print("No more audio data available.")
|
|
|
+
|
|
|
+ for _ in range(n_audios - batch_infer_num):
|
|
|
+ audios.append(
|
|
|
+ gr.Audio(value=None, visible=False),
|
|
|
+ )
|
|
|
+
|
|
|
+ return None, *audios, None
|
|
|
+
|
|
|
|
|
|
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
|
|
buffer = io.BytesIO()
|
|
|
@@ -263,16 +311,28 @@ def build_app():
|
|
|
lines=1,
|
|
|
value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
|
|
)
|
|
|
+ with gr.Tab(label=i18n("Batch Inference")):
|
|
|
+ batch_infer_num = gr.Slider(
|
|
|
+ label="Batch infer nums",
|
|
|
+ minimum=1,
|
|
|
+ maximum=3,
|
|
|
+ step=1,
|
|
|
+ value=1,
|
|
|
+ )
|
|
|
|
|
|
with gr.Column(scale=3):
|
|
|
with gr.Row():
|
|
|
- error = gr.HTML(label=i18n("Error Message"))
|
|
|
- with gr.Row():
|
|
|
- audio = gr.Audio(
|
|
|
- label=i18n("Generated Audio"),
|
|
|
- type="numpy",
|
|
|
- interactive=False,
|
|
|
- )
|
|
|
+ error = gr.HTML(label=i18n("Error Message"), visible=False)
|
|
|
+ for _ in range(n_audios):
|
|
|
+ with gr.Row():
|
|
|
+ audio = gr.Audio(
|
|
|
+ label=i18n("Generated Audio"),
|
|
|
+ type="numpy",
|
|
|
+ interactive=False,
|
|
|
+ visible=False,
|
|
|
+ )
|
|
|
+ global_audio_list.append(audio)
|
|
|
+
|
|
|
with gr.Row():
|
|
|
stream_audio = gr.Audio(
|
|
|
label=i18n("Streaming Audio"),
|
|
|
@@ -291,7 +351,7 @@ def build_app():
|
|
|
)
|
|
|
# # Submit
|
|
|
generate.click(
|
|
|
- inference,
|
|
|
+ inference_wrapper,
|
|
|
[
|
|
|
text,
|
|
|
enable_reference_audio,
|
|
|
@@ -303,8 +363,9 @@ def build_app():
|
|
|
repetition_penalty,
|
|
|
temperature,
|
|
|
speaker,
|
|
|
+ batch_infer_num,
|
|
|
],
|
|
|
- [stream_audio, audio, error],
|
|
|
+ [stream_audio, *global_audio_list, error],
|
|
|
concurrency_limit=1,
|
|
|
)
|
|
|
|