|
@@ -0,0 +1,83 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# coding=utf-8
|
|
|
+
|
|
|
+import argparse
|
|
|
+import json
|
|
|
+
|
|
|
+import pandas as pd
|
|
|
+from sklearn.model_selection import train_test_split
|
|
|
+from tqdm import tqdm
|
|
|
+
|
|
|
+image_tag = "<image>"
|
|
|
+
|
|
|
+prompt = "根据用户在社交平台的昵称和头像,判断用户的性别倾向,返回[男性|女性|未知];昵称:%s,头像:%s。"
|
|
|
+
|
|
|
+
|
|
|
+def get_user_message(content):
|
|
|
+ return {
|
|
|
+ "content": prompt % (content, image_tag),
|
|
|
+ "role": "user"}
|
|
|
+
|
|
|
+
|
|
|
+def get_assistant_message(content):
|
|
|
+ return {
|
|
|
+ "content": content,
|
|
|
+ "role": "assistant"}
|
|
|
+
|
|
|
+
|
|
|
+def get_conversation(row):
|
|
|
+ messages = [
|
|
|
+ get_user_message(row["nick_name"]),
|
|
|
+ get_assistant_message(row["gender"])
|
|
|
+ ]
|
|
|
+ return {
|
|
|
+ "messages": messages,
|
|
|
+ "images": row["image"].split(",")
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def split_train_test(df, rate):
|
|
|
+ train_df_ = pd.DataFrame()
|
|
|
+ test_df_ = pd.DataFrame()
|
|
|
+ if 0 < rate < 1:
|
|
|
+ train_df_, test_df_ = train_test_split(df, test_size=rate)
|
|
|
+ elif 0 == rate:
|
|
|
+ train_df_ = all_df
|
|
|
+ elif 1 == rate:
|
|
|
+ test_df_ = all_df
|
|
|
+ return train_df_, test_df_
|
|
|
+
|
|
|
+
|
|
|
+def process(df, output_file):
|
|
|
+ conversation_list = []
|
|
|
+ for index, row in tqdm(df.iterrows(), total=len(df)):
|
|
|
+ conversation_list.append(get_conversation(row))
|
|
|
+ if conversation_list:
|
|
|
+ with open(output_file, "w") as out_fp:
|
|
|
+ json.dump(conversation_list, out_fp, ensure_ascii=False, indent=4)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+ parser.add_argument('--input_file', required=True, help='input files')
|
|
|
+ parser.add_argument('--train_file', required=True, help='train file')
|
|
|
+ parser.add_argument('--test_file', required=True, help='test file')
|
|
|
+ parser.add_argument('--test_rate', default=0.5, type=float, help='test rate')
|
|
|
+ args = parser.parse_args()
|
|
|
+ print('\n\n')
|
|
|
+ print(args)
|
|
|
+ print('\n\n')
|
|
|
+
|
|
|
+ # load
|
|
|
+ all_df = pd.read_csv(args.input_file)
|
|
|
+
|
|
|
+ # split
|
|
|
+ train_df, test_df = split_train_test(all_df, args.test_rate)
|
|
|
+
|
|
|
+ # train
|
|
|
+ if len(train_df.index) > 0:
|
|
|
+ process(train_df, args.train_file)
|
|
|
+
|
|
|
+ # test
|
|
|
+ if len(test_df.index) > 0:
|
|
|
+ process(test_df, args.test_file)
|