jch 1 week ago
commit
20509cbb3a

+ 4 - 0
.gitignore

@@ -0,0 +1,4 @@
+data/
+image/
+**/__pycache__
+.idea/

+ 19 - 0
readme.txt

@@ -0,0 +1,19 @@
+1.finetune性别大模型(昵称&头像)
+    a.下载随机数据
+        python src/preprocess/download_user_info.py --num 4000 --output_file data/user_info.csv
+
+    b.标注数据
+        根据昵称(nick_name)和头像(avatar_url), 标注data/user_info.csv的gender[男性|女性|未知]
+
+    c.合并标注后的数据
+        将标注后的数据下载到本地(csv)
+        python src/preprocess/merge_label_data.py --files data/微信昵称\&头像\ -\ 1-昌辉-完成.csv,data/微信昵称\&头像\ -\ 2-张博-完成.csv,data/微信昵称\&头像\ -\ 3-jh-完成.csv,data/微信昵称\&头像\ -\ 4-ln-完成.csv,data/微信昵称\&头像\ -\ 5-wz-完成.csv,data/微信昵称\&头像\ -\ 6-dm-完成.csv --output data/user_info_label.csv
+
+    d.下载头像
+        python src/preprocess/download_image.py --input_file data/user_info_label.csv --output_dir image
+
+    e.格式化数据
+        python src/preprocess/format_user_info.py --input_file data/user_info_label.csv --image_dir image --output_file data/user_info_format.csv
+
+    f.生成样本数据
+        python src/preprocess/generate_qw2_5_lora_sft_data.py --input_file data/user_info_format.csv --train_file data/train_sft.csv --test_file data/test_sft.csv --test_rate 0.5

+ 49 - 0
src/preprocess/download_image.py

@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+import argparse
+import os
+
+import pandas as pd
+import requests
+from tqdm import tqdm
+
+
+def download_image_from_url(image_url, image_save_path):
+    try:
+        response = requests.get(image_url, stream=True)
+        # 确保请求成功 (状态码为 200)
+        response.raise_for_status()
+
+        # 以二进制写模式打开文件
+        with open(image_save_path, 'wb') as f:
+            # 迭代响应内容,分块写入文件
+            for chunk in response.iter_content(chunk_size=8192):
+                f.write(chunk)
+        print(f"图片已成功下载到:{image_save_path}")
+    except requests.exceptions.RequestException as e:
+        print(f"下载图片时发生错误:{image_save_path} {e}")
+
+
+def process(input_file, output_dir, suffix='jpeg'):
+    df = pd.read_csv(input_file)
+    for index, row in tqdm(df.iterrows(), total=len(df)):
+        uid = row['uid']
+        nick_name = row['nick_name']
+        avatar_url = row['avatar_url']
+        avatar_save_path = '%s/%s.%s' % (output_dir, uid, suffix)
+        if not os.path.exists(avatar_save_path):
+            download_image_from_url(avatar_url, avatar_save_path)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--input_file', required=True, help='input file')
+    parser.add_argument('--output_dir', required=True, help='output dir')
+    parser.add_argument('--suffix', default='jpeg', type=str, help='image suffix')
+    args = parser.parse_args()
+    print('\n\n')
+    print(args)
+    print('\n\n')
+
+    process(args.input_file, args.output_dir, args.suffix)

+ 58 - 0
src/preprocess/download_user_info.py

