Przeglądaj źródła

Adapt finetuning (#314)

* Add Windows Setup Help

* Optimize documents/bootscripts for Windows User

* Correct some description

* Fix dependecies

* fish 1.2 webui & api

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix spelling

* Fix CUDA env

* Update api usage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Adapt finetuning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
spicysama 1 rok temu
rodzic
commit
3d6d1d7863
1 zmienionych plików z 19 dodań i 36 usunięć
  1. 19 36
      fish_speech/webui/manage.py

+ 19 - 36
fish_speech/webui/manage.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
+import datetime
 import html
 import html
 import json
 import json
 import os
 import os
@@ -180,8 +181,6 @@ def change_infer(
                 infer_decoder_config,
                 infer_decoder_config,
                 "--llama-checkpoint-path",
                 "--llama-checkpoint-path",
                 infer_llama_model,
                 infer_llama_model,
-                "--tokenizer",
-                "checkpoints/fish-speech-1.2",
             ]
             ]
             + (["--compile"] if infer_compile == "Yes" else []),
             + (["--compile"] if infer_compile == "Yes" else []),
             env=env,
             env=env,
@@ -400,6 +399,12 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
     )
     )
 
 
 
 
+def generate_folder_name():
+    now = datetime.datetime.now()
+    folder_name = now.strftime("%Y%m%d_%H%M%S")
+    return folder_name
+
+
 def train_process(
 def train_process(
     data_path: str,
     data_path: str,
     option: str,
     option: str,
@@ -419,12 +424,6 @@ def train_process(
     llama_use_speaker,
     llama_use_speaker,
     llama_use_lora,
     llama_use_lora,
 ):
 ):
-    import datetime
-
-    def generate_folder_name():
-        now = datetime.datetime.now()
-        folder_name = now.strftime("%Y%m%d_%H%M%S")
-        return folder_name
 
 
     backend = "nccl" if sys.platform == "linux" else "gloo"
     backend = "nccl" if sys.platform == "linux" else "gloo"
 
 
@@ -464,14 +463,9 @@ def train_process(
                 "16",
                 "16",
             ]
             ]
         )
         )
-        ckpt_path = (
-            "text2semantic-sft-medium-v1.1-4k.pth"
-            if llama_base_config == "dual_ar_2_codebook_medium"
-            else "text2semantic-sft-large-v1.1-4k.pth"
-        )
+        ckpt_path = "checkpoints/fish-speech-1.2/model.pth"
         lora_prefix = "lora_" if llama_use_lora else ""
         lora_prefix = "lora_" if llama_use_lora else ""
