|
|
@@ -1,5 +1,6 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
+import datetime
|
|
|
import html
|
|
|
import json
|
|
|
import os
|
|
|
@@ -180,8 +181,6 @@ def change_infer(
|
|
|
infer_decoder_config,
|
|
|
"--llama-checkpoint-path",
|
|
|
infer_llama_model,
|
|
|
- "--tokenizer",
|
|
|
- "checkpoints/fish-speech-1.2",
|
|
|
]
|
|
|
+ (["--compile"] if infer_compile == "Yes" else []),
|
|
|
env=env,
|
|
|
@@ -400,6 +399,12 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
|
|
|
)
|
|
|
|
|
|
|
|
|
+def generate_folder_name():
|
|
|
+ now = datetime.datetime.now()
|
|
|
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
|
|
|
+ return folder_name
|
|
|
+
|
|
|
+
|
|
|
def train_process(
|
|
|
data_path: str,
|
|
|
option: str,
|
|
|
@@ -419,12 +424,6 @@ def train_process(
|
|
|
llama_use_speaker,
|
|
|
llama_use_lora,
|
|
|
):
|
|
|
- import datetime
|
|
|
-
|
|
|
- def generate_folder_name():
|
|
|
- now = datetime.datetime.now()
|
|
|
- folder_name = now.strftime("%Y%m%d_%H%M%S")
|
|
|
- return folder_name
|
|
|
|
|
|
backend = "nccl" if sys.platform == "linux" else "gloo"
|
|
|
|
|
|
@@ -464,14 +463,9 @@ def train_process(
|
|
|
"16",
|
|
|
]
|
|
|
)
|
|
|
- ckpt_path = (
|
|
|
- "text2semantic-sft-medium-v1.1-4k.pth"
|
|
|
- if llama_base_config == "dual_ar_2_codebook_medium"
|
|
|
- else "text2semantic-sft-large-v1.1-4k.pth"
|
|
|
- )
|
|
|
+ ckpt_path = "checkpoints/fish-speech-1.2/model.pth"
|
|
|
lora_prefix = "lora_" if llama_use_lora else ""
|
|
|
- llama_size = "large_" if ("large" in llama_base_config) else "medium_"
|
|
|
- llama_name = lora_prefix + "text2semantic_" + llama_size + new_project
|
|
|
+ llama_name = lora_prefix + "text2semantic_" + new_project
|
|
|
latest = next(
|
|
|
iter(
|
|
|
sorted(
|
|
|
@@ -500,10 +494,7 @@ def train_process(
|
|
|
"--config-name",
|
|
|
"text2semantic_finetune",
|
|
|
f"project={project}",
|
|
|
- f"ckpt_path=checkpoints/{ckpt_path}",
|
|
|
f"trainer.strategy.process_group_backend={backend}",
|
|
|
- f"model@model.model={llama_base_config}",
|
|
|
- "tokenizer.pretrained_model_name_or_path=checkpoints",
|
|
|
f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
|
|
|
f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
|
|
|
f"model.optimizer.lr={llama_lr}",
|
|
|
@@ -514,8 +505,8 @@ def train_process(
|
|
|
f"trainer.precision={llama_precision}",
|
|
|
f"trainer.val_check_interval={llama_check_interval}",
|
|
|
f"trainer.accumulate_grad_batches={llama_grad_batches}",
|
|
|
- f"train_dataset.use_speaker={llama_use_speaker}",
|
|
|
- ] + ([f"+lora@model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
|
|
|
+ f"train_dataset.interactive_prob={llama_use_speaker}",
|
|
|
+ ] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
|
|
|
logger.info(train_cmd)
|
|
|
subprocess.run(train_cmd)
|
|
|
|
|
|
@@ -573,10 +564,7 @@ def list_decoder_models():
|
|
|
|
|
|
|
|
|
def list_llama_models():
|
|
|
- choices = [
|
|
|
- str(p).replace("\\", "/") for p in Path("checkpoints").glob("text2sem*.*")
|
|
|
- ]
|
|
|
- choices += [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
|
|
|
+ choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*.pth")]
|
|
|
if not choices:
|
|
|
logger.warning("No LLaMA model found")
|
|
|
return choices
|
|
|
@@ -627,16 +615,12 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
|
|
|
merge_cmd = [
|
|
|
PYTHON,
|
|
|
"tools/llama/merge_lora.py",
|
|
|
- "--llama-config",
|
|
|
- lora_llama_config,
|
|
|
"--lora-config",
|
|
|
"r_8_alpha_16",
|
|
|
- "--llama-weight",
|
|
|
- llama_weight,
|
|
|
"--lora-weight",
|
|
|
lora_weight,
|
|
|
"--output",
|
|
|
- llama_lora_output,
|
|
|
+ llama_lora_output + "_" + generate_folder_name(),
|
|
|
]
|
|
|
logger.info(merge_cmd)
|
|
|
subprocess.run(merge_cmd)
|
|
|
@@ -759,6 +743,7 @@ with gr.Blocks(
|
|
|
"Use LoRA can save GPU memory, but may reduce the quality of the model"
|
|
|
),
|
|
|
value=True,
|
|
|
+ interactive=False,
|
|
|
)
|
|
|
llama_ckpt = gr.Dropdown(
|
|
|
label=i18n("Select LLAMA ckpt"),
|
|
|
@@ -792,7 +777,6 @@ with gr.Blocks(
|
|
|
llama_base_config = gr.Dropdown(
|
|
|
label=i18n("Model Size"),
|
|
|
choices=[
|
|
|
- "text2semantic_agent",
|
|
|
"text2semantic_finetune",
|
|
|
],
|
|
|
value="text2semantic_finetune",
|
|
|
@@ -865,7 +849,7 @@ with gr.Blocks(
|
|
|
maximum=1.0,
|
|
|
step=0.05,
|
|
|
value=init_llama_yml["train_dataset"][
|
|
|
- "use_speaker"
|
|
|
+ "interactive_prob"
|
|
|
],
|
|
|
)
|
|
|
|
|
|
@@ -879,7 +863,7 @@ with gr.Blocks(
|
|
|
choices=[
|
|
|
"checkpoints/fish-speech-1.2/model.pth",
|
|
|
],
|
|
|
- value=init_llama_yml["ckpt_path"],
|
|
|
+ value="checkpoints/fish-speech-1.2/model.pth",
|
|
|
allow_custom_value=True,
|
|
|
interactive=True,
|
|
|
)
|
|
|
@@ -902,10 +886,9 @@ with gr.Blocks(
|
|
|
"Type the path or select from the dropdown"
|
|
|
),
|
|
|
choices=[
|
|
|
- "text2semantic_agent",
|
|
|
"text2semantic_finetune",
|
|
|
],
|
|
|
- value="text2semantic_agent",
|
|
|
+ value="text2semantic_finetune",
|
|
|
allow_custom_value=True,
|
|
|
)
|
|
|
with gr.Row(equal_height=False):
|
|
|
@@ -914,8 +897,8 @@ with gr.Blocks(
|
|
|
info=i18n(
|
|
|
"Type the path or select from the dropdown"
|
|
|
),
|
|
|
- value="checkpoints/merged.ckpt",
|
|
|
- choices=["checkpoints/merged.ckpt"],
|
|
|
+ value="checkpoints/merged",
|
|
|
+ choices=["checkpoints/merged"],
|
|
|
allow_custom_value=True,
|
|
|
interactive=True,
|
|
|
)
|