|
@@ -358,6 +358,8 @@ def train_process(
|
|
|
llama_data_max_length,
|
|
llama_data_max_length,
|
|
|
llama_precision,
|
|
llama_precision,
|
|
|
llama_check_interval,
|
|
llama_check_interval,
|
|
|
|
|
+ llama_grad_batches,
|
|
|
|
|
+ llama_use_speaker,
|
|
|
):
|
|
):
|
|
|
backend = "nccl" if sys.platform == "linux" else "gloo"
|
|
backend = "nccl" if sys.platform == "linux" else "gloo"
|
|
|
if option == "VQGAN" or option == "all":
|
|
if option == "VQGAN" or option == "all":
|
|
@@ -410,14 +412,13 @@ def train_process(
|
|
|
"tools/llama/build_dataset.py",
|
|
"tools/llama/build_dataset.py",
|
|
|
"--input",
|
|
"--input",
|
|
|
str(data_pre_output),
|
|
str(data_pre_output),
|
|
|
|
|
+ "--text-extension",
|
|
|
|
|
+ ".lab",
|
|
|
"--num-workers",
|
|
"--num-workers",
|
|
|
"16",
|
|
"16",
|
|
|
]
|
|
]
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- protos_list = [
|
|
|
|
|
- str(file) for file in Path("data/quantized-dataset-ft").glob("*.protos")
|
|
|
|
|
- ]
|
|
|
|
|
train_cmd = [
|
|
train_cmd = [
|
|
|
PYTHON,
|
|
PYTHON,
|
|
|
"fish_speech/train.py",
|
|
"fish_speech/train.py",
|
|
@@ -426,8 +427,8 @@ def train_process(
|
|
|
f"trainer.strategy.process_group_backend={backend}",
|
|
f"trainer.strategy.process_group_backend={backend}",
|
|
|
"model@model.model=dual_ar_2_codebook_medium",
|
|
"model@model.model=dual_ar_2_codebook_medium",
|
|
|
"tokenizer.pretrained_model_name_or_path=checkpoints",
|
|
"tokenizer.pretrained_model_name_or_path=checkpoints",
|
|
|
- f"train_dataset.proto_files={str(protos_list)}",
|
|
|
|
|
- f"val_dataset.proto_files={str(protos_list)}",
|
|
|
|
|
|
|
+ f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
|
|
|
|
|
+ f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
|
|
|
f"model.optimizer.lr={llama_lr}",
|
|
f"model.optimizer.lr={llama_lr}",
|
|
|
f"trainer.max_steps={llama_maxsteps}",
|
|
f"trainer.max_steps={llama_maxsteps}",
|
|
|
f"trainer.limit_val_batches={llama_limit_val_batches}",
|
|
f"trainer.limit_val_batches={llama_limit_val_batches}",
|
|
@@ -436,6 +437,8 @@ def train_process(
|
|
|
f"max_length={llama_data_max_length}",
|
|
f"max_length={llama_data_max_length}",
|
|
|
f"trainer.precision={llama_precision}",
|
|
f"trainer.precision={llama_precision}",
|
|
|
f"trainer.val_check_interval={llama_check_interval}",
|
|
f"trainer.val_check_interval={llama_check_interval}",
|
|
|
|
|
+ f"trainer.accumulate_grad_batches={llama_grad_batches}",
|
|
|
|
|
+ f"train_dataset.use_speaker={llama_use_speaker}",
|
|
|
]
|
|
]
|
|
|
logger.info(train_cmd)
|
|
logger.info(train_cmd)
|
|
|
subprocess.run(train_cmd)
|
|
subprocess.run(train_cmd)
|
|
@@ -654,6 +657,27 @@ with gr.Blocks(
|
|
|
"val_check_interval"
|
|
"val_check_interval"
|
|
|
],
|
|
],
|
|
|
)
|
|
)
|
|
|
|
|
+ with gr.Row(equal_height=False):
|
|
|
|
|
+ llama_grad_batches = gr.Slider(
|
|
|
|
|
+ label="accumulate_grad_batches",
|
|
|
|
|
+ interactive=True,
|
|
|
|
|
+ minimum=1,
|
|
|
|
|
+ maximum=20,
|
|
|
|
|
+ step=1,
|
|
|
|
|
+ value=init_llama_yml["trainer"][
|
|
|
|
|
+ "accumulate_grad_batches"
|
|
|
|
|
+ ],
|
|
|
|
|
+ )
|
|
|
|
|
+ llama_use_speaker = gr.Slider(
|
|
|
|
|
+ label="use_speaker_ratio",
|
|
|
|
|
+ interactive=True,
|
|
|
|
|
+ minimum=0.1,
|
|
|
|
|
+ maximum=1.0,
|
|
|
|
|
+ step=0.05,
|
|
|
|
|
+ value=init_llama_yml["train_dataset"][
|
|
|
|
|
+ "use_speaker"
|
|
|
|
|
+ ],
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
with gr.Tab("\U0001F9E0 进入推理界面"):
|
|
with gr.Tab("\U0001F9E0 进入推理界面"):
|
|
|
with gr.Column():
|
|
with gr.Column():
|
|
@@ -769,6 +793,8 @@ with gr.Blocks(
|
|
|
llama_data_max_length_slider,
|
|
llama_data_max_length_slider,
|
|
|
llama_precision_dropdown,
|
|
llama_precision_dropdown,
|
|
|
llama_check_interval_slider,
|
|
llama_check_interval_slider,
|
|
|
|
|
+ llama_grad_batches,
|
|
|
|
|
+ llama_use_speaker,
|
|
|
],
|
|
],
|
|
|
outputs=[train_error],
|
|
outputs=[train_error],
|
|
|
)
|
|
)
|