Просмотр исходного кода

Add vits-decoder UI support & Fix bugs (#167)

* Fix button height

* Streaming support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Convert to 1 channel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix Conversion bug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix target path

* Add checkpoint selection

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix gpup decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add link for labeler

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Localize labeler

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add LoRA llama config

* Allow download stream audio

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* asr

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add cache auto recycling

* 多打了一个字母

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Check 'compile' avaliable

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add vits-decoder UI support & Fix bugs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Configurable audio length, i18n

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Label Exception

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 год назад
Родитель
Сommit
e2b0fb10b3

+ 2 - 1
fish_speech/configs/vits_decoder_finetune.yaml

@@ -10,7 +10,8 @@ resume_weights_only: true
 trainer:
   accelerator: gpu
   devices: auto
-  strategy: ddp_find_unused_parameters_true
+  strategy:
+    find_unused_parameters: true
   precision: 32
   max_steps: 100_000
   val_check_interval: 1000

+ 11 - 2
fish_speech/i18n/locale/en_US.json

@@ -14,6 +14,8 @@
     "Data Preprocessing": "Data Preprocessing",
     "Data Preprocessing Path": "Data Preprocessing Path",
     "Data Source": "Data Source",
+    "Decoder Model Config": "Decoder Model Config",
+    "Decoder Model Path": "Decoder Model Path",
     "Disabled": "Disabled",
     "Enable Reference Audio": "Enable Reference Audio",
     "English": "English",
@@ -39,12 +41,14 @@
     "LLAMA Model Path": "LLAMA Model Path",
     "Labeling Device": "Labeling Device",
     "LoRA Model to be merged": "LoRA Model to be merged",
+    "Maximum Audio Duration": "Maximum Audio Duration",
     "Maximum Length per Sample": "Maximum Length per Sample",
     "Maximum Training Steps": "Maximum Training Steps",
     "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
     "Merge": "Merge",
     "Merge LoRA": "Merge LoRA",
     "Merge successfully": "Merge successfully",
+    "Minimum Audio Duration": "Minimum Audio Duration",
     "Model Output Path": "Model Output Path",
     "Model Size": "Model Size",
     "Move": "Move",
@@ -70,6 +74,9 @@
     "Removed path successfully!": "Removed path successfully!",
     "Repetition Penalty": "Repetition Penalty",
     "Save model every n steps": "Save model every n steps",
+    "Select LLAMA ckpt": "Select LLAMA ckpt",
+    "Select VITS ckpt": "Select VITS ckpt",
+    "Select VQGAN ckpt": "Select VQGAN ckpt",
     "Select source file processing method": "Select source file processing method",
     "Select the model to be trained": "Select the model to be trained",
     "Selected: {}": "Selected: {}",
@@ -94,8 +101,8 @@
     "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
     "Use filelist": "Use filelist",
     "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
+    "VITS Configuration": "VITS Configuration",
     "VQGAN Configuration": "VQGAN Configuration",
-    "VQGAN Model Path": "VQGAN Model Path",
     "Validation Batch Size": "Validation Batch Size",
     "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
     "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
@@ -103,5 +110,7 @@
     "WebUI Port": "WebUI Port",
     "Whisper Model": "Whisper Model",
     "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
-    "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
+    "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
+    "latest": "latest",
+    "new": "new"
 }

+ 11 - 2
fish_speech/i18n/locale/es_ES.json

@@ -14,6 +14,8 @@
     "Data Preprocessing": "Preprocesamiento de Datos",
     "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
     "Data Source": "Fuente de Datos",
+    "Decoder Model Config": "Configuración del modelo decodificador",
+    "Decoder Model Path": "Ruta del modelo decodificador",
     "Disabled": "Desactivado",
     "Enable Reference Audio": "Habilitar Audio de Referencia",
     "English": "Inglés",
@@ -39,12 +41,14 @@
     "LLAMA Model Path": "Ruta del Modelo LLAMA",
     "Labeling Device": "Dispositivo de Etiquetado",
     "LoRA Model to be merged": "Modelo LoRA a fusionar",
