|
@@ -7,14 +7,17 @@ import pandas as pd
|
|
|
from openai import OpenAI
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
+model_name_or_path = 'Qwen/Qwen2.5-VL-7B-Instruct'
|
|
|
+
|
|
|
system = "根据用户在社交平台的昵称和头像,判断用户的性别倾向;当头像不可用时,仅使用昵称进行判断;直接返回[男性|女性|未知],不要返回多余的信息。"
|
|
|
prompt = "昵称:%s, 头像:"
|
|
|
|
|
|
-host = "117.50.199.192"
|
|
|
-port = 8000
|
|
|
-base_url = 'http://%s:%d/v1' % (host, port)
|
|
|
-client = OpenAI(api_key="0", base_url=base_url)
|
|
|
-model_name_or_path = 'Qwen/Qwen2.5-VL-7B-Instruct'
|
|
|
+
|
|
|
+class ModelService:
|
|
|
+ def __init__(self, host, port, model_name):
|
|
|
+ base_url = 'http://%s:%d/v1' % (host, port)
|
|
|
+ self.client = OpenAI(api_key="0", base_url=base_url)
|
|
|
+ self.model_name = model_name
|
|
|
|
|
|
|
|
|
def get_system_message():
|
|
@@ -48,20 +51,20 @@ def get_messages(row):
|
|
|
return messages
|
|
|
|
|
|
|
|
|
-def inference(messages):
|
|
|
+def inference(model_service, messages):
|
|
|
try:
|
|
|
- result = client.chat.completions.create(messages=messages, model=model_name_or_path)
|
|
|
+ result = model_service.client.chat.completions.create(messages=messages, model=model_service.model_name)
|
|
|
return result.choices[0].message.content
|
|
|
except Exception as e:
|
|
|
print(f"获取结果时发生错误:{messages} {e}")
|
|
|
return 'error'
|
|
|
|
|
|
|
|
|
-def process(df, output_file):
|
|
|
+def process(model_service, df, output_file):
|
|
|
row_list = []
|
|
|
for index, row in tqdm(df.iterrows(), total=len(df)):
|
|
|
messages = get_messages(row)
|
|
|
- result = inference(messages)
|
|
|
+ result = inference(model_service, messages)
|
|
|
new_row = {
|
|
|
'uid': row['uid'],
|
|
|
'nick_name': row['nick_name'],
|
|
@@ -77,6 +80,8 @@ def process(df, output_file):
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser()
|
|
|
+ parser.add_argument('--host', default='0.0.0.0', type=str, help='service ip')
|
|
|
+ parser.add_argument('--port', default=8000, type=int, help='service port')
|
|
|
parser.add_argument('--input_file', required=True, help='input files')
|
|
|
parser.add_argument('--output_file', required=True, help='output file')
|
|
|
args = parser.parse_args()
|
|
@@ -86,4 +91,9 @@ if __name__ == '__main__':
|
|
|
|
|
|
# load
|
|
|
all_df = pd.read_csv(args.input_file)
|
|
|
- process(all_df, args.output_file)
|
|
|
+
|
|
|
+ # model
|
|
|
+ model = ModelService(args.host, args.port, model_name_or_path)
|
|
|
+
|
|
|
+ # process
|
|
|
+ process(model, all_df, args.output_file)
|