|
|
@@ -555,16 +555,16 @@ def fresh_tb_dir():
|
|
|
|
|
|
|
|
|
def list_decoder_models():
|
|
|
- paths = [str(p) for p in Path("checkpoints").glob("vq*.*")] + [
|
|
|
- str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")
|
|
|
- ]
|
|
|
+ paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
|
|
|
if not paths:
|
|
|
logger.warning("No decoder model found")
|
|
|
return paths
|
|
|
|
|
|
|
|
|
def list_llama_models():
|
|
|
- choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*.pth")]
|
|
|
+ choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
|
|
|
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
|
|
|
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
|
|
|
if not choices:
|
|
|
logger.warning("No LLaMA model found")
|
|
|
return choices
|
|
|
@@ -593,11 +593,7 @@ def fresh_llama_ckpt(llama_use_lora):
|
|
|
|
|
|
|
|
|
def fresh_llama_model():
|
|
|
- choices = [
|
|
|
- str(p).replace("\\", "/") for p in Path("checkpoints").glob("text2sem*.*")
|
|
|
- ]
|
|
|
- choices += [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
|
|
|
- return gr.Dropdown(choices=choices)
|
|
|
+ return gr.Dropdown(choices=list_llama_models())
|
|
|
|
|
|
|
|
|
def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
|
|
|
@@ -627,6 +623,39 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
|
|
|
return build_html_ok_message(i18n("Merge successfully"))
|
|
|
|
|
|
|
|
|
+def llama_quantify(llama_weight, quantify_mode):
|
|
|
+ if llama_weight is None or not Path(llama_weight).exists():
|
|
|
+ return build_html_error_message(
|
|
|
+ i18n(
|
|
|
+ "Path error, please check the model file exists in the corresponding path"
|
|
|
+ )
|
|
|
+ )
|
|
|
+ now = generate_folder_name()
|
|
|
+ quantify_cmd = [
|
|
|
+ PYTHON,
|
|
|
+ "tools/llama/quantize.py",
|
|
|
+ "--checkpoint-path",
|
|
|
+ llama_weight,
|
|
|
+ "--mode",
|
|
|
+ quantify_mode,
|
|
|
+ "--timestamp",
|
|
|
+ now,
|
|
|
+ ]
|
|
|
+ logger.info(quantify_cmd)
|
|
|
+ subprocess.run(quantify_cmd)
|
|
|
+ if quantify_mode == "int8":
|
|
|
+ quantize_path = str(
|
|
|
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ quantize_path = str(
|
|
|
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
|
|
|
+ )
|
|
|
+ return build_html_ok_message(
|
|
|
+ i18n("Quantify successfully") + f"Path: {quantize_path}"
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
|
|
|
init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
|
|
|
|
|
|
@@ -907,7 +936,34 @@ with gr.Blocks(
|
|
|
value=i18n("Merge"), variant="primary"
|
|
|
)
|
|
|
|
|
|
- with gr.Tab(label="Tensorboard", id=5):
|
|
|
+ with gr.Tab(label=i18n("Model Quantization"), id=5):
|
|
|
+ with gr.Row(equal_height=False):
|
|
|
+ llama_weight_to_quantify = gr.Dropdown(
|
|
|
+ label=i18n("Base LLAMA Model"),
|
|
|
+ info=i18n(
|
|
|
+ "Type the path or select from the dropdown"
|
|
|
+ ),
|
|
|
+ choices=list_llama_models(),
|
|
|
+ value="checkpoints/fish-speech-1.2",
|
|
|
+ allow_custom_value=True,
|
|
|
+ interactive=True,
|
|
|
+ )
|
|
|
+ quantify_mode = gr.Dropdown(
|
|
|
+ label=i18n("Post-quantification Precision"),
|
|
|
+ info=i18n(
|
|
|
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
|
|
|
+ ),
|
|
|
+ choices=["int8", "int4"],
|
|
|
+ value="int8",
|
|
|
+ allow_custom_value=False,
|
|
|
+ interactive=True,
|
|
|
+ )
|
|
|
+ with gr.Row(equal_height=False):
|
|
|
+ llama_quantify_btn = gr.Button(
|
|
|
+ value=i18n("Quantify"), variant="primary"
|
|
|
+ )
|
|
|
+
|
|
|
+ with gr.Tab(label="Tensorboard", id=6):
|
|
|
with gr.Row(equal_height=False):
|
|
|
tb_host = gr.Textbox(
|
|
|
label=i18n("Tensorboard Host"), value="127.0.0.1"
|
|
|
@@ -1122,6 +1178,11 @@ with gr.Blocks(
|
|
|
inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
|
|
|
outputs=[train_error],
|
|
|
)
|
|
|
+ llama_quantify_btn.click(
|
|
|
+ fn=llama_quantify,
|
|
|
+ inputs=[llama_weight_to_quantify, quantify_mode],
|
|
|
+ outputs=[train_error],
|
|
|
+ )
|
|
|
infer_checkbox.change(
|
|
|
fn=change_infer,
|
|
|
inputs=[
|