| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823 |
- from __future__ import annotations
- import html
- import json
- import os
- import platform
- import random
- import shutil
- import signal
- import subprocess
- import sys
- from pathlib import Path
- import gradio as gr
- import psutil
- import yaml
- from loguru import logger
- from tqdm import tqdm
- from fish_speech.webui.launch_utils import Seafoam, versions_html
- PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
- sys.path.insert(0, "")
- print(sys.path)
- cur_work_dir = Path(os.getcwd()).resolve()
- print("You are in ", str(cur_work_dir))
- config_path = cur_work_dir / "fish_speech" / "configs"
- vqgan_yml_path = config_path / "vqgan_finetune.yaml"
- llama_yml_path = config_path / "text2semantic_sft.yaml"
- env = os.environ.copy()
- env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
- seafoam = Seafoam()
- def build_html_error_message(error):
- return f"""
- <div style="color: red; font-weight: bold;">
- {html.escape(error)}
- </div>
- """
- def build_html_ok_message(msg):
- return f"""
- <div style="color: green; font-weight: bold;">
- {html.escape(msg)}
- </div>
- """
- def load_data_in_raw(path):
- with open(path, "r", encoding="utf-8") as file:
- data = file.read()
- return str(data)
- def kill_proc_tree(pid, including_parent=True):
- try:
- parent = psutil.Process(pid)
- except psutil.NoSuchProcess:
- # Process already terminated
- return
- children = parent.children(recursive=True)
- for child in children:
- try:
- os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
- except OSError:
- pass
- if including_parent:
- try:
- os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
- except OSError:
- pass
- system = platform.system()
- p_label = None
- p_infer = None
- def kill_process(pid):
- if system == "Windows":
- cmd = "taskkill /t /f /pid %s" % pid
- # os.system(cmd)
- subprocess.run(cmd)
- else:
- kill_proc_tree(pid)
- def change_label(if_label):
- global p_label
- if if_label == True and p_label == None:
- cmd = ["asr-label-win-x64.exe"]
- yield f"打标工具WebUI已开启, 访问:http://localhost:{3000}"
- p_label = subprocess.Popen(cmd, shell=True, env=env)
- elif if_label == False and p_label != None:
- kill_process(p_label.pid)
- p_label = None
- yield "打标工具WebUI已关闭"
- def change_infer(
- if_infer, host, port, infer_vqgan_model, infer_llama_model, infer_compile
- ):
- global p_infer
- if if_infer == True and p_infer == None:
- env = os.environ.copy()
- env["GRADIO_SERVER_NAME"] = host
- env["GRADIO_SERVER_PORT"] = port
- # 启动第二个进程
- yield build_html_ok_message(f"推理界面已开启, 访问 http://{host}:{port}")
- p_infer = subprocess.Popen(
- [
- PYTHON,
- "tools/webui.py",
- "--vqgan-checkpoint-path",
- infer_vqgan_model,
- "--llama-checkpoint-path",
- infer_llama_model,
- "--tokenizer",
- "checkpoints",
- ]
- + (["--compile"] if infer_compile == "Yes" else []),
- env=env,
- )
- elif if_infer == False and p_infer != None:
- kill_process(p_infer.pid)
- p_infer = None
- yield build_html_error_message("推理界面已关闭")
- js = load_data_in_raw("fish_speech/webui/js/animate.js")
- css = load_data_in_raw("fish_speech/webui/css/style.css")
- data_pre_output = (cur_work_dir / "data").resolve()
- default_model_output = (cur_work_dir / "results").resolve()
- default_filelist = data_pre_output / "detect.list"
- data_pre_output.mkdir(parents=True, exist_ok=True)
- items = []
- dict_items = {}
- def load_yaml_data_in_fact(yml_path):
- with open(yml_path, "r", encoding="utf-8") as file:
- yml = yaml.safe_load(file)
- return yml
- def write_yaml_data_in_fact(yml, yml_path):
- with open(yml_path, "w", encoding="utf-8") as file:
- yaml.safe_dump(yml, file, allow_unicode=True)
- return yml
- def generate_tree(directory, depth=0, max_depth=None, prefix=""):
- if max_depth is not None and depth > max_depth:
- return ""
- tree_str = ""
- files = []
- directories = []
- for item in os.listdir(directory):
- if os.path.isdir(os.path.join(directory, item)):
- directories.append(item)
- else:
- files.append(item)
- entries = directories + files
- for i, entry in enumerate(entries):
- connector = "├── " if i < len(entries) - 1 else "└── "
- tree_str += f"{prefix}{connector}{entry}<br />"
- if i < len(directories):
- extension = "│ " if i < len(entries) - 1 else " "
- tree_str += generate_tree(
- os.path.join(directory, entry),
- depth + 1,
- max_depth,
- prefix=prefix + extension,
- )
- return tree_str
- def new_explorer(data_path, max_depth):
- return gr.Markdown(
- elem_classes=["scrollable-component"],
- value=generate_tree(data_path, max_depth=max_depth),
- )
- def add_item(folder: str, method: str, filelist: str, label_lang: str):
- folder = folder.strip(" ").strip('"')
- filelist = filelist.strip(" ").strip('"')
- folder_path = Path(folder)
- filelist_path = Path(filelist)
- if folder and folder not in items and data_pre_output not in folder_path.parents:
- if folder_path.is_dir():
- items.append(folder)
- dict_items[folder] = dict(
- type="folder", method=method, label_lang=label_lang
- )
- elif folder:
- err = folder
- return gr.Checkboxgroup(choices=items), build_html_error_message(
- f"添加文件夹路径无效: {err}"
- )
- if (
- filelist
- and filelist not in items
- and data_pre_output not in filelist_path.parents
- ):
- if filelist_path.is_file():
- items.append(filelist)
- dict_items[filelist] = dict(
- type="file", method=method, label_lang=label_lang
- )
- elif filelist:
- err = filelist
- return gr.Checkboxgroup(choices=items), build_html_error_message(
- f"添加文件路径无效: {err}"
- )
- formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
- logger.info(formatted_data)
- return gr.Checkboxgroup(choices=items), build_html_ok_message("添加文件(夹)路径成功!")
- def remove_items(selected_items):
- global items, dict_items
- to_remove = [item for item in items if item in selected_items]
- for item in to_remove:
- del dict_items[item]
- items = [item for item in items if item in dict_items.keys()]
- formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
- logger.info(formatted_data)
- return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
- "删除文件(夹)路径成功!"
- )
- def show_selected(options):
- selected_options = ", ".join(options)
- return f"你选中了: {selected_options}" if options else "你没有选中任何选项"
- def list_copy(list_file_path, method):
- wav_root = data_pre_output
- lst = []
- with list_file_path.open("r", encoding="utf-8") as file:
- for line in tqdm(file, desc="Processing audio/transcript"):
- wav_path, speaker_name, language, text = line.strip().split("|")
- original_wav_path = Path(wav_path)
- target_wav_path = (
- wav_root / original_wav_path.parent.name / original_wav_path.name
- )
- lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
- if target_wav_path.is_file():
- continue
- target_wav_path.parent.mkdir(parents=True, exist_ok=True)
- if method == "复制一份":
- shutil.copy(original_wav_path, target_wav_path)
- else:
- shutil.move(original_wav_path, target_wav_path.parent)
- original_lab_path = original_wav_path.with_suffix(".lab")
- target_lab_path = (
- wav_root
- / original_wav_path.parent.name
- / original_wav_path.with_suffix(".lab").name
- )
- if target_lab_path.is_file():
- continue
- if method == "复制一份":
- shutil.copy(original_lab_path, target_lab_path)
- else:
- shutil.move(original_lab_path, target_lab_path.parent)
- if method == "直接移动":
- with list_file_path.open("w", encoding="utf-8") as file:
- file.writelines("\n".join(lst))
- del lst
- return build_html_ok_message("使用filelist")
- def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
- dict_to_language = {"中文": "ZH", "英文": "EN", "日文": "JP", "不打标": "WTF"}
- global dict_items
- data_path = Path(data_path)
- for item, content in dict_items.items():
- item_path = Path(item)
- tar_path = data_path / item_path.name
- if content["type"] == "folder" and item_path.is_dir():
- cur_lang = dict_to_language[content["label_lang"]]
- if cur_lang != "WTF":
- try:
- subprocess.run(
- [
- PYTHON,
- "tools/whisper_asr.py",
- "--model-size",
- label_model,
- "--device",
- label_device,
- "--audio-dir",
- item_path,
- "--save-dir",
- item_path,
- "--language",
- cur_lang,
- ],
- env=env,
- )
- except Exception:
- print("Transcription error occurred")
- if content["method"] == "复制一份":
- os.makedirs(tar_path, exist_ok=True)
- shutil.copytree(
- src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
- )
- elif not tar_path.is_dir():
- shutil.move(src=str(item_path), dst=str(tar_path))
- elif content["type"] == "file" and item_path.is_file():
- list_copy(item_path, content["method"])
- return build_html_ok_message("文件移动完毕"), new_explorer(data_path, max_depth=max_depth)
- def train_process(
- data_path: str,
- option: str,
- # vq-gan config
- vqgan_lr,
- vqgan_maxsteps,
- vqgan_data_num_workers,
- vqgan_data_batch_size,
- vqgan_data_val_batch_size,
- vqgan_precision,
- vqgan_check_interval,
- # llama config
- llama_lr,
- llama_maxsteps,
- llama_limit_val_batches,
- llama_data_num_workers,
- llama_data_batch_size,
- llama_data_max_length,
- llama_precision,
- llama_check_interval,
- llama_grad_batches,
- llama_use_speaker,
- ):
- backend = "nccl" if sys.platform == "linux" else "gloo"
- if option == "VQGAN" or option == "all":
- subprocess.run(
- [
- PYTHON,
- "tools/vqgan/create_train_split.py",
- str(data_pre_output.relative_to(cur_work_dir)),
- ]
- )
- train_cmd = [
- PYTHON,
- "fish_speech/train.py",
- "--config-name",
- "vqgan_finetune",
- f"trainer.strategy.process_group_backend={backend}",
- f"model.optimizer.lr={vqgan_lr}",
- f"trainer.max_steps={vqgan_maxsteps}",
- f"data.num_workers={vqgan_data_num_workers}",
- f"data.batch_size={vqgan_data_batch_size}",
- f"data.val_batch_size={vqgan_data_val_batch_size}",
- f"trainer.precision={vqgan_precision}",
- f"trainer.val_check_interval={vqgan_check_interval}",
- f"train_dataset.filelist={str(data_pre_output / 'vq_train_filelist.txt')}",
- f"val_dataset.filelist={str(data_pre_output / 'vq_val_filelist.txt')}",
- ]
- logger.info(train_cmd)
- subprocess.run(train_cmd)
- if option == "LLAMA" or option == "all":
- subprocess.run(
- [
- PYTHON,
- "tools/vqgan/extract_vq.py",
- str(data_pre_output),
- "--num-workers",
- "1",
- "--batch-size",
- "16",
- "--config-name",
- "vqgan_pretrain",
- "--checkpoint-path",
- "checkpoints/vq-gan-group-fsq-2x1024.pth",
- ]
- )
- subprocess.run(
- [
- PYTHON,
- "tools/llama/build_dataset.py",
- "--input",
- str(data_pre_output),
- "--text-extension",
- ".lab",
- "--num-workers",
- "16",
- ]
- )
- train_cmd = [
- PYTHON,
- "fish_speech/train.py",
- "--config-name",
- "text2semantic_sft",
- f"trainer.strategy.process_group_backend={backend}",
- "model@model.model=dual_ar_2_codebook_medium",
- "tokenizer.pretrained_model_name_or_path=checkpoints",
- f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
- f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
- f"model.optimizer.lr={llama_lr}",
- f"trainer.max_steps={llama_maxsteps}",
- f"trainer.limit_val_batches={llama_limit_val_batches}",
- f"data.num_workers={llama_data_num_workers}",
- f"data.batch_size={llama_data_batch_size}",
- f"max_length={llama_data_max_length}",
- f"trainer.precision={llama_precision}",
- f"trainer.val_check_interval={llama_check_interval}",
- f"trainer.accumulate_grad_batches={llama_grad_batches}",
- f"train_dataset.use_speaker={llama_use_speaker}",
- ]
- logger.info(train_cmd)
- subprocess.run(train_cmd)
- return build_html_ok_message("训练终止")
- init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
- init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
- with gr.Blocks(
- head="<style>\n" + css + "\n</style>",
- js=js,
- theme=seafoam,
- analytics_enabled=False,
- title="Fish-Speech 鱼语",
- ) as demo:
- with gr.Row():
- with gr.Column():
- with gr.Tab("\U0001F4D6 数据集准备"):
- with gr.Row():
- textbox = gr.Textbox(
- label="\U0000270F 输入音频&转写源文件夹路径",
- info="音频装在一个以说话人命名的文件夹内作为区分",
- interactive=True,
- )
- transcript_path = gr.Textbox(
- label="\U0001F4DD 转写文本filelist所在路径",
- info="支持 Bert-Vits2 / GPT-SoVITS 格式",
- interactive=True,
- )
- with gr.Row(equal_height=False):
- with gr.Column():
- output_radio = gr.Radio(
- label="\U0001F4C1 选择源文件(夹)处理方式",
- choices=["复制一份", "直接移动"],
- value="复制一份",
- interactive=True,
- )
- with gr.Column():
- error = gr.HTML(label="错误信息")
- if_label = gr.Checkbox(
- label="是否开启打标WebUI", scale=0, show_label=True
- )
- with gr.Row():
- add_button = gr.Button("\U000027A1提交到处理区", variant="primary")
- remove_button = gr.Button("\U000026D4 取消所选内容")
- with gr.Row():
- label_device = gr.Dropdown(
- label="打标设备",
- info="建议使用cuda, 实在是低配置再用cpu",
- choices=["cpu", "cuda"],
- value="cuda",
- interactive=True,
- )
- label_model = gr.Dropdown(
- label="打标模型大小",
- info="显存10G以上用large, 5G用medium, 2G用small",
- choices=["large", "medium", "small"],
- value="small",
- interactive=True,
- )
- label_radio = gr.Dropdown(
- label="(可选)打标语言",
- info="如果没有音频对应的文本,则进行辅助打标, 支持.txt或.lab格式",
- choices=["中文", "日文", "英文", "不打标"],
- value="不打标",
- interactive=True,
- )
- with gr.Tab("\U0001F6E0 训练配置项"): # hammer
- with gr.Column():
- with gr.Row():
- model_type_radio = gr.Radio(
- label="选择要训练的模型类型",
- interactive=True,
- choices=["VQGAN", "LLAMA", "all"],
- value="all",
- )
- with gr.Row():
- with gr.Accordion("VQGAN配置项", open=False):
- with gr.Row(equal_height=False):
- vqgan_lr_slider = gr.Slider(
- label="初始学习率",
- interactive=True,
- minimum=1e-5,
- maximum=1e-4,
- step=1e-5,
- value=init_vqgan_yml["model"]["optimizer"]["lr"],
- )
- vqgan_maxsteps_slider = gr.Slider(
- label="训练最大步数",
- interactive=True,
- minimum=1000,
- maximum=100000,
- step=1000,
- value=init_vqgan_yml["trainer"]["max_steps"],
- )
- with gr.Row(equal_height=False):
- vqgan_data_num_workers_slider = gr.Slider(
- label="num_workers",
- interactive=True,
- minimum=1,
- maximum=16,
- step=1,
- value=init_vqgan_yml["data"]["num_workers"],
- )
- vqgan_data_batch_size_slider = gr.Slider(
- label="batch_size",
- interactive=True,
- minimum=1,
- maximum=32,
- step=1,
- value=init_vqgan_yml["data"]["batch_size"],
- )
- with gr.Row(equal_height=False):
- vqgan_data_val_batch_size_slider = gr.Slider(
- label="val_batch_size",
- interactive=True,
- minimum=1,
- maximum=32,
- step=1,
- value=init_vqgan_yml["data"]["val_batch_size"],
- )
- vqgan_precision_dropdown = gr.Dropdown(
- label="训练精度",
- interactive=True,
- choices=["32", "bf16-true", "bf16-mixed"],
- value=str(init_vqgan_yml["trainer"]["precision"]),
- )
- with gr.Row(equal_height=False):
- vqgan_check_interval_slider = gr.Slider(
- label="每n步保存一个模型",
- interactive=True,
- minimum=500,
- maximum=10000,
- step=500,
- value=init_vqgan_yml["trainer"][
- "val_check_interval"
- ],
- )
- with gr.Row():
- with gr.Accordion("LLAMA配置项", open=False):
- with gr.Row(equal_height=False):
- llama_lr_slider = gr.Slider(
- label="初始学习率",
- interactive=True,
- minimum=1e-5,
- maximum=1e-4,
- step=1e-5,
- value=init_llama_yml["model"]["optimizer"]["lr"],
- )
- llama_maxsteps_slider = gr.Slider(
- label="训练最大步数",
- interactive=True,
- minimum=1000,
- maximum=100000,
- step=1000,
- value=init_llama_yml["trainer"]["max_steps"],
- )
- with gr.Row(equal_height=False):
- llama_limit_val_batches_slider = gr.Slider(
- label="limit_val_batches",
- interactive=True,
- minimum=1,
- maximum=20,
- step=1,
- value=init_llama_yml["trainer"][
- "limit_val_batches"
- ],
- )
- llama_data_num_workers_slider = gr.Slider(
- label="num_workers",
- minimum=0,
- maximum=16,
- step=1,
- value=init_llama_yml["data"]["num_workers"]
- if sys.platform == "linux"
- else 0,
- )
- with gr.Row(equal_height=False):
- llama_data_batch_size_slider = gr.Slider(
- label="batch_size",
- interactive=True,
- minimum=1,
- maximum=32,
- step=1,
- value=init_llama_yml["data"]["batch_size"],
- )
- llama_data_max_length_slider = gr.Slider(
- label="max_length",
- interactive=True,
- minimum=1024,
- maximum=4096,
- step=128,
- value=init_llama_yml["max_length"],
- )
- with gr.Row(equal_height=False):
- llama_precision_dropdown = gr.Dropdown(
- label="训练精度",
- interactive=True,
- choices=["32", "bf16-true", "16-mixed"],
- value="bf16-true",
- )
- llama_check_interval_slider = gr.Slider(
- label="每n步保存一个模型",
- interactive=True,
- minimum=500,
- maximum=10000,
- step=500,
- value=init_llama_yml["trainer"][
- "val_check_interval"
- ],
- )
- with gr.Row(equal_height=False):
- llama_grad_batches = gr.Slider(
- label="accumulate_grad_batches",
- interactive=True,
- minimum=1,
- maximum=20,
- step=1,
- value=init_llama_yml["trainer"][
- "accumulate_grad_batches"
- ],
- )
- llama_use_speaker = gr.Slider(
- label="use_speaker_ratio",
- interactive=True,
- minimum=0.1,
- maximum=1.0,
- step=0.05,
- value=init_llama_yml["train_dataset"][
- "use_speaker"
- ],
- )
- with gr.Tab("\U0001F9E0 进入推理界面"):
- with gr.Column():
- with gr.Row():
- with gr.Accordion(label="\U0001F5A5 推理服务器配置", open=False):
- with gr.Row():
- infer_host_textbox = gr.Textbox(
- label="Webui启动服务器地址", value="127.0.0.1"
- )
- infer_port_textbox = gr.Textbox(
- label="Webui启动服务器端口", value="7862"
- )
- with gr.Row():
- infer_vqgan_model = gr.Textbox(
- label="VQGAN模型位置",
- placeholder="填写pth/ckpt文件路径",
- value="checkpoints/vq-gan-group-fsq-2x1024.pth",
- )
- with gr.Row():
- infer_llama_model = gr.Textbox(
- label="LLAMA模型位置",
- placeholder="填写pth/ckpt文件路径",
- value="checkpoints/text2semantic-medium-v1-2k.pth",
- )
- with gr.Row():
- infer_compile = gr.Radio(
- label="是否编译模型?", choices=["Yes", "No"], value="Yes"
- )
- with gr.Row():
- infer_checkbox = gr.Checkbox(label="是否打开推理界面")
- infer_error = gr.HTML(label="推理界面错误信息")
- with gr.Column():
- train_error = gr.HTML(label="训练时的报错信息")
- checkbox_group = gr.CheckboxGroup(
- label="\U0001F4CA 数据源列表",
- info="左侧输入文件夹所在路径或filelist。无论是否勾选,在此列表中都会被用以后续训练。",
- elem_classes=["data_src"],
- )
- train_box = gr.Textbox(
- label="数据预处理文件夹路径", value=str(data_pre_output), interactive=False
- )
- model_box = gr.Textbox(
- label="\U0001F4BE 模型输出路径",
- value=str(default_model_output),
- interactive=False,
- )
- with gr.Accordion(
- "查看预处理文件夹状态 (滑块为显示深度大小)",
- elem_classes=["scrollable-component"],
- elem_id="file_accordion",
- ):
- tree_slider = gr.Slider(
- minimum=0,
- maximum=3,
- value=0,
- step=1,
- show_label=False,
- container=False,
- )
- file_markdown = new_explorer(str(data_pre_output), 0)
- with gr.Row(equal_height=False):
- admit_btn = gr.Button(
- "\U00002705 文件预处理", scale=0, min_width=160, variant="primary"
- )
- fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
- help_button = gr.Button("\U00002753", scale=0, min_width=80) # question
- train_btn = gr.Button("训练启动!", variant="primary")
- footer = load_data_in_raw("fish_speech/webui/html/footer.html")
- footer = footer.format(
- versions=versions_html(),
- api_docs="https://speech.fish.audio/inference/#http-api",
- )
- gr.HTML(footer, elem_id="footer")
- add_button.click(
- fn=add_item,
- inputs=[textbox, output_radio, transcript_path, label_radio],
- outputs=[checkbox_group, error],
- )
- remove_button.click(
- fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
- )
- checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
- help_button.click(
- fn=None,
- js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
- 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
- )
- if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
- train_btn.click(
- fn=train_process,
- inputs=[
- train_box,
- model_type_radio,
- # vq-gan config
- vqgan_lr_slider,
- vqgan_maxsteps_slider,
- vqgan_data_num_workers_slider,
- vqgan_data_batch_size_slider,
- vqgan_data_val_batch_size_slider,
- vqgan_precision_dropdown,
- vqgan_check_interval_slider,
- # llama config
- llama_lr_slider,
- llama_maxsteps_slider,
- llama_limit_val_batches_slider,
- llama_data_num_workers_slider,
- llama_data_batch_size_slider,
- llama_data_max_length_slider,
- llama_precision_dropdown,
- llama_check_interval_slider,
- llama_grad_batches,
- llama_use_speaker,
- ],
- outputs=[train_error],
- )
- admit_btn.click(
- fn=check_files,
- inputs=[train_box, tree_slider, label_model, label_device],
- outputs=[error, file_markdown],
- )
- fresh_btn.click(
- fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
- )
- infer_checkbox.change(
- fn=change_infer,
- inputs=[
- infer_checkbox,
- infer_host_textbox,
- infer_port_textbox,
- infer_vqgan_model,
- infer_llama_model,
- infer_compile,
- ],
- outputs=[infer_error],
- )
- demo.launch(inbrowser=True)
|