@@ -0,0 +1,58 @@
+#!/usr/bin/env python
+# coding=utf-8
+import argparse
+
+from odps_module import ODPSClient
+
+odps_client = ODPSClient()
+
+
+# /132 结尾的是压缩图片
+# /0 结尾的原始图片
+def get_sql(n):
+    sql = f"""
+    SELECT  t1.uid
+            ,t2.nick_name
+            ,t2.avatar_url
+    FROM    (
+                SELECT  DISTINCT uid
+                FROM    loghubods.dwd_recsys_alg_exposure_base_20250108
+                WHERE   dt >= '20250905'
+                AND     apptype NOT IN ('12')
+                AND     uid IS NOT NULL
+            ) t1
+    LEFT JOIN   (
+                SELECT  uid
+                        ,nick_name
+                        ,REGEXP_REPLACE(avatar_url,'132$','0') as avatar_url
+                FROM    videoods.wx_user_wechar_detail
+                WHERE   nick_name IS NOT NULL
+                AND     nick_name NOT IN ('','微信用户')
+                AND     avatar_url RLIKE '^https.*wx.qlogo.cn'
+            ) t2
+    ON      t1.uid = t2.uid
+    WHERE   t2.uid IS NOT NULL
+    CLUSTER BY rand()
+    LIMIT   {n};
+    """
+    return sql
+
+
+def download_from_odps(sql, output_file):
+    odps_client.execute_sql_result_save_file(sql, output_file)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--num', default=2000, type=int, help='download number')
+    parser.add_argument('--output_file', required=True, type=str, help='output file')
+    args = parser.parse_args()
+    print('\n\n')
+    print(args)
+    print('\n\n')
+
+    # sql
+    user_sql = get_sql(args.num)
+
+    # download
+    download_from_odps(user_sql, args.output_file)

+ 62 - 0
src/preprocess/format_user_info.py

@@ -0,0 +1,62 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+import argparse
+import csv
+import os
+
+import pandas as pd
+from PIL import Image
+from tqdm import tqdm
+
+
+# 空白图片(暂时无法查看)
+def validation_image(image_path, invalid_width=120, invalid_height=120):
+    img = Image.open(image_path)
+    img_width, img_height = img.size
+    img.close()
+    if img_width == invalid_width and img_height == invalid_height:
+        return False
+    return True
+
+
+def process(input_file, image_dir, output_file, suffix='jpeg'):
+    row_list = []
+    df = pd.read_csv(input_file)
+    for index, row in tqdm(df.iterrows(), total=len(df)):
+        uid = row['uid']
+        nick_name = row['nick_name']
+        image_path = '%s/%s.%s' % (image_dir, uid, suffix)
+        gender = row['gender']
+        avatar_url = row['avatar_url']
+        if not os.path.exists(image_path):
+            continue
+        flag = '无效'
+        if validation_image(image_path):
+            flag = '有效'
+        new_row = {
+            'uid': uid,
+            'nick_name': nick_name,
+            'image': image_path,
+            'gender': gender,
+            'valid': flag,
+            'avatar_url': avatar_url}
+        row_list.append(new_row)
+    if row_list:
+        new_df = pd.DataFrame(row_list)
+        new_df.index += 1
+        new_df.to_csv(output_file, encoding='utf-8', index_label='index', quoting=csv.QUOTE_ALL)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--input_file', required=True, help='input file')
+    parser.add_argument('--image_dir', required=True, help='image dir')
+    parser.add_argument('--output_file', required=True, help='output file')
+    parser.add_argument('--suffix', default='jpeg', type=str, help='image suffix')
+    args = parser.parse_args()
+    print('\n\n')
+    print(args)
+    print('\n\n')
+
+    process(args.input_file, args.image_dir, args.output_file, args.suffix)

+ 83 - 0
src/preprocess/generate_qw2_5_lora_sft_data.py

