jch 1 周之前
父节点
当前提交
e32ec68c84
共有 2 个文件被更改,包括 21 次插入11 次删除
  1. 1 1
      README.md
  2. 20 10
      src/preprocess/qw_api_url_inference.py

+ 1 - 1
README.md

@@ -15,7 +15,7 @@
 ## 1.5 数据格式化
 - python src/preprocess/format_user_info.py --input_file data/user_info_label.csv --image_dir image --output_file data/user_info_format.csv
 
-## 1.6 拆分训练和测试数据
+## 1.6 拆分训练和测试
 - python src/preprocess/split_train_test.py --input_file data/user_info_format.csv --train_file data/user_info_format_train.csv --test_file data/user_info_format_test.csv
 
 ## 1.7 生成训练数据

+ 20 - 10
src/preprocess/qw_api_url_inference.py

@@ -7,14 +7,17 @@ import pandas as pd
 from openai import OpenAI
 from tqdm import tqdm
 
+model_name_or_path = 'Qwen/Qwen2.5-VL-7B-Instruct'
+
 system = "根据用户在社交平台的昵称和头像,判断用户的性别倾向;当头像不可用时,仅使用昵称进行判断;直接返回[男性|女性|未知],不要返回多余的信息。"
 prompt = "昵称:%s, 头像:"
 
-host = "117.50.199.192"
-port = 8000
-base_url = 'http://%s:%d/v1' % (host, port)
-client = OpenAI(api_key="0", base_url=base_url)
-model_name_or_path = 'Qwen/Qwen2.5-VL-7B-Instruct'
+
+class ModelService:
+    def __init__(self, host, port, model_name):
+        base_url = 'http://%s:%d/v1' % (host, port)
+        self.client = OpenAI(api_key="0", base_url=base_url)
+        self.model_name = model_name
 
 
 def get_system_message():
@@ -48,20 +51,20 @@ def get_messages(row):
     return messages
 
 
-def inference(messages):
+def inference(model_service, messages):
     try:
-        result = client.chat.completions.create(messages=messages, model=model_name_or_path)
+        result = model_service.client.chat.completions.create(messages=messages, model=model_service.model_name)
         return result.choices[0].message.content
     except Exception as e:
         print(f"获取结果时发生错误:{messages} {e}")
     return 'error'
 
 
-def process(df, output_file):
+def process(model_service, df, output_file):
     row_list = []
     for index, row in tqdm(df.iterrows(), total=len(df)):
         messages = get_messages(row)
-        result = inference(messages)
+        result = inference(model_service, messages)
         new_row = {
             'uid': row['uid'],
             'nick_name': row['nick_name'],
@@ -77,6 +80,8 @@ def process(df, output_file):
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
+    parser.add_argument('--host', default='0.0.0.0', type=str, help='service ip')
+    parser.add_argument('--port', default=8000, type=int, help='service port')
     parser.add_argument('--input_file', required=True, help='input files')
     parser.add_argument('--output_file', required=True, help='output file')
     args = parser.parse_args()
@@ -86,4 +91,9 @@ if __name__ == '__main__':
 
     # load
     all_df = pd.read_csv(args.input_file)
-    process(all_df, args.output_file)
+
+    # model
+    model = ModelService(args.host, args.port, model_name_or_path)
+
+    # process
+    process(model, all_df, args.output_file)