|
@@ -614,6 +614,8 @@ def train_process(
|
|
|
else "text2semantic-sft-large-v1.1-4k.pth"
|
|
else "text2semantic-sft-large-v1.1-4k.pth"
|
|
|
)
|
|
)
|
|
|
lora_prefix = "lora_" if llama_use_lora else ""
|
|
lora_prefix = "lora_" if llama_use_lora else ""
|
|
|
|
|
+ llama_size = "large_" if ("large" in llama_base_config) else "medium_"
|
|
|
|
|
+ llama_name = lora_prefix + "text2semantic_" + llama_size + new_project
|
|
|
latest = next(
|
|
latest = next(
|
|
|
iter(
|
|
iter(
|
|
|
sorted(
|
|
sorted(
|
|
@@ -624,14 +626,14 @@ def train_process(
|
|
|
reverse=True,
|
|
reverse=True,
|
|
|
)
|
|
)
|
|
|
),
|
|
),
|
|
|
- (lora_prefix + "text2semantic_" + new_project),
|
|
|
|
|
|
|
+ llama_name,
|
|
|
)
|
|
)
|
|
|
project = (
|
|
project = (
|
|
|
- (lora_prefix + "text2semantic_" + new_project)
|
|
|
|
|
|
|
+ llama_name
|
|
|
if llama_ckpt == i18n("new")
|
|
if llama_ckpt == i18n("new")
|
|
|
else (
|
|
else (
|
|
|
latest
|
|
latest
|
|
|
- if llama_ckpt == i18n("latest") + "(not lora)"
|
|
|
|
|
|
|
+ if llama_ckpt == i18n("latest")
|
|
|
else Path(llama_ckpt).relative_to("results")
|
|
else Path(llama_ckpt).relative_to("results")
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
@@ -678,7 +680,7 @@ def tensorboard_process(
|
|
|
)
|
|
)
|
|
|
prefix = ["tensorboard"]
|
|
prefix = ["tensorboard"]
|
|
|
if Path("fishenv").exists():
|
|
if Path("fishenv").exists():
|
|
|
- prefix = ["fishenv/python.exe", "fishenv/Scripts/tensorboard.exe"]
|
|
|
|
|
|
|
+ prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"]
|
|
|
|
|
|
|
|
p_tensorboard = subprocess.Popen(
|
|
p_tensorboard = subprocess.Popen(
|
|
|
prefix
|
|
prefix
|
|
@@ -727,6 +729,13 @@ def list_llama_models():
|
|
|
return choices
|
|
return choices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def list_lora_llama_models():
|
|
|
|
|
+ choices = [str(p) for p in Path("results").glob("lora*/**/*.ckpt")]
|
|
|
|
|
+ if not choices:
|
|
|
|
|
+ logger.warning("No LoRA LLaMA model found")
|
|
|
|
|
+ return choices
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def fresh_decoder_model():
|
|
def fresh_decoder_model():
|
|
|
return gr.Dropdown(choices=list_decoder_models())
|
|
return gr.Dropdown(choices=list_decoder_models())
|
|
|
|
|
|
|
@@ -745,11 +754,14 @@ def fresh_vits_ckpt():
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
-def fresh_llama_ckpt():
|
|
|
|
|
|
|
+def fresh_llama_ckpt(llama_use_lora):
|
|
|
return gr.Dropdown(
|
|
return gr.Dropdown(
|
|
|
choices=[i18n("latest"), i18n("new")]
|
|
choices=[i18n("latest"), i18n("new")]
|
|
|
- + [str(p) for p in Path("results").glob("text2sem*/")]
|
|
|
|
|
- + [str(p) for p in Path("results").glob("lora_*/")]
|
|
|
|
|
|
|
+ + (
|
|
|
|
|
+ [str(p) for p in Path("results").glob("text2sem*/")]
|
|
|
|
|
+ if not llama_use_lora
|
|
|
|
|
+ else [str(p) for p in Path("results").glob("lora_*/")]
|
|
|
|
|
+ )
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1063,13 +1075,13 @@ with gr.Blocks(
|
|
|
)
|
|
)
|
|
|
llama_ckpt = gr.Dropdown(
|
|
llama_ckpt = gr.Dropdown(
|
|
|
label=i18n("Select LLAMA ckpt"),
|
|
label=i18n("Select LLAMA ckpt"),
|
|
|
- choices=[i18n("latest") + "(not lora)", i18n("new")]
|
|
|
|
|
|
|
+ choices=[i18n("latest"), i18n("new")]
|
|
|
+ [
|
|
+ [
|
|
|
str(p)
|
|
str(p)
|
|
|
for p in Path("results").glob("text2sem*/")
|
|
for p in Path("results").glob("text2sem*/")
|
|
|
]
|
|
]
|
|
|
+ [str(p) for p in Path("results").glob("lora*/")],
|
|
+ [str(p) for p in Path("results").glob("lora*/")],
|
|
|
- value=i18n("latest") + "(not lora)",
|
|
|
|
|
|
|
+ value=i18n("latest"),
|
|
|
interactive=True,
|
|
interactive=True,
|
|
|
)
|
|
)
|
|
|
with gr.Row(equal_height=False):
|
|
with gr.Row(equal_height=False):
|
|
@@ -1100,13 +1112,13 @@ with gr.Blocks(
|
|
|
)
|
|
)
|
|
|
llama_data_num_workers_slider = gr.Slider(
|
|
llama_data_num_workers_slider = gr.Slider(
|
|
|
label=i18n("Number of Workers"),
|
|
label=i18n("Number of Workers"),
|
|
|
- minimum=0,
|
|
|
|
|
|
|
+ minimum=1,
|
|
|
maximum=16,
|
|
maximum=16,
|
|
|
step=1,
|
|
step=1,
|
|
|
value=(
|
|
value=(
|
|
|
init_llama_yml["data"]["num_workers"]
|
|
init_llama_yml["data"]["num_workers"]
|
|
|
if sys.platform == "linux"
|
|
if sys.platform == "linux"
|
|
|
- else 0
|
|
|
|
|
|
|
+ else 1
|
|
|
),
|
|
),
|
|
|
)
|
|
)
|
|
|
with gr.Row(equal_height=False):
|
|
with gr.Row(equal_height=False):
|
|
@@ -1177,7 +1189,10 @@ with gr.Blocks(
|
|
|
info=i18n(
|
|
info=i18n(
|
|
|
"Type the path or select from the dropdown"
|
|
"Type the path or select from the dropdown"
|
|
|
),
|
|
),
|
|
|
- choices=[init_llama_yml["ckpt_path"]],
|
|
|
|
|
|
|
+ choices=[
|
|
|
|
|
+ "checkpoints/text2semantic-sft-large-v1.1-4k.pth",
|
|
|
|
|
+ "checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
|
|
|
|
|
+ ],
|
|
|
value=init_llama_yml["ckpt_path"],
|
|
value=init_llama_yml["ckpt_path"],
|
|
|
allow_custom_value=True,
|
|
allow_custom_value=True,
|
|
|
interactive=True,
|
|
interactive=True,
|
|
@@ -1390,14 +1405,7 @@ with gr.Blocks(
|
|
|
'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
|
|
'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
|
|
|
)
|
|
)
|
|
|
if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
|
|
if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
|
|
|
- infer_decoder_model.change(
|
|
|
|
|
- fn=change_decoder_config,
|
|
|
|
|
- inputs=[infer_decoder_model],
|
|
|
|
|
- outputs=[infer_decoder_config],
|
|
|
|
|
- )
|
|
|
|
|
- infer_llama_model.change(
|
|
|
|
|
- fn=change_llama_config, inputs=[infer_llama_model], outputs=[infer_llama_config]
|
|
|
|
|
- )
|
|
|
|
|
|
|
+
|
|
|
train_btn.click(
|
|
train_btn.click(
|
|
|
fn=train_process,
|
|
fn=train_process,
|
|
|
inputs=[
|
|
inputs=[
|
|
@@ -1445,6 +1453,14 @@ with gr.Blocks(
|
|
|
outputs=[train_error],
|
|
outputs=[train_error],
|
|
|
)
|
|
)
|
|
|
tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
|
|
tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
|
|
|
|
|
+ infer_decoder_model.change(
|
|
|
|
|
+ fn=change_decoder_config,
|
|
|
|
|
+ inputs=[infer_decoder_model],
|
|
|
|
|
+ outputs=[infer_decoder_config],
|
|
|
|
|
+ )
|
|
|
|
|
+ infer_llama_model.change(
|
|
|
|
|
+ fn=change_llama_config, inputs=[infer_llama_model], outputs=[infer_llama_config]
|
|
|
|
|
+ )
|
|
|
infer_decoder_model.change(
|
|
infer_decoder_model.change(
|
|
|
fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
|
|
fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
|
|
|
)
|
|
)
|
|
@@ -1462,7 +1478,20 @@ with gr.Blocks(
|
|
|
)
|
|
)
|
|
|
vqgan_ckpt.change(fn=fresh_vqgan_ckpt, inputs=[], outputs=[vqgan_ckpt])
|
|
vqgan_ckpt.change(fn=fresh_vqgan_ckpt, inputs=[], outputs=[vqgan_ckpt])
|
|
|
vits_ckpt.change(fn=fresh_vits_ckpt, inputs=[], outputs=[vits_ckpt])
|
|
vits_ckpt.change(fn=fresh_vits_ckpt, inputs=[], outputs=[vits_ckpt])
|
|
|
- llama_ckpt.change(fn=fresh_llama_ckpt, inputs=[], outputs=[llama_ckpt])
|
|
|
|
|
|
|
+ llama_use_lora.change(
|
|
|
|
|
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
|
|
|
|
|
+ )
|
|
|
|
|
+ llama_ckpt.change(
|
|
|
|
|
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
|
|
|
|
|
+ )
|
|
|
|
|
+ lora_weight.change(
|
|
|
|
|
+ fn=change_llama_config, inputs=[lora_weight], outputs=[lora_llama_config]
|
|
|
|
|
+ )
|
|
|
|
|
+ lora_weight.change(
|
|
|
|
|
+ fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
|
|
|
|
|
+ inputs=[],
|
|
|
|
|
+ outputs=[lora_weight],
|
|
|
|
|
+ )
|
|
|
llama_lora_merge_btn.click(
|
|
llama_lora_merge_btn.click(
|
|
|
fn=llama_lora_merge,
|
|
fn=llama_lora_merge,
|
|
|
inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
|
|
inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
|