Преглед изворни кода

Implement webui for training & annotating (#138)

* init package

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix spelling

* Decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Specify the backend on the command line

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama пре 1 година
родитељ
комит
40afd74225

+ 3 - 0
.gitignore

@@ -17,3 +17,6 @@ filelists
 /results
 /data
 /.idea
+ffmpeg.exe
+asr-label-win-x64.exe
+/.cache

+ 3 - 0
fish_speech/configs/text2semantic_sft.yaml

@@ -17,6 +17,9 @@ trainer:
   precision: bf16-true
   limit_val_batches: 10
   val_check_interval: 500
+  strategy:
+    _target_: lightning.pytorch.strategies.DDPStrategy
+    process_group_backend: nccl  # This should be override when training on windows
 
 # Dataset Configuration
 tokenizer:

+ 2 - 0
fish_speech/configs/vqgan_finetune.yaml

@@ -14,6 +14,8 @@ trainer:
   max_steps: 100_000
   val_check_interval: 5000
   strategy:
+    _target_: lightning.pytorch.strategies.DDPStrategy
+    process_group_backend: nccl  # This should be override when training on windows
     find_unused_parameters: true
 
 sample_rate: 44100

+ 3 - 1
fish_speech/datasets/vqgan.py

@@ -28,7 +28,7 @@ class VQGANDataset(Dataset):
 
         self.files = [
             root / line.strip()
-            for line in filelist.read_text().splitlines()
+            for line in filelist.read_text(encoding="utf-8").splitlines()
             if line.strip()
         ]
         self.sample_rate = sample_rate
@@ -120,6 +120,7 @@ class VQGANDataModule(LightningDataModule):
             collate_fn=VQGANCollator(),
             num_workers=self.num_workers,
             shuffle=True,
+            persistent_workers=True,
         )
 
     def val_dataloader(self):
@@ -128,6 +129,7 @@ class VQGANDataModule(LightningDataModule):
             batch_size=self.val_batch_size,
             collate_fn=VQGANCollator(),
             num_workers=self.num_workers,
+            persistent_workers=True,
         )
 
 

+ 5 - 1
fish_speech/train.py

@@ -1,4 +1,5 @@
 import os
+import sys
 from typing import Optional
 
 import hydra
@@ -7,6 +8,7 @@ import pyrootutils
 import torch
 from lightning import Callback, LightningDataModule, LightningModule, Trainer
 from lightning.pytorch.loggers import Logger
+from lightning.pytorch.strategies import DDPStrategy
 from omegaconf import DictConfig, OmegaConf
 
 os.environ.pop("SLURM_NTASKS", None)
@@ -61,7 +63,9 @@ def train(cfg: DictConfig) -> tuple[dict, dict]:
 
     log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
     trainer: Trainer = hydra.utils.instantiate(
-        cfg.trainer, callbacks=callbacks, logger=logger
+        cfg.trainer,
+        callbacks=callbacks,
+        logger=logger,
     )
 
     object_dict = {

+ 161 - 0
fish_speech/webui/css/style.css

@@ -0,0 +1,161 @@
+:root {
+  --my-200: #80eeee;
+  --my-50: #ecfdf5;
+  --water-width: 300px;
+  --water-heigh: 300px;
+}
+
+
+/* general styled components */
+.tools {
+  align-items: center;
+  justify-content: center;
+}
+
+.gradio-button {
+    max-width: 2.2em;
+    min-width: 2.2em !important;
+    height: 2.4em;
+    align-self: end;
+    line-height: 1em;
+    border-radius: 0.5em;
+
+}
+
+.gradio-button.secondary-down, .gradio-button.secondary-down:hover{
+    box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset;
+}
+
+/* replace original footer with ours */
+a{
+    font-weight: bold;
+    cursor: pointer;
+    color: #030C14 !important;
+}
+
+footer {
+    display: none !important;
+}
+
+#footer{
+    text-align: center;
+}
+
+#footer div{
+    display: inline-block;
+}
+
+#footer .versions{
+    font-size: 85%;
+    opacity: 0.85;
+}
+
+/*@keyframes moveBackground {*/
+/*  0% {*/
+/*    background-position: 0 0;*/
+/*  }*/
+/*  100% {*/
+/*    background-position: -100px 100px;*/
+/*  }*/
+/*}*/
+@keyframes moveJellyBackground {
+  0% {
+    background-position: 0% 50%;
+  }
+  50% {
+    background-position: 100% 50%;
+  }
+  100% {
+    background-position: 0% 50%;
+  }
+}
+
+.gradio-container {
+  position: absolute;
+  z-index: 10;
+}
+
+
+.quan {
+  position: absolute;
+  bottom: 0;
+  width: var(--water-width);
+  height: var(--water-heigh);
+  border-radius: 0;
+  /*border: 3px solid rgb(246, 247, 248);*/
+  /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/
+  z-index: 0;
+
+}
+
+.quan:last-child {
+  margin-right: 0;
+}
+
+.shui {
+  position: absolute;
+  top: 0;
+  left: 0;
+  width: 100%;
+  height: 100%;
+  background-color: rgb(23, 106, 201);
+  border-radius: 0;
+  overflow: hidden;
+  z-index: 0;
+}
+
+.shui::after {
+
+  content: '';
+  position: absolute;
+  top: 20%;
+  left: 50%;
+  width: 150%;
+  height: 150%;
+  border-radius: 40%;
+  background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%);
+  animation: shi 5s linear infinite;
+}
+
+@keyframes shi {
+  0% {
+    transform: translate(-50%, -65%) rotate(0deg);
+  }
+  100% {
+    transform: translate(-50%, -65%) rotate(360deg);
+  }
+}
+
+.shui::before {
+  content: '';
+  position: absolute;
+  top: 20%;
+  left: 50%;
+  width: 150%;
+  height: 150%;
+  border-radius: 42%;
+  background-color: rgb(240, 228, 228, 0.2);
+  animation: xu 7s linear infinite;
+}
+
+@keyframes xu {
+  0% {
+    transform: translate(-50%, -60%) rotate(0deg);
+  }
+  100% {
+    transform: translate(-50%, -60%) rotate(360deg);
+  }
+}
+
+fieldset.data_src div.wrap label {
+  background: #f8bffee0 !important;
+}
+
+.scrollable-component {
+  max-height: 100px;
+  overflow-y: auto;
+}
+
+#file_accordion {
+  max-height: 220px !important;
+}

