Selaa lähdekoodia

增加agent提交执行和查询任务功能

xueyiming 2 päivää sitten
vanhempi
commit
6c7cc2d8a1
2 muutettua tiedostoa jossa 36 lisäystä ja 9 poistoa
  1. 33 7
      pqai_agent_server/agent_task_server.py
  2. 3 2
      pqai_agent_server/api_server.py

+ 33 - 7
pqai_agent_server/agent_task_server.py

@@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
 from datetime import datetime
 from typing import Dict
 
-from sqlalchemy import func
+from sqlalchemy import func, select
 
 from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
 from pqai_agent.data_models.agent_configuration import AgentConfiguration
@@ -45,11 +45,35 @@ class AgentTaskManager:
                 AgentTaskStatus.IN_PROGRESS.value
             ])).all()
 
-    def get_agent_task_details(self, task_id):
+    def get_agent_task_details(self, task_id, parent_execution_id):
         """获取任务详情"""
         with self.session_maker() as session:
-            return session.query(AgentTaskDetail).filter(AgentTaskDetail.agent_task_id == task_id).filter(
-                AgentTaskDetail.parent_execution_id == None).all()
+            # 创建子查询:统计每个父节点的子节点数量
+            subquery = (
+                select(
+                    AgentTaskDetail.parent_execution_id.label('parent_id'),
+                    func.count('*').label('child_count')
+                )
+                .where(AgentTaskDetail.parent_execution_id.isnot(None))
+                .group_by(AgentTaskDetail.parent_execution_id)
+                .subquery()
+            )
+
+            # 主查询:关联子查询,判断是否有子节点
+            # 修正连接条件:使用parent_execution_id关联
+            query = (
+                select(
+                    AgentTaskDetail,
+                    (func.coalesce(subquery.c.child_count, 0) > 0).label('has_children')
+                )
+                .outerjoin(
+                    subquery,
+                    AgentTaskDetail.id == subquery.c.parent_id  # 使用当前记录的id匹配子查询的parent_id
+                )
+            ).where(AgentTaskDetail.agent_task_id == task_id).where(
+                AgentTaskDetail.parent_execution_id == parent_execution_id)
+            # 执行查询
+            return session.execute(query).all()
 
     def save_agent_task_details_batch(self, agent_task_details: list, agent_task_id: int, message: str):
         """批量保存子任务到数据库"""
@@ -180,12 +204,13 @@ class AgentTaskManager:
         with self.task_locks[task_id]:
             self.running_tasks.add(task_id)
 
-    def get_agent_task_detail(self, agent_task_id):
+    def get_agent_task_detail(self, agent_task_id, parent_execution_id):
         agent_task = self.get_agent_task(agent_task_id)
-        agent_task_details = self.get_agent_task_details(agent_task_id)
+        agent_task_details = self.get_agent_task_details(agent_task_id, parent_execution_id)
         agent_task_detail_datas = [
             {
                 "id": agent_task_detail.id,
+                "agentTaskId": agent_task_detail.agent_task_id,
                 "executorType": agent_task_detail.executor_type,
                 "statusName": get_agent_task_detail_status_desc(agent_task_detail.status),
                 "inputData": agent_task_detail.input_data,
@@ -193,10 +218,11 @@ class AgentTaskManager:
                 "reasoning": agent_task_detail.reasoning,
                 "outputData": agent_task_detail.output_data,
                 "errorMessage": agent_task_detail.error_message,
+                "hasChildren": has_children,
                 "createTime": agent_task_detail.create_time.strftime("%Y-%m-%d %H:%M:%S"),
                 "updateTime": agent_task_detail.update_time.strftime("%Y-%m-%d %H:%M:%S")
             }
-            for agent_task_detail in agent_task_details
+            for agent_task_detail, has_children in agent_task_details
         ]
         return {
             "input": agent_task.input,

+ 3 - 2
pqai_agent_server/api_server.py

@@ -797,6 +797,7 @@ def get_agent_task_list():
     response = app.agent_task_manager.get_agent_task_list(page_num, page_size)
     return wrap_response(200, data=response)
 
+
 @app.route("/api/getAgentTaskDetail", methods=["GET"])
 def get_agent_task_detail():
     """
@@ -806,10 +807,10 @@ def get_agent_task_detail():
     agent_task_id = request.args.get("agentTaskId", None)
     if not agent_task_id:
         return wrap_response(404, msg='agent_task_id is required')
-    response = app.agent_task_manager.get_agent_task_detail(int(agent_task_id))
+    parent_execution_id = request.args.get("parentExecutionId", None)
+    response = app.agent_task_manager.get_agent_task_detail(int(agent_task_id), parent_execution_id)
     return wrap_response(200, data=response)
 
-
 @app.errorhandler(werkzeug.exceptions.BadRequest)
 def handle_bad_request(e):
     logger.error(e)