+    "Maximum Audio Duration": "Duración máxima de audio",
     "Maximum Length per Sample": "Longitud Máxima por Muestra",
     "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
     "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
     "Merge": "Fusionar",
     "Merge LoRA": "Fusionar LoRA",
     "Merge successfully": "Fusionado exitosamente",
+    "Minimum Audio Duration": "Duración mínima de audio",
     "Model Output Path": "Ruta de Salida del Modelo",
     "Model Size": "Tamaño del Modelo",
     "Move": "Mover",
@@ -70,6 +74,9 @@
     "Removed path successfully!": "¡Ruta eliminada exitosamente!",
     "Repetition Penalty": "Penalización por Repetición",
     "Save model every n steps": "Guardar modelo cada n pasos",
+    "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
+    "Select VITS ckpt": "Seleccionar punto de control VITS",
+    "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
     "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
     "Select the model to be trained": "Seleccione el modelo a ser entrenado",
     "Selected: {}": "Seleccionado: {}",
@@ -94,8 +101,8 @@
     "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
     "Use filelist": "Usar lista de archivos",
     "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
+    "VITS Configuration": "Configuración de VITS",
     "VQGAN Configuration": "Configuración de VQGAN",
-    "VQGAN Model Path": "Ruta del Modelo VQGAN",
     "Validation Batch Size": "Tamaño del Lote de Validación",
     "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
     "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
@@ -103,5 +110,7 @@
     "WebUI Port": "Puerto de WebUI",
     "Whisper Model": "Modelo Whisper",
     "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
-    "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+"
+    "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
+    "latest": "más reciente",
+    "new": "nuevo"
 }

+ 12 - 3
fish_speech/i18n/locale/ja_JP.json

@@ -14,6 +14,8 @@
     "Data Preprocessing": "データ前処理",
     "Data Preprocessing Path": "データ前処理パス",
     "Data Source": "データソース",
+    "Decoder Model Config": "デコーダーモデルの構成",
+    "Decoder Model Path": "デコーダーモデルのパス",
     "Disabled": "無効",
     "Enable Reference Audio": "リファレンスオーディオを有効にする",
     "English": "英語",
@@ -39,12 +41,14 @@
     "LLAMA Model Path": "LLAMAモデルパス",
     "Labeling Device": "ラベリングデバイス",
     "LoRA Model to be merged": "マージするLoRAモデル",
+    "Maximum Audio Duration": "最大オーディオの長さ",
     "Maximum Length per Sample": "サンプルあたりの最大長",
     "Maximum Training Steps": "最大トレーニングステップ数",
     "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
     "Merge": "マージ",
     "Merge LoRA": "LoRAのマージ",
     "Merge successfully": "マージに成功しました",
+    "Minimum Audio Duration": "最小オーディオの長さ",
     "Model Output Path": "モデル出力パス",
     "Model Size": "モデルサイズ",
     "Move": "移動",
@@ -70,6 +74,9 @@
     "Removed path successfully!": "パスの削除に成功しました!",
     "Repetition Penalty": "反復ペナルティ",
     "Save model every n steps": "nステップごとにモデルを保存",
+    "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
+    "Select VITS ckpt": "VITS チェックポイントを選択",
+    "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
     "Select source file processing method": "ソースファイルの処理方法を選択",
     "Select the model to be trained": "トレーニングするモデルを選択",
     "Selected: {}": "選択済み: {}",
@@ -94,8 +101,8 @@
     "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
     "Use filelist": "ファイルリストを使用",
     "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
-    "VQGAN Configuration": "VQGAN設定",
-    "VQGAN Model Path": "VQGANモデルパス",
+    "VITS Configuration": "VITS の構成",
+    "VQGAN Configuration": "VQGAN の構成",
     "Validation Batch Size": "検証バッチサイズ",
     "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
     "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
@@ -103,5 +110,7 @@
     "WebUI Port": "WebUIポート",
     "Whisper Model": "Whisperモデル",
     "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
-    "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします"
+    "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
+    "latest": "最新",
+    "new": "新規"
 }

