Bläddra i källkod

Optimize WebUI. (#108)

* Allow inplace transcripting. Fix some bugs. Add options.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update app.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update app.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 2 år sedan
förälder
incheckning
cc4ce60295
1 ändrade filer med 149 tillägg och 16 borttagningar
  1. 149 16
      fish_speech/webui/app.py

+ 149 - 16
fish_speech/webui/app.py

@@ -1,10 +1,13 @@
 import html
 import io
+import os
 import traceback
+import wave
 from pathlib import Path
 
 import gradio as gr
 import librosa
+import numpy as np
 import requests
 
 from fish_speech.text import parse_text_to_segments
@@ -158,8 +161,10 @@ def build_model_config_block():
     llama_ckpt_path = gr.Dropdown(
         label="Llama 模型路径",
         value=str(Path("checkpoints/text2semantic-400m-v0.3-4k.pth")),
-        choices=[str(pth_file) for pth_file in Path("results").rglob("*text*/*.ckpt")]
-        + [str(pth_file) for pth_file in Path("checkpoints").rglob("*text*.pth")],
+        choices=[
+            str(pth_file) for pth_file in Path("results").rglob("**/text*/**/*.ckpt")
+        ]
+        + [str(pth_file) for pth_file in Path("checkpoints").rglob("**/*text*.pth")],
         allow_custom_value=True,
     )
     llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune")
@@ -172,8 +177,10 @@ def build_model_config_block():
     vqgan_ckpt_path = gr.Dropdown(
         label="VQGAN 模型路径",
         value=str(Path("checkpoints/vqgan-v1.pth")),
-        choices=[str(pth_file) for pth_file in Path("results").rglob("*vqgan*/*.ckpt")]
-        + [str(pth_file) for pth_file in Path("checkpoints").rglob("*vqgan*.pth")],
+        choices=[
+            str(pth_file) for pth_file in Path("results").rglob("**/vqgan*/**/*.ckpt")
+        ]
+        + [str(pth_file) for pth_file in Path("checkpoints").rglob("**/*vqgan*.pth")],
         allow_custom_value=True,
     )
     vqgan_config_name = gr.Dropdown(
@@ -265,6 +272,81 @@ def inference(
     return (sr, content), None
 
 
+def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=22050):
+    # copy and paste
+    wav_buf = io.BytesIO()
+    with wave.open(wav_buf, "wb") as vfout:
+        vfout.setnchannels(channels)
+        vfout.setsampwidth(sample_width)
+        vfout.setframerate(sample_rate)
+        vfout.writeframes(frame_input)
+
+    wav_buf.seek(0)
+    return wav_buf.read()
+
+
+def inference_stream(
+    server_url,
+    text,
+    input_mode,
+    language0,
+    language1,
+    language2,
+    enable_reference_audio,
+    reference_audio,
+    reference_text,
+    max_new_tokens,
+    top_k,
+    top_p,
+    repetition_penalty,
+    temperature,
+    speaker,
+):
+    languages = [language0, language1, language2]
+    languages = [
+        {
+            "中文": "zh",
+            "日文": "jp",
+            "英文": "en",
+        }[language]
+        for language in languages
+    ]
+
+    if len(set(languages)) != len(languages):
+        return []
+
+    order = ",".join(languages)
+    payload = {
+        "text": text,
+        "prompt_text": reference_text if enable_reference_audio else None,
+        "prompt_tokens": reference_audio if enable_reference_audio else None,
+        "max_new_tokens": int(max_new_tokens),
+        "top_k": int(top_k) if top_k > 0 else None,
+        "top_p": top_p,
+        "repetition_penalty": repetition_penalty,
+        "temperature": temperature,
+        "order": order,
+        "use_g2p": input_mode == "自动音素",
+        "seed": None,
+        "speaker": speaker if speaker.strip() != "" else None,
+    }
+
+    resp = requests.post(
+        f"{server_url}/v1/models/default/invoke_stream", json=payload, stream=True
+    )
+    resp.raise_for_status()
+
+    yield wave_header_chunk(), None
+
+    for chunk in resp.iter_content(chunk_size=None):
+        if chunk:
+            content = io.BytesIO(chunk)
+            content.seek(0)
+            audio, sr = librosa.load(content, sr=None, mono=True)
+            print(audio.shape, sr)
+            yield (np.concatenate([audio], 0) * 32768).astype(np.int16).tobytes(), None
+
+
 with gr.Blocks(theme=gr.themes.Base()) as app:
     gr.Markdown(HEADER_MD)
 
@@ -368,18 +450,34 @@ with gr.Blocks(theme=gr.themes.Base()) as app:
                             value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
                         )
 
-                with gr.Row():
-                    with gr.Column(scale=2):
-                        generate = gr.Button(value="合成", variant="primary")
-                    with gr.Column(scale=1):
-                        clear = gr.Button(value="清空")
-
         with gr.Column(scale=3):
-            error = gr.HTML(label="错误信息")
-            parsed_text = gr.Dataframe(
-                label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"]
-            )
-            audio = gr.Audio(label="合成音频", type="numpy")
+            with gr.Row():
+                error = gr.HTML(label="错误信息")
+            with gr.Row():
+                parsed_text = gr.Dataframe(
+                    label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"]
+                )
+            with gr.Row():
+                audio_once = gr.Audio(label="一次合成音频", type="numpy")
+            with gr.Row():
+                audio_stream = gr.Audio(
+                    label="流式合成音频",
+                    autoplay=True,
+                    streaming=True,
+                    show_label=True,
+                    interactive=False,
+                )
+            with gr.Row():
+                with gr.Column(scale=3):
+                    generate = gr.Button(value="\U0001F3A7 合成", variant="primary")
+                    stream_generate = gr.Button(
+                        value="\U0001F4A7 流式合成", variant="primary"
+                    )
+                with gr.Column(scale=1):
+                    audio_download = gr.Button(
+                        value="\U0001F449 下载流式音频", elem_id="audio_download"
+                    )
+                    clear = gr.Button(value="清空")
 
     # Language & Text Parsing
     kwargs = dict(
@@ -422,7 +520,42 @@ with gr.Blocks(theme=gr.themes.Base()) as app:
             temperature,
             speaker,
         ],
-        [audio, error],
+        [audio_once, error],
+    )
+
+    stream_generate.click(
+        inference_stream,
+        [
+            server_url,
+            text,
+            input_mode,
+            language0,
+            language1,
+            language2,
+            enable_reference_audio,
+            reference_audio,
+            reference_text,
+            max_new_tokens,
+            top_k,
+            top_p,
+            repetition_penalty,
+            temperature,
+            speaker,
+        ],
+        [audio_stream, error],
+    ).then(lambda: gr.update(interactive=True), None, [text], queue=False)
+
+    audio_download.click(
+        None,
+        js="() => { "
+        'var btn = document.getElementById("audio_download"); '
+        "btn.disabled = true; "
+        "setTimeout(() => { btn.disabled = false; }, 1000); "
+        'var win = window.open("http://localhost:8000/v1/models/default/download", '
+        '"newwindow", "height=100, width=400, toolbar=no, menubar=no, scrollbars=no, '
+        'resizable=no, location=no, status=no"); '
+        "setTimeout(function() { win.close(); }, 1000);"
+        "}",
     )