Lengyue 2 anos atrás
pai
commit
681567df58
2 arquivos alterados com 283 adições e 47 exclusões
  1. 10 0
      docs/zh/inference.md
  2. 273 47
      fish_speech/webui/app.py

+ 10 - 0
docs/zh/inference.md

@@ -76,3 +76,13 @@ python -m zibai tools.api_server:app --listen 127.0.0.1:8000
 随后, 你可以在 `http://127.0.0.1:8000/docs` 中查看并测试 API.  
 一般来说, 你需要先调用 `PUT /v1/models/default` 来加载模型, 然后调用 `POST /v1/models/default/invoke` 来进行推理. 具体的参数请参考 API 文档.
 
+
+## WebUI 推理
+在运行 WebUI 之前, 你需要先启动 HTTP 服务, 如上所述.
+
+随后你可以使用以下命令来启动 WebUI:
+```bash
+python fish_speech/webui/app.py
+```
+
+祝大家玩得开心!

+ 273 - 47
fish_speech/webui/app.py

@@ -1,14 +1,17 @@
 import html
+import io
 import traceback
 
 import gradio as gr
+import librosa
+import requests
 
 from fish_speech.text import parse_text_to_segments, segments_to_phones
 
 HEADER_MD = """
 # Fish Speech
 
-基于 VITS 和 GPT 的多语种语音合成. 项目很大程度上基于 Rcell 的 GPT-VITS.
+基于 VQ-GAN 和 Llama 的多语种语音合成. 感谢 Rcell 的 GPT-VITS 提供的思路.
 """
 
 TEXTBOX_PLACEHOLDER = """在启用自动音素的情况下, 模型默认会全自动将输入文本转换为音素. 例如:
@@ -66,7 +69,7 @@ def prepare_text(
     else:
         reference_text = ""
 
-    if input_mode != "自动音素转换":
+    if input_mode != "自动音素":
         return [
             [idx, reference_text + line, "-", "-"]
             for idx, line in enumerate(lines)
@@ -92,69 +95,272 @@ def prepare_text(
     return rows, None
 
 
+def load_model(
+    server_url,
+    llama_ckpt_path,
+    llama_config_name,
+    tokenizer,
+    vqgan_ckpt_path,
+    vqgan_config_name,
+    device,
+    precision,
+    compile_model,
+):
+    payload = {
+        "device": device,
+        "llama": {
+            "config_name": llama_config_name,
+            "checkpoint_path": llama_ckpt_path,
+            "precision": precision,
+            "tokenizer": tokenizer,
+            "compile": compile_model,
+        },
+        "vqgan": {
+            "config_name": vqgan_config_name,
+            "checkpoint_path": vqgan_ckpt_path,
+        },
+    }
+
+    try:
+        resp = requests.put(f"{server_url}/v1/models/default", json=payload)
+        resp.raise_for_status()
+    except Exception:
+        traceback.print_exc()
+        err = traceback.format_exc()
+        return build_html_error_message(f"加载模型时发生错误. \n\n{err}")
+
+    return "模型加载成功."
+
+
+def build_model_config_block():
+    server_url = gr.Textbox(label="服务器地址", value="http://localhost:8000")
+
+    with gr.Row():
+        with gr.Column(scale=1):
+            device = gr.Dropdown(
+                label="设备",
+                choices=["cpu", "cuda"],
+                value="cuda",
+            )
+        with gr.Column(scale=1):
+            precision = gr.Dropdown(
+                label="精度",
+                choices=["bfloat16", "float16"],
+                value="float16",
+            )
+        with gr.Column(scale=1):
+            compile_model = gr.Checkbox(
+                label="编译模型",
+                value=True,
+            )
+
+    llama_ckpt_path = gr.Textbox(
+        label="Llama 模型路径", value="checkpoints/text2semantic-400m-v0.2-4k.pth"
+    )
+    llama_config_name = gr.Textbox(label="Llama 配置文件", value="text2semantic_finetune")
+    tokenizer = gr.Textbox(label="Tokenizer", value="fishaudio/speech-lm-v1")
+
+    vqgan_ckpt_path = gr.Textbox(label="VQGAN 模型路径", value="checkpoints/vqgan-v1.pth")
+    vqgan_config_name = gr.Textbox(label="VQGAN 配置文件", value="vqgan_pretrain")
+
+    load_model_btn = gr.Button(value="加载模型", variant="primary")
+    error = gr.HTML(label="错误信息")
+
+    load_model_btn.click(
+        load_model,
+        [
+            server_url,
+            llama_ckpt_path,
+            llama_config_name,
+            tokenizer,
+            vqgan_ckpt_path,
+            vqgan_config_name,
+            device,
+            precision,
+            compile_model,
+        ],
+        [error],
+    )
+
+    return server_url
+
+
+def inference(
+    server_url,
+    text,
+    input_mode,
+    language0,
+    language1,
+    language2,
+    enable_reference_audio,
+    reference_audio,
+    reference_text,
+    max_new_tokens,
+    top_k,
+    top_p,
+    repetition_penalty,
+    temperature,
+    speaker,
+):
+    languages = [language0, language1, language2]
+    languages = [
+        {
+            "中文": "zh",
+            "日文": "jp",
+            "英文": "en",
+        }[language]
+        for language in languages
+    ]
+
+    if len(set(languages)) != len(languages):
+        return [], build_html_error_message("语言优先级不能重复.")
+
+    order = ",".join(languages)
+    payload = {
+        "text": text,
+        "prompt_text": reference_text if enable_reference_audio else None,
+        "prompt_tokens": reference_audio if enable_reference_audio else None,
+        "max_new_tokens": int(max_new_tokens),
+        "top_k": int(top_k) if top_k > 0 else None,
+        "top_p": top_p,
+        "repetition_penalty": repetition_penalty,
+        "temperature": temperature,
+        "order": order,
+        "use_g2p": input_mode == "自动音素",
+        "seed": None,
+        "speaker": speaker if speaker.strip() != "" else None,
+    }
+
+    try:
+        resp = requests.post(f"{server_url}/v1/models/default/invoke", json=payload)
+        resp.raise_for_status()
+    except Exception:
+        traceback.print_exc()
+        err = traceback.format_exc()
+        return [], build_html_error_message(f"推理时发生错误. \n\n{err}")
+
+    content = io.BytesIO(resp.content)
+    content.seek(0)
+    content, sr = librosa.load(content, sr=None, mono=True)
+
+    return (sr, content), None
+
+
 with gr.Blocks(theme=gr.themes.Base()) as app:
     gr.Markdown(HEADER_MD)
 
+    # Use light theme by default
+    app.load(
+        None,
+        None,
+        js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}",
+    )
+
+    # Inference
     with gr.Row():
         with gr.Column(scale=3):
-            text = gr.Textbox(label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=3)
-
-            with gr.Row():
-                with gr.Tab(label="合成参数"):
-                    gr.Markdown("配置常见合成参数.")
+            with gr.Tab(label="模型配置"):
+                server_url = build_model_config_block()
+
+            with gr.Tab(label="推理配置"):
+                text = gr.Textbox(
+                    label="输入文本", placeholder=TEXTBOX_PLACEHOLDER, lines=15
+                )
+
+                with gr.Row():
+                    with gr.Tab(label="合成参数"):
+                        gr.Markdown("配置常见合成参数. 自动音素会在推理时自动将文本转换为音素.")
+
+                        input_mode = gr.Dropdown(
+                            choices=["文本", "自动音素"],
+                            value="文本",
+                            label="输入模式",
+                        )
 
-                    input_mode = gr.Dropdown(
-                        choices=["手动输入音素/文本", "自动音素转换"],
-                        value="手动输入音素/文本",
-                        label="输入模式",
-                    )
+                        max_new_tokens = gr.Slider(
+                            label="最大生成 Token 数",
+                            minimum=0,
+                            maximum=4096,
+                            value=0,  # 0 means no limit
+                            step=8,
+                        )
 
-                with gr.Tab(label="语言优先级"):
-                    gr.Markdown("该参数只在自动音素转换时生效.")
+                        top_k = gr.Slider(
+                            label="Top-K", minimum=0, maximum=100, value=0, step=1
+                        )
 
-                    with gr.Column(scale=1):
-                        language0 = gr.Dropdown(
-                            choices=["中文", "日文", "英文"],
-                            label="语言 1",
-                            value="中文",
+                        top_p = gr.Slider(
+                            label="Top-P", minimum=0, maximum=1, value=0.5, step=0.01
                         )
 
-                    with gr.Column(scale=1):
-                        language1 = gr.Dropdown(
-                            choices=["中文", "日文", "英文"],
-                            label="语言 2",
-                            value="日文",
+                        repetition_penalty = gr.Slider(
+                            label="重复惩罚", minimum=0, maximum=2, value=1.5, step=0.01
                         )
 
-                    with gr.Column(scale=1):
-                        language2 = gr.Dropdown(
-                            choices=["中文", "日文", "英文"],
-                            label="语言 3",
-                            value="英文",
+                        temperature = gr.Slider(
+                            label="温度", minimum=0, maximum=2, value=0.7, step=0.01
                         )
 
-                with gr.Tab(label="参考音频"):
-                    gr.Markdown("3 秒左右的参考音频, 适用于无微调直接推理.")
+                        speaker = gr.Textbox(
+                            label="说话人",
+                            placeholder="说话人",
+                            lines=1,
+                        )
 
-                    enable_reference_audio = gr.Checkbox(label="启用参考音频", value=False)
-                    reference_audio = gr.Audio(label="参考音频")
-                    reference_text = gr.Textbox(
-                        label="参考文本",
-                        placeholder="参考文本",
-                        lines=1,
-                        value="万一他很崇拜我们呢? 嘿嘿.",
-                    )
+                    with gr.Tab(label="语言优先级"):
+                        gr.Markdown("该参数只在自动音素转换时生效.")
+
+                        with gr.Column(scale=1):
+                            language0 = gr.Dropdown(
+                                choices=["中文", "日文", "英文"],
+                                label="语言 1",
+                                value="中文",
+                            )
+
+                        with gr.Column(scale=1):
+                            language1 = gr.Dropdown(
+                                choices=["中文", "日文", "英文"],
+                                label="语言 2",
+                                value="日文",
+                            )
+
+                        with gr.Column(scale=1):
+                            language2 = gr.Dropdown(
+                                choices=["中文", "日文", "英文"],
+                                label="语言 3",
+                                value="英文",
+                            )
+
+                    with gr.Tab(label="参考音频"):
+                        gr.Markdown("5-10 秒的参考音频, 适用于指定音色.")
+
+                        enable_reference_audio = gr.Checkbox(
+                            label="启用参考音频", value=False
+                        )
+                        reference_audio = gr.Audio(
+                            label="参考音频",
+                            value="docs/assets/audios/0_input.wav",
+                            type="filepath",
+                        )
+                        reference_text = gr.Textbox(
+                            label="参考文本",
+                            placeholder="参考文本",
+                            lines=1,
+                            value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
+                        )
 
-            with gr.Row():
-                with gr.Column(scale=2):
-                    generate = gr.Button(value="合成", variant="primary")
-                with gr.Column(scale=1):
-                    clear = gr.Button(value="清空")
+                with gr.Row():
+                    with gr.Column(scale=2):
+                        generate = gr.Button(value="合成", variant="primary")
+                    with gr.Column(scale=1):
+                        clear = gr.Button(value="清空")
 
         with gr.Column(scale=3):
             error = gr.HTML(label="错误信息")
-            parsed_text = gr.Dataframe(label="解析结果", headers=["ID", "文本", "语言", "音素"])
-            audio = gr.Audio(label="合成音频")
+            parsed_text = gr.Dataframe(
+                label="解析结果 (仅参考)", headers=["ID", "文本", "语言", "音素"]
+            )
+            audio = gr.Audio(label="合成音频", type="numpy")
 
     # Language & Text Parsing
     kwargs = dict(
@@ -178,7 +384,27 @@ with gr.Blocks(theme=gr.themes.Base()) as app:
     enable_reference_audio.change(prepare_text, **kwargs)
 
     # Submit
-    generate.click(lambda: None, outputs=[audio])
+    generate.click(
+        inference,
+        [
+            server_url,
+            text,
+            input_mode,
+            language0,
+            language1,
+            language2,
+            enable_reference_audio,
+            reference_audio,
+            reference_text,
+            max_new_tokens,
+            top_k,
+            top_p,
+            repetition_penalty,
+            temperature,
+            speaker,
+        ],
+        [audio, error],
+    )
 
 
 if __name__ == "__main__":