+ 11 - 2
fish_speech/i18n/locale/zh_CN.json

@@ -14,6 +14,8 @@
     "Data Preprocessing": "数据预处理",
     "Data Preprocessing Path": "数据预处理路径",
     "Data Source": "数据源",
+    "Decoder Model Config": "解码器模型配置",
+    "Decoder Model Path": "解码器模型路径",
     "Disabled": "禁用",
     "Enable Reference Audio": "启用参考音频",
     "English": "英文",
@@ -39,12 +41,14 @@
     "LLAMA Model Path": "LLAMA 模型路径",
     "Labeling Device": "标注加速设备",
     "LoRA Model to be merged": "要合并的 LoRA 模型",
+    "Maximum Audio Duration": "最大音频时长",
     "Maximum Length per Sample": "每个样本的最大长度",
     "Maximum Training Steps": "最大训练步数",
     "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
     "Merge": "合并",
     "Merge LoRA": "合并 LoRA",
     "Merge successfully": "合并成功",
+    "Minimum Audio Duration": "最小音频时长",
     "Model Output Path": "模型输出路径",
     "Model Size": "模型规模",
     "Move": "移动",
@@ -70,6 +74,9 @@
     "Removed path successfully!": "移除路径成功!",
     "Repetition Penalty": "重复惩罚",
     "Save model every n steps": "每 n 步保存模型",
+    "Select LLAMA ckpt": "选择 LLAMA 检查点",
+    "Select VITS ckpt": "选择 VITS 检查点",
+    "Select VQGAN ckpt": "选择 VQGAN 检查点",
     "Select source file processing method": "选择源文件处理方法",
     "Select the model to be trained": "选择要训练的模型",
     "Selected: {}": "已选择: {}",
@@ -94,8 +101,8 @@
     "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
     "Use filelist": "使用文件列表",
     "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
+    "VITS Configuration": "VITS 配置",
     "VQGAN Configuration": "VQGAN 配置",
-    "VQGAN Model Path": "VQGAN 模型路径",
     "Validation Batch Size": "验证批次大小",
     "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
     "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
@@ -103,5 +110,7 @@
     "WebUI Port": "WebUI 端口",
     "Whisper Model": "Whisper 模型",
     "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
-    "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed"
+    "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
+    "latest": "最近的检查点",
+    "new": "创建新的检查点"
 }

+ 279 - 64
fish_speech/webui/manage.py

@@ -27,6 +27,7 @@ print("You are in ", str(cur_work_dir))
 config_path = cur_work_dir / "fish_speech" / "configs"
 vqgan_yml_path = config_path / "vqgan_finetune.yaml"
 llama_yml_path = config_path / "text2semantic_finetune.yaml"
+vits_yml_path = config_path / "vits_decoder_finetune.yaml"
 
 env = os.environ.copy()
 env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
@@ -105,15 +106,19 @@ def change_label(if_label):
     if if_label == True and p_label is None:
         url = "http://localhost:3000"
         remote_url = "https://text-labeler.pages.dev/"
