jch 1 week ago
parent
commit
179f4665cb

+ 11 - 5
README.md

@@ -10,15 +10,21 @@
 - python src/preprocess/merge_label_data.py --files data/微信昵称\&头像\ -\ 1-昌辉-完成.csv,data/微信昵称\&头像\ -\ 2-张博-完成.csv,data/微信昵称\&头像\ -\ 3-jh-完成.csv,data/微信昵称\&头像\ -\ 4-ln-完成.csv,data/微信昵称\&头像\ -\ 5-wz-完成.csv,data/微信昵称\&头像\ -\ 6-dm-完成.csv --output data/user_info_label.csv
 
 ## 1.4 下载头像
-- python src/preprocess/download_image.py --input_file data/user_info_label.csv --output_dir image
+- python src/preprocess/download_image.py --input_file data/user_info_label.csv --image_dir image
 
-## 1.5 格式化数据
+## 1.5 数据格式化
 - python src/preprocess/format_user_info.py --input_file data/user_info_label.csv --image_dir image --output_file data/user_info_format.csv
 
-## 1.6 生成样本数据
-- python src/preprocess/generate_qw2_5_lora_sft_data.py --input_file data/user_info_format.csv --train_file data/train_sft.json --test_file data/test_sft.json --test_rate 0.5
+## 1.6 拆分训练和测试数据
+- python src/preprocess/split_train_test.py --input_file data/user_info_format.csv --train_file data/user_info_format_train.csv --test_file data/user_info_format_test.csv
 
