Browse Source

同步prompt

jch 1 month ago
parent
commit
f482000757
2 changed files with 121 additions and 3 deletions
  1. 117 0
      src/preprocess/qw_api_local_inference.py
  2. 4 3
      src/preprocess/qw_api_url_inference.py

+ 117 - 0
src/preprocess/qw_api_local_inference.py

@@ -0,0 +1,117 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+import argparse
+import base64
+
+import pandas as pd
+from openai import OpenAI
+from tqdm import tqdm
+
+model_name_or_path = 'Qwen/Qwen2.5-VL-7B-Instruct'
+
+system = """
+根据中老年用户在社交平台上的姓名/昵称和头像,分析其性别倾向。
+先基于姓名/昵称分析,不能判断时,再结合头像分析;
+头像中的儿童/青少年,不能作为判断依据;
+当头像不可用时,仅基于姓名/昵称分析。
+直接返回[男性|女性|未知],不要返回多余的信息。
+"""
+system = system.replace("\n", "")
+prompt = "姓名/昵称:%s, 头像:"
+
+
+class ModelService:
+    def __init__(self, host, port, model_name):
+        base_url = 'http://%s:%d/v1' % (host, port)
+        self.client = OpenAI(api_key="0", base_url=base_url)
+        self.model_name = model_name
+
+
+def get_image_base64(image_path):
+    with open(image_path, "rb") as f:
+        encoded_image = base64.b64encode(f.read())
+    encoded_image_text = encoded_image.decode("utf-8")
+    return f"data:image/jpeg;base64,{encoded_image_text}"
+
+
+def get_system_message():
+    return {
+        "role": "system",
+        "content": system}
+
+
+def get_user_message(nick_name, image_path):
+    return {
+        "role": "user",
+        "content": [
+            {"type": "text", "text": prompt % nick_name},
+            {"type": "image_url", "image_url": {"url": get_image_base64(image_path)}}
+        ]
+    }
+
+
+def get_assistant_message(content):
+    return {
+        "role": "assistant",
+        "content": content}
+
+
+def get_messages(row):
+    messages = [
+        get_system_message(),
+        get_user_message(row["nick_name"], row["image"])
+        # get_assistant_message(row["gender"])
+    ]
+    return messages
+
+
+def inference(model_service, messages):
+    try:
+        top_p = 0.5  # 1.0
+        extra_body = {"do_sample": False}  # True
+        result = model_service.client.chat.completions.create(messages=messages, model=model_service.model_name,
+                                                              extra_body=extra_body)
+        return result.choices[0].message.content
+    except Exception as e:
+        print(f"获取结果时发生错误:{messages} {e}")
+    return 'error'
+
+
+def process(model_service, df, output_file):
+    row_list = []
+    for index, row in tqdm(df.iterrows(), total=len(df)):
+        messages = get_messages(row)
+        result = inference(model_service, messages)
+        new_row = {
+            'uid': row['uid'],
+            'nick_name': row['nick_name'],
+            'gender': row['gender'],
+            'predict': result,
+            'valid': row['valid'],
+            'avatar_url': row['avatar_url']}
+        row_list.append(new_row)
+    if row_list:
+        new_df = pd.DataFrame(row_list)
+        new_df.to_csv(output_file, encoding='utf-8', index=False)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--host', default='0.0.0.0', type=str, help='service ip')
+    parser.add_argument('--port', default=8000, type=int, help='service port')
+    parser.add_argument('--input_file', required=True, help='input files')
+    parser.add_argument('--output_file', required=True, help='output file')
+    args = parser.parse_args()
+    print('\n\n')
+    print(args)
+    print('\n\n')
+
+    # load
+    all_df = pd.read_csv(args.input_file)
+
+    # model
+    model = ModelService(args.host, args.port, model_name_or_path)
+
+    # process
+    process(model, all_df, args.output_file)

+ 4 - 3
src/preprocess/qw_api_url_inference.py

@@ -10,9 +10,10 @@ from tqdm import tqdm
 model_name_or_path = 'Qwen/Qwen2.5-VL-7B-Instruct'
 
 system = """
-根据用户在社交平台上的姓名/昵称和头像,分析用户的性别倾向;
-先基于姓名/昵称分析,再基于头像分析;
-当头像不可用时,仅基于姓名/昵称分析;
+根据中老年用户在社交平台上的姓名/昵称和头像,分析其性别倾向。
+先基于姓名/昵称分析,不能判断时,再结合头像分析;
+头像中的儿童/青少年,不能作为判断依据;
+当头像不可用时,仅基于姓名/昵称分析。
 直接返回[男性|女性|未知],不要返回多余的信息。
 """
 system = system.replace("\n", "")