Explorar el Código

Fix BUG (#139)

* init package

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix spelling

* Decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Specify the backend on the command line

* Fix Encoding Error

* Add start configuration

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama hace 1 año
padre
commit
d71272817f
Se han modificado 4 ficheros con 43 adiciones y 8 borrados
  1. 1 0
      .gitignore
  2. 31 5
      fish_speech/webui/manage.py
  3. 6 1
      start.bat
  4. 5 2
      tools/llama/build_dataset.py

+ 1 - 0
.gitignore

@@ -20,3 +20,4 @@ filelists
 ffmpeg.exe
 asr-label-win-x64.exe
 /.cache
+/fishenv

+ 31 - 5
fish_speech/webui/manage.py

@@ -358,6 +358,8 @@ def train_process(
     llama_data_max_length,
     llama_precision,
     llama_check_interval,
+    llama_grad_batches,
+    llama_use_speaker,
 ):
     backend = "nccl" if sys.platform == "linux" else "gloo"
     if option == "VQGAN" or option == "all":
@@ -410,14 +412,13 @@ def train_process(
                 "tools/llama/build_dataset.py",
                 "--input",
                 str(data_pre_output),
+                "--text-extension",
+                ".lab",
                 "--num-workers",
                 "16",
             ]
         )
 
-        protos_list = [
-            str(file) for file in Path("data/quantized-dataset-ft").glob("*.protos")
-        ]
         train_cmd = [
             PYTHON,
             "fish_speech/train.py",
@@ -426,8 +427,8 @@ def train_process(
             f"trainer.strategy.process_group_backend={backend}",
             "model@model.model=dual_ar_2_codebook_medium",
             "tokenizer.pretrained_model_name_or_path=checkpoints",
-            f"train_dataset.proto_files={str(protos_list)}",
-            f"val_dataset.proto_files={str(protos_list)}",
+            f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
+            f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
             f"model.optimizer.lr={llama_lr}",
             f"trainer.max_steps={llama_maxsteps}",
             f"trainer.limit_val_batches={llama_limit_val_batches}",
@@ -436,6 +437,8 @@ def train_process(
             f"max_length={llama_data_max_length}",
             f"trainer.precision={llama_precision}",
             f"trainer.val_check_interval={llama_check_interval}",
+            f"trainer.accumulate_grad_batches={llama_grad_batches}",
+            f"train_dataset.use_speaker={llama_use_speaker}",
         ]
         logger.info(train_cmd)
         subprocess.run(train_cmd)
@@ -654,6 +657,27 @@ with gr.Blocks(
                                         "val_check_interval"
                                     ],
                                 )
+                            with gr.Row(equal_height=False):
+                                llama_grad_batches = gr.Slider(
+                                    label="accumulate_grad_batches",
+                                    interactive=True,
+                                    minimum=1,
+                                    maximum=20,
+                                    step=1,
+                                    value=init_llama_yml["trainer"][
+                                        "accumulate_grad_batches"
+                                    ],
+                                )
+                                llama_use_speaker = gr.Slider(
+                                    label="use_speaker_ratio",
+                                    interactive=True,
+                                    minimum=0.1,
+                                    maximum=1.0,
+                                    step=0.05,
+                                    value=init_llama_yml["train_dataset"][
+                                        "use_speaker"
+                                    ],
+                                )
 
             with gr.Tab("\U0001F9E0 进入推理界面"):
                 with gr.Column():
@@ -769,6 +793,8 @@ with gr.Blocks(
             llama_data_max_length_slider,
             llama_precision_dropdown,
             llama_check_interval_slider,
+            llama_grad_batches,
+            llama_use_speaker,
         ],
         outputs=[train_error],
     )

+ 6 - 1
start.bat

@@ -3,4 +3,9 @@ chcp 65001
 echo loading page...
 set PYTHONPATH=%~dp0
 set no_proxy="localhost, 127.0.0.1, 0.0.0.0"
-python fish_speech\webui\manage.py
+
+if exist ".\fishenv\" (
+    .\fishenv\python fish_speech\webui\manage.py
+) else (
+    python fish_speech\webui\manage.py
+)

+ 5 - 2
tools/llama/build_dataset.py

@@ -31,9 +31,12 @@ def task_generator_folder(root: Path, text_extension: str):
 
         try:
             if isinstance(text_extension, str):
-                texts = [file.with_suffix(text_extension).read_text()]
+                texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
             else:
-                texts = [file.with_suffix(ext).read_text() for ext in text_extension]
+                texts = [
+                    file.with_suffix(ext).read_text(encoding="utf-8")
+                    for ext in text_extension
+                ]
         except Exception as e:
             logger.error(f"Failed to read text {file}: {e}")
             continue