-## 1.7 finetune
+## 1.7 生成训练数据
+- python src/preprocess/generate_qw2_5_lora_sft_json.py --input_file data/user_info_format_train.csv --output_file data/train_sft.json
+
+## 1.8 finetune
 - [qwen2_5vl_lora_sft](https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/README_zh.md#%E5%A4%9A%E6%A8%A1%E6%80%81%E6%8C%87%E4%BB%A4%E7%9B%91%E7%9D%A3%E5%BE%AE%E8%B0%83)
 - [qwen2_5vl_lora_dpo](https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/README_zh.md#%E5%A4%9A%E6%A8%A1%E6%80%81-dpoorposimpo-%E8%AE%AD%E7%BB%83)
 - [qwen2_5vl_full_sft](https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/README_zh.md#%E5%A4%9A%E6%A8%A1%E6%80%81%E6%8C%87%E4%BB%A4%E7%9B%91%E7%9D%A3%E5%BE%AE%E8%B0%83-1)
+
+## 1.9 API推理
+- python src/preprocess/qw_api_url_inference.py --input data/user_info_format_test.csv --output_file eval.csv

+ 1 - 1
src/preprocess/download_image.py

@@ -39,7 +39,7 @@ def process(input_file, output_dir, suffix='jpeg'):
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('--input_file', required=True, help='input file')
-    parser.add_argument('--output_dir', required=True, help='output dir')
+    parser.add_argument('--image_dir', required=True, help='image dir')
     parser.add_argument('--suffix', default='jpeg', type=str, help='image suffix')
     args = parser.parse_args()
     print('\n\n')

+ 7 - 30
src/preprocess/generate_qw2_5_lora_sft_data.py → src/preprocess/generate_qw2_5_lora_sft_json.py

@@ -5,7 +5,6 @@ import argparse
 import json
 
 import pandas as pd
-from sklearn.model_selection import train_test_split
 from tqdm import tqdm
 
 image_tag = "<image>"
@@ -13,6 +12,8 @@ image_tag = "<image>"
 prompt = "根据用户在社交平台的昵称和头像,判断用户的性别倾向;当头像不可用时,仅使用昵称进行判断;直接返回[男性|女性|未知],不要返回多余的信息;昵称:%s,头像:%s"
 
 
+# LLaMA-Factory 训练模版
+# https://github.com/hiyouga/LLaMA-Factory/blob/main/examples/train_lora/qwen2_5vl_lora_sft.yaml
 def get_user_message(content):
     return {
         "content": prompt % (content, image_tag),
@@ -36,20 +37,9 @@ def get_conversation(row):
     }
 
 
-def split_train_test(df, rate):
-    train_df_ = pd.DataFrame()
-    test_df_ = pd.DataFrame()
-    if 0 < rate < 1:
-        train_df_, test_df_ = train_test_split(df, test_size=rate)
-    elif 0 == rate:
-        train_df_ = all_df
-    elif 1 == rate:
-        test_df_ = all_df
-    return train_df_, test_df_
-
-
-def process(df, output_file):
+def process(input_file, output_file):
     conversation_list = []
+    df = pd.read_csv(input_file)
     for index, row in tqdm(df.iterrows(), total=len(df)):
         conversation_list.append(get_conversation(row))
     if conversation_list:
@@ -60,24 +50,11 @@ def process(df, output_file):
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
     parser.add_argument('--input_file', required=True, help='input files')
-    parser.add_argument('--train_file', required=True, help='train file')
-    parser.add_argument('--test_file', required=True, help='test file')
-    parser.add_argument('--test_rate', default=0.5, type=float, help='test rate')
+    parser.add_argument('--output_file', required=True, help='output file')
     args = parser.parse_args()
     print('\n\n')
     print(args)
     print('\n\n')
 
-    # load
-    all_df = pd.read_csv(args.input_file)
-
-    # split
-    train_df, test_df = split_train_test(all_df, args.test_rate)
-
-    # train
-    if len(train_df.index) > 0:
-        process(train_df, args.train_file)
-
-    # test
-    if len(test_df.index) > 0:
-        process(test_df, args.test_file)
+    #
+    process(args.input_file, args.output_file)

+ 89 - 0
src/preprocess/qw_api_url_inference.py

@@ -0,0 +1,89 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+import argparse
+
+import pandas as pd
+from openai import OpenAI
+from tqdm import tqdm
+
+system = "根据用户在社交平台的昵称和头像,判断用户的性别倾向;当头像不可用时,仅使用昵称进行判断;直接返回[男性|女性|未知],不要返回多余的信息。"
+prompt = "昵称:%s, 头像:"
+
+host = "117.50.199.192"
+port = 8000
+base_url = 'http://%s:%d/v1' % (host, port)
+client = OpenAI(api_key="0", base_url=base_url)
+model_name_or_path = 'Qwen/Qwen2.5-VL-7B-Instruct'
+
+
+def get_system_message():
+    return {
+        "role": "system",
+        "content": system}
+
+
+def get_user_message(nick_name, image_url):
+    return {
+        "role": "user",
+        "content": [
+            {"type": "text", "text": prompt % nick_name},
+            {"type": "image_url", "image_url": {"url": image_url}}
+        ]
+    }
+
+
+def get_assistant_message(content):
+    return {
+        "role": "assistant",
+        "content": content}
+
+
+def get_messages(row):
+    messages = [
+        get_system_message(),
+        get_user_message(row["nick_name"], row["avatar_url"])
+        # get_assistant_message(row["gender"])
+    ]
+    return messages
+
+
+def inference(messages):
+    try:
+        result = client.chat.completions.create(messages=messages, model=model_name_or_path)
+        return result.choices[0].message.content
+    except Exception as e:
+        print(f"获取结果时发生错误:{messages} {e}")
+    return 'error'
+
+
+def process(df, output_file):
+    row_list = []
+    for index, row in tqdm(df.iterrows(), total=len(df)):
+        messages = get_messages(row)
+        result = inference(messages)
+        new_row = {
+            'uid': row['uid'],
+            'nick_name': row['nick_name'],
+            'gender': row['gender'],
+            'predict': result,
+            'valid': row['valid'],
+            'avatar_url': row['avatar_url']}
+        row_list.append(new_row)
+    if row_list:
+        new_df = pd.DataFrame(row_list)
+        new_df.to_csv(output_file, encoding='utf-8', index=False)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--input_file', required=True, help='input files')
+    parser.add_argument('--output_file', required=True, help='output file')
+    args = parser.parse_args()
+    print('\n\n')
+    print(args)
+    print('\n\n')
+
+    # load
+    all_df = pd.read_csv(args.input_file)
+    process(all_df, args.output_file)

+ 46 - 0
src/preprocess/split_train_test.py

@@ -0,0 +1,46 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+import argparse
+
+import pandas as pd
+from sklearn.model_selection import train_test_split
+
+
+def split_train_test(df, rate, random_state):
+    train_df_ = pd.DataFrame()
+    test_df_ = pd.DataFrame()
+    if 0 < rate < 1:
+        train_df_, test_df_ = train_test_split(df, test_size=rate, random_state=random_state)
+    elif 0 == rate:
+        train_df_ = all_df
+    elif 1 == rate:
+        test_df_ = all_df
+    return train_df_, test_df_
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--input_file', required=True, help='input files')
+    parser.add_argument('--train_file', required=True, help='train file')
+    parser.add_argument('--test_file', required=True, help='test file')
+    parser.add_argument('--test_rate', default=0.5, type=float, help='test rate')
+    parser.add_argument('--random_state', default=42, type=int, help='random state')
+    args = parser.parse_args()
+    print('\n\n')
+    print(args)
+    print('\n\n')
+
+    # load
+    all_df = pd.read_csv(args.input_file)
+
+    # split
+    train_df, test_df = split_train_test(all_df, args.test_rate, args.random_state)
+
+    # train
+    if len(train_df.index) > 0:
+        train_df.to_csv(args.train_file, encoding='utf-8', index=False)
+
+    # test
+    if len(test_df.index) > 0:
+        test_df.to_csv(args.test_file, encoding='utf-8', index=False)