-        p_label = subprocess.Popen(
-            [
-                (
-                    "asr-label-linux-x64"
-                    if sys.platform == "linux"
-                    else "asr-label-win-x64.exe"
-                )
-            ]
-        )
+        try:
+            p_label = subprocess.Popen(
+                [
+                    (
+                        "asr-label-linux-x64"
+                        if sys.platform == "linux"
+                        else "asr-label-win-x64.exe"
+                    )
+                ]
+            )
+        except FileNotFoundError:
+            logger.warning("asr-label execution not found!")
+
         yield build_html_href(
             link=remote_url,
             desc=i18n("Optional online ver"),
@@ -146,7 +151,8 @@ def change_infer(
     if_infer,
     host,
     port,
-    infer_vqgan_model,
+    infer_decoder_model,
+    infer_decoder_config,
     infer_llama_model,
     infer_llama_config,
     infer_compile,
@@ -169,8 +175,10 @@ def change_infer(
             [
                 PYTHON,
                 "tools/webui.py",
-                "--vqgan-checkpoint-path",
-                infer_vqgan_model,
+                "--decoder-checkpoint-path",
+                infer_decoder_model,
+                "--decoder-config-name",
+                infer_decoder_config,
                 "--llama-checkpoint-path",
                 infer_llama_model,
                 "--llama-config-name",
@@ -398,6 +406,8 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
 def train_process(
     data_path: str,
     option: str,
+    min_duration: float,
+    max_duration: float,
     # vq-gan config
     vqgan_ckpt,
     vqgan_lr,
@@ -407,6 +417,15 @@ def train_process(
     vqgan_data_val_batch_size,
     vqgan_precision,
     vqgan_check_interval,
+    # vits config
+    vits_ckpt,
+    vits_lr,
+    vits_maxsteps,
+    vits_data_num_workers,
+    vits_data_batch_size,
+    vits_data_val_batch_size,
+    vits_precision,
+    vits_check_interval,
     # llama config
     llama_ckpt,
     llama_base_config,
@@ -434,27 +453,39 @@ def train_process(
 
     print("New Project Name: ", new_project)
 
-    if option == "VQGAN" or option == "all":
+    if min_duration > max_duration:
+        min_duration, max_duration = max_duration, min_duration
+
+    if option == "VQGAN" or option == "VITS":
         subprocess.run(
             [
                 PYTHON,
                 "tools/vqgan/create_train_split.py",
                 str(data_pre_output.relative_to(cur_work_dir)),
+                "--min-duration",
+                str(min_duration),
+                "--max-duration",
+                str(max_duration),
             ]
         )
-        latest = list(
-            sorted(
-                [
-                    str(p.relative_to("results"))
-                    for p in Path("results").glob("vqgan_*/")
-                ],
-                reverse=True,
-            )
-        )[0]
+
+    if option == "VQGAN":
+        latest = next(
+            iter(
+                sorted(
+                    [
+                        str(p.relative_to("results"))
+                        for p in Path("results").glob("vqgan_*/")
+                    ],
+                    reverse=True,
+                )
+            ),
+            ("vqgan_" + new_project),
+        )
         project = (
             ("vqgan_" + new_project)
-            if vqgan_ckpt == "new"
-            else latest if vqgan_ckpt == "latest" else vqgan_ckpt
+            if vqgan_ckpt == i18n("new")
+            else latest if vqgan_ckpt == i18n("latest") else vqgan_ckpt
         )
         logger.info(project)
         train_cmd = [
@@ -477,7 +508,49 @@ def train_process(
         logger.info(train_cmd)
         subprocess.run(train_cmd)
 
-    if option == "LLAMA" or option == "all":
+    if option == "VITS":
+        latest = next(
+            iter(
+                sorted(
+                    [
+                        str(p.relative_to("results"))
+                        for p in Path("results").glob("vits_*/")
+                    ],
+                    reverse=True,
+                )
+            ),
+            ("vits_" + new_project),
+        )
+        project = (
+            ("vits_" + new_project)
+            if vits_ckpt == i18n("new")
+            else latest if vits_ckpt == i18n("latest") else vits_ckpt
+        )
+        ckpt_path = str(Path("checkpoints/vits_decoder_v1.1.ckpt"))
+        logger.info(project)
+        train_cmd = [
+            PYTHON,
+            "fish_speech/train.py",
+            "--config-name",
+            "vits_decoder_finetune",
+            f"project={project}",
+            f"ckpt_path={ckpt_path}",
+            f"trainer.strategy.process_group_backend={backend}",
+            "tokenizer.pretrained_model_name_or_path=checkpoints",
+            f"model.optimizer.lr={vits_lr}",
+            f"trainer.max_steps={vits_maxsteps}",
+            f"data.num_workers={vits_data_num_workers}",
+            f"data.batch_size={vits_data_batch_size}",
+            f"data.val_batch_size={vits_data_val_batch_size}",
+            f"trainer.precision={vits_precision}",
+            f"trainer.val_check_interval={vits_check_interval}",
+            f"train_dataset.filelist={str(data_pre_output / 'vq_train_filelist.txt')}",
+            f"val_dataset.filelist={str(data_pre_output / 'vq_val_filelist.txt')}",
+        ]
+        logger.info(train_cmd)
+        subprocess.run(train_cmd)
+
+    if option == "LLAMA":
         subprocess.run(
             [
                 PYTHON,
@@ -507,24 +580,27 @@ def train_process(
             ]
         )
         ckpt_path = (
-            "text2semantic-pretrain-medium-2k-v1.pth"
+            "text2semantic-sft-medium-v1.1-4k.pth"
             if llama_base_config == "dual_ar_2_codebook_medium"
-            else "text2semantic-sft-medium-v1-4k.pth"
+            else "text2semantic-sft-large-v1.1-4k.pth"
         )
 
-        latest = list(
-            sorted(
-                [
-                    str(p.relative_to("results"))
-                    for p in Path("results").glob("text2sem*/")
-                ],
-                reverse=True,
-            )
-        )[0]
+        latest = next(
+            iter(
+                sorted(
+                    [
+                        str(p.relative_to("results"))
+                        for p in Path("results").glob("text2sem*/")
+                    ],
+                    reverse=True,
+                )
+            ),
+            ("text2semantic_" + new_project),
+        )
         project = (
             ("text2semantic_" + new_project)
-            if llama_ckpt == "new"
-            else latest if llama_ckpt == "latest" else llama_ckpt
+            if llama_ckpt == i18n("new")
+            else latest if llama_ckpt == i18n("latest") else llama_ckpt
         )
         logger.info(project)
         train_cmd = [
@@ -596,22 +672,33 @@ def fresh_tb_dir():
     )
 
 
-def fresh_vqgan_model():
+def fresh_decoder_model():
     return gr.Dropdown(
         choices=[init_vqgan_yml["ckpt_path"]]
+        + [str(Path("checkpoints/vits_decoder_v1.1.ckpt"))]
         + [str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")]
+        + [str(p) for p in Path("results").glob("vits*/**/*.ckpt")]
     )
 
 
 def fresh_vqgan_ckpt():
     return gr.Dropdown(
-        choices=["latest", "new"] + [str(p) for p in Path("results").glob("vqgan_*/")]
+        choices=[i18n("latest"), i18n("new")]
+        + [str(p) for p in Path("results").glob("vqgan_*/")]
+    )
+
+
+def fresh_vits_ckpt():
+    return gr.Dropdown(
+        choices=[i18n("latest"), i18n("new")]
+        + [str(p) for p in Path("results").glob("vits_*/")]
     )
 
 
 def fresh_llama_ckpt():
     return gr.Dropdown(
-        choices=["latest", "new"] + [str(p) for p in Path("results").glob("text2sem*/")]
+        choices=[i18n("latest"), i18n("new")]
+        + [str(p) for p in Path("results").glob("text2sem*/")]
     )
 
 
@@ -655,6 +742,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
 
 init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
 init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
+init_vits_yml = load_yaml_data_in_fact(vits_yml_path)
 
 with gr.Blocks(
     head="<style>\n" + css + "\n</style>",
@@ -687,6 +775,22 @@ with gr.Blocks(
                         if_label = gr.Checkbox(
                             label=i18n("Open Labeler WebUI"), scale=0, show_label=True
                         )
+                with gr.Row():
+                    min_duration = gr.Slider(
+                        label=i18n("Minimum Audio Duration"),
+                        value=1.5,
+                        step=0.1,
+                        minimum=0.4,
+                        maximum=30,
+                    )
+                    max_duration = gr.Slider(
+                        label=i18n("Maximum Audio Duration"),
+                        value=30,
+                        step=0.1,
+                        minimum=0.4,
+                        maximum=30,
+                    )
+
                 with gr.Row():
                     add_button = gr.Button(
                         "\U000027A1 " + i18n("Add to Processing Area"),
@@ -735,17 +839,17 @@ with gr.Blocks(
                     model_type_radio = gr.Radio(
                         label=i18n("Select the model to be trained"),
                         interactive=True,
-                        choices=["VQGAN", "LLAMA", "all"],
-                        value="all",
+                        choices=["VQGAN", "VITS", "LLAMA"],
+                        value="VITS",
                     )
                 with gr.Row():
                     with gr.Tab(label=i18n("VQGAN Configuration")):
                         with gr.Row(equal_height=False):
                             vqgan_ckpt = gr.Dropdown(
-                                label="Select VQGAN ckpt",
-                                choices=["latest", "new"]
+                                label=i18n("Select VQGAN ckpt"),
+                                choices=[i18n("latest"), i18n("new")]
                                 + [str(p) for p in Path("results").glob("vqgan_*/")],
-                                value="latest",
+                                value=i18n("latest"),
                                 interactive=True,
                             )
                         with gr.Row(equal_height=False):
@@ -812,6 +916,79 @@ with gr.Blocks(
                                 value=init_vqgan_yml["trainer"]["val_check_interval"],
                             )
 
+                    with gr.Tab(label=i18n("VITS Configuration")):
+                        with gr.Row(equal_height=False):
+                            vits_ckpt = gr.Dropdown(
+                                label=i18n("Select VITS ckpt"),
+                                choices=[i18n("latest"), i18n("new")]
+                                + [str(p) for p in Path("results").glob("vits_*/")],
+                                value=i18n("latest"),
+                                interactive=True,
+                            )
+                        with gr.Row(equal_height=False):
+                            vits_lr_slider = gr.Slider(
+                                label=i18n("Initial Learning Rate"),
+                                interactive=True,
+                                minimum=1e-5,
+                                maximum=1e-4,
+                                step=1e-5,
+                                value=init_vits_yml["model"]["optimizer"]["lr"],
+                            )
+                            vits_maxsteps_slider = gr.Slider(
+                                label=i18n("Maximum Training Steps"),
+                                interactive=True,
+                                minimum=1000,
+                                maximum=100000,
+                                step=1000,
+                                value=init_vits_yml["trainer"]["max_steps"],
+                            )
+
+                        with gr.Row(equal_height=False):
+                            vits_data_num_workers_slider = gr.Slider(
+                                label=i18n("Number of Workers"),
+                                interactive=True,
+                                minimum=1,
+                                maximum=16,
+                                step=1,
+                                value=init_vits_yml["data"]["num_workers"],
+                            )
+
+                            vits_data_batch_size_slider = gr.Slider(
+                                label=i18n("Batch Size"),
+                                interactive=True,
+                                minimum=1,
+                                maximum=32,
+                                step=1,
+                                value=init_vits_yml["data"]["batch_size"],
+                            )
+                        with gr.Row(equal_height=False):
+                            vits_data_val_batch_size_slider = gr.Slider(
+                                label=i18n("Validation Batch Size"),
+                                interactive=True,
+                                minimum=1,
+                                maximum=32,
+                                step=1,
+                                value=init_vits_yml["data"]["val_batch_size"],
+                            )
+                            vits_precision_dropdown = gr.Dropdown(
+                                label=i18n("Precision"),
+                                interactive=True,
+                                choices=["32", "bf16-true", "bf16-mixed"],
+                                info=i18n(
+                                    "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
+                                ),
+                                value=str(init_vits_yml["trainer"]["precision"]),
+                            )
+                        with gr.Row(equal_height=False):
+                            vits_check_interval_slider = gr.Slider(
+                                label=i18n("Save model every n steps"),
+                                interactive=True,
+                                minimum=500,
+                                maximum=10000,
+                                step=500,
+                                value=init_vits_yml["trainer"]["val_check_interval"],
+                            )
+
                     with gr.Tab(label=i18n("LLAMA Configuration")):
                         with gr.Row(equal_height=False):
                             llama_use_lora = gr.Checkbox(
@@ -822,10 +999,10 @@ with gr.Blocks(
                                 value=True,
                             )
                             llama_ckpt = gr.Dropdown(
-                                label="Select LLAMA ckpt",
-                                choices=["latest", "new"]
+                                label=i18n("Select LLAMA ckpt"),
+                                choices=[i18n("latest"), i18n("new")]
                                 + [str(p) for p in Path("results").glob("text2sem*/")],
-                                value="latest",
+                                value=i18n("latest"),
                                 interactive=True,
                             )
                         with gr.Row(equal_height=False):
@@ -943,6 +1120,7 @@ with gr.Blocks(
                             )
                             lora_llama_config = gr.Dropdown(
                                 label=i18n("LLAMA Model Config"),
+                                info=i18n("Type the path or select from the dropdown"),
                                 choices=[
                                     "dual_ar_2_codebook_large",
                                     "dual_ar_2_codebook_medium",
@@ -1004,18 +1182,39 @@ with gr.Blocks(
                                     label=i18n("WebUI Port"), value="7862"
                                 )
                             with gr.Row():
-                                infer_vqgan_model = gr.Dropdown(
-                                    label=i18n("VQGAN Model Path"),
+                                infer_decoder_model = gr.Dropdown(
+                                    label=i18n("Decoder Model Path"),
                                     info=i18n(
                                         "Type the path or select from the dropdown"
                                     ),
-                                    value=init_vqgan_yml["ckpt_path"],
+                                    value=str(
+                                        Path("checkpoints/vits_decoder_v1.1.ckpt")
+                                    ),
                                     choices=[init_vqgan_yml["ckpt_path"]]
+                                    + [str(Path("checkpoints/vits_decoder_v1.1.ckpt"))]
                                     + [
                                         str(p)
                                         for p in Path("results").glob(
                                             "vqgan*/**/*.ckpt"
                                         )
+                                    ]
+                                    + [
+                                        str(p)
+                                        for p in Path("results").glob("vits*/**/*.ckpt")
+                                    ],
+                                    allow_custom_value=True,
+                                )
+                                infer_decoder_config = gr.Dropdown(
+                                    label=i18n("Decoder Model Config"),
+                                    info=i18n(
+                                        "Type the path or select from the dropdown"
+                                    ),
+                                    value="vits_decoder_finetune",
+                                    choices=[
+                                        "vits_decoder_finetune",
+                                        "vits_decoder_pretrain",
+                                        "vqgan_finetune",
+                                        "vqgan_pretrain",
                                     ],
                                     allow_custom_value=True,
                                 )
@@ -1035,6 +1234,18 @@ with gr.Blocks(
                                     ],
                                     allow_custom_value=True,
                                 )
+                                infer_llama_config = gr.Dropdown(
+                                    label=i18n("LLAMA Model Config"),
+                                    info=i18n(
+                                        "Type the path or select from the dropdown"
+                                    ),
+                                    choices=[
+                                        "dual_ar_2_codebook_large",
+                                        "dual_ar_2_codebook_medium",
+                                    ],
+                                    value="dual_ar_2_codebook_large",
+                                    allow_custom_value=True,
+                                )
                             with gr.Row():
                                 infer_compile = gr.Radio(
                                     label=i18n("Compile Model"),
@@ -1052,15 +1263,6 @@ with gr.Blocks(
                                     ),
                                     interactive=is_module_installed("triton"),
                                 )
-                                infer_llama_config = gr.Dropdown(
-                                    label=i18n("LLAMA Model Config"),
-                                    choices=[
-                                        "dual_ar_2_codebook_large",
-                                        "dual_ar_2_codebook_medium",
-                                    ],
-                                    value="dual_ar_2_codebook_large",
-                                    allow_custom_value=True,
-                                )
 
                     with gr.Row():
                         infer_checkbox = gr.Checkbox(
@@ -1140,6 +1342,8 @@ with gr.Blocks(
         inputs=[
             train_box,
             model_type_radio,
+            min_duration,
+            max_duration,
             # vq-gan config
             vqgan_ckpt,
             vqgan_lr_slider,
@@ -1149,6 +1353,15 @@ with gr.Blocks(
             vqgan_data_val_batch_size_slider,
             vqgan_precision_dropdown,
             vqgan_check_interval_slider,
+            # vits config
+            vits_ckpt,
+            vits_lr_slider,
+            vits_maxsteps_slider,
+            vits_data_num_workers_slider,
+            vits_data_batch_size_slider,
+            vits_data_val_batch_size_slider,
+            vits_precision_dropdown,
+            vits_check_interval_slider,
             # llama config
             llama_ckpt,
             llama_base_config,
@@ -1171,8 +1384,8 @@ with gr.Blocks(
         outputs=[train_error],
     )
     tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
-    infer_vqgan_model.change(
-        fn=fresh_vqgan_model, inputs=[], outputs=[infer_vqgan_model]
+    infer_decoder_model.change(
+        fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
     )
     infer_llama_model.change(
         fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
@@ -1187,6 +1400,7 @@ with gr.Blocks(
         fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
     )
     vqgan_ckpt.change(fn=fresh_vqgan_ckpt, inputs=[], outputs=[vqgan_ckpt])
+    vits_ckpt.change(fn=fresh_vits_ckpt, inputs=[], outputs=[vits_ckpt])
     llama_ckpt.change(fn=fresh_llama_ckpt, inputs=[], outputs=[llama_ckpt])
     llama_lora_merge_btn.click(
         fn=llama_lora_merge,
@@ -1199,7 +1413,8 @@ with gr.Blocks(
             infer_checkbox,
             infer_host_textbox,
             infer_port_textbox,
-            infer_vqgan_model,
+            infer_decoder_model,
+            infer_decoder_config,
             infer_llama_model,
             infer_llama_config,
             infer_compile,

+ 20 - 9
tools/vqgan/create_train_split.py

@@ -4,6 +4,7 @@ from random import Random
 
 import click
 from loguru import logger
+from pydub import AudioSegment
 from tqdm import tqdm
 
 from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
@@ -14,38 +15,48 @@ from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
 @click.option("--val-ratio", type=float, default=None)
 @click.option("--val-count", type=int, default=None)
 @click.option("--filelist", default=None, type=Path)
-def main(root, val_ratio, val_count, filelist):
+@click.option("--min-duration", default=0.2, type=float)
+@click.option("--max-duration", default=30, type=float)
+def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
     if filelist:
         files = [i[0] for i in load_filelist(filelist)]
     else:
         files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
 
-    logger.info(f"Found {len(files)} files")
-    files = [str(file.relative_to(root)) for file in tqdm(files)]
+    filtered_files = []
+    for file in tqdm(files):
+        try:
+            audio = AudioSegment.from_file(str(file))
+            duration = len(audio) / 1000.0
+            if min_duration <= duration <= max_duration:
+                filtered_files.append(str(file.relative_to(root)))
+        except Exception as e:
+            logger.info(f"Error processing {file}: {e}")
 
-    Random(42).shuffle(files)
+    logger.info(f"Found {len(files)} files | Got Filtered {len(filtered_files)} files")
+    Random(42).shuffle(filtered_files)
 
     if val_count is None and val_ratio is None:
         logger.info("Validation ratio and count not specified, using min(20%, 100)")
-        val_size = min(100, math.ceil(len(files) * 0.2))
+        val_size = min(100, math.ceil(len(filtered_files) * 0.2))
     elif val_count is not None and val_ratio is not None:
         logger.error("Cannot specify both val_count and val_ratio")
         return
     elif val_count is not None:
-        if val_count < 1 or val_count > len(files):
+        if val_count < 1 or val_count > len(filtered_files):
             logger.error("val_count must be between 1 and number of files")
             return
         val_size = val_count
     else:
-        val_size = math.ceil(len(files) * val_ratio)
+        val_size = math.ceil(len(filtered_files) * val_ratio)
 
     logger.info(f"Using {val_size} files for validation")
 
     with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
-        f.write("\n".join(files[val_size:]))
+        f.write("\n".join(filtered_files[val_size:]))
 
     with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
-        f.write("\n".join(files[:val_size]))
+        f.write("\n".join(filtered_files[:val_size]))
 
     logger.info("Done")