Forráskód Böngészése

调用score计算

xueyiming 2 napja
szülő
commit
c9788ba903

+ 1 - 0
pqai_agent/data_models/agent_test_task.py

@@ -13,6 +13,7 @@ class AgentTestTask(Base):
     create_user = Column(String(32), nullable=True, comment="创建用户")
     update_user = Column(String(32), nullable=True, comment="更新用户")
     dataset_ids = Column(Text, nullable=True, comment="数据集ids")
+    evaluate_type = Column(Integer, nullable=False, default=0, comment="数据集ids")
     status = Column(Integer, nullable=True, comment="状态(0:未开始, 1:进行中, 2:已完成, 3:已取消)")
     create_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", comment="创建时间")
     update_time = Column(TIMESTAMP, nullable=False, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",

+ 5 - 1
pqai_agent_server/api_server.py

@@ -7,6 +7,7 @@ import werkzeug.exceptions
 from flask import Flask, request, jsonify
 from argparse import ArgumentParser
 
+from pandas.core.computation.expressions import evaluate
 from pyarrow.dataset import dataset
 from sqlalchemy.orm import sessionmaker
 
@@ -579,11 +580,14 @@ def create_test_task():
     req_data = request.json
     agent_id = req_data.get('agentId', None)
     module_id = req_data.get('moduleId', None)
+    evaluate_type = req_data.get('evaluateType', None)
     if not agent_id:
         return wrap_response(404, msg='agent id is required')
     if not module_id:
         return wrap_response(404, msg='module id is required')
-    app.task_manager.create_task(agent_id, module_id)
+    if not evaluate_type:
+        return wrap_response(404, msg='evaluate_type id is required')
+    app.task_manager.create_task(agent_id, module_id, evaluate_type)
     return wrap_response(200)
 
 

+ 39 - 16
pqai_agent_server/task_server.py

@@ -2,6 +2,7 @@ import json
 import threading
 import concurrent.futures
 import time
+import traceback
 from concurrent.futures import ThreadPoolExecutor
 from datetime import datetime
 from typing import Dict
@@ -15,8 +16,12 @@ from pqai_agent.data_models.agent_configuration import AgentConfiguration
 from pqai_agent.data_models.agent_test_task import AgentTestTask
 from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
 from pqai_agent.data_models.service_module import ServiceModule
+from pqai_agent.utils.prompt_utils import format_agent_profile
 from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc
 from concurrent.futures import ThreadPoolExecutor
+
+from scripts.evaluate_agent import evaluate_agent
+
 logger = logging_service.logger
 
 
@@ -84,7 +89,7 @@ class TaskManager:
                 {
                     "id": agent_test_task_conversation.id,
                     "agentName": agent_configuration.name,
-                    "input":MultiModalChatAgent.compose_dialogue(json.loads(agent_test_task_conversation.input)),
+                    "input": MultiModalChatAgent.compose_dialogue(json.loads(agent_test_task_conversation.input)),
                     "output": agent_test_task_conversation.output,
                     "score": agent_test_task_conversation.score,
                     "statusName": get_test_task_status_desc(agent_test_task_conversation.status),
@@ -101,10 +106,10 @@ class TaskManager:
                 "list": response_data,
             }
 
-    def create_task(self, agent_id: int, module_id: int) -> Dict:
+    def create_task(self, agent_id: int, module_id: int, evaluate_type: int) -> Dict:
         """创建新任务"""
         with self.session_maker() as session:
-            agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id,
+            agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id, evaluate_type=evaluate_type,
                                             status=TestTaskStatus.CREATING.value)
             session.add(agent_test_task)
             session.commit()  # 显式提交
@@ -126,7 +131,8 @@ class TaskManager:
 
             for dataset_module in dataset_module_list:
                 # 获取对话数据列表
-                conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(dataset_module.dataset_id)
+                conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(
+                    dataset_module.dataset_id)
 
                 for conversation_data in conversation_datas:
                     # 创建子任务对象
@@ -337,8 +343,10 @@ class TaskManager:
             agent_configuration = self.get_agent_configuration_by_task_id(task_id)
             query_prompt_template = agent_configuration.task_prompt
 
