ソースを参照

Gradio UI: Lora update (#142)

* init package

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix spelling

* Decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Specify the backend on the command line

* Fix Encoding Error

* Add start configuration

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add lora config & GC

* Label UI updated

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add llama base config

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* New Project Names & Remove ref audio

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Upgrade UI & Tensorboard

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add user custom path

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 年間 前
コミット
615c8398ed
2 ファイル変更402 行追加217 行削除
  1. 397 216
      fish_speech/webui/manage.py
  2. 5 1
      tools/webui.py

+ 397 - 216
fish_speech/webui/manage.py

@@ -4,11 +4,11 @@ import html
 import json
 import os
 import platform
-import random
 import shutil
 import signal
 import subprocess
 import sys
+import webbrowser
 from pathlib import Path
 
 import gradio as gr
@@ -26,7 +26,7 @@ cur_work_dir = Path(os.getcwd()).resolve()
 print("You are in ", str(cur_work_dir))
 config_path = cur_work_dir / "fish_speech" / "configs"
 vqgan_yml_path = config_path / "vqgan_finetune.yaml"
-llama_yml_path = config_path / "text2semantic_sft.yaml"
+llama_yml_path = config_path / "text2semantic_finetune.yaml"
 
 env = os.environ.copy()
 env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
@@ -79,6 +79,7 @@ def kill_proc_tree(pid, including_parent=True):
 system = platform.system()
 p_label = None
 p_infer = None
+p_tensorboard = None
 
 
 def kill_process(pid):
@@ -92,18 +93,24 @@ def kill_process(pid):
 
 def change_label(if_label):
     global p_label
-    if if_label == True and p_label == None:
-        cmd = ["asr-label-win-x64.exe"]
-        yield f"打标工具WebUI已开启, 访问:http://localhost:{3000}"
-        p_label = subprocess.Popen(cmd, shell=True, env=env)
-    elif if_label == False and p_label != None:
-        kill_process(p_label.pid)
+    if if_label == True:
+        # 设置要访问的URL
+        url = "https://text-labeler.pages.dev/"
+        webbrowser.open(url)
+        yield f"已打开网址"
+    elif if_label == False:
         p_label = None
-        yield "打标工具WebUI已关闭"
+        yield "Nothing"
 
 
 def change_infer(
-    if_infer, host, port, infer_vqgan_model, infer_llama_model, infer_compile
+    if_infer,
+    host,
+    port,
+    infer_vqgan_model,
+    infer_llama_model,
+    infer_llama_config,
+    infer_compile,
 ):
     global p_infer
     if if_infer == True and p_infer == None:
@@ -121,6 +128,8 @@ def change_infer(
                 infer_vqgan_model,
                 "--llama-checkpoint-path",
                 infer_llama_model,
+                "--llama-config-name",
+                infer_llama_config,
                 "--tokenizer",
                 "checkpoints",
             ]
@@ -193,12 +202,10 @@ def new_explorer(data_path, max_depth):
     )
 
 
-def add_item(folder: str, method: str, filelist: str, label_lang: str):
+def add_item(folder: str, method: str, label_lang: str):
     folder = folder.strip(" ").strip('"')
-    filelist = filelist.strip(" ").strip('"')
 
     folder_path = Path(folder)
-    filelist_path = Path(filelist)
 
     if folder and folder not in items and data_pre_output not in folder_path.parents:
         if folder_path.is_dir():
@@ -212,22 +219,6 @@ def add_item(folder: str, method: str, filelist: str, label_lang: str):
                 f"添加文件夹路径无效: {err}"
             )
 
-    if (
-        filelist
-        and filelist not in items
-        and data_pre_output not in filelist_path.parents
-    ):
-        if filelist_path.is_file():
-            items.append(filelist)
-            dict_items[filelist] = dict(
-                type="file", method=method, label_lang=label_lang
-            )
-        elif filelist:
-            err = filelist
-            return gr.Checkboxgroup(choices=items), build_html_error_message(
-                f"添加文件路径无效: {err}"
-            )
-
     formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
     logger.info(formatted_data)
     return gr.Checkboxgroup(choices=items), build_html_ok_message("添加文件(夹)路径成功!")
