瀏覽代碼

merge label

jch 2 月之前
父節點
當前提交
71af44a2f0

+ 1 - 1
README.md

@@ -7,7 +7,7 @@
 
 ## 1.3 合并标注后的数据
 - 将[标注完成的数据](https://w42nne6hzg.feishu.cn/wiki/BEMpwfvMriHrNakjwQncr3l6nUc?open_in_browser=true&sheet=rLAglD)下载到本地,并保存为csv格式<br>
-- python src/preprocess/merge_label_data.py --files data/微信昵称\&头像\ -\ 1-昌辉-完成.csv,data/微信昵称\&头像\ -\ 2-张博-完成.csv,data/微信昵称\&头像\ -\ 3-jh-完成.csv,data/微信昵称\&头像\ -\ 4-ln-完成.csv,data/微信昵称\&头像\ -\ 5-wz-完成.csv,data/微信昵称\&头像\ -\ 6-dm-完成.csv --output data/user_info_label.csv
+- python src/preprocess/merge_label_data.py --files data/微信昵称\&头像\ -\ 1-昌辉-完成.csv,data/微信昵称\&头像\ -\ 2-张博-完成.csv,data/微信昵称\&头像\ -\ 3-jh-完成.csv,data/微信昵称\&头像\ -\ 4-ln-完成.csv,data/微信昵称\&头像\ -\ 5-wz-完成.csv,data/微信昵称\&头像\ -\ 6-dm-完成.csv --output_file data/user_info_label.csv
 
 ## 1.4 下载头像
 - python src/preprocess/download_image.py --input_file data/user_info_label.csv --image_dir image

+ 1 - 1
src/preprocess/download_image.py

@@ -46,4 +46,4 @@ if __name__ == '__main__':
     print(args)
     print('\n\n')
 
-    process(args.input_file, args.output_dir, args.suffix)
+    process(args.input_file, args.image_dir, args.suffix)

+ 3 - 1
src/preprocess/format_user_info.py

@@ -27,7 +27,9 @@ def process(input_file, image_dir, output_file, suffix='jpeg'):
         uid = row['uid']
         nick_name = row['nick_name']
         image_path = '%s/%s.%s' % (image_dir, uid, suffix)
-        gender = row['gender']
+        gender = ''
+        if 'gender' in row.index:
+            gender = row['gender']
         avatar_url = row['avatar_url']
         if not os.path.exists(image_path):
             continue

+ 6 - 5
src/preprocess/generate_qw2_5_lora_sft_json.py

@@ -11,13 +11,13 @@ from tqdm import tqdm
 image_tag = "<image>"
 
 system = """
-根据用户在社交平台上的姓名/昵称和头像,分析用户的性别倾向;
+根据中老年用户在社交平台上的姓名/昵称和头像,分析性别倾向;
 先基于姓名/昵称分析,再基于头像分析;
 当头像不可用时,仅基于姓名/昵称分析;
 直接返回[男性|女性|未知],不要返回多余的信息。
 """
 system = system.replace("\n", "")
-prompt = "姓名/昵称:%s头像:%s"
+prompt = "姓名/昵称:%s, 头像:%s"
 
 
 # LLaMA-Factory 训练模版
@@ -54,10 +54,10 @@ def get_conversation(row):
     }
 
 
-def process(input_file, output_file):
+def process(input_file, output_file, random_state):
     conversation_list = []
     df = pd.read_csv(input_file)
-    df = shuffle(df)
+    df = shuffle(df, random_state=random_state)
     for index, row in tqdm(df.iterrows(), total=len(df)):
         conversation_list.append(get_conversation(row))
     if conversation_list:
@@ -69,10 +69,11 @@ if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('--input_file', required=True, help='input files')
     parser.add_argument('--output_file', required=True, help='output file')
+    parser.add_argument('--random_state', default=42, type=int, help='random state')
     args = parser.parse_args()
     print('\n\n')
     print(args)
     print('\n\n')
 
     #
-    process(args.input_file, args.output_file)
+    process(args.input_file, args.output_file, args.random_state)

+ 1 - 1
src/preprocess/merge_label_data.py

@@ -46,4 +46,4 @@ if __name__ == '__main__':
     print(args)
     print('\n\n')
 
-    process(args.files, args.output)
+    process(args.files, args.output_file)