import json import threading from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Dict from sqlalchemy import func, select from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent from pqai_agent.data_models.agent_configuration import AgentConfiguration from pqai_agent.data_models.agent_task import AgentTask from pqai_agent.data_models.agent_task_detail import AgentTaskDetail from pqai_agent.logging import logger from pqai_agent.toolkit import get_tools from pqai_agent_server.const.status_enum import AgentTaskStatus, get_agent_task_detail_status_desc, \ AgentTaskDetailStatus, get_agent_task_status_desc class AgentTaskManager: """任务管理器""" def __init__(self, session_maker): self.session_maker = session_maker self.task_events = {} # 任务ID -> Event (用于取消任务) self.task_locks = {} # 任务ID -> Lock (用于任务状态同步) self.running_tasks = set() self.executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='AgentTaskWorker') self.task_futures = {} # 任务ID -> Future def get_agent_task(self, agent_task_id: int): """获取任务信息""" with self.session_maker() as session: return session.query(AgentTask).filter(AgentTask.id == agent_task_id).one() def get_agent_config(self, agent_id: int): """获取任务信息""" with self.session_maker() as session: return session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).one() def get_in_progress_task(self): """获取执行中任务""" with self.session_maker() as session: return session.query(AgentTask).filter(AgentTask.status.in_([ AgentTaskStatus.NOT_STARTED.value, AgentTaskStatus.IN_PROGRESS.value ])).all() def get_agent_task_details(self, task_id, parent_execution_id): """获取任务详情""" with self.session_maker() as session: # 创建子查询:统计每个父节点的子节点数量 subquery = ( select( AgentTaskDetail.parent_execution_id.label('parent_id'), func.count('*').label('child_count') ) .where(AgentTaskDetail.parent_execution_id.isnot(None)) .group_by(AgentTaskDetail.parent_execution_id) .subquery() ) # 主查询:关联子查询,判断是否有子节点 # 修正连接条件:使用parent_execution_id关联 query = ( select( AgentTaskDetail, (func.coalesce(subquery.c.child_count, 0) > 0).label('has_children') ) .outerjoin( subquery, AgentTaskDetail.id == subquery.c.parent_id # 使用当前记录的id匹配子查询的parent_id ) ).where(AgentTaskDetail.agent_task_id == task_id).where( AgentTaskDetail.parent_execution_id == parent_execution_id) # 执行查询 return session.execute(query).all() def save_agent_task_details_batch(self, agent_task_details: list, agent_task_id: int, message: str): """批量保存子任务到数据库""" try: with self.session_maker() as session: with session.begin(): session.add_all(agent_task_details) session.query(AgentTask).filter( AgentTask.id == agent_task_id).update( {"status": AgentTaskStatus.COMPLETED.value, "output": message, "update_time": datetime.now()}) session.commit() except Exception as e: logger.error(e) raise Exception(e) def update_task_failed(self, task_id, error_message: str): """更新任务失败""" with self.session_maker() as session: session.query(AgentTask).filter(AgentTask.id == task_id).update( {"status": AgentTaskStatus.FAILED.value, "error_message": error_message, "update_time": datetime.now()}) session.commit() def update_task_status(self, task_id, status): """更新任务状态""" with self.session_maker() as session: session.query(AgentTask).filter(AgentTask.id == task_id).update( {"status": status, "update_time": datetime.now()}) session.commit() def get_agent_task_list(self, page_num: int, page_size: int) -> Dict: with self.session_maker() as session: # 计算偏移量 offset = (page_num - 1) * page_size # 查询分页数据 result = (session.query(AgentTask, AgentConfiguration) .outerjoin(AgentConfiguration, AgentTask.agent_id == AgentConfiguration.id) .limit(page_size).offset(offset).all()) # 查询总记录数 total = session.query(func.count(AgentTask.id)).scalar() total_page = total // page_size + 1 if total % page_size > 0 else total // page_size total_page = 1 if total_page <= 0 else total_page response_data = [ { "id": agent_task.id, "agentName": agent_configuration.name, "statusName": get_agent_task_status_desc(agent_task.status), "startTime": agent_task.start_time.strftime("%Y-%m-%d %H:%M:%S"), "endTime": agent_task.start_time.strftime( "%Y-%m-%d %H:%M:%S") if agent_task.start_time is not None else None, "createUser": agent_task.create_user, "input": agent_task.input, "output": agent_task.output, "createTime": agent_task.create_time.strftime("%Y-%m-%d %H:%M:%S"), "updateTime": agent_task.update_time.strftime("%Y-%m-%d %H:%M:%S") } for agent_task, agent_configuration in result ] return { "currentPage": page_num, "pageSize": page_size, "totalSize": total_page, "total": total, "list": response_data, } def create_task(self, agent_id: int, task_prompt: str): """创建新任务""" with self.session_maker() as session: agent_config = session.get(AgentConfiguration, agent_id) agent_task = AgentTask(agent_id=agent_id, status=AgentTaskStatus.NOT_STARTED.value, start_time=datetime.now(), input=task_prompt, tools=agent_config.tools) session.add(agent_task) session.commit() # 显式提交 task_id = agent_task.id # 异步执行创建任务 self.executor.submit(self._execute_task, task_id) def _process_task(self, task_id: int): try: self.update_task_status(task_id, AgentTaskStatus.IN_PROGRESS.value) agent_task = self.get_agent_task(task_id) agent_config = self.get_agent_config(agent_task.agent_id) tools = get_tools(json.loads(agent_config.tools)) chat_agent = SimpleOpenAICompatibleChatAgent(model=agent_config.execution_model, system_prompt=agent_config.system_prompt, tools=tools) message = chat_agent.run(agent_task.input) agent_task_details = chat_agent.get_agent_task_details() for agent_task_detail in agent_task_details: agent_task_detail.agent_task_id = task_id agent_task_detail.status = AgentTaskDetailStatus.SUCCESS.value self.save_agent_task_details_batch(agent_task_details, task_id, message) except Exception as e: logger.error(e) self.update_task_failed(task_id, str(e)) def recover_tasks(self): """服务启动时恢复未完成的任务""" in_progress_tasks = self.get_in_progress_task() for task in in_progress_tasks: task_id = task.id # 重新提交任务 self._execute_task(task_id) def _execute_task(self, task_id: int): """提交任务到线程池执行""" # 确保任务状态一致性 if task_id in self.running_tasks: return # 创建任务事件和锁 if task_id not in self.task_events: self.task_events[task_id] = threading.Event() if task_id not in self.task_locks: self.task_locks[task_id] = threading.Lock() # 提交到线程池 future = self.executor.submit(self._process_task, task_id) self.task_futures[task_id] = future # 标记任务为运行中 with self.task_locks[task_id]: self.running_tasks.add(task_id) def get_agent_task_detail(self, agent_task_id, parent_execution_id): agent_task = self.get_agent_task(agent_task_id) agent_task_details = self.get_agent_task_details(agent_task_id, parent_execution_id) agent_task_detail_datas = [ { "id": agent_task_detail.id, "agentTaskId": agent_task_detail.agent_task_id, "executorType": agent_task_detail.executor_type, "statusName": get_agent_task_detail_status_desc(agent_task_detail.status), "inputData": agent_task_detail.input_data, "executorName": agent_task_detail.executor_name, "reasoning": agent_task_detail.reasoning, "outputData": agent_task_detail.output_data, "errorMessage": agent_task_detail.error_message, "hasChildren": has_children, "createTime": agent_task_detail.create_time.strftime("%Y-%m-%d %H:%M:%S"), "updateTime": agent_task_detail.update_time.strftime("%Y-%m-%d %H:%M:%S") } for agent_task_detail, has_children in agent_task_details ] return { "input": agent_task.input, "tools": agent_task.tools, "agentTaskDetails": agent_task_detail_datas }