Przeglądaj źródła

增加agent调用 生成对话内容

xueyiming 3 dni temu
rodzic
commit
fda54a40bb

+ 3 - 3
pqai_agent_server/api_server.py

@@ -624,7 +624,7 @@ def resume_test_task():
 @app.route("/api/getDatasetList", methods=["GET"])
 def get_dataset_list():
     """
-       获取单元测试任务列表
+       获取数据集列表
        :return:
     """
     page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
@@ -641,7 +641,7 @@ def get_dataset_list():
 @app.route("/api/getConversationDataList", methods=["GET"])
 def get_conversation_data_list():
     """
-       获取单元测试任务列表
+       获取对话列表
        :return:
     """
     dataset_id = request.args.get("datasetId", None)
@@ -697,7 +697,7 @@ if __name__ == '__main__':
         chat_history_table=chat_history_db_config['table']
     )
     app.session_manager = session_manager
-    agent_db_engine = create_ai_agent_db_engine(config['database']['ai_agent'])
+    agent_db_engine = create_ai_agent_db_engine()
     app.session_maker = sessionmaker(bind=agent_db_engine)
 
     dataset_service = DatasetService(session_maker=sessionmaker(bind=agent_db_engine))

+ 16 - 2
pqai_agent_server/dataset_service.py

@@ -34,7 +34,7 @@ class DatasetService:
             return result_df.iloc[0].to_dict()  # 获取第一行
         return None
 
-    def get_dataset_list_by_module(self, module_id: int):
+    def get_dataset_module_list_by_module(self, module_id: int):
         with self.session_maker() as session:
             return session.query(DatasetModule).filter(DatasetModule.module_id == module_id).filter(
                 DatasetModule.is_delete == 0).all()
@@ -59,7 +59,8 @@ class DatasetService:
 
     def get_conversation_list_by_ids(self, conversation_ids: List[int]):
         with self.session_maker() as session:
-            conversations = session.query(QywxChatHistory).filter(QywxChatHistory.id.in_(conversation_ids)).all()
+            conversations = session.query(QywxChatHistory).filter(QywxChatHistory.id.in_(conversation_ids)).order_by(
+                QywxChatHistory.id.asc()).all()
             result = []
             for conversation in conversations:
                 data = {}
@@ -73,6 +74,19 @@ class DatasetService:
                 result.append(data)
         return result
 
+    def get_chat_conversation_list_by_ids(self, conversation_ids: List[int], staff_id):
+        result = self.get_conversation_list_by_ids(conversation_ids)
+        conversations = [
+            {
+                "content": conversation['content'],
+                "role": "assistant" if conversation['sender'] == staff_id else "user",
+                "timestamp": conversation['sendtime']
+            } for conversation in result
+        ]
+        return conversations
+
+
+
     def get_dataset_list(self, page_num: int, page_size: int):
         with self.session_maker() as session:
             # 计算偏移量

+ 65 - 15
pqai_agent_server/task_server.py

@@ -8,9 +8,12 @@ from typing import Dict
 from sqlalchemy import func
 
 from pqai_agent import logging_service
+from pqai_agent.agents.message_push_agent import MessagePushAgent
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.data_models.agent_configuration import AgentConfiguration
 from pqai_agent.data_models.agent_test_task import AgentTestTask
 from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
+from pqai_agent.data_models.service_module import ServiceModule
 from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc
 
 logger = logging_service.logger
@@ -80,7 +83,7 @@ class TaskManager:
                 {
                     "id": agent_test_task_conversation.id,
                     "agentName": agent_configuration.name,
-                    "input": agent_test_task_conversation.input,
+                    "input":MultiModalChatAgent.compose_dialogue(json.loads(agent_test_task_conversation.input)),
                     "output": agent_test_task_conversation.output,
                     "score": agent_test_task_conversation.score,
                     "statusName": get_test_task_status_desc(agent_test_task_conversation.status),
@@ -114,22 +117,22 @@ class TaskManager:
         """异步生成子任务"""
         try:
             # 获取数据集列表
-            datasets_list = self.dataset_service.get_dataset_list_by_module(module_id)
+            dataset_module_list = self.dataset_service.get_dataset_module_list_by_module(module_id)
 
             # 批量处理数据集 - 减少数据库交互
             batch_size = 100  # 每批处理100个子任务
             agent_test_task_conversation_batch = []
 
-            for dataset in datasets_list:
+            for dataset_module in dataset_module_list:
                 # 获取对话数据列表
-                conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(dataset.id)
+                conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(dataset_module.dataset_id)
 
                 for conversation_data in conversation_datas:
                     # 创建子任务对象
                     agent_test_task_conversation = AgentTestTaskConversations(
                         task_id=task_id,
                         agent_id=agent_id,
-                        dataset_id=dataset.id,
+                        dataset_id=dataset_module.dataset_id,
                         conversation_id=conversation_data.id,
                         status=TestTaskConversationsStatus.PENDING.value
                     )
@@ -164,6 +167,22 @@ class TaskManager:
         except Exception as e:
             logger.error(e)
 
+    def get_agent_configuration_by_task_id(self, task_id: int):
+        """获取指定任务ID对应的Agent配置信息"""
+        with self.session_maker() as session:
+            return session.query(AgentConfiguration) \
+                .join(AgentTestTask, AgentTestTask.agent_id == AgentConfiguration.id) \
+                .filter(AgentTestTask.id == task_id) \
+                .one_or_none()  # 返回单个对象或None(如果未找到)
+
+    def get_service_module_by_task_id(self, task_id: int):
+        """获取指定任务ID对应的Agent配置信息"""
+        with self.session_maker() as session:
+            return session.query(ServiceModule) \
+                .join(AgentTestTask, AgentTestTask.module_id == ServiceModule.id) \
+                .filter(AgentTestTask.id == task_id) \
+                .one_or_none()  # 返回单个对象或None(如果未找到)
+
     def get_task(self, task_id: int):
         """获取任务信息"""
         with self.session_maker() as session:
@@ -215,12 +234,13 @@ class TaskManager:
                 {"status": status, "update_time": datetime.now()})
             session.commit()
 