@@ -350,9 +341,9 @@ def train_process(
     vqgan_precision,
     vqgan_check_interval,
     # llama config
+    llama_base_config,
     llama_lr,
     llama_maxsteps,
-    llama_limit_val_batches,
     llama_data_num_workers,
     llama_data_batch_size,
     llama_data_max_length,
@@ -360,8 +351,21 @@ def train_process(
     llama_check_interval,
     llama_grad_batches,
     llama_use_speaker,
+    llama_use_lora,
 ):
+    import datetime
+
+    def generate_folder_name():
+        now = datetime.datetime.now()
+        folder_name = now.strftime("%Y%m%d_%H%M%S")
+        return folder_name
+
     backend = "nccl" if sys.platform == "linux" else "gloo"
+
+    new_project = generate_folder_name()
+
+    print("New Project Name: ", new_project)
+
     if option == "VQGAN" or option == "all":
         subprocess.run(
             [
@@ -375,6 +379,7 @@ def train_process(
             "fish_speech/train.py",
             "--config-name",
             "vqgan_finetune",
+            f"project={'vqgan_' + new_project}",
             f"trainer.strategy.process_group_backend={backend}",
             f"model.optimizer.lr={vqgan_lr}",
             f"trainer.max_steps={vqgan_maxsteps}",
@@ -418,20 +423,25 @@ def train_process(
                 "16",
             ]
         )
-
+        ckpt_path = (
+            "text2semantic-pretrain-medium-2k-v1.pth"
+            if llama_base_config == "dual_ar_2_codebook_medium"
+            else "text2semantic-sft-large-v1-4k.pth"
+        )
         train_cmd = [
             PYTHON,
             "fish_speech/train.py",
             "--config-name",
-            "text2semantic_sft",
+            "text2semantic_finetune",
+            f"project={'text2semantic_' + new_project}",
+            f"ckpt_path=checkpoints/{ckpt_path}",
             f"trainer.strategy.process_group_backend={backend}",
-            "model@model.model=dual_ar_2_codebook_medium",
+            f"model@model.model={llama_base_config}",
             "tokenizer.pretrained_model_name_or_path=checkpoints",
             f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
             f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
             f"model.optimizer.lr={llama_lr}",
             f"trainer.max_steps={llama_maxsteps}",
-            f"trainer.limit_val_batches={llama_limit_val_batches}",
             f"data.num_workers={llama_data_num_workers}",
             f"data.batch_size={llama_data_batch_size}",
             f"max_length={llama_data_max_length}",
@@ -439,13 +449,91 @@ def train_process(
             f"trainer.val_check_interval={llama_check_interval}",
             f"trainer.accumulate_grad_batches={llama_grad_batches}",
             f"train_dataset.use_speaker={llama_use_speaker}",
-        ]
+        ] + ([f"+lora@model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
         logger.info(train_cmd)
         subprocess.run(train_cmd)
 
     return build_html_ok_message("训练终止")
 
 
+def tensorboard_process(
+    if_tensorboard: bool,
+    tensorboard_dir: str,
+    host: str,
+    port: str,
+):
+    global p_tensorboard
+    if if_tensorboard == True and p_tensorboard == None:
+        yield build_html_ok_message(f"Tensorboard界面已开启, 访问 http://{host}:{port}")
+        p_tensorboard = subprocess.Popen(
+            [
+                "fishenv/python.exe",
+                "fishenv/Scripts/tensorboard.exe"
+                if Path("fishenv").exists()
+                else "tensorboard",
+                "--logdir",
+                tensorboard_dir,
+                "--host",
+                host,
+                "--port",
+                port,
+                "--reload_interval",
+                "120",
+            ]
+        )
+    elif if_tensorboard == False and p_tensorboard != None:
+        kill_process(p_tensorboard.pid)
+        p_tensorboard = None
+        yield build_html_error_message("Tensorboard界面已关闭")
+
+
+def fresh_tb_dir():
+    return gr.Dropdown(
+        choices=[str(p) for p in Path("results").glob("**/tensorboard/version_*/")]
+    )
+
+
+def fresh_vqgan_model():
+    return gr.Dropdown(
+        choices=[init_vqgan_yml["ckpt_path"]]
+        + [str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")]
+    )
+
+
+def fresh_llama_model():
+    return gr.Dropdown(
+        choices=[init_llama_yml["ckpt_path"]]
+        + [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
+    )
+
+
+def llama_lora_merge(llama_weight, lora_weight, llama_lora_output):
+    if (
+        lora_weight is None
+        or not Path(lora_weight).exists()
+        or not Path(llama_weight).exists()
+    ):
+        return build_html_error_message("路径错误,请检查模型文件是否存在于对应路径")
+
+    merge_cmd = [
+        PYTHON,
+        "tools/llama/merge_lora.py",
+        "--llama-config",
+        "dual_ar_2_codebook_large",
+        "--lora-config",
+        "r_8_alpha_16",
+        "--llama-weight",
+        llama_weight,
+        "--lora-weight",
+        lora_weight,
+        "--output",
+        llama_lora_output,
+    ]
+    logger.info(merge_cmd)
+    subprocess.run(merge_cmd)
+    return build_html_ok_message("融合终止")
+
+
 init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
 init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
 
@@ -465,11 +553,6 @@ with gr.Blocks(
                         info="音频装在一个以说话人命名的文件夹内作为区分",
                         interactive=True,
                     )
-                    transcript_path = gr.Textbox(
-                        label="\U0001F4DD 转写文本filelist所在路径",
-                        info="支持 Bert-Vits2 / GPT-SoVITS 格式",
-                        interactive=True,
-                    )
                 with gr.Row(equal_height=False):
                     with gr.Column():
                         output_radio = gr.Radio(
@@ -511,173 +594,227 @@ with gr.Blocks(
                     )
 
             with gr.Tab("\U0001F6E0 训练配置项"):  # hammer
-                with gr.Column():
-                    with gr.Row():
-                        model_type_radio = gr.Radio(
-                            label="选择要训练的模型类型",
-                            interactive=True,
-                            choices=["VQGAN", "LLAMA", "all"],
-                            value="all",
-                        )
-                    with gr.Row():
-                        with gr.Accordion("VQGAN配置项", open=False):
-                            with gr.Row(equal_height=False):
-                                vqgan_lr_slider = gr.Slider(
-                                    label="初始学习率",
-                                    interactive=True,
-                                    minimum=1e-5,
-                                    maximum=1e-4,
-                                    step=1e-5,
-                                    value=init_vqgan_yml["model"]["optimizer"]["lr"],
-                                )
-                                vqgan_maxsteps_slider = gr.Slider(
-                                    label="训练最大步数",
-                                    interactive=True,
-                                    minimum=1000,
-                                    maximum=100000,
-                                    step=1000,
-                                    value=init_vqgan_yml["trainer"]["max_steps"],
-                                )
-
-                            with gr.Row(equal_height=False):
-                                vqgan_data_num_workers_slider = gr.Slider(
-                                    label="num_workers",
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=16,
-                                    step=1,
-                                    value=init_vqgan_yml["data"]["num_workers"],
-                                )
-
-                                vqgan_data_batch_size_slider = gr.Slider(
-                                    label="batch_size",
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=init_vqgan_yml["data"]["batch_size"],
-                                )
-                            with gr.Row(equal_height=False):
-                                vqgan_data_val_batch_size_slider = gr.Slider(
-                                    label="val_batch_size",
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=init_vqgan_yml["data"]["val_batch_size"],
-                                )
-                                vqgan_precision_dropdown = gr.Dropdown(
-                                    label="训练精度",
-                                    interactive=True,
-                                    choices=["32", "bf16-true", "bf16-mixed"],
-                                    value=str(init_vqgan_yml["trainer"]["precision"]),
-                                )
-                            with gr.Row(equal_height=False):
-                                vqgan_check_interval_slider = gr.Slider(
-                                    label="每n步保存一个模型",
-                                    interactive=True,
-                                    minimum=500,
-                                    maximum=10000,
-                                    step=500,
-                                    value=init_vqgan_yml["trainer"][
-                                        "val_check_interval"
-                                    ],
-                                )
-
-                    with gr.Row():
-                        with gr.Accordion("LLAMA配置项", open=False):
-                            with gr.Row(equal_height=False):
-                                llama_lr_slider = gr.Slider(
-                                    label="初始学习率",
-                                    interactive=True,
-                                    minimum=1e-5,
-                                    maximum=1e-4,
-                                    step=1e-5,
-                                    value=init_llama_yml["model"]["optimizer"]["lr"],
-                                )
-                                llama_maxsteps_slider = gr.Slider(
-                                    label="训练最大步数",
-                                    interactive=True,
-                                    minimum=1000,
-                                    maximum=100000,
-                                    step=1000,
-                                    value=init_llama_yml["trainer"]["max_steps"],
-                                )
-                            with gr.Row(equal_height=False):
-                                llama_limit_val_batches_slider = gr.Slider(
-                                    label="limit_val_batches",
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=20,
-                                    step=1,
-                                    value=init_llama_yml["trainer"][
-                                        "limit_val_batches"
-                                    ],
-                                )
-                                llama_data_num_workers_slider = gr.Slider(
-                                    label="num_workers",
-                                    minimum=0,
-                                    maximum=16,
-                                    step=1,
-                                    value=init_llama_yml["data"]["num_workers"]
-                                    if sys.platform == "linux"
-                                    else 0,
-                                )
-                            with gr.Row(equal_height=False):
-                                llama_data_batch_size_slider = gr.Slider(
-                                    label="batch_size",
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=32,
-                                    step=1,
-                                    value=init_llama_yml["data"]["batch_size"],
-                                )
-                                llama_data_max_length_slider = gr.Slider(
-                                    label="max_length",
-                                    interactive=True,
-                                    minimum=1024,
-                                    maximum=4096,
-                                    step=128,
-                                    value=init_llama_yml["max_length"],
-                                )
-                            with gr.Row(equal_height=False):
-                                llama_precision_dropdown = gr.Dropdown(
-                                    label="训练精度",
-                                    interactive=True,
-                                    choices=["32", "bf16-true", "16-mixed"],
-                                    value="bf16-true",
-                                )
-                                llama_check_interval_slider = gr.Slider(
-                                    label="每n步保存一个模型",
-                                    interactive=True,
-                                    minimum=500,
-                                    maximum=10000,
-                                    step=500,
-                                    value=init_llama_yml["trainer"][
-                                        "val_check_interval"
-                                    ],
-                                )
-                            with gr.Row(equal_height=False):
-                                llama_grad_batches = gr.Slider(
-                                    label="accumulate_grad_batches",
-                                    interactive=True,
-                                    minimum=1,
-                                    maximum=20,
-                                    step=1,
-                                    value=init_llama_yml["trainer"][
-                                        "accumulate_grad_batches"
-                                    ],
-                                )
-                                llama_use_speaker = gr.Slider(
-                                    label="use_speaker_ratio",
-                                    interactive=True,
-                                    minimum=0.1,
-                                    maximum=1.0,
-                                    step=0.05,
-                                    value=init_llama_yml["train_dataset"][
-                                        "use_speaker"
-                                    ],
-                                )
+                with gr.Row():
+                    model_type_radio = gr.Radio(
+                        label="选择要训练的模型类型",
+                        interactive=True,
+                        choices=["VQGAN", "LLAMA", "all"],
+                        value="all",
+                    )
+                with gr.Row():
+                    with gr.Tab(label="VQGAN配置项"):
+                        with gr.Row(equal_height=False):
+                            vqgan_lr_slider = gr.Slider(
+                                label="初始学习率",
+                                interactive=True,
+                                minimum=1e-5,
+                                maximum=1e-4,
+                                step=1e-5,
+                                value=init_vqgan_yml["model"]["optimizer"]["lr"],
+                            )
+                            vqgan_maxsteps_slider = gr.Slider(
+                                label="训练最大步数",
+                                interactive=True,
+                                minimum=1000,
+                                maximum=100000,
+                                step=1000,
+                                value=init_vqgan_yml["trainer"]["max_steps"],
+                            )
+
+                        with gr.Row(equal_height=False):
+                            vqgan_data_num_workers_slider = gr.Slider(
+                                label="num_workers",
+                                interactive=True,
+                                minimum=1,
+                                maximum=16,
+                                step=1,
+                                value=init_vqgan_yml["data"]["num_workers"],
+                            )
+
+                            vqgan_data_batch_size_slider = gr.Slider(
+                                label="batch_size",
+                                interactive=True,
+                                minimum=1,
+                                maximum=32,
+                                step=1,
+                                value=init_vqgan_yml["data"]["batch_size"],
+                            )
+                        with gr.Row(equal_height=False):
+                            vqgan_data_val_batch_size_slider = gr.Slider(
+                                label="val_batch_size",
+                                interactive=True,
+                                minimum=1,
+                                maximum=32,
+                                step=1,
+                                value=init_vqgan_yml["data"]["val_batch_size"],
+                            )
+                            vqgan_precision_dropdown = gr.Dropdown(
+                                label="训练精度",
+                                interactive=True,
+                                choices=["32", "bf16-true", "bf16-mixed"],
+                                value=str(init_vqgan_yml["trainer"]["precision"]),
+                            )
+                        with gr.Row(equal_height=False):
+                            vqgan_check_interval_slider = gr.Slider(
+                                label="每n步保存一个模型",
+                                interactive=True,
+                                minimum=500,
+                                maximum=10000,
+                                step=500,
+                                value=init_vqgan_yml["trainer"]["val_check_interval"],
+                            )
+
+                    with gr.Tab(label="LLAMA配置项"):
+                        with gr.Row(equal_height=False):
+                            llama_use_lora = gr.Checkbox(
+                                label="使用lora训练?",
+                                value=True,
+                            )
+                        with gr.Row(equal_height=False):
+                            llama_lr_slider = gr.Slider(
+                                label="初始学习率",
+                                interactive=True,
+                                minimum=1e-5,
+                                maximum=1e-4,
+                                step=1e-5,
+                                value=init_llama_yml["model"]["optimizer"]["lr"],
+                            )
+                            llama_maxsteps_slider = gr.Slider(
+                                label="训练最大步数",
+                                interactive=True,
+                                minimum=1000,
+                                maximum=100000,
+                                step=1000,
+                                value=init_llama_yml["trainer"]["max_steps"],
+                            )
+                        with gr.Row(equal_height=False):
+                            llama_base_config = gr.Dropdown(
+                                label="模型基础属性",
+                                choices=[
+                                    "dual_ar_2_codebook_large",
+                                    "dual_ar_2_codebook_medium",
+                                ],
+                                value="dual_ar_2_codebook_large",
+                            )
+                            llama_data_num_workers_slider = gr.Slider(
+                                label="num_workers",
+                                minimum=0,
+                                maximum=16,
+                                step=1,
+                                value=init_llama_yml["data"]["num_workers"]
+                                if sys.platform == "linux"
+                                else 0,
+                            )
+                        with gr.Row(equal_height=False):
+                            llama_data_batch_size_slider = gr.Slider(
+                                label="batch_size",
+                                interactive=True,
+                                minimum=1,
+                                maximum=32,
+                                step=1,
+                                value=init_llama_yml["data"]["batch_size"],
+                            )
+                            llama_data_max_length_slider = gr.Slider(
+                                label="max_length",
+                                interactive=True,
+                                minimum=1024,
+                                maximum=4096,
+                                step=128,
+                                value=init_llama_yml["max_length"],
+                            )
+                        with gr.Row(equal_height=False):
+                            llama_precision_dropdown = gr.Dropdown(
+                                label="训练精度",
+                                interactive=True,
+                                choices=["32", "bf16-true", "16-mixed"],
+                                value="bf16-true",
+                            )
+                            llama_check_interval_slider = gr.Slider(
+                                label="每n步保存一个模型",
+                                interactive=True,
+                                minimum=500,
+                                maximum=10000,
+                                step=500,
+                                value=init_llama_yml["trainer"]["val_check_interval"],
+                            )
+                        with gr.Row(equal_height=False):
+                            llama_grad_batches = gr.Slider(
+                                label="accumulate_grad_batches",
+                                interactive=True,
+                                minimum=1,
+                                maximum=20,
+                                step=1,
+                                value=init_llama_yml["trainer"][
+                                    "accumulate_grad_batches"
+                                ],
+                            )
+                            llama_use_speaker = gr.Slider(
+                                label="use_speaker_ratio",
+                                interactive=True,
+                                minimum=0.1,
+                                maximum=1.0,
+                                step=0.05,
+                                value=init_llama_yml["train_dataset"]["use_speaker"],
+                            )
+
+                    with gr.Tab(label="LLAMA_lora融合"):
+                        with gr.Row(equal_height=False):
+                            llama_weight = gr.Dropdown(
+                                label="要融入的原模型",
+                                info="输入路径,或者下拉选择",
+                                choices=[init_llama_yml["ckpt_path"]],
+                                value=init_llama_yml["ckpt_path"],
+                                allow_custom_value=True,
+                                interactive=True,
+                            )
+                        with gr.Row(equal_height=False):
+                            lora_weight = gr.Dropdown(
+                                label="要融入的lora模型",
+                                info="输入路径,或者下拉选择",
+                                choices=[
+                                    str(p)
+                                    for p in Path("results").glob("text2*ar/**/*.ckpt")
+                                ],
+                                allow_custom_value=True,
+                                interactive=True,
+                            )
+                        with gr.Row(equal_height=False):
+                            llama_lora_output = gr.Dropdown(
+                                label="输出的lora模型",
+                                info="输出路径",
+                                value="checkpoints/merged.ckpt",
+                                choices=["checkpoints/merged.ckpt"],
+                                allow_custom_value=True,
+                                interactive=True,
+                            )
+                        with gr.Row(equal_height=False):
+                            llama_lora_merge_btn = gr.Button(
+                                value="开始融合", variant="primary"
+                            )
+
+                    with gr.Tab(label="Tensorboard"):
+                        with gr.Row(equal_height=False):
+                            tb_host = gr.Textbox(
+                                label="Tensorboard Host", value="127.0.0.1"
+                            )
+                            tb_port = gr.Textbox(
+                                label="Tensorboard Port", value="11451"
+                            )
+                        with gr.Row(equal_height=False):
+                            tb_dir = gr.Dropdown(
+                                label="Tensorboard 日志文件夹",
+                                allow_custom_value=True,
+                                choices=[
+                                    str(p)
+                                    for p in Path("results").glob(
+                                        "**/tensorboard/version_*/"
+                                    )
+                                ],
+                            )
+                        with gr.Row(equal_height=False):
+                            if_tb = gr.Checkbox(
+                                label="是否打开tensorboard?",
+                            )
 
             with gr.Tab("\U0001F9E0 进入推理界面"):
                 with gr.Column():
@@ -691,21 +828,46 @@ with gr.Blocks(
                                     label="Webui启动服务器端口", value="7862"
                                 )
                             with gr.Row():
-                                infer_vqgan_model = gr.Textbox(
+                                infer_vqgan_model = gr.Dropdown(
                                     label="VQGAN模型位置",
-                                    placeholder="填写pth/ckpt文件路径",
-                                    value="checkpoints/vq-gan-group-fsq-2x1024.pth",
+                                    info="填写pth/ckpt文件路径",
+                                    value=init_vqgan_yml["ckpt_path"],
+                                    choices=[init_vqgan_yml["ckpt_path"]]
+                                    + [
+                                        str(p)
+                                        for p in Path("results").glob(
+                                            "vqgan*/**/*.ckpt"
+                                        )
+                                    ],
+                                    allow_custom_value=True,
                                 )
                             with gr.Row():
-                                infer_llama_model = gr.Textbox(
+                                infer_llama_model = gr.Dropdown(
                                     label="LLAMA模型位置",
-                                    placeholder="填写pth/ckpt文件路径",
-                                    value="checkpoints/text2semantic-medium-v1-2k.pth",
+                                    info="填写pth/ckpt文件路径",
+                                    value=init_llama_yml["ckpt_path"],
+                                    choices=[init_llama_yml["ckpt_path"]]
+                                    + [
+                                        str(p)
+                                        for p in Path("results").glob(
+                                            "text2sem*/**/*.ckpt"
+                                        )
+                                    ],
+                                    allow_custom_value=True,
                                 )
                             with gr.Row():
                                 infer_compile = gr.Radio(
                                     label="是否编译模型?", choices=["Yes", "No"], value="Yes"
                                 )
+                                infer_llama_config = gr.Dropdown(
+                                    label="LLAMA模型基础属性",
+                                    choices=[
+                                        "dual_ar_2_codebook_large",
+                                        "dual_ar_2_codebook_medium",
+                                    ],
+                                    value="dual_ar_2_codebook_large",
+                                    allow_custom_value=True,
+                                )
 
                     with gr.Row():
                         infer_checkbox = gr.Checkbox(label="是否打开推理界面")
@@ -758,7 +920,7 @@ with gr.Blocks(
 
     add_button.click(
         fn=add_item,
-        inputs=[textbox, output_radio, transcript_path, label_radio],
+        inputs=[textbox, output_radio, label_radio],
         outputs=[checkbox_group, error],
     )
     remove_button.click(
@@ -785,9 +947,9 @@ with gr.Blocks(
             vqgan_precision_dropdown,
             vqgan_check_interval_slider,
             # llama config
+            llama_base_config,
             llama_lr_slider,
             llama_maxsteps_slider,
-            llama_limit_val_batches_slider,
             llama_data_num_workers_slider,
             llama_data_batch_size_slider,
             llama_data_max_length_slider,
@@ -795,9 +957,23 @@ with gr.Blocks(
             llama_check_interval_slider,
             llama_grad_batches,
             llama_use_speaker,
+            llama_use_lora,
         ],
         outputs=[train_error],
     )
+    if_tb.change(
+        fn=tensorboard_process,
+        inputs=[if_tb, tb_dir, tb_host, tb_port],
+        outputs=[train_error],
+    )
+    tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
+    infer_vqgan_model.change(
+        fn=fresh_vqgan_model, inputs=[], outputs=[infer_vqgan_model]
+    )
+    infer_llama_model.change(
+        fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
+    )
+    llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
     admit_btn.click(
         fn=check_files,
         inputs=[train_box, tree_slider, label_model, label_device],
@@ -806,7 +982,11 @@ with gr.Blocks(
     fresh_btn.click(
         fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
     )
-
+    llama_lora_merge_btn.click(
+        fn=llama_lora_merge,
+        inputs=[llama_weight, lora_weight, llama_lora_output],
+        outputs=[train_error],
+    )
     infer_checkbox.change(
         fn=change_infer,
         inputs=[
@@ -815,6 +995,7 @@ with gr.Blocks(
             infer_port_textbox,
             infer_vqgan_model,
             infer_llama_model,
+            infer_llama_config,
             infer_compile,
         ],
         outputs=[infer_error],

+ 5 - 1
tools/webui.py

@@ -1,3 +1,4 @@
+import gc
 import html
 import os
 import threading
@@ -138,6 +139,10 @@ def inference(
 
     fake_audios = fake_audios.float().cpu().numpy()
 
+    if torch.cuda.is_available():
+        torch.cuda.empty_cache()
+        gc.collect()
+
     return (vqgan_model.sampling_rate, fake_audios), None
 
 
@@ -217,7 +222,6 @@ def build_app():
                         )
                         reference_audio = gr.Audio(
                             label="Reference Audio / 参考音频",
-                            value="docs/assets/audios/0_input.wav",
                             type="filepath",
                         )
                         reference_text = gr.Textbox(