+ 11 - 0
fish_speech/webui/html/footer.html

@@ -0,0 +1,11 @@
+<div style="color: rgba(25,255,205,0.7) !important;">
+        <a href="{api_docs}">API</a>
+         • 
+        <a href="https://github.com/AnyaCoder/fish-speech">Github</a>
+         • 
+        <a href="https://gradio.app">Gradio</a>
+</div>
+<br />
+<div class="versions" style="color: rgba(25,255,205,0.7) !important;">
+{versions}
+</div>

+ 71 - 0
fish_speech/webui/js/animate.js

@@ -0,0 +1,71 @@
+
+function createGradioAnimation() {
+    const params = new URLSearchParams(window.location.search);
+    if (!params.has('__theme')) {
+        params.set('__theme', 'light');
+        window.location.search = params.toString();
+    }
+
+    var gradioApp = document.querySelector('gradio-app');
+    if (gradioApp) {
+
+        document.documentElement.style.setProperty('--my-200', '#80eeee');
+        document.documentElement.style.setProperty('--my-50', '#ecfdf5');
+
+        gradioApp.style.position = 'relative';
+        gradioApp.style.backgroundSize = '200% 200%';
+        gradioApp.style.animation = 'moveJellyBackground 10s ease infinite';
+        gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)';
+        gradioApp.style.display = 'flex';
+        gradioApp.style.justifyContent = 'flex-start';
+        gradioApp.style.flexWrap = 'nowrap';
+        gradioApp.style.overflowX = 'auto';
+
+        for (let i = 0; i < 6; i++) {
+            var quan = document.createElement('div');
+            quan.className = 'quan';
+            gradioApp.insertBefore(quan, gradioApp.firstChild);
+            quan.id = 'quan' + i.toString();
+            quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')';
+            var quanContainer = document.querySelector('.quan');
+            if (quanContainer) {
+                var shui = document.createElement('div');
+                shui.className = 'shui';
+                quanContainer.insertBefore(shui, quanContainer.firstChild)
+            }
+        }
+
+
+    }
+
+    var container = document.createElement('div');
+    container.id = 'gradio-animation';
+    container.style.fontSize = '2em';
+    container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace';
+    container.style.fontWeight = 'bold';
+    container.style.textAlign = 'center';
+    container.style.marginBottom = '20px';
+
+    var text = 'Welcome to Fish-Speech!';
+    for (var i = 0; i < text.length; i++) {
+        (function(i){
+            setTimeout(function(){
+                var letter = document.createElement('span');
+                letter.style.opacity = '0';
+                letter.style.transition = 'opacity 0.5s';
+                letter.innerText = text[i];
+
+                container.appendChild(letter);
+
+                setTimeout(function() {
+                    letter.style.opacity = '1';
+                }, 50);
+            }, i * 200);
+        })(i);
+    }
+
+    var gradioContainer = document.querySelector('.gradio-container');
+    gradioContainer.insertBefore(container, gradioContainer.firstChild);
+
+    return 'Animation created';
+}

+ 118 - 0
fish_speech/webui/launch_utils.py

