from __future__ import annotations import html import json import os import platform import random import shutil import signal import subprocess import sys from pathlib import Path import gradio as gr import psutil import yaml from loguru import logger from tqdm import tqdm from fish_speech.webui.launch_utils import Seafoam, versions_html PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python") sys.path.insert(0, "") print(sys.path) cur_work_dir = Path(os.getcwd()).resolve() print("You are in ", str(cur_work_dir)) config_path = cur_work_dir / "fish_speech" / "configs" vqgan_yml_path = config_path / "vqgan_finetune.yaml" llama_yml_path = config_path / "text2semantic_sft.yaml" env = os.environ.copy() env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0" seafoam = Seafoam() def build_html_error_message(error): return f"""
{html.escape(error)}
""" def build_html_ok_message(msg): return f"""
{html.escape(msg)}
""" def load_data_in_raw(path): with open(path, "r", encoding="utf-8") as file: data = file.read() return str(data) def kill_proc_tree(pid, including_parent=True): try: parent = psutil.Process(pid) except psutil.NoSuchProcess: # Process already terminated return children = parent.children(recursive=True) for child in children: try: os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL except OSError: pass if including_parent: try: os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL except OSError: pass system = platform.system() p_label = None p_infer = None def kill_process(pid): if system == "Windows": cmd = "taskkill /t /f /pid %s" % pid # os.system(cmd) subprocess.run(cmd) else: kill_proc_tree(pid) def change_label(if_label): global p_label if if_label == True and p_label == None: cmd = ["asr-label-win-x64.exe"] yield f"打标工具WebUI已开启, 访问:http://localhost:{3000}" p_label = subprocess.Popen(cmd, shell=True, env=env) elif if_label == False and p_label != None: kill_process(p_label.pid) p_label = None yield "打标工具WebUI已关闭" def change_infer( if_infer, host, port, infer_vqgan_model, infer_llama_model, infer_compile ): global p_infer if if_infer == True and p_infer == None: env = os.environ.copy() env["GRADIO_SERVER_NAME"] = host env["GRADIO_SERVER_PORT"] = port # 启动第二个进程 yield build_html_ok_message(f"推理界面已开启, 访问 http://{host}:{port}") p_infer = subprocess.Popen( [ PYTHON, "tools/webui.py", "--vqgan-checkpoint-path", infer_vqgan_model, "--llama-checkpoint-path", infer_llama_model, "--tokenizer", "checkpoints", ] + (["--compile"] if infer_compile == "Yes" else []), env=env, ) elif if_infer == False and p_infer != None: kill_process(p_infer.pid) p_infer = None yield build_html_error_message("推理界面已关闭") js = load_data_in_raw("fish_speech/webui/js/animate.js") css = load_data_in_raw("fish_speech/webui/css/style.css") data_pre_output = (cur_work_dir / "data").resolve() default_model_output = (cur_work_dir / "results").resolve() default_filelist = data_pre_output / "detect.list" data_pre_output.mkdir(parents=True, exist_ok=True) items = [] dict_items = {} def load_yaml_data_in_fact(yml_path): with open(yml_path, "r", encoding="utf-8") as file: yml = yaml.safe_load(file) return yml def write_yaml_data_in_fact(yml, yml_path): with open(yml_path, "w", encoding="utf-8") as file: yaml.safe_dump(yml, file, allow_unicode=True) return yml def generate_tree(directory, depth=0, max_depth=None, prefix=""): if max_depth is not None and depth > max_depth: return "" tree_str = "" files = [] directories = [] for item in os.listdir(directory): if os.path.isdir(os.path.join(directory, item)): directories.append(item) else: files.append(item) entries = directories + files for i, entry in enumerate(entries): connector = "├── " if i < len(entries) - 1 else "└── " tree_str += f"{prefix}{connector}{entry}
" if i < len(directories): extension = "│ " if i < len(entries) - 1 else " " tree_str += generate_tree( os.path.join(directory, entry), depth + 1, max_depth, prefix=prefix + extension, ) return tree_str def new_explorer(data_path, max_depth): return gr.Markdown( elem_classes=["scrollable-component"], value=generate_tree(data_path, max_depth=max_depth), ) def add_item(folder: str, method: str, filelist: str, label_lang: str): folder = folder.strip(" ").strip('"') filelist = filelist.strip(" ").strip('"') folder_path = Path(folder) filelist_path = Path(filelist) if folder and folder not in items and data_pre_output not in folder_path.parents: if folder_path.is_dir(): items.append(folder) dict_items[folder] = dict( type="folder", method=method, label_lang=label_lang ) elif folder: err = folder return gr.Checkboxgroup(choices=items), build_html_error_message( f"添加文件夹路径无效: {err}" ) if ( filelist and filelist not in items and data_pre_output not in filelist_path.parents ): if filelist_path.is_file(): items.append(filelist) dict_items[filelist] = dict( type="file", method=method, label_lang=label_lang ) elif filelist: err = filelist return gr.Checkboxgroup(choices=items), build_html_error_message( f"添加文件路径无效: {err}" ) formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) logger.info(formatted_data) return gr.Checkboxgroup(choices=items), build_html_ok_message("添加文件(夹)路径成功!") def remove_items(selected_items): global items, dict_items to_remove = [item for item in items if item in selected_items] for item in to_remove: del dict_items[item] items = [item for item in items if item in dict_items.keys()] formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4) logger.info(formatted_data) return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message( "删除文件(夹)路径成功!" ) def show_selected(options): selected_options = ", ".join(options) return f"你选中了: {selected_options}" if options else "你没有选中任何选项" def list_copy(list_file_path, method): wav_root = data_pre_output lst = [] with list_file_path.open("r", encoding="utf-8") as file: for line in tqdm(file, desc="Processing audio/transcript"): wav_path, speaker_name, language, text = line.strip().split("|") original_wav_path = Path(wav_path) target_wav_path = ( wav_root / original_wav_path.parent.name / original_wav_path.name ) lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}") if target_wav_path.is_file(): continue target_wav_path.parent.mkdir(parents=True, exist_ok=True) if method == "复制一份": shutil.copy(original_wav_path, target_wav_path) else: shutil.move(original_wav_path, target_wav_path.parent) original_lab_path = original_wav_path.with_suffix(".lab") target_lab_path = ( wav_root / original_wav_path.parent.name / original_wav_path.with_suffix(".lab").name ) if target_lab_path.is_file(): continue if method == "复制一份": shutil.copy(original_lab_path, target_lab_path) else: shutil.move(original_lab_path, target_lab_path.parent) if method == "直接移动": with list_file_path.open("w", encoding="utf-8") as file: file.writelines("\n".join(lst)) del lst return build_html_ok_message("使用filelist") def check_files(data_path: str, max_depth: int, label_model: str, label_device: str): dict_to_language = {"中文": "ZH", "英文": "EN", "日文": "JP", "不打标": "WTF"} global dict_items data_path = Path(data_path) for item, content in dict_items.items(): item_path = Path(item) tar_path = data_path / item_path.name if content["type"] == "folder" and item_path.is_dir(): cur_lang = dict_to_language[content["label_lang"]] if cur_lang != "WTF": try: subprocess.run( [ PYTHON, "tools/whisper_asr.py", "--model-size", label_model, "--device", label_device, "--audio-dir", item_path, "--save-dir", item_path, "--language", cur_lang, ], env=env, ) except Exception: print("Transcription error occurred") if content["method"] == "复制一份": os.makedirs(tar_path, exist_ok=True) shutil.copytree( src=str(item_path), dst=str(tar_path), dirs_exist_ok=True ) elif not tar_path.is_dir(): shutil.move(src=str(item_path), dst=str(tar_path)) elif content["type"] == "file" and item_path.is_file(): list_copy(item_path, content["method"]) return build_html_ok_message("文件移动完毕"), new_explorer(data_path, max_depth=max_depth) def train_process( data_path: str, option: str, # vq-gan config vqgan_lr, vqgan_maxsteps, vqgan_data_num_workers, vqgan_data_batch_size, vqgan_data_val_batch_size, vqgan_precision, vqgan_check_interval, # llama config llama_lr, llama_maxsteps, llama_limit_val_batches, llama_data_num_workers, llama_data_batch_size, llama_data_max_length, llama_precision, llama_check_interval, llama_grad_batches, llama_use_speaker, ): backend = "nccl" if sys.platform == "linux" else "gloo" if option == "VQGAN" or option == "all": subprocess.run( [ PYTHON, "tools/vqgan/create_train_split.py", str(data_pre_output.relative_to(cur_work_dir)), ] ) train_cmd = [ PYTHON, "fish_speech/train.py", "--config-name", "vqgan_finetune", f"trainer.strategy.process_group_backend={backend}", f"model.optimizer.lr={vqgan_lr}", f"trainer.max_steps={vqgan_maxsteps}", f"data.num_workers={vqgan_data_num_workers}", f"data.batch_size={vqgan_data_batch_size}", f"data.val_batch_size={vqgan_data_val_batch_size}", f"trainer.precision={vqgan_precision}", f"trainer.val_check_interval={vqgan_check_interval}", f"train_dataset.filelist={str(data_pre_output / 'vq_train_filelist.txt')}", f"val_dataset.filelist={str(data_pre_output / 'vq_val_filelist.txt')}", ] logger.info(train_cmd) subprocess.run(train_cmd) if option == "LLAMA" or option == "all": subprocess.run( [ PYTHON, "tools/vqgan/extract_vq.py", str(data_pre_output), "--num-workers", "1", "--batch-size", "16", "--config-name", "vqgan_pretrain", "--checkpoint-path", "checkpoints/vq-gan-group-fsq-2x1024.pth", ] ) subprocess.run( [ PYTHON, "tools/llama/build_dataset.py", "--input", str(data_pre_output), "--text-extension", ".lab", "--num-workers", "16", ] ) train_cmd = [ PYTHON, "fish_speech/train.py", "--config-name", "text2semantic_sft", f"trainer.strategy.process_group_backend={backend}", "model@model.model=dual_ar_2_codebook_medium", "tokenizer.pretrained_model_name_or_path=checkpoints", f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}", f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}", f"model.optimizer.lr={llama_lr}", f"trainer.max_steps={llama_maxsteps}", f"trainer.limit_val_batches={llama_limit_val_batches}", f"data.num_workers={llama_data_num_workers}", f"data.batch_size={llama_data_batch_size}", f"max_length={llama_data_max_length}", f"trainer.precision={llama_precision}", f"trainer.val_check_interval={llama_check_interval}", f"trainer.accumulate_grad_batches={llama_grad_batches}", f"train_dataset.use_speaker={llama_use_speaker}", ] logger.info(train_cmd) subprocess.run(train_cmd) return build_html_ok_message("训练终止") init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path) init_llama_yml = load_yaml_data_in_fact(llama_yml_path) with gr.Blocks( head="", js=js, theme=seafoam, analytics_enabled=False, title="Fish-Speech 鱼语", ) as demo: with gr.Row(): with gr.Column(): with gr.Tab("\U0001F4D6 数据集准备"): with gr.Row(): textbox = gr.Textbox( label="\U0000270F 输入音频&转写源文件夹路径", info="音频装在一个以说话人命名的文件夹内作为区分", interactive=True, ) transcript_path = gr.Textbox( label="\U0001F4DD 转写文本filelist所在路径", info="支持 Bert-Vits2 / GPT-SoVITS 格式", interactive=True, ) with gr.Row(equal_height=False): with gr.Column(): output_radio = gr.Radio( label="\U0001F4C1 选择源文件(夹)处理方式", choices=["复制一份", "直接移动"], value="复制一份", interactive=True, ) with gr.Column(): error = gr.HTML(label="错误信息") if_label = gr.Checkbox( label="是否开启打标WebUI", scale=0, show_label=True ) with gr.Row(): add_button = gr.Button("\U000027A1提交到处理区", variant="primary") remove_button = gr.Button("\U000026D4 取消所选内容") with gr.Row(): label_device = gr.Dropdown( label="打标设备", info="建议使用cuda, 实在是低配置再用cpu", choices=["cpu", "cuda"], value="cuda", interactive=True, ) label_model = gr.Dropdown( label="打标模型大小", info="显存10G以上用large, 5G用medium, 2G用small", choices=["large", "medium", "small"], value="small", interactive=True, ) label_radio = gr.Dropdown( label="(可选)打标语言", info="如果没有音频对应的文本,则进行辅助打标, 支持.txt或.lab格式", choices=["中文", "日文", "英文", "不打标"], value="不打标", interactive=True, ) with gr.Tab("\U0001F6E0 训练配置项"): # hammer with gr.Column(): with gr.Row(): model_type_radio = gr.Radio( label="选择要训练的模型类型", interactive=True, choices=["VQGAN", "LLAMA", "all"], value="all", ) with gr.Row(): with gr.Accordion("VQGAN配置项", open=False): with gr.Row(equal_height=False): vqgan_lr_slider = gr.Slider( label="初始学习率", interactive=True, minimum=1e-5, maximum=1e-4, step=1e-5, value=init_vqgan_yml["model"]["optimizer"]["lr"], ) vqgan_maxsteps_slider = gr.Slider( label="训练最大步数", interactive=True, minimum=1000, maximum=100000, step=1000, value=init_vqgan_yml["trainer"]["max_steps"], ) with gr.Row(equal_height=False): vqgan_data_num_workers_slider = gr.Slider( label="num_workers", interactive=True, minimum=1, maximum=16, step=1, value=init_vqgan_yml["data"]["num_workers"], ) vqgan_data_batch_size_slider = gr.Slider( label="batch_size", interactive=True, minimum=1, maximum=32, step=1, value=init_vqgan_yml["data"]["batch_size"], ) with gr.Row(equal_height=False): vqgan_data_val_batch_size_slider = gr.Slider( label="val_batch_size", interactive=True, minimum=1, maximum=32, step=1, value=init_vqgan_yml["data"]["val_batch_size"], ) vqgan_precision_dropdown = gr.Dropdown( label="训练精度", interactive=True, choices=["32", "bf16-true", "bf16-mixed"], value=str(init_vqgan_yml["trainer"]["precision"]), ) with gr.Row(equal_height=False): vqgan_check_interval_slider = gr.Slider( label="每n步保存一个模型", interactive=True, minimum=500, maximum=10000, step=500, value=init_vqgan_yml["trainer"][ "val_check_interval" ], ) with gr.Row(): with gr.Accordion("LLAMA配置项", open=False): with gr.Row(equal_height=False): llama_lr_slider = gr.Slider( label="初始学习率", interactive=True, minimum=1e-5, maximum=1e-4, step=1e-5, value=init_llama_yml["model"]["optimizer"]["lr"], ) llama_maxsteps_slider = gr.Slider( label="训练最大步数", interactive=True, minimum=1000, maximum=100000, step=1000, value=init_llama_yml["trainer"]["max_steps"], ) with gr.Row(equal_height=False): llama_limit_val_batches_slider = gr.Slider( label="limit_val_batches", interactive=True, minimum=1, maximum=20, step=1, value=init_llama_yml["trainer"][ "limit_val_batches" ], ) llama_data_num_workers_slider = gr.Slider( label="num_workers", minimum=0, maximum=16, step=1, value=init_llama_yml["data"]["num_workers"] if sys.platform == "linux" else 0, ) with gr.Row(equal_height=False): llama_data_batch_size_slider = gr.Slider( label="batch_size", interactive=True, minimum=1, maximum=32, step=1, value=init_llama_yml["data"]["batch_size"], ) llama_data_max_length_slider = gr.Slider( label="max_length", interactive=True, minimum=1024, maximum=4096, step=128, value=init_llama_yml["max_length"], ) with gr.Row(equal_height=False): llama_precision_dropdown = gr.Dropdown( label="训练精度", interactive=True, choices=["32", "bf16-true", "16-mixed"], value="bf16-true", ) llama_check_interval_slider = gr.Slider( label="每n步保存一个模型", interactive=True, minimum=500, maximum=10000, step=500, value=init_llama_yml["trainer"][ "val_check_interval" ], ) with gr.Row(equal_height=False): llama_grad_batches = gr.Slider( label="accumulate_grad_batches", interactive=True, minimum=1, maximum=20, step=1, value=init_llama_yml["trainer"][ "accumulate_grad_batches" ], ) llama_use_speaker = gr.Slider( label="use_speaker_ratio", interactive=True, minimum=0.1, maximum=1.0, step=0.05, value=init_llama_yml["train_dataset"][ "use_speaker" ], ) with gr.Tab("\U0001F9E0 进入推理界面"): with gr.Column(): with gr.Row(): with gr.Accordion(label="\U0001F5A5 推理服务器配置", open=False): with gr.Row(): infer_host_textbox = gr.Textbox( label="Webui启动服务器地址", value="127.0.0.1" ) infer_port_textbox = gr.Textbox( label="Webui启动服务器端口", value="7862" ) with gr.Row(): infer_vqgan_model = gr.Textbox( label="VQGAN模型位置", placeholder="填写pth/ckpt文件路径", value="checkpoints/vq-gan-group-fsq-2x1024.pth", ) with gr.Row(): infer_llama_model = gr.Textbox( label="LLAMA模型位置", placeholder="填写pth/ckpt文件路径", value="checkpoints/text2semantic-medium-v1-2k.pth", ) with gr.Row(): infer_compile = gr.Radio( label="是否编译模型?", choices=["Yes", "No"], value="Yes" ) with gr.Row(): infer_checkbox = gr.Checkbox(label="是否打开推理界面") infer_error = gr.HTML(label="推理界面错误信息") with gr.Column(): train_error = gr.HTML(label="训练时的报错信息") checkbox_group = gr.CheckboxGroup( label="\U0001F4CA 数据源列表", info="左侧输入文件夹所在路径或filelist。无论是否勾选,在此列表中都会被用以后续训练。", elem_classes=["data_src"], ) train_box = gr.Textbox( label="数据预处理文件夹路径", value=str(data_pre_output), interactive=False ) model_box = gr.Textbox( label="\U0001F4BE 模型输出路径", value=str(default_model_output), interactive=False, ) with gr.Accordion( "查看预处理文件夹状态 (滑块为显示深度大小)", elem_classes=["scrollable-component"], elem_id="file_accordion", ): tree_slider = gr.Slider( minimum=0, maximum=3, value=0, step=1, show_label=False, container=False, ) file_markdown = new_explorer(str(data_pre_output), 0) with gr.Row(equal_height=False): admit_btn = gr.Button( "\U00002705 文件预处理", scale=0, min_width=160, variant="primary" ) fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80) help_button = gr.Button("\U00002753", scale=0, min_width=80) # question train_btn = gr.Button("训练启动!", variant="primary") footer = load_data_in_raw("fish_speech/webui/html/footer.html") footer = footer.format( versions=versions_html(), api_docs="https://speech.fish.audio/inference/#http-api", ) gr.HTML(footer, elem_id="footer") add_button.click( fn=add_item, inputs=[textbox, output_radio, transcript_path, label_radio], outputs=[checkbox_group, error], ) remove_button.click( fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error] ) checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error]) help_button.click( fn=None, js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, ' 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}', ) if_label.change(fn=change_label, inputs=[if_label], outputs=[error]) train_btn.click( fn=train_process, inputs=[ train_box, model_type_radio, # vq-gan config vqgan_lr_slider, vqgan_maxsteps_slider, vqgan_data_num_workers_slider, vqgan_data_batch_size_slider, vqgan_data_val_batch_size_slider, vqgan_precision_dropdown, vqgan_check_interval_slider, # llama config llama_lr_slider, llama_maxsteps_slider, llama_limit_val_batches_slider, llama_data_num_workers_slider, llama_data_batch_size_slider, llama_data_max_length_slider, llama_precision_dropdown, llama_check_interval_slider, llama_grad_batches, llama_use_speaker, ], outputs=[train_error], ) admit_btn.click( fn=check_files, inputs=[train_box, tree_slider, label_model, label_device], outputs=[error, file_markdown], ) fresh_btn.click( fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown] ) infer_checkbox.change( fn=change_infer, inputs=[ infer_checkbox, infer_host_textbox, infer_port_textbox, infer_vqgan_model, infer_llama_model, infer_compile, ], outputs=[infer_error], ) demo.launch(inbrowser=True)