|
@@ -5,7 +5,6 @@ import argparse
|
|
|
import json
|
|
|
|
|
|
import pandas as pd
|
|
|
-from sklearn.model_selection import train_test_split
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
image_tag = "<image>"
|
|
@@ -13,6 +12,8 @@ image_tag = "<image>"
|
|
|
prompt = "根据用户在社交平台的昵称和头像,判断用户的性别倾向;当头像不可用时,仅使用昵称进行判断;直接返回[男性|女性|未知],不要返回多余的信息;昵称:%s,头像:%s"
|
|
|
|
|
|
|
|
|
+# LLaMA-Factory 训练模版
|
|
|
+# https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/train_lora/qwen2_5vl_lora_sft.yaml
|
|
|
def get_user_message(content):
|
|
|
return {
|
|
|
"content": prompt % (content, image_tag),
|
|
@@ -36,20 +37,9 @@ def get_conversation(row):
|
|
|
}
|
|
|
|
|
|
|
|
|
-def split_train_test(df, rate):
|
|
|
- train_df_ = pd.DataFrame()
|
|
|
- test_df_ = pd.DataFrame()
|
|
|
- if 0 < rate < 1:
|
|
|
- train_df_, test_df_ = train_test_split(df, test_size=rate)
|
|
|
- elif 0 == rate:
|
|
|
- train_df_ = all_df
|
|
|
- elif 1 == rate:
|
|
|
- test_df_ = all_df
|
|
|
- return train_df_, test_df_
|
|
|
-
|
|
|
-
|
|
|
-def process(df, output_file):
|
|
|
+def process(input_file, output_file):
|
|
|
conversation_list = []
|
|
|
+ df = pd.read_csv(input_file)
|
|
|
for index, row in tqdm(df.iterrows(), total=len(df)):
|
|
|
conversation_list.append(get_conversation(row))
|
|
|
if conversation_list:
|
|
@@ -60,24 +50,11 @@ def process(df, output_file):
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument('--input_file', required=True, help='input files')
|
|
|
- parser.add_argument('--train_file', required=True, help='train file')
|
|
|
- parser.add_argument('--test_file', required=True, help='test file')
|
|
|
- parser.add_argument('--test_rate', default=0.5, type=float, help='test rate')
|
|
|
+ parser.add_argument('--output_file', required=True, help='output file')
|
|
|
args = parser.parse_args()
|
|
|
print('\n\n')
|
|
|
print(args)
|
|
|
print('\n\n')
|
|
|
|
|
|
- # load
|
|
|
- all_df = pd.read_csv(args.input_file)
|
|
|
-
|
|
|
- # split
|
|
|
- train_df, test_df = split_train_test(all_df, args.test_rate)
|
|
|
-
|
|
|
- # train
|
|
|
- if len(train_df.index) > 0:
|
|
|
- process(train_df, args.train_file)
|
|
|
-
|
|
|
- # test
|
|
|
- if len(test_df.index) > 0:
|
|
|
- process(test_df, args.test_file)
|
|
|
+ #
|
|
|
+ process(args.input_file, args.output_file)
|