Explorar o código

Improve WebUI Inference Server Configuration Interaction (#204)

蓝梦实 hai 1 ano
pai
achega
9ebe18a5c1

+ 3 - 0
fish_speech/i18n/locale/en_US.json

@@ -1,4 +1,5 @@
 {
+    "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
     "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
     "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
     "Accumulate Gradient Batches": "Accumulate Gradient Batches",
@@ -6,7 +7,9 @@
     "Added path successfully!": "Added path successfully!",
     "Advanced Config": "Advanced Config",
     "Base LLAMA Model": "Base LLAMA Model",
+    "Batch Inference": "Batch Inference",
     "Batch Size": "Batch Size",
+    "Changing with the Model Path": "Changing with the Model Path",
     "Chinese": "Chinese",
     "Compile Model": "Compile Model",
     "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",

+ 3 - 0
fish_speech/i18n/locale/es_ES.json

@@ -1,4 +1,5 @@
 {
+    "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
     "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
     "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
     "Accumulate Gradient Batches": "Acumular lotes de gradientes",
@@ -6,7 +7,9 @@
     "Added path successfully!": "¡Ruta agregada exitosamente!",
     "Advanced Config": "Configuración Avanzada",
     "Base LLAMA Model": "Modelo Base LLAMA",
+    "Batch Inference": "Inferencia por Lote",
     "Batch Size": "Tamaño del Lote",
+    "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
     "Chinese": "Chino",
     "Compile Model": "Compilar Modelo",
     "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",

+ 3 - 0
fish_speech/i18n/locale/ja_JP.json

@@ -1,4 +1,5 @@
 {
+    "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
     "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
     "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
     "Accumulate Gradient Batches": "勾配バッチの累積",
@@ -6,7 +7,9 @@
     "Added path successfully!": "パスの追加に成功しました!",
     "Advanced Config": "詳細設定",
     "Base LLAMA Model": "基本LLAMAモデル",
+    "Batch Inference": "バッチ推論",
     "Batch Size": "バッチサイズ",
+    "Changing with the Model Path": "モデルのパスに伴って変化する",
     "Chinese": "中国語",
     "Compile Model": "モデルのコンパイル",
     "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",

+ 3 - 0
fish_speech/i18n/locale/zh_CN.json

@@ -1,4 +1,5 @@
 {
+    "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
     "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
     "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
     "Accumulate Gradient Batches": "梯度累积批次",
@@ -6,7 +7,9 @@
     "Added path successfully!": "添加路径成功!",
     "Advanced Config": "高级参数",
     "Base LLAMA Model": "基础 LLAMA 模型",
+    "Batch Inference": "批量推理",
     "Batch Size": "批次大小",
+    "Changing with the Model Path": "随模型路径变化",
     "Chinese": "中文",
     "Compile Model": "编译模型",
     "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",

+ 54 - 21
fish_speech/webui/manage.py

@@ -16,14 +16,15 @@ import yaml
 from loguru import logger
 from tqdm import tqdm
 
-from fish_speech.i18n import i18n
-from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
-
 PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
 sys.path.insert(0, "")
 print(sys.path)
 cur_work_dir = Path(os.getcwd()).resolve()
 print("You are in ", str(cur_work_dir))
+
+from fish_speech.i18n import i18n
+from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
+
 config_path = cur_work_dir / "fish_speech" / "configs"
 vqgan_yml_path = config_path / "vqgan_finetune.yaml"
 llama_yml_path = config_path / "text2semantic_finetune.yaml"
@@ -131,6 +132,26 @@ def change_label(if_label):
         yield build_html_ok_message("Nothing")
 
 
+def change_decoder_config(decoder_model_path):
+    if "vits" in decoder_model_path:
+        choices = ["vits_decoder_finetune", "vits_decoder_pretrain"]
+        return gr.Dropdown(choices=choices, value=choices[0])
+    elif "vqgan" in decoder_model_path or "vq-gan" in decoder_model_path:
+        choices = ["vqgan_finetune", "vqgan_pretrain"]
+        return gr.Dropdown(choices=choices, value=choices[0])
+    else:
+        raise ValueError("Invalid decoder name")
+
+
+def change_llama_config(llama_model_path):
+    if "large" in llama_model_path:
+        return gr.Dropdown(value="dual_ar_2_codebook_large", interactive=False)
+    elif "medium" in llama_model_path:
+        return gr.Dropdown(value="dual_ar_2_codebook_medium", interactive=False)
+    else:
+        raise ValueError("Invalid model size")
+
+
 def clean_infer_cache():
     import tempfile
 
@@ -685,12 +706,25 @@ def fresh_tb_dir():
 
 
 def list_decoder_models():
-    return (
+    paths = (
         [str(p) for p in Path("checkpoints").glob("vits*.*")]
         + [str(p) for p in Path("checkpoints").glob("vq*.*")]
         + [str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")]
         + [str(p) for p in Path("results").glob("vits*/**/*.ckpt")]
     )
+    if not paths:
+        logger.warning("No decoder model found")
+    return paths
+
+
+def list_llama_models():
+    choices = [
+        str(p).replace("\\", "/") for p in Path("checkpoints").glob("text2sem*.*")
+    ]
+    choices += [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
+    if not choices:
+        logger.warning("No LLaMA model found")
+    return choices
 
 
 def fresh_decoder_model():
@@ -720,10 +754,11 @@ def fresh_llama_ckpt():
 
 
 def fresh_llama_model():
-    return gr.Dropdown(
-        choices=[init_llama_yml["ckpt_path"]]
-        + [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
-    )
+    choices = [
+        str(p).replace("\\", "/") for p in Path("checkpoints").glob("text2sem*.*")
+    ]
+    choices += [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
+    return gr.Dropdown(choices=choices)
 
 
 def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
@@ -1207,9 +1242,7 @@ with gr.Blocks(
                                 )
                                 infer_decoder_config = gr.Dropdown(
                                     label=i18n("Decoder Model Config"),
-                                    info=i18n(
-                                        "Type the path or select from the dropdown"
-                                    ),
+                                    info=i18n("Changing with the Model Path"),
                                     value="vits_decoder_finetune",
                                     choices=[
                                         "vits_decoder_finetune",
@@ -1226,20 +1259,12 @@ with gr.Blocks(
                                         "Type the path or select from the dropdown"
                                     ),
                                     value=init_llama_yml["ckpt_path"],
-                                    choices=[init_llama_yml["ckpt_path"]]
-                                    + [
-                                        str(p)
-                                        for p in Path("results").glob(
-                                            "text2sem*/**/*.ckpt"
-                                        )
-                                    ],
+                                    choices=list_llama_models(),
                                     allow_custom_value=True,
                                 )
                                 infer_llama_config = gr.Dropdown(
                                     label=i18n("LLAMA Model Config"),
-                                    info=i18n(
-                                        "Type the path or select from the dropdown"
-                                    ),
+                                    info=i18n("Changing with the Model Path"),
                                     choices=[
                                         "dual_ar_2_codebook_large",
                                         "dual_ar_2_codebook_medium",
@@ -1333,6 +1358,14 @@ with gr.Blocks(
         'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
     )
     if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
+    infer_decoder_model.change(
+        fn=change_decoder_config,
+        inputs=[infer_decoder_model],
+        outputs=[infer_decoder_config],
+    )
+    infer_llama_model.change(
+        fn=change_llama_config, inputs=[infer_llama_model], outputs=[infer_llama_config]
+    )
     train_btn.click(
         fn=train_process,
         inputs=[