|
|
@@ -0,0 +1,117 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# coding=utf-8
|
|
|
+
|
|
|
+import argparse
|
|
|
+import base64
|
|
|
+
|
|
|
+import pandas as pd
|
|
|
+from openai import OpenAI
|
|
|
+from tqdm import tqdm
|
|
|
+
|
|
|
+model_name_or_path = 'Qwen/Qwen2.5-VL-7B-Instruct'
|
|
|
+
|
|
|
+system = """
|
|
|
+根据中老年用户在社交平台上的姓名/昵称和头像,分析其性别倾向。
|
|
|
+先基于姓名/昵称分析,不能判断时,再结合头像分析;
|
|
|
+头像中的儿童/青少年,不能作为判断依据;
|
|
|
+当头像不可用时,仅基于姓名/昵称分析。
|
|
|
+直接返回[男性|女性|未知],不要返回多余的信息。
|
|
|
+"""
|
|
|
+system = system.replace("\n", "")
|
|
|
+prompt = "姓名/昵称:%s, 头像:"
|
|
|
+
|
|
|
+
|
|
|
+class ModelService:
|
|
|
+ def __init__(self, host, port, model_name):
|
|
|
+ base_url = 'http://%s:%d/v1' % (host, port)
|
|
|
+ self.client = OpenAI(api_key="0", base_url=base_url)
|
|
|
+ self.model_name = model_name
|
|
|
+
|
|
|
+
|
|
|
+def get_image_base64(image_path):
|
|
|
+ with open(image_path, "rb") as f:
|
|
|
+ encoded_image = base64.b64encode(f.read())
|
|
|
+ encoded_image_text = encoded_image.decode("utf-8")
|
|
|
+ return f"data:image/jpeg;base64,{encoded_image_text}"
|
|
|
+
|
|
|
+
|
|
|
+def get_system_message():
|
|
|
+ return {
|
|
|
+ "role": "system",
|
|
|
+ "content": system}
|
|
|
+
|
|
|
+
|
|
|
+def get_user_message(nick_name, image_path):
|
|
|
+ return {
|
|
|
+ "role": "user",
|
|
|
+ "content": [
|
|
|
+ {"type": "text", "text": prompt % nick_name},
|
|
|
+ {"type": "image_url", "image_url": {"url": get_image_base64(image_path)}}
|
|
|
+ ]
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def get_assistant_message(content):
|
|
|
+ return {
|
|
|
+ "role": "assistant",
|
|
|
+ "content": content}
|
|
|
+
|
|
|
+
|
|
|
+def get_messages(row):
|
|
|
+ messages = [
|
|
|
+ get_system_message(),
|
|
|
+ get_user_message(row["nick_name"], row["image"])
|
|
|
+ # get_assistant_message(row["gender"])
|
|
|
+ ]
|
|
|
+ return messages
|
|
|
+
|
|
|
+
|
|
|
+def inference(model_service, messages):
|
|
|
+ try:
|
|
|
+ top_p = 0.5 # 1.0
|
|
|
+ extra_body = {"do_sample": False} # True
|
|
|
+ result = model_service.client.chat.completions.create(messages=messages, model=model_service.model_name,
|
|
|
+ extra_body=extra_body)
|
|
|
+ return result.choices[0].message.content
|
|
|
+ except Exception as e:
|
|
|
+ print(f"获取结果时发生错误:{messages} {e}")
|
|
|
+ return 'error'
|
|
|
+
|
|
|
+
|
|
|
+def process(model_service, df, output_file):
|
|
|
+ row_list = []
|
|
|
+ for index, row in tqdm(df.iterrows(), total=len(df)):
|
|
|
+ messages = get_messages(row)
|
|
|
+ result = inference(model_service, messages)
|
|
|
+ new_row = {
|
|
|
+ 'uid': row['uid'],
|
|
|
+ 'nick_name': row['nick_name'],
|
|
|
+ 'gender': row['gender'],
|
|
|
+ 'predict': result,
|
|
|
+ 'valid': row['valid'],
|
|
|
+ 'avatar_url': row['avatar_url']}
|
|
|
+ row_list.append(new_row)
|
|
|
+ if row_list:
|
|
|
+ new_df = pd.DataFrame(row_list)
|
|
|
+ new_df.to_csv(output_file, encoding='utf-8', index=False)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+ parser.add_argument('--host', default='0.0.0.0', type=str, help='service ip')
|
|
|
+ parser.add_argument('--port', default=8000, type=int, help='service port')
|
|
|
+ parser.add_argument('--input_file', required=True, help='input files')
|
|
|
+ parser.add_argument('--output_file', required=True, help='output file')
|
|
|
+ args = parser.parse_args()
|
|
|
+ print('\n\n')
|
|
|
+ print(args)
|
|
|
+ print('\n\n')
|
|
|
+
|
|
|
+ # load
|
|
|
+ all_df = pd.read_csv(args.input_file)
|
|
|
+
|
|
|
+ # model
|
|
|
+ model = ModelService(args.host, args.port, model_name_or_path)
|
|
|
+
|
|
|
+ # process
|
|
|
+ process(model, all_df, args.output_file)
|