@@ -0,0 +1,83 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+import argparse
+import json
+
+import pandas as pd
+from sklearn.model_selection import train_test_split
+from tqdm import tqdm
+
+image_tag = "<image>"
+
+prompt = "根据用户在社交平台的昵称和头像,判断用户的性别倾向,返回[男性|女性|未知];昵称:%s,头像:%s。"
+
+
+def get_user_message(content):
+    return {
+        "content": prompt % (content, image_tag),
+        "role": "user"}
+
+
+def get_assistant_message(content):
+    return {
+        "content": content,
+        "role": "assistant"}
+
+
+def get_conversation(row):
+    messages = [
+        get_user_message(row["nick_name"]),
+        get_assistant_message(row["gender"])
+    ]
+    return {
+        "messages": messages,
+        "images": row["image"].split(",")
+    }
+
+
+def split_train_test(df, rate):
+    train_df_ = pd.DataFrame()
+    test_df_ = pd.DataFrame()
+    if 0 < rate < 1:
+        train_df_, test_df_ = train_test_split(df, test_size=rate)
+    elif 0 == rate:
+        train_df_ = all_df
+    elif 1 == rate:
+        test_df_ = all_df
+    return train_df_, test_df_
+
+
+def process(df, output_file):
+    conversation_list = []
+    for index, row in tqdm(df.iterrows(), total=len(df)):
+        conversation_list.append(get_conversation(row))
+    if conversation_list:
+        with open(output_file, "w") as out_fp:
+            json.dump(conversation_list, out_fp, ensure_ascii=False, indent=4)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--input_file', required=True, help='input files')
+    parser.add_argument('--train_file', required=True, help='train file')
+    parser.add_argument('--test_file', required=True, help='test file')
+    parser.add_argument('--test_rate', default=0.5, type=float, help='test rate')
+    args = parser.parse_args()
+    print('\n\n')
+    print(args)
+    print('\n\n')
+
+    # load
+    all_df = pd.read_csv(args.input_file)
+
+    # split
+    train_df, test_df = split_train_test(all_df, args.test_rate)
+
+    # train
+    if len(train_df.index) > 0:
+        process(train_df, args.train_file)
+
+    # test
+    if len(test_df.index) > 0:
+        process(test_df, args.test_file)

+ 49 - 0
src/preprocess/merge_label_data.py

@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+import argparse
+import csv
+
+import pandas as pd
+from tqdm import tqdm
+
+
+def load_data(files):
+    df_list = []
+    for file in files.split(','):
+        df = pd.read_csv(file)
+        df["gender"] = df["gender"].fillna("未知")
+        df_list.append(df)
+    return pd.concat(df_list)
+
+
+def process(files, output_file):
+    df = load_data(files)
+    row_list = []
+    for index, row in tqdm(df.iterrows(), total=len(df)):
+        uid = row['uid']
+        nick_name = row['nick_name']
+        gender = row['gender']
+        avatar_url = row['avatar_url']
+        new_row = {
+            'uid': uid,
+            'nick_name': nick_name,
+            'gender': gender,
+            'avatar_url': avatar_url}
+        row_list.append(new_row)
+    if row_list:
+        new_df = pd.DataFrame(row_list)
+        new_df.index += 1
+        new_df.to_csv(output_file, encoding='utf-8', index_label='index', quoting=csv.QUOTE_ALL)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--files', required=True, help='input files')
+    parser.add_argument('--output_file', required=True, help='output file')
+    args = parser.parse_args()
+    print('\n\n')
+    print(args)
+    print('\n\n')
+
+    process(args.files, args.output)

+ 31 - 0
src/preprocess/odps_module.py

@@ -0,0 +1,31 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+from odps import ODPS
+
+
+class ODPSClient(object):
+    def __init__(self, project="loghubods"):
+        self.accessId = "LTAIWYUujJAm7CbH"
+        self.accessSecret = "RfSjdiWwED1sGFlsjXv0DlfTnZTG1P"
+        self.endpoint = "http://service.odps.aliyun.com/api"
+        self.tunnelUrl = "http://dt.cn-hangzhou.maxcompute.aliyun-inc.com"
+
+        self.odps = ODPS(
+            self.accessId,
+            self.accessSecret,
+            project,
+            self.endpoint
+        )
+
+    def execute_sql(self, sql: str):
+        hints = {
+            'odps.sql.submit.mode': 'script'
+        }
+        with self.odps.execute_sql(sql, hints=hints).open_reader(tunnel=True) as reader:
+            pd_df = reader.to_pandas()
+        return pd_df
+
+    def execute_sql_result_save_file(self, sql: str, output_file: str):
+        data_df = self.execute_sql(sql)
+        data_df.to_csv(output_file, index=False)