Просмотр исходного кода

Fix error info (#202)

* Fix error info

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix infer UI

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 год назад
Родитель
Сommit
0df704bf9b
1 измененных файлов с 23 добавлено и 9 удалено
  1. 23 9
      tools/webui.py

+ 23 - 9
tools/webui.py

@@ -167,9 +167,10 @@ def inference(
 
 inference_stream = partial(inference, streaming=True)
 
-n_audios = 3
+n_audios = 4
 
 global_audio_list = []
+global_error_list = []
 
 
 def inference_wrapper(
@@ -186,6 +187,8 @@ def inference_wrapper(
     batch_infer_num,
 ):
     audios = []
+    errors = []
+
     for _ in range(batch_infer_num):
         items = inference(
             text,
@@ -199,6 +202,7 @@ def inference_wrapper(
             temperature,
             speaker,
         )
+
         try:
             item = next(items)
         except StopIteration:
@@ -207,13 +211,19 @@ def inference_wrapper(
         audios.append(
             gr.Audio(value=item[1] if (item and item[1]) else None, visible=True),
         )
+        errors.append(
+            gr.HTML(value=item[2] if (item and item[2]) else None, visible=True),
+        )
 
-    for _ in range(n_audios - batch_infer_num):
+    for _ in range(batch_infer_num, n_audios):
         audios.append(
             gr.Audio(value=None, visible=False),
         )
+        errors.append(
+            gr.HTML(value=None, visible=False),
+        )
 
-    return None, *audios, None
+    return None, *audios, *errors
 
 
 def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
@@ -315,21 +325,25 @@ def build_app():
                         batch_infer_num = gr.Slider(
                             label="Batch infer nums",
                             minimum=1,
-                            maximum=3,
+                            maximum=n_audios,
                             step=1,
                             value=1,
                         )
 
             with gr.Column(scale=3):
-                with gr.Row():
-                    error = gr.HTML(label=i18n("Error Message"), visible=False)
                 for _ in range(n_audios):
+                    with gr.Row():
+                        error = gr.HTML(
+                            label=i18n("Error Message"),
+                            visible=True if _ == 0 else False,
+                        )
+                        global_error_list.append(error)
                     with gr.Row():
                         audio = gr.Audio(
                             label=i18n("Generated Audio"),
                             type="numpy",
                             interactive=False,
-                            visible=False,
+                            visible=True if _ == 0 else False,
                         )
                         global_audio_list.append(audio)
 
@@ -365,7 +379,7 @@ def build_app():
                 speaker,
                 batch_infer_num,
             ],
-            [stream_audio, *global_audio_list, error],
+            [stream_audio, *global_audio_list, *global_error_list],
             concurrency_limit=1,
         )
 
@@ -383,7 +397,7 @@ def build_app():
                 temperature,
                 speaker,
             ],
-            [stream_audio, audio, error],
+            [stream_audio, global_audio_list[0], global_error_list[0]],
             concurrency_limit=10,
         )
     return app