Jelajahi Sumber

Adapt finetuning (#314)

* Add Windows Setup Help

* Optimize documents/bootscripts for Windows User

* Correct some description

* Fix dependecies

* fish 1.2 webui & api

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix spelling

* Fix CUDA env

* Update api usage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Adapt finetuning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 tahun lalu
induk
melakukan
3d6d1d7863
1 mengubah file dengan 19 tambahan dan 36 penghapusan
  1. 19 36
      fish_speech/webui/manage.py

+ 19 - 36
fish_speech/webui/manage.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import datetime
 import html
 import json
 import os
@@ -180,8 +181,6 @@ def change_infer(
                 infer_decoder_config,
                 "--llama-checkpoint-path",
                 infer_llama_model,
-                "--tokenizer",
-                "checkpoints/fish-speech-1.2",
             ]
             + (["--compile"] if infer_compile == "Yes" else []),
             env=env,
@@ -400,6 +399,12 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
     )
 
 
+def generate_folder_name():
+    now = datetime.datetime.now()
+    folder_name = now.strftime("%Y%m%d_%H%M%S")
+    return folder_name
+
+
 def train_process(
     data_path: str,
     option: str,
@@ -419,12 +424,6 @@ def train_process(
     llama_use_speaker,
     llama_use_lora,
 ):
-    import datetime
-
-    def generate_folder_name():
-        now = datetime.datetime.now()
-        folder_name = now.strftime("%Y%m%d_%H%M%S")
-        return folder_name
 
     backend = "nccl" if sys.platform == "linux" else "gloo"
 
@@ -464,14 +463,9 @@ def train_process(
                 "16",
             ]
         )
-        ckpt_path = (
-            "text2semantic-sft-medium-v1.1-4k.pth"
-            if llama_base_config == "dual_ar_2_codebook_medium"
-            else "text2semantic-sft-large-v1.1-4k.pth"
-        )
+        ckpt_path = "checkpoints/fish-speech-1.2/model.pth"
         lora_prefix = "lora_" if llama_use_lora else ""
-        llama_size = "large_" if ("large" in llama_base_config) else "medium_"
-        llama_name = lora_prefix + "text2semantic_" + llama_size + new_project
+        llama_name = lora_prefix + "text2semantic_" + new_project
         latest = next(
             iter(
                 sorted(
@@ -500,10 +494,7 @@ def train_process(
             "--config-name",
             "text2semantic_finetune",
             f"project={project}",
-            f"ckpt_path=checkpoints/{ckpt_path}",
             f"trainer.strategy.process_group_backend={backend}",
-            f"model@model.model={llama_base_config}",
-            "tokenizer.pretrained_model_name_or_path=checkpoints",
             f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
             f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
             f"model.optimizer.lr={llama_lr}",
@@ -514,8 +505,8 @@ def train_process(
             f"trainer.precision={llama_precision}",
             f"trainer.val_check_interval={llama_check_interval}",
             f"trainer.accumulate_grad_batches={llama_grad_batches}",
-            f"train_dataset.use_speaker={llama_use_speaker}",
-        ] + ([f"+lora@model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
+            f"train_dataset.interactive_prob={llama_use_speaker}",
+        ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
         logger.info(train_cmd)
         subprocess.run(train_cmd)
 
@@ -573,10 +564,7 @@ def list_decoder_models():
 
 
 def list_llama_models():
-    choices = [
-        str(p).replace("\\", "/") for p in Path("checkpoints").glob("text2sem*.*")
-    ]
-    choices += [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
+    choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*.pth")]
     if not choices:
         logger.warning("No LLaMA model found")
     return choices
@@ -627,16 +615,12 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
     merge_cmd = [
         PYTHON,
         "tools/llama/merge_lora.py",
-        "--llama-config",
-        lora_llama_config,
         "--lora-config",
         "r_8_alpha_16",
-        "--llama-weight",
-        llama_weight,
         "--lora-weight",
         lora_weight,
         "--output",
-        llama_lora_output,
+        llama_lora_output + "_" + generate_folder_name(),
     ]
     logger.info(merge_cmd)
     subprocess.run(merge_cmd)
@@ -759,6 +743,7 @@ with gr.Blocks(
                                         "Use LoRA can save GPU memory, but may reduce the quality of the model"
                                     ),
                                     value=True,
+                                    interactive=False,
                                 )
                                 llama_ckpt = gr.Dropdown(
                                     label=i18n("Select LLAMA ckpt"),
@@ -792,7 +777,6 @@ with gr.Blocks(
                                 llama_base_config = gr.Dropdown(
                                     label=i18n("Model Size"),
                                     choices=[
-                                        "text2semantic_agent",
                                         "text2semantic_finetune",
                                     ],
                                     value="text2semantic_finetune",
@@ -865,7 +849,7 @@ with gr.Blocks(
                                     maximum=1.0,
                                     step=0.05,
                                     value=init_llama_yml["train_dataset"][
-                                        "use_speaker"
+                                        "interactive_prob"
                                     ],
                                 )
 
@@ -879,7 +863,7 @@ with gr.Blocks(
                                     choices=[
                                         "checkpoints/fish-speech-1.2/model.pth",
                                     ],
-                                    value=init_llama_yml["ckpt_path"],
+                                    value="checkpoints/fish-speech-1.2/model.pth",
                                     allow_custom_value=True,
                                     interactive=True,
                                 )
@@ -902,10 +886,9 @@ with gr.Blocks(
                                         "Type the path or select from the dropdown"
                                     ),
                                     choices=[
-                                        "text2semantic_agent",
                                         "text2semantic_finetune",
                                     ],
-                                    value="text2semantic_agent",
+                                    value="text2semantic_finetune",
                                     allow_custom_value=True,
                                 )
                             with gr.Row(equal_height=False):
@@ -914,8 +897,8 @@ with gr.Blocks(
                                     info=i18n(
                                         "Type the path or select from the dropdown"
                                     ),
-                                    value="checkpoints/merged.ckpt",
-                                    choices=["checkpoints/merged.ckpt"],
+                                    value="checkpoints/merged",
+                                    choices=["checkpoints/merged"],
                                     allow_custom_value=True,
                                     interactive=True,
                                 )