|
|
@@ -251,7 +251,13 @@ def new_explorer(data_path, max_depth):
|
|
|
)
|
|
|
|
|
|
|
|
|
-def add_item(folder: str, method: str, label_lang: str):
|
|
|
+def add_item(
|
|
|
+ folder: str,
|
|
|
+ method: str,
|
|
|
+ label_lang: str,
|
|
|
+ if_initial_prompt: bool,
|
|
|
+ initial_prompt: str | None,
|
|
|
+):
|
|
|
folder = folder.strip(" ").strip('"')
|
|
|
|
|
|
folder_path = Path(folder)
|
|
|
@@ -260,7 +266,10 @@ def add_item(folder: str, method: str, label_lang: str):
|
|
|
if folder_path.is_dir():
|
|
|
items.append(folder)
|
|
|
dict_items[folder] = dict(
|
|
|
- type="folder", method=method, label_lang=label_lang
|
|
|
+ type="folder",
|
|
|
+ method=method,
|
|
|
+ label_lang=label_lang,
|
|
|
+ initial_prompt=initial_prompt if if_initial_prompt else None,
|
|
|
)
|
|
|
elif folder:
|
|
|
err = folder
|
|
|
@@ -269,7 +278,8 @@ def add_item(folder: str, method: str, label_lang: str):
|
|
|
)
|
|
|
|
|
|
formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
|
|
|
- logger.info(formatted_data)
|
|
|
+ logger.info("After Adding: " + formatted_data)
|
|
|
+ gr.Info(formatted_data)
|
|
|
return gr.Checkboxgroup(choices=items), build_html_ok_message(
|
|
|
i18n("Added path successfully!")
|
|
|
)
|
|
|
@@ -283,6 +293,7 @@ def remove_items(selected_items):
|
|
|
items = [item for item in items if item in dict_items.keys()]
|
|
|
formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
|
|
|
logger.info(formatted_data)
|
|
|
+ gr.Warning("After Removing: " + formatted_data)
|
|
|
return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
|
|
|
i18n("Removed path successfully!")
|
|
|
)
|
|
|
@@ -351,6 +362,7 @@ def list_copy(list_file_path, method):
|
|
|
def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
|
|
|
global dict_items
|
|
|
data_path = Path(data_path)
|
|
|
+ gr.Warning("Pre-processing begins...")
|
|
|
for item, content in dict_items.items():
|
|
|
item_path = Path(item)
|
|
|
tar_path = data_path / item_path.name
|
|
|
@@ -369,23 +381,31 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
|
|
|
convert_to_mono_in_place(audio_path)
|
|
|
|
|
|
cur_lang = content["label_lang"]
|
|
|
+ initial_prompt = content["initial_prompt"]
|
|
|
+
|
|
|
+ transcribe_cmd = [
|
|
|
+ PYTHON,
|
|
|
+ "tools/whisper_asr.py",
|
|
|
+ "--model-size",
|
|
|
+ label_model,
|
|
|
+ "--device",
|
|
|
+ label_device,
|
|
|
+ "--audio-dir",
|
|
|
+ tar_path,
|
|
|
+ "--save-dir",
|
|
|
+ tar_path,
|
|
|
+ "--language",
|
|
|
+ cur_lang,
|
|
|
+ ]
|
|
|
+
|
|
|
+ if initial_prompt is not None:
|
|
|
+ transcribe_cmd += ["--initial-prompt", initial_prompt]
|
|
|
+
|
|
|
if cur_lang != "IGNORE":
|
|
|
try:
|
|
|
+ gr.Warning("Begin To Transcribe")
|
|
|
subprocess.run(
|
|
|
- [
|
|
|
- PYTHON,
|
|
|
- "tools/whisper_asr.py",
|
|
|
- "--model-size",
|
|
|
- label_model,
|
|
|
- "--device",
|
|
|
- label_device,
|
|
|
- "--audio-dir",
|
|
|
- tar_path,
|
|
|
- "--save-dir",
|
|
|
- tar_path,
|
|
|
- "--language",
|
|
|
- cur_lang,
|
|
|
- ],
|
|
|
+ transcribe_cmd,
|
|
|
env=env,
|
|
|
)
|
|
|
except Exception:
|
|
|
@@ -408,8 +428,6 @@ def generate_folder_name():
|
|
|
def train_process(
|
|
|
data_path: str,
|
|
|
option: str,
|
|
|
- min_duration: float,
|
|
|
- max_duration: float,
|
|
|
# llama config
|
|
|
llama_ckpt,
|
|
|
llama_base_config,
|
|
|
@@ -428,13 +446,17 @@ def train_process(
|
|
|
backend = "nccl" if sys.platform == "linux" else "gloo"
|
|
|
|
|
|
new_project = generate_folder_name()
|
|
|
-
|
|
|
print("New Project Name: ", new_project)
|
|
|
|
|
|
- if min_duration > max_duration:
|
|
|
- min_duration, max_duration = max_duration, min_duration
|
|
|
+ if option == "VQGAN":
|
|
|
+ msg = "Skipped VQGAN Training."
|
|
|
+ gr.Warning(msg)
|
|
|
+ logger.info(msg)
|
|
|
|
|
|
if option == "LLAMA":
|
|
|
+ msg = "LLAMA Training begins..."
|
|
|
+ gr.Warning(msg)
|
|
|
+ logger.info(msg)
|
|
|
subprocess.run(
|
|
|
[
|
|
|
PYTHON,
|
|
|
@@ -565,13 +587,16 @@ def list_llama_models():
|
|
|
choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
|
|
|
choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
|
|
|
choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
|
|
|
+ choices = sorted(choices, reverse=True)
|
|
|
if not choices:
|
|
|
logger.warning("No LLaMA model found")
|
|
|
return choices
|
|
|
|
|
|
|
|
|
def list_lora_llama_models():
|
|
|
- choices = [str(p) for p in Path("results").glob("lora*/**/*.ckpt")]
|
|
|
+ choices = sorted(
|
|
|
+ [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
|
|
|
+ )
|
|
|
if not choices:
|
|
|
logger.warning("No LoRA LLaMA model found")
|
|
|
return choices
|
|
|
@@ -607,7 +632,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
|
|
|
"Path error, please check the model file exists in the corresponding path"
|
|
|
)
|
|
|
)
|
|
|
-
|
|
|
+ gr.Warning("Merging begins...")
|
|
|
merge_cmd = [
|
|
|
PYTHON,
|
|
|
"tools/llama/merge_lora.py",
|
|
|
@@ -630,6 +655,9 @@ def llama_quantify(llama_weight, quantify_mode):
|
|
|
"Path error, please check the model file exists in the corresponding path"
|
|
|
)
|
|
|
)
|
|
|
+
|
|
|
+ gr.Warning("Quantifying begins...")
|
|
|
+
|
|
|
now = generate_folder_name()
|
|
|
quantify_cmd = [
|
|
|
PYTHON,
|
|
|
@@ -690,30 +718,6 @@ with gr.Blocks(
|
|
|
if_label = gr.Checkbox(
|
|
|
label=i18n("Open Labeler WebUI"), scale=0, show_label=True
|
|
|
)
|
|
|
- with gr.Row():
|
|
|
- min_duration = gr.Slider(
|
|
|
- label=i18n("Minimum Audio Duration"),
|
|
|
- value=1.5,
|
|
|
- step=0.1,
|
|
|
- minimum=0.4,
|
|
|
- maximum=30,
|
|
|
- )
|
|
|
- max_duration = gr.Slider(
|
|
|
- label=i18n("Maximum Audio Duration"),
|
|
|
- value=30,
|
|
|
- step=0.1,
|
|
|
- minimum=0.4,
|
|
|
- maximum=30,
|
|
|
- )
|
|
|
-
|
|
|
- with gr.Row():
|
|
|
- add_button = gr.Button(
|
|
|
- "\U000027A1 " + i18n("Add to Processing Area"),
|
|
|
- variant="primary",
|
|
|
- )
|
|
|
- remove_button = gr.Button(
|
|
|
- "\U000026D4 " + i18n("Remove Selected Data")
|
|
|
- )
|
|
|
|
|
|
with gr.Row():
|
|
|
label_device = gr.Dropdown(
|
|
|
@@ -728,9 +732,9 @@ with gr.Blocks(
|
|
|
label_model = gr.Dropdown(
|
|
|
label=i18n("Whisper Model"),
|
|
|
info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
|
|
|
- choices=["large-v3"],
|
|
|
+ choices=["large-v3", "medium"],
|
|
|
value="large-v3",
|
|
|
- interactive=False,
|
|
|
+ interactive=True,
|
|
|
)
|
|
|
label_radio = gr.Dropdown(
|
|
|
label=i18n("Optional Label Language"),
|
|
|
@@ -738,9 +742,9 @@ with gr.Blocks(
|
|
|
"If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
|
|
|
),
|
|
|
choices=[
|
|
|
- (i18n("Chinese"), "ZH"),
|
|
|
- (i18n("English"), "EN"),
|
|
|
- (i18n("Japanese"), "JA"),
|
|
|
+ (i18n("Chinese"), "zh"),
|
|
|
+ (i18n("English"), "en"),
|
|
|
+ (i18n("Japanese"), "ja"),
|
|
|
(i18n("Disabled"), "IGNORE"),
|
|
|
(i18n("auto"), "auto"),
|
|
|
],
|
|
|
@@ -748,6 +752,31 @@ with gr.Blocks(
|
|
|
interactive=True,
|
|
|
)
|
|
|
|
|
|
+ with gr.Row():
|
|
|
+ if_initial_prompt = gr.Checkbox(
|
|
|
+ value=False,
|
|
|
+ label=i18n("Enable Initial Prompt"),
|
|
|
+ min_width=120,
|
|
|
+ scale=0,
|
|
|
+ )
|
|
|
+ initial_prompt = gr.Textbox(
|
|
|
+ label=i18n("Initial Prompt"),
|
|
|
+ info=i18n(
|
|
|
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model."
|
|
|
+ ),
|
|
|
+ placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
|
|
|
+ interactive=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ with gr.Row():
|
|
|
+ add_button = gr.Button(
|
|
|
+ "\U000027A1 " + i18n("Add to Processing Area"),
|
|
|
+ variant="primary",
|
|
|
+ )
|
|
|
+ remove_button = gr.Button(
|
|
|
+ "\U000026D4 " + i18n("Remove Selected Data")
|
|
|
+ )
|
|
|
+
|
|
|
with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
|
|
|
with gr.Row():
|
|
|
model_type_radio = gr.Radio(
|
|
|
@@ -1103,7 +1132,7 @@ with gr.Blocks(
|
|
|
llama_page.select(lambda: "LLAMA", None, model_type_radio)
|
|
|
add_button.click(
|
|
|
fn=add_item,
|
|
|
- inputs=[textbox, output_radio, label_radio],
|
|
|
+ inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
|
|
|
outputs=[checkbox_group, error],
|
|
|
)
|
|
|
remove_button.click(
|
|
|
@@ -1116,14 +1145,16 @@ with gr.Blocks(
|
|
|
'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
|
|
|
)
|
|
|
if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
|
|
|
-
|
|
|
+ if_initial_prompt.change(
|
|
|
+ fn=lambda x: gr.Textbox(value="", interactive=x),
|
|
|
+ inputs=[if_initial_prompt],
|
|
|
+ outputs=[initial_prompt],
|
|
|
+ )
|
|
|
train_btn.click(
|
|
|
fn=train_process,
|
|
|
inputs=[
|
|
|
train_box,
|
|
|
model_type_radio,
|
|
|
- min_duration,
|
|
|
- max_duration,
|
|
|
# llama config
|
|
|
llama_ckpt,
|
|
|
llama_base_config,
|