-        llama_size = "large_" if ("large" in llama_base_config) else "medium_"
-        llama_name = lora_prefix + "text2semantic_" + llama_size + new_project
+        llama_name = lora_prefix + "text2semantic_" + new_project
         latest = next(
         latest = next(
             iter(
             iter(
                 sorted(
                 sorted(
@@ -500,10 +494,7 @@ def train_process(
             "--config-name",
             "--config-name",
             "text2semantic_finetune",
             "text2semantic_finetune",
             f"project={project}",
             f"project={project}",
-            f"ckpt_path=checkpoints/{ckpt_path}",
             f"trainer.strategy.process_group_backend={backend}",
             f"trainer.strategy.process_group_backend={backend}",
-            f"model@model.model={llama_base_config}",
-            "tokenizer.pretrained_model_name_or_path=checkpoints",
             f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
             f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
             f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
             f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
             f"model.optimizer.lr={llama_lr}",
             f"model.optimizer.lr={llama_lr}",
@@ -514,8 +505,8 @@ def train_process(
             f"trainer.precision={llama_precision}",
             f"trainer.precision={llama_precision}",
             f"trainer.val_check_interval={llama_check_interval}",
             f"trainer.val_check_interval={llama_check_interval}",
             f"trainer.accumulate_grad_batches={llama_grad_batches}",
             f"trainer.accumulate_grad_batches={llama_grad_batches}",
-            f"train_dataset.use_speaker={llama_use_speaker}",
-        ] + ([f"+lora@model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
+            f"train_dataset.interactive_prob={llama_use_speaker}",
+        ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
         logger.info(train_cmd)
         logger.info(train_cmd)
         subprocess.run(train_cmd)
         subprocess.run(train_cmd)
 
 
@@ -573,10 +564,7 @@ def list_decoder_models():
 
 
 
 
 def list_llama_models():
 def list_llama_models():
-    choices = [
-        str(p).replace("\\", "/") for p in Path("checkpoints").glob("text2sem*.*")
-    ]
-    choices += [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
+    choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*.pth")]
     if not choices:
     if not choices:
         logger.warning("No LLaMA model found")
         logger.warning("No LLaMA model found")
     return choices
     return choices
@@ -627,16 +615,12 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
     merge_cmd = [
     merge_cmd = [
         PYTHON,
         PYTHON,
         "tools/llama/merge_lora.py",
         "tools/llama/merge_lora.py",
-        "--llama-config",
-        lora_llama_config,
         "--lora-config",
         "--lora-config",
         "r_8_alpha_16",
         "r_8_alpha_16",
-        "--llama-weight",
-        llama_weight,
         "--lora-weight",
         "--lora-weight",
         lora_weight,
         lora_weight,
         "--output",
         "--output",
-        llama_lora_output,
+        llama_lora_output + "_" + generate_folder_name(),
     ]
     ]
     logger.info(merge_cmd)
     logger.info(merge_cmd)
     subprocess.run(merge_cmd)
     subprocess.run(merge_cmd)
@@ -759,6 +743,7 @@ with gr.Blocks(
                                         "Use LoRA can save GPU memory, but may reduce the quality of the model"
                                         "Use LoRA can save GPU memory, but may reduce the quality of the model"
                                     ),
                                     ),
                                     value=True,
                                     value=True,
+                                    interactive=False,
                                 )
                                 )
                                 llama_ckpt = gr.Dropdown(
                                 llama_ckpt = gr.Dropdown(
                                     label=i18n("Select LLAMA ckpt"),
                                     label=i18n("Select LLAMA ckpt"),
@@ -792,7 +777,6 @@ with gr.Blocks(
                                 llama_base_config = gr.Dropdown(
                                 llama_base_config = gr.Dropdown(
                                     label=i18n("Model Size"),
                                     label=i18n("Model Size"),
                                     choices=[
                                     choices=[
-                                        "text2semantic_agent",
                                         "text2semantic_finetune",
                                         "text2semantic_finetune",
                                     ],
                                     ],
                                     value="text2semantic_finetune",
                                     value="text2semantic_finetune",
@@ -865,7 +849,7 @@ with gr.Blocks(
                                     maximum=1.0,
                                     maximum=1.0,
                                     step=0.05,
                                     step=0.05,
                                     value=init_llama_yml["train_dataset"][
                                     value=init_llama_yml["train_dataset"][
-                                        "use_speaker"
+                                        "interactive_prob"
                                     ],
                                     ],
                                 )
                                 )
 
 
@@ -879,7 +863,7 @@ with gr.Blocks(
                                     choices=[
                                     choices=[
                                         "checkpoints/fish-speech-1.2/model.pth",
                                         "checkpoints/fish-speech-1.2/model.pth",
                                     ],
                                     ],
-                                    value=init_llama_yml["ckpt_path"],
+                                    value="checkpoints/fish-speech-1.2/model.pth",
                                     allow_custom_value=True,
                                     allow_custom_value=True,
                                     interactive=True,
                                     interactive=True,
                                 )
                                 )
@@ -902,10 +886,9 @@ with gr.Blocks(
                                         "Type the path or select from the dropdown"
                                         "Type the path or select from the dropdown"
                                     ),
                                     ),
                                     choices=[
                                     choices=[
-                                        "text2semantic_agent",
                                         "text2semantic_finetune",
                                         "text2semantic_finetune",
                                     ],
                                     ],
-                                    value="text2semantic_agent",
+                                    value="text2semantic_finetune",
                                     allow_custom_value=True,
                                     allow_custom_value=True,
                                 )
                                 )
                             with gr.Row(equal_height=False):
                             with gr.Row(equal_height=False):
@@ -914,8 +897,8 @@ with gr.Blocks(
                                     info=i18n(
                                     info=i18n(
                                         "Type the path or select from the dropdown"
                                         "Type the path or select from the dropdown"
                                     ),
                                     ),
-                                    value="checkpoints/merged.ckpt",
-                                    choices=["checkpoints/merged.ckpt"],
+                                    value="checkpoints/merged",
+                                    choices=["checkpoints/merged"],
                                     allow_custom_value=True,
                                     allow_custom_value=True,
                                     interactive=True,
                                     interactive=True,
                                 )
                                 )