@@ -0,0 +1,118 @@
+import os
+import subprocess
+import sys
+from functools import lru_cache
+from pathlib import Path
+from typing import Iterable
+
+import gradio as gr
+from gradio.themes.base import Base
+from gradio.themes.utils import colors, fonts, sizes
+
+GIT = (
+    (Path(os.environ.get("GIT_HOME", "")) / "git").resolve()
+    if sys.platform == "win32"
+    else "git"
+)
+GIT = str(GIT)
+
+
+@lru_cache()
+def commit_hash():
+    try:
+        return subprocess.check_output(
+            [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8"
+        ).strip()
+    except Exception:
+        return "<none>"
+
+
+def versions_html():
+    import torch
+
+    python_version = ".".join([str(x) for x in sys.version_info[0:3]])
+    commit = commit_hash()
+    hash = commit.strip("'").split(" ")[0]
+
+    return f"""
+version: <a href="https://github.com/AnyaCoder/fish-speech/commit/{hash}">{hash}</a>
+&#x2000;•&#x2000;
+python: <span title="{sys.version}">{python_version}</span>
+&#x2000;•&#x2000;
+torch: {getattr(torch, '__long_version__',torch.__version__)}
+&#x2000;•&#x2000;
+gradio: {gr.__version__}
+&#x2000;•&#x2000;
+author: <a href="https://github.com/AnyaCoder">laziman/AnyaCoder</a>
+"""
+
+
+def version_check(commit):
+    try:
+        import requests
+
+        commits = requests.get(
+            "https://api.github.com/repos/AnyaCoder/fish-speech/branches/main"
+        ).json()
+        if commit != "<none>" and commits["commit"]["sha"] != commit:
+            print("--------------------------------------------------------")
+            print("| You are not up to date with the most recent release. |")
+            print("| Consider running `git pull` to update.               |")
+            print("--------------------------------------------------------")
+        elif commits["commit"]["sha"] == commit:
+            print("You are up to date with the most recent release.")
+        else:
+            print("Not a git clone, can't perform version check.")
+    except Exception as e:
+        print("version check failed", e)
+
+
+class Seafoam(Base):
+    def __init__(
+        self,
+        *,
+        primary_hue: colors.Color | str = colors.emerald,
+        secondary_hue: colors.Color | str = colors.blue,
+        neutral_hue: colors.Color | str = colors.blue,
+        spacing_size: sizes.Size | str = sizes.spacing_md,
+        radius_size: sizes.Size | str = sizes.radius_md,
+        text_size: sizes.Size | str = sizes.text_lg,
+        font: fonts.Font
+        | str
+        | Iterable[fonts.Font | str] = (
+            fonts.GoogleFont("Quicksand"),
+            "ui-sans-serif",
+            "sans-serif",
+        ),
+        font_mono: fonts.Font
+        | str
+        | Iterable[fonts.Font | str] = (
+            fonts.GoogleFont("IBM Plex Mono"),
+            "ui-monospace",
+            "monospace",
+        ),
+    ):
+        super().__init__(
+            primary_hue=primary_hue,
+            secondary_hue=secondary_hue,
+            neutral_hue=neutral_hue,
+            spacing_size=spacing_size,
+            radius_size=radius_size,
+            text_size=text_size,
+            font=font,
+            font_mono=font_mono,
+        )
+        super().set(
+            button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
+            button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
+            button_primary_text_color="white",
+            button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
+            slider_color="*secondary_300",
+            slider_color_dark="*secondary_600",
+            block_title_text_weight="600",
+            block_border_width="3px",
+            block_shadow="*shadow_drop_lg",
+            button_shadow="*shadow_drop_lg",
+            button_small_padding="0px",
+            button_large_padding="3px",
+        )

+ 797 - 0
fish_speech/webui/manage.py

@@ -0,0 +1,797 @@
+from __future__ import annotations
+
+import html
+import json
+import os
+import platform
+import random
+import shutil
+import signal
+import subprocess
+import sys
+from pathlib import Path
+
+import gradio as gr
+import psutil
+import yaml
+from loguru import logger
+from tqdm import tqdm
+
+from fish_speech.webui.launch_utils import Seafoam, versions_html
+
+PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
+sys.path.insert(0, "")
+print(sys.path)
+cur_work_dir = Path(os.getcwd()).resolve()
+print("You are in ", str(cur_work_dir))
+config_path = cur_work_dir / "fish_speech" / "configs"
+vqgan_yml_path = config_path / "vqgan_finetune.yaml"
+llama_yml_path = config_path / "text2semantic_sft.yaml"
+
+env = os.environ.copy()
+env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
+
+seafoam = Seafoam()
+
+
+def build_html_error_message(error):
+    return f"""
+    <div style="color: red; font-weight: bold;">
+        {html.escape(error)}
+    </div>
+    """
+
+
+def build_html_ok_message(msg):
+    return f"""
+    <div style="color: green; font-weight: bold;">
+        {html.escape(msg)}
+    </div>
+    """
+
+
+def load_data_in_raw(path):
+    with open(path, "r", encoding="utf-8") as file:
+        data = file.read()
+    return str(data)
+
+
+def kill_proc_tree(pid, including_parent=True):
+    try:
+        parent = psutil.Process(pid)
+    except psutil.NoSuchProcess:
+        # Process already terminated
+        return
+
+    children = parent.children(recursive=True)
+    for child in children:
+        try:
+            os.kill(child.pid, signal.SIGTERM)  # or signal.SIGKILL
+        except OSError:
+            pass
+    if including_parent:
+        try:
+            os.kill(parent.pid, signal.SIGTERM)  # or signal.SIGKILL
+        except OSError:
+            pass
+
+
+system = platform.system()
+p_label = None
+p_infer = None
+
+
+def kill_process(pid):
+    if system == "Windows":
+        cmd = "taskkill /t /f /pid %s" % pid
+        # os.system(cmd)
+        subprocess.run(cmd)
+    else:
+        kill_proc_tree(pid)
+
+
+def change_label(if_label):
+    global p_label
+    if if_label == True and p_label == None:
+        cmd = ["asr-label-win-x64.exe"]
+        yield f"打标工具WebUI已开启, 访问:http://localhost:{3000}"
+        p_label = subprocess.Popen(cmd, shell=True, env=env)
+    elif if_label == False and p_label != None:
+        kill_process(p_label.pid)
+        p_label = None
+        yield "打标工具WebUI已关闭"
+
+
+def change_infer(
+    if_infer, host, port, infer_vqgan_model, infer_llama_model, infer_compile
+):
+    global p_infer
+    if if_infer == True and p_infer == None:
+        env = os.environ.copy()
+
+        env["GRADIO_SERVER_NAME"] = host
+        env["GRADIO_SERVER_PORT"] = port
+        # 启动第二个进程
+        yield build_html_ok_message(f"推理界面已开启, 访问 http://{host}:{port}")
+        p_infer = subprocess.Popen(
+            [
+                PYTHON,
+                "tools/webui.py",
+                "--vqgan-checkpoint-path",
+                infer_vqgan_model,
+                "--llama-checkpoint-path",
+                infer_llama_model,
+                "--tokenizer",
+                "checkpoints",
+            ]
+            + (["--compile"] if infer_compile == "Yes" else []),
+            env=env,
+        )
+
+    elif if_infer == False and p_infer != None:
+        kill_process(p_infer.pid)
+        p_infer = None
+        yield build_html_error_message("推理界面已关闭")
+
+
+js = load_data_in_raw("fish_speech/webui/js/animate.js")
+css = load_data_in_raw("fish_speech/webui/css/style.css")
+
+data_pre_output = (cur_work_dir / "data").resolve()
+default_model_output = (cur_work_dir / "results").resolve()
+default_filelist = data_pre_output / "detect.list"
+data_pre_output.mkdir(parents=True, exist_ok=True)
+
+items = []
+dict_items = {}
+
+
+def load_yaml_data_in_fact(yml_path):
+    with open(yml_path, "r", encoding="utf-8") as file:
+        yml = yaml.safe_load(file)
+    return yml
+
+
+def write_yaml_data_in_fact(yml, yml_path):
+    with open(yml_path, "w", encoding="utf-8") as file:
+        yaml.safe_dump(yml, file, allow_unicode=True)
+    return yml
+
+
+def generate_tree(directory, depth=0, max_depth=None, prefix=""):
+    if max_depth is not None and depth > max_depth:
+        return ""
+
+    tree_str = ""
+    files = []
+    directories = []
+    for item in os.listdir(directory):
+        if os.path.isdir(os.path.join(directory, item)):
+            directories.append(item)
+        else:
+            files.append(item)
+
+    entries = directories + files
+    for i, entry in enumerate(entries):
+        connector = "├── " if i < len(entries) - 1 else "└── "
+        tree_str += f"{prefix}{connector}{entry}<br />"
+        if i < len(directories):
+            extension = "│   " if i < len(entries) - 1 else "    "
+            tree_str += generate_tree(
+                os.path.join(directory, entry),
+                depth + 1,
+                max_depth,
+                prefix=prefix + extension,
+            )
+    return tree_str
+
+
+def new_explorer(data_path, max_depth):
+    return gr.Markdown(
+        elem_classes=["scrollable-component"],
+        value=generate_tree(data_path, max_depth=max_depth),
+    )
+
+
+def add_item(folder: str, method: str, filelist: str, label_lang: str):
+    folder = folder.strip(" ").strip('"')
+    filelist = filelist.strip(" ").strip('"')
+
+    folder_path = Path(folder)
+    filelist_path = Path(filelist)
+
+    if folder and folder not in items and data_pre_output not in folder_path.parents:
+        if folder_path.is_dir():
+            items.append(folder)
+            dict_items[folder] = dict(
+                type="folder", method=method, label_lang=label_lang
+            )
+        elif folder:
+            err = folder
+            return gr.Checkboxgroup(choices=items), build_html_error_message(
+                f"添加文件夹路径无效: {err}"
+            )
+
+    if (
+        filelist
+        and filelist not in items
+        and data_pre_output not in filelist_path.parents
+    ):
+        if filelist_path.is_file():
+            items.append(filelist)
+            dict_items[filelist] = dict(
+                type="file", method=method, label_lang=label_lang
+            )
+        elif filelist:
+            err = filelist
+            return gr.Checkboxgroup(choices=items), build_html_error_message(
+                f"添加文件路径无效: {err}"
+            )
+
+    formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
+    logger.info(formatted_data)
+    return gr.Checkboxgroup(choices=items), build_html_ok_message("添加文件(夹)路径成功!")
+
+
+def remove_items(selected_items):
+    global items, dict_items
+    to_remove = [item for item in items if item in selected_items]
+    for item in to_remove:
+        del dict_items[item]
+    items = [item for item in items if item in dict_items.keys()]
+    formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
+    logger.info(formatted_data)
+    return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
+        "删除文件(夹)路径成功!"
+    )
+
+
+def show_selected(options):
+    selected_options = ", ".join(options)
+    return f"你选中了: {selected_options}" if options else "你没有选中任何选项"
+
+
+def list_copy(list_file_path, method):
+    wav_root = data_pre_output
+    lst = []
+    with list_file_path.open("r", encoding="utf-8") as file:
+        for line in tqdm(file, desc="Processing audio/transcript"):
+            wav_path, speaker_name, language, text = line.strip().split("|")
+            original_wav_path = Path(wav_path)
+            target_wav_path = (
+                wav_root / original_wav_path.parent.name / original_wav_path.name
+            )
+            lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
+            if target_wav_path.is_file():
+                continue
+            target_wav_path.parent.mkdir(parents=True, exist_ok=True)
+            if method == "复制一份":
+                shutil.copy(original_wav_path, target_wav_path)
+            else:
+                shutil.move(original_wav_path, target_wav_path.parent)
+
+            original_lab_path = original_wav_path.with_suffix(".lab")
+            target_lab_path = (
+                wav_root
+                / original_wav_path.parent.name
+                / original_wav_path.with_suffix(".lab").name
+            )
+            if target_lab_path.is_file():
+                continue
+            if method == "复制一份":
+                shutil.copy(original_lab_path, target_lab_path)
+            else:
+                shutil.move(original_lab_path, target_lab_path.parent)
+
+    if method == "直接移动":
+        with list_file_path.open("w", encoding="utf-8") as file:
+            file.writelines("\n".join(lst))
+
+    del lst
+    return build_html_ok_message("使用filelist")
+
+
+def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
+    dict_to_language = {"中文": "ZH", "英文": "EN", "日文": "JP", "不打标": "WTF"}
+
+    global dict_items
+    data_path = Path(data_path)
+    for item, content in dict_items.items():
+        item_path = Path(item)
+        tar_path = data_path / item_path.name
+
+        if content["type"] == "folder" and item_path.is_dir():
+            cur_lang = dict_to_language[content["label_lang"]]
+            if cur_lang != "WTF":
+                try:
+                    subprocess.run(
+                        [
+                            PYTHON,
+                            "tools/whisper_asr.py",
+                            "--model-size",
+                            label_model,
+                            "--device",
+                            label_device,
+                            "--audio-dir",
+                            item_path,
+                            "--save-dir",
+                            item_path,
+                            "--language",
+                            cur_lang,
+                        ],
+                        env=env,
+                    )
+                except Exception:
+                    print("Transcription error occurred")
+
+            if content["method"] == "复制一份":
+                os.makedirs(tar_path, exist_ok=True)
+                shutil.copytree(
+                    src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
+                )
+            elif not tar_path.is_dir():
+                shutil.move(src=str(item_path), dst=str(tar_path))
+
+        elif content["type"] == "file" and item_path.is_file():
+            list_copy(item_path, content["method"])
+
+    return build_html_ok_message("文件移动完毕"), new_explorer(data_path, max_depth=max_depth)
+
+
+def train_process(
+    data_path: str,
+    option: str,
+    # vq-gan config
+    vqgan_lr,
+    vqgan_maxsteps,
+    vqgan_data_num_workers,
+    vqgan_data_batch_size,
+    vqgan_data_val_batch_size,
+    vqgan_precision,
+    vqgan_check_interval,
+    # llama config
+    llama_lr,
+    llama_maxsteps,
+    llama_limit_val_batches,
+    llama_data_num_workers,
+    llama_data_batch_size,
+    llama_data_max_length,
+    llama_precision,
+    llama_check_interval,
+):
+    backend = "nccl" if sys.platform == "linux" else "gloo"
+    if option == "VQGAN" or option == "all":
+        subprocess.run(
+            [
+                PYTHON,
+                "tools/vqgan/create_train_split.py",
+                str(data_pre_output.relative_to(cur_work_dir)),
+            ]
+        )
+        train_cmd = [
+            PYTHON,
+            "fish_speech/train.py",
+            "--config-name",
+            "vqgan_finetune",
+            f"trainer.strategy.process_group_backend={backend}",
+            f"model.optimizer.lr={vqgan_lr}",
+            f"trainer.max_steps={vqgan_maxsteps}",
+            f"data.num_workers={vqgan_data_num_workers}",
+            f"data.batch_size={vqgan_data_batch_size}",
+            f"data.val_batch_size={vqgan_data_val_batch_size}",
+            f"trainer.precision={vqgan_precision}",
+            f"trainer.val_check_interval={vqgan_check_interval}",
+            f"train_dataset.filelist={str(data_pre_output / 'vq_train_filelist.txt')}",
+            f"val_dataset.filelist={str(data_pre_output / 'vq_val_filelist.txt')}",
+        ]
+        logger.info(train_cmd)
+        subprocess.run(train_cmd)
+
+    if option == "LLAMA" or option == "all":
+        subprocess.run(
+            [
+                PYTHON,
+                "tools/vqgan/extract_vq.py",
+                str(data_pre_output),
+                "--num-workers",
+                "1",
+                "--batch-size",
+                "16",
+                "--config-name",
+                "vqgan_pretrain",
+                "--checkpoint-path",
+                "checkpoints/vq-gan-group-fsq-2x1024.pth",
+            ]
+        )
+
+        subprocess.run(
+            [
+                PYTHON,
+                "tools/llama/build_dataset.py",
+                "--input",
+                str(data_pre_output),
+                "--num-workers",
+                "16",
+            ]
+        )
+
+        protos_list = [
+            str(file) for file in Path("data/quantized-dataset-ft").glob("*.protos")
+        ]
+        train_cmd = [
+            PYTHON,
+            "fish_speech/train.py",
+            "--config-name",
+            "text2semantic_sft",
+            f"trainer.strategy.process_group_backend={backend}",
+            "model@model.model=dual_ar_2_codebook_medium",
+            "tokenizer.pretrained_model_name_or_path=checkpoints",
+            f"train_dataset.proto_files={str(protos_list)}",
+            f"val_dataset.proto_files={str(protos_list)}",
+            f"model.optimizer.lr={llama_lr}",
+            f"trainer.max_steps={llama_maxsteps}",
+            f"trainer.limit_val_batches={llama_limit_val_batches}",
+            f"data.num_workers={llama_data_num_workers}",
+            f"data.batch_size={llama_data_batch_size}",
+            f"max_length={llama_data_max_length}",
+            f"trainer.precision={llama_precision}",
+            f"trainer.val_check_interval={llama_check_interval}",
+        ]
+        logger.info(train_cmd)
+        subprocess.run(train_cmd)
+
+    return build_html_ok_message("训练终止")
+
+
+init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
+init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
+
+with gr.Blocks(
+    head="<style>\n" + css + "\n</style>",
+    js=js,
+    theme=seafoam,
+    analytics_enabled=False,
+    title="Fish-Speech 鱼语",
+) as demo:
+    with gr.Row():
+        with gr.Column():
+            with gr.Tab("\U0001F4D6 数据集准备"):
+                with gr.Row():
+                    textbox = gr.Textbox(
+                        label="\U0000270F 输入音频&转写源文件夹路径",
+                        info="音频装在一个以说话人命名的文件夹内作为区分",
+                        interactive=True,
+                    )
+                    transcript_path = gr.Textbox(
+                        label="\U0001F4DD 转写文本filelist所在路径",
+                        info="支持 Bert-Vits2 / GPT-SoVITS 格式",
+                        interactive=True,
+                    )
+                with gr.Row(equal_height=False):
+                    with gr.Column():
+                        output_radio = gr.Radio(
+                            label="\U0001F4C1 选择源文件(夹)处理方式",
+                            choices=["复制一份", "直接移动"],
+                            value="复制一份",
+                            interactive=True,
+                        )
+                    with gr.Column():
+                        error = gr.HTML(label="错误信息")
+                        if_label = gr.Checkbox(
+                            label="是否开启打标WebUI", scale=0, show_label=True
+                        )
+                with gr.Row():
+                    add_button = gr.Button("\U000027A1提交到处理区", variant="primary")
+                    remove_button = gr.Button("\U000026D4 取消所选内容")
+
+                with gr.Row():
+                    label_device = gr.Dropdown(
+                        label="打标设备",
+                        info="建议使用cuda, 实在是低配置再用cpu",
+                        choices=["cpu", "cuda"],
+                        value="cuda",
+                        interactive=True,
+                    )
+                    label_model = gr.Dropdown(
+                        label="打标模型大小",
+                        info="显存10G以上用large, 5G用medium, 2G用small",
+                        choices=["large", "medium", "small"],
+                        value="small",
+                        interactive=True,
+                    )
+                    label_radio = gr.Dropdown(
+                        label="(可选)打标语言",
+                        info="如果没有音频对应的文本,则进行辅助打标, 支持.txt或.lab格式",
+                        choices=["中文", "日文", "英文", "不打标"],
+                        value="不打标",
+                        interactive=True,
+                    )
+
+            with gr.Tab("\U0001F6E0 训练配置项"):  # hammer
+                with gr.Column():
+                    with gr.Row():
+                        model_type_radio = gr.Radio(
+                            label="选择要训练的模型类型",
+                            interactive=True,
+                            choices=["VQGAN", "LLAMA", "all"],
+                            value="all",
+                        )
+                    with gr.Row():
+                        with gr.Accordion("VQGAN配置项", open=False):
+                            with gr.Row(equal_height=False):
+                                vqgan_lr_slider = gr.Slider(
+                                    label="初始学习率",
+                                    interactive=True,
+                                    minimum=1e-5,
+                                    maximum=1e-4,
+                                    step=1e-5,
+                                    value=init_vqgan_yml["model"]["optimizer"]["lr"],
+                                )
+                                vqgan_maxsteps_slider = gr.Slider(
+                                    label="训练最大步数",
+                                    interactive=True,
+                                    minimum=1000,
+                                    maximum=100000,
+                                    step=1000,
+                                    value=init_vqgan_yml["trainer"]["max_steps"],
+                                )
+
+                            with gr.Row(equal_height=False):
+                                vqgan_data_num_workers_slider = gr.Slider(
+                                    label="num_workers",
+                                    interactive=True,
+                                    minimum=1,
+                                    maximum=16,
+                                    step=1,
+                                    value=init_vqgan_yml["data"]["num_workers"],
+                                )
+
+                                vqgan_data_batch_size_slider = gr.Slider(
+                                    label="batch_size",
+                                    interactive=True,
+                                    minimum=1,
+                                    maximum=32,
+                                    step=1,
+                                    value=init_vqgan_yml["data"]["batch_size"],
+                                )
+                            with gr.Row(equal_height=False):
+                                vqgan_data_val_batch_size_slider = gr.Slider(
+                                    label="val_batch_size",
+                                    interactive=True,
+                                    minimum=1,
+                                    maximum=32,
+                                    step=1,
+                                    value=init_vqgan_yml["data"]["val_batch_size"],
+                                )
+                                vqgan_precision_dropdown = gr.Dropdown(
+                                    label="训练精度",
+                                    interactive=True,
+                                    choices=["32", "bf16-true", "bf16-mixed"],
+                                    value=str(init_vqgan_yml["trainer"]["precision"]),
+                                )
+                            with gr.Row(equal_height=False):
+                                vqgan_check_interval_slider = gr.Slider(
+                                    label="每n步保存一个模型",
+                                    interactive=True,
+                                    minimum=500,
+                                    maximum=10000,
+                                    step=500,
+                                    value=init_vqgan_yml["trainer"][
+                                        "val_check_interval"
+                                    ],
+                                )
+
+                    with gr.Row():
+                        with gr.Accordion("LLAMA配置项", open=False):
+                            with gr.Row(equal_height=False):
+                                llama_lr_slider = gr.Slider(
+                                    label="初始学习率",
+                                    interactive=True,
+                                    minimum=1e-5,
+                                    maximum=1e-4,
+                                    step=1e-5,
+                                    value=init_llama_yml["model"]["optimizer"]["lr"],
+                                )
+                                llama_maxsteps_slider = gr.Slider(
+                                    label="训练最大步数",
+                                    interactive=True,
+                                    minimum=1000,
+                                    maximum=100000,
+                                    step=1000,
+                                    value=init_llama_yml["trainer"]["max_steps"],
+                                )
+                            with gr.Row(equal_height=False):
+                                llama_limit_val_batches_slider = gr.Slider(
+                                    label="limit_val_batches",
+                                    interactive=True,
+                                    minimum=1,
+                                    maximum=20,
+                                    step=1,
+                                    value=init_llama_yml["trainer"][
+                                        "limit_val_batches"
+                                    ],
+                                )
+                                llama_data_num_workers_slider = gr.Slider(
+                                    label="num_workers",
+                                    minimum=0,
+                                    maximum=16,
+                                    step=1,
+                                    value=init_llama_yml["data"]["num_workers"]
+                                    if sys.platform == "linux"
+                                    else 0,
+                                )
+                            with gr.Row(equal_height=False):
+                                llama_data_batch_size_slider = gr.Slider(
+                                    label="batch_size",
+                                    interactive=True,
+                                    minimum=1,
+                                    maximum=32,
+                                    step=1,
+                                    value=init_llama_yml["data"]["batch_size"],
+                                )
+                                llama_data_max_length_slider = gr.Slider(
+                                    label="max_length",
+                                    interactive=True,
+                                    minimum=1024,
+                                    maximum=4096,
+                                    step=128,
+                                    value=init_llama_yml["max_length"],
+                                )
+                            with gr.Row(equal_height=False):
+                                llama_precision_dropdown = gr.Dropdown(
+                                    label="训练精度",
+                                    interactive=True,
+                                    choices=["32", "bf16-true", "16-mixed"],
+                                    value="bf16-true",
+                                )
+                                llama_check_interval_slider = gr.Slider(
+                                    label="每n步保存一个模型",
+                                    interactive=True,
+                                    minimum=500,
+                                    maximum=10000,
+                                    step=500,
+                                    value=init_llama_yml["trainer"][
+                                        "val_check_interval"
+                                    ],
+                                )
+
+            with gr.Tab("\U0001F9E0 进入推理界面"):
+                with gr.Column():
+                    with gr.Row():
+                        with gr.Accordion(label="\U0001F5A5 推理服务器配置", open=False):
+                            with gr.Row():
+                                infer_host_textbox = gr.Textbox(
+                                    label="Webui启动服务器地址", value="127.0.0.1"
+                                )
+                                infer_port_textbox = gr.Textbox(
+                                    label="Webui启动服务器端口", value="7862"
+                                )
+                            with gr.Row():
+                                infer_vqgan_model = gr.Textbox(
+                                    label="VQGAN模型位置",
+                                    placeholder="填写pth/ckpt文件路径",
+                                    value="checkpoints/vq-gan-group-fsq-2x1024.pth",
+                                )
+                            with gr.Row():
+                                infer_llama_model = gr.Textbox(
+                                    label="LLAMA模型位置",
+                                    placeholder="填写pth/ckpt文件路径",
+                                    value="checkpoints/text2semantic-medium-v1-2k.pth",
+                                )
+                            with gr.Row():
+                                infer_compile = gr.Radio(
+                                    label="是否编译模型?", choices=["Yes", "No"], value="Yes"
+                                )
+
+                    with gr.Row():
+                        infer_checkbox = gr.Checkbox(label="是否打开推理界面")
+                        infer_error = gr.HTML(label="推理界面错误信息")
+
+        with gr.Column():
+            train_error = gr.HTML(label="训练时的报错信息")
+            checkbox_group = gr.CheckboxGroup(
+                label="\U0001F4CA 数据源列表",
+                info="左侧输入文件夹所在路径或filelist。无论是否勾选,在此列表中都会被用以后续训练。",
+                elem_classes=["data_src"],
+            )
+            train_box = gr.Textbox(
+                label="数据预处理文件夹路径", value=str(data_pre_output), interactive=False
+            )
+            model_box = gr.Textbox(
+                label="\U0001F4BE 模型输出路径",
+                value=str(default_model_output),
+                interactive=False,
+            )
+
+            with gr.Accordion(
+                "查看预处理文件夹状态 (滑块为显示深度大小)",
+                elem_classes=["scrollable-component"],
+                elem_id="file_accordion",
+            ):
+                tree_slider = gr.Slider(
+                    minimum=0,
+                    maximum=3,
+                    value=0,
+                    step=1,
+                    show_label=False,
+                    container=False,
+                )
+                file_markdown = new_explorer(str(data_pre_output), 0)
+            with gr.Row(equal_height=False):
+                admit_btn = gr.Button(
+                    "\U00002705 文件预处理", scale=0, min_width=160, variant="primary"
+                )
+                fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
+                help_button = gr.Button("\U00002753", scale=0, min_width=80)  # question
+                train_btn = gr.Button("训练启动!", variant="primary")
+
+    footer = load_data_in_raw("fish_speech/webui/html/footer.html")
+    footer = footer.format(
+        versions=versions_html(),
+        api_docs="https://speech.fish.audio/inference/#http-api",
+    )
+    gr.HTML(footer, elem_id="footer")
+
+    add_button.click(
+        fn=add_item,
+        inputs=[textbox, output_radio, transcript_path, label_radio],
+        outputs=[checkbox_group, error],
+    )
+    remove_button.click(
+        fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
+    )
+    checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
+    help_button.click(
+        fn=None,
+        js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
+        'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
+    )
+    if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
+    train_btn.click(
+        fn=train_process,
+        inputs=[
+            train_box,
+            model_type_radio,
+            # vq-gan config
+            vqgan_lr_slider,
+            vqgan_maxsteps_slider,
+            vqgan_data_num_workers_slider,
+            vqgan_data_batch_size_slider,
+            vqgan_data_val_batch_size_slider,
+            vqgan_precision_dropdown,
+            vqgan_check_interval_slider,
+            # llama config
+            llama_lr_slider,
+            llama_maxsteps_slider,
+            llama_limit_val_batches_slider,
+            llama_data_num_workers_slider,
+            llama_data_batch_size_slider,
+            llama_data_max_length_slider,
+            llama_precision_dropdown,
+            llama_check_interval_slider,
+        ],
+        outputs=[train_error],
+    )
+    admit_btn.click(
+        fn=check_files,
+        inputs=[train_box, tree_slider, label_model, label_device],
+        outputs=[error, file_markdown],
+    )
+    fresh_btn.click(
+        fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
+    )
+
+    infer_checkbox.change(
+        fn=change_infer,
+        inputs=[
+            infer_checkbox,
+            infer_host_textbox,
+            infer_port_textbox,
+            infer_vqgan_model,
+            infer_llama_model,
+            infer_compile,
+        ],
+        outputs=[infer_error],
+    )
+
+demo.launch(inbrowser=True)