-    def update_task_conversations_res(self, task_conversations_id: int, status: int, score: str):
+    def update_task_conversations_res(self, task_conversations_id: int, status: int, input: str, output: str,
+                                      score: str):
         """更新子任务结果"""
         with self.session_maker() as session:
             session.query(AgentTestTaskConversations).filter(
                 AgentTestTaskConversations.id == task_conversations_id).update(
-                {"status": status, "score": score, "update_time": datetime.now()})
+                {"status": status, "input": input, "output": output, "score": score, "update_time": datetime.now()})
             session.commit()
 
     def cancel_task(self, task_id: int):
@@ -312,6 +332,11 @@ class TaskManager:
             # 获取所有待处理的子任务
             task_conversations = self.get_pending_task_conversations(task_id)
 
+            agent_configuration = self.get_agent_configuration_by_task_id(task_id)
+            query_prompt_template = agent_configuration.task_prompt
+            agent = MultiModalChatAgent(model=agent_configuration.execution_model,
+                                        system_prompt=agent_configuration.system_prompt,
+                                        tools=json.loads(agent_configuration.tools))
             # 执行每个子任务
             for task_conversation in task_conversations:
                 # 检查任务是否被取消
@@ -327,19 +352,44 @@ class TaskManager:
                         task_conversation.conversation_id)
                     user_profile_data = self.dataset_service.get_user_profile_data(conversation_data.user_id,
                                                                                    conversation_data.version_date.replace(
-                                                                                       "-", ""))['profile_data_v1']
+                                                                                       "-", ""))
+                    user_profile = json.loads(user_profile_data['profile_data_v1'])
+                    avatar = user_profile_data['iconurl']
                     staff_profile_data = self.dataset_service.get_staff_profile_data(
                         conversation_data.staff_id).agent_profile
-                    conversations = self.dataset_service.get_conversation_list_by_ids(
-                        json.loads(conversation_data.conversation))
-
-                    # 模拟任务执行 - 在实际应用中替换为实际业务逻辑
-                    # TODO 后续改成实际任务执行
-                    time.sleep(1)
+                    conversations = self.dataset_service.get_chat_conversation_list_by_ids(
+                        json.loads(conversation_data.conversation), conversation_data.staff_id)
+                    conversations = sorted(conversations, key=lambda i: i['timestamp'], reverse=False)
+
+                    last_timestamp = int(conversations[-1]["timestamp"])
+                    push_time = int(last_timestamp / 1000) + 24 * 3600
+                    push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
+                    push_message = agent._generate_message(
+                        context={
+                            "formatted_staff_profile": staff_profile_data,
+                            "nickname": user_profile['nickname'],
+                            "name": user_profile['name'],
+                            "avatar": avatar,
+                            "preferred_nickname": user_profile['preferred_nickname'],
+                            "gender": user_profile['gender'],
+                            "age": user_profile['age'],
+                            "region": user_profile['region'],
+                            "health_conditions": user_profile['health_conditions'],
+                            "medications": user_profile['medications'],
+                            "interests": user_profile['interests'],
+                            "current_datetime": push_dt
+                        },
+                        dialogue_history=conversations,
+                        query_prompt_template=query_prompt_template
+                    )
+                    # TODO 获取打分
                     score = '{"score":0.05}'
                     # 更新子任务状态为已完成
                     self.update_task_conversations_res(task_conversation.id,
-                                                       TestTaskConversationsStatus.SUCCESS.value, score)
+                                                       TestTaskConversationsStatus.SUCCESS.value,
+                                                       json.dumps(conversations, ensure_ascii=False),
+                                                       json.dumps(push_message, ensure_ascii=False),
+                                                       score)
                 except Exception as e:
                     logger.error(f"Error executing task {task_id}: {str(e)}")
                     self.update_task_conversations_status(task_conversation.id,