+            task = self.get_task(task_id)
+
             # 使用线程池执行子任务
-            with ThreadPoolExecutor(max_workers=20) as executor:  # 可根据需要调整并发数
+            with ThreadPoolExecutor(max_workers=8) as executor:  # 可根据需要调整并发数
                 futures = {}
                 for task_conversation in task_conversations:
                     if self.task_events[task_id].is_set():
@@ -348,6 +356,7 @@ class TaskManager:
                     future = executor.submit(
                         self._process_single_conversation,
                         task_id,
+                        task,
                         task_conversation,
                         query_prompt_template,
                         agent_configuration
@@ -375,7 +384,8 @@ class TaskManager:
         finally:
             self._cleanup_task_resources(task_id)
 
-    def _process_single_conversation(self, task_id, task_conversation, query_prompt_template, agent_configuration):
+    def _process_single_conversation(self, task_id, task, task_conversation, query_prompt_template,
+                                     agent_configuration):
         """处理单个对话子任务(线程安全)"""
         # 检查任务是否被取消
         if self.task_events[task_id].is_set():
@@ -404,7 +414,7 @@ class TaskManager:
                 conversation_data.version_date.replace("-", ""))
             user_profile = json.loads(user_profile_data['profile_data_v1'])
             avatar = user_profile_data['iconurl']
-            staff_profile_data = self.dataset_service.get_staff_profile_data(
+            staff_profile = self.dataset_service.get_staff_profile_data(
                 conversation_data.staff_id).agent_profile
             conversations = self.dataset_service.get_chat_conversation_list_by_ids(
                 json.loads(conversation_data.conversation),
@@ -414,11 +424,17 @@ class TaskManager:
 
             # 生成推送消息
             last_timestamp = int(conversations[-1]["timestamp"])
-            push_time = int(last_timestamp / 1000) + 24 * 3600
-            push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
-            push_message = agent._generate_message(
+            match task.evaluate_type:
+                case 0:
+                    send_timestamp = int(last_timestamp / 1000) + 10
+                case 1:
+                    send_timestamp = int(last_timestamp / 1000) + 24 * 3600
+                case _:
+                    raise ValueError("evaluate_type must be 0 or 1")
+            send_time = datetime.fromtimestamp(send_timestamp).strftime('%Y-%m-%d %H:%M:%S')
+            message = agent._generate_message(
                 context={
-                    "formatted_staff_profile": staff_profile_data,
+                    "formatted_staff_profile": staff_profile,
                     "nickname": user_profile['nickname'],
                     "name": user_profile['name'],
                     "avatar": avatar,
@@ -429,26 +445,33 @@ class TaskManager:
                     "health_conditions": user_profile['health_conditions'],
                     "medications": user_profile['medications'],
                     "interests": user_profile['interests'],
-                    "current_datetime": push_dt
+                    "current_datetime": send_time
                 },
                 dialogue_history=conversations,
                 query_prompt_template=query_prompt_template
             )
 
-            # 获取打分(TODO: 实际实现)
-            score = '{"score":0.05}'
+            param = {}
+            param["dialogue_history"] = conversations
+            param["message"] = message
+            param["send_time"] = send_time
+            print(staff_profile)
+            param["agent_profile"] = json.loads(staff_profile)
+            param["user_profile"] = user_profile
+            score = evaluate_agent(param, task.evaluate_type)
 
             # 更新子任务结果
             self.update_task_conversations_res(
                 task_conversation.id,
                 TestTaskConversationsStatus.SUCCESS.value,
                 json.dumps(conversations, ensure_ascii=False),
-                json.dumps(push_message, ensure_ascii=False),
-                score
+                json.dumps(message, ensure_ascii=False),
+                json.dumps(score)
             )
 
         except Exception as e:
             logger.error(f"Subtask {task_conversation.id} failed: {str(e)}")
+            print(traceback.format_exc())
             self.update_task_conversations_status(
                 task_conversation.id,
                 TestTaskConversationsStatus.FAILED.value