+ 6 - 0
start.bat

@@ -0,0 +1,6 @@
+@echo off
+chcp 65001
+echo loading page...
+set PYTHONPATH=%~dp0
+set no_proxy="localhost, 127.0.0.1, 0.0.0.0"
+python fish_speech\webui\manage.py

+ 1 - 1
tools/vqgan/extract_vq.py

@@ -73,7 +73,7 @@ def process_batch(files: list[Path], model) -> float:
     for file in files:
         try:
             wav, sr = torchaudio.load(
-                str(file), backend="sox"
+                str(file), backend="sox" if sys.platform == "linux" else "soundfile"
             )  # Need to install libsox-dev
         except Exception as e:
             logger.error(f"Error reading {file}: {e}")

+ 15 - 3
tools/webui.py

@@ -7,7 +7,6 @@ from pathlib import Path
 
 import gradio as gr
 import librosa
-import spaces
 import torch
 from loguru import logger
 from torchaudio import functional as AF
@@ -37,16 +36,29 @@ We are not responsible for any misuse of the model, please consider your local l
 
 TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
 
+try:
+    import spaces
+
+    GPU_DECORATOR = spaces.GPU
+except ImportError:
+
+    def GPU_DECORATOR(func):
+        def wrapper(*args, **kwargs):
+            return func(*args, **kwargs)
+
+        return wrapper
+
 
 def build_html_error_message(error):
     return f"""
-    <div style="color: red; font-weight: bold;">
+    <div style="color: red; 
+    font-weight: bold;">
         {html.escape(error)}
     </div>
     """
 
 
-@spaces.GPU
+@GPU_DECORATOR
 def inference(
     text,
     enable_reference_audio,