|
@@ -0,0 +1,206 @@
|
|
|
+import concurrent.futures
|
|
|
+import json
|
|
|
+import threading
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
+from datetime import datetime
|
|
|
+from typing import Dict
|
|
|
+
|
|
|
+from sqlalchemy import func
|
|
|
+
|
|
|
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
|
|
|
+from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
|
|
|
+from pqai_agent.data_models import agent_task_detail
|
|
|
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
|
|
|
+from pqai_agent.data_models.agent_task import AgentTask
|
|
|
+from pqai_agent.data_models.agent_task_detail import AgentTaskDetail
|
|
|
+from pqai_agent.data_models.agent_test_task import AgentTestTask
|
|
|
+from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
|
|
|
+from pqai_agent.data_models.service_module import ServiceModule
|
|
|
+from pqai_agent.logging import logger
|
|
|
+from pqai_agent.toolkit import get_tools
|
|
|
+from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc, \
|
|
|
+ AgentTaskStatus, get_agent_task_detail_status_desc, AgentTaskDetailStatus
|
|
|
+from scripts.evaluate_agent import evaluate_agent
|
|
|
+
|
|
|
+
|
|
|
+class AgentTaskManager:
|
|
|
+ """任务管理器"""
|
|
|
+
|
|
|
+ def __init__(self, session_maker):
|
|
|
+ self.session_maker = session_maker
|
|
|
+ self.task_events = {} # 任务ID -> Event (用于取消任务)
|
|
|
+ self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
|
|
|
+ self.running_tasks = set()
|
|
|
+ self.executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='AgentTaskWorker')
|
|
|
+ self.task_futures = {} # 任务ID -> Future
|
|
|
+
|
|
|
+ def get_agent_task(self, agent_task_id: int):
|
|
|
+ """获取任务信息"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(AgentTask).filter(AgentTask.id == agent_task_id).one()
|
|
|
+
|
|
|
+ def get_agent_config(self, agent_id: int):
|
|
|
+ """获取任务信息"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).one()
|
|
|
+
|
|
|
+ def get_in_progress_task(self):
|
|
|
+ """获取执行中任务"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(AgentTask).filter(AgentTask.status.in_([
|
|
|
+ AgentTaskStatus.NOT_STARTED.value,
|
|
|
+ AgentTaskStatus.IN_PROGRESS.value
|
|
|
+ ])).all()
|
|
|
+
|
|
|
+ def get_agent_task_details(self, task_id):
|
|
|
+ """更新任务状态"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(AgentTaskDetail).filter(AgentTaskDetail.agent_task_id == task_id).all()
|
|
|
+
|
|
|
+ def save_agent_task_details_batch(self, agent_task_details: list, agent_task_id: int, message: str):
|
|
|
+ """批量保存子任务到数据库"""
|
|
|
+ try:
|
|
|
+ with self.session_maker() as session:
|
|
|
+ with session.begin():
|
|
|
+ session.add_all(agent_task_details)
|
|
|
+ session.query(AgentTask).filter(
|
|
|
+ AgentTask.id == agent_task_id).update(
|
|
|
+ {"status": AgentTaskStatus.COMPLETED.value, "output": message, "update_time": datetime.now()})
|
|
|
+ session.commit()
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(e)
|
|
|
+ raise Exception(e)
|
|
|
+
|
|
|
+ def update_task_failed(self, task_id, error_message: str):
|
|
|
+ """更新任务状态"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ session.query(AgentTask).filter(AgentTask.id == task_id).update(
|
|
|
+ {"status": AgentTaskStatus.FAILED, "error_message": error_message, "update_time": datetime.now()})
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ def update_task_status(self, task_id, status):
|
|
|
+ """更新任务状态"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ session.query(AgentTask).filter(AgentTask.id == task_id).update(
|
|
|
+ {"status": status, "update_time": datetime.now()})
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ def get_agent_task_list(self, page_num: int, page_size: int) -> Dict:
|
|
|
+ with self.session_maker() as session:
|
|
|
+ # 计算偏移量
|
|
|
+ offset = (page_num - 1) * page_size
|
|
|
+ # 查询分页数据
|
|
|
+ result = (session.query(AgentTestTask, AgentConfiguration)
|
|
|
+ .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
|
|
|
+ .limit(page_size).offset(offset).all())
|
|
|
+ # 查询总记录数
|
|
|
+ total = session.query(func.count(AgentTestTask.id)).scalar()
|
|
|
+
|
|
|
+ total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
|
|
|
+ total_page = 1 if total_page <= 0 else total_page
|
|
|
+ response_data = [
|
|
|
+ {
|
|
|
+ "id": agent_test_task.id,
|
|
|
+ "agentName": agent_configuration.name,
|
|
|
+ "createUser": agent_test_task.create_user,
|
|
|
+ "updateUser": agent_test_task.update_user,
|
|
|
+ "statusName": get_test_task_status_desc(agent_test_task.status),
|
|
|
+ "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
+ "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ }
|
|
|
+ for agent_test_task, agent_configuration in result
|
|
|
+ ]
|
|
|
+ return {
|
|
|
+ "currentPage": page_num,
|
|
|
+ "pageSize": page_size,
|
|
|
+ "totalSize": total_page,
|
|
|
+ "total": total,
|
|
|
+ "list": response_data,
|
|
|
+ }
|
|
|
+
|
|
|
+ def create_task(self, agent_id: int, task_prompt: str):
|
|
|
+ """创建新任务"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ agent_config = session.get(AgentConfiguration, agent_id)
|
|
|
+ agent_task = AgentTask(agent_id=agent_id,
|
|
|
+ status=AgentTaskStatus.NOT_STARTED.value,
|
|
|
+ start_time=datetime.now(),
|
|
|
+ input=task_prompt,
|
|
|
+ tools=agent_config.tools)
|
|
|
+ session.add(agent_task)
|
|
|
+ session.commit() # 显式提交
|
|
|
+ task_id = agent_task.id
|
|
|
+ # 异步执行创建任务
|
|
|
+ self.executor.submit(self._execute_task, task_id)
|
|
|
+
|
|
|
+ def _process_task(self, task_id: int):
|
|
|
+ try:
|
|
|
+ self.update_task_status(task_id, AgentTaskStatus.IN_PROGRESS.value)
|
|
|
+ agent_task = self.get_agent_task(task_id)
|
|
|
+ agent_config = self.get_agent_config(agent_task.agent_id)
|
|
|
+ tools = get_tools(json.loads(agent_config.tools))
|
|
|
+ chat_agent = SimpleOpenAICompatibleChatAgent(model=agent_config.execution_model,
|
|
|
+ system_prompt=agent_config.system_prompt,
|
|
|
+ tools=tools)
|
|
|
+ message = chat_agent.run(agent_task.input)
|
|
|
+ agent_task_details = chat_agent.get_agent_task_details()
|
|
|
+ for agent_task_detail in agent_task_details:
|
|
|
+ agent_task_detail.agent_task_id = task_id
|
|
|
+ agent_task_detail.status = AgentTaskDetailStatus.SUCCESS.value
|
|
|
+ self.save_agent_task_details_batch(agent_task_details, task_id, message)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(e)
|
|
|
+ self.update_task_failed(task_id, str(e))
|
|
|
+
|
|
|
+ def recover_tasks(self):
|
|
|
+ """服务启动时恢复未完成的任务"""
|
|
|
+
|
|
|
+ in_progress_tasks = self.get_in_progress_task()
|
|
|
+
|
|
|
+ for task in in_progress_tasks:
|
|
|
+ task_id = task.id
|
|
|
+ # 重新提交任务
|
|
|
+ self._execute_task(task_id)
|
|
|
+
|
|
|
+ def _execute_task(self, task_id: int):
|
|
|
+ """提交任务到线程池执行"""
|
|
|
+ # 确保任务状态一致性
|
|
|
+ if task_id in self.running_tasks:
|
|
|
+ return
|
|
|
+
|
|
|
+ # 创建任务事件和锁
|
|
|
+ if task_id not in self.task_events:
|
|
|
+ self.task_events[task_id] = threading.Event()
|
|
|
+ if task_id not in self.task_locks:
|
|
|
+ self.task_locks[task_id] = threading.Lock()
|
|
|
+
|
|
|
+ # 提交到线程池
|
|
|
+ future = self.executor.submit(self._process_task, task_id)
|
|
|
+ self.task_futures[task_id] = future
|
|
|
+
|
|
|
+ # 标记任务为运行中
|
|
|
+ with self.task_locks[task_id]:
|
|
|
+ self.running_tasks.add(task_id)
|
|
|
+
|
|
|
+ def get_agent_task_detail(self, agent_task_id):
|
|
|
+ agent_task = self.get_agent_task(agent_task_id)
|
|
|
+ agent_task_details = self.get_agent_task_details(agent_task_id)
|
|
|
+ agent_task_detail_datas = []
|
|
|
+ for agent_task_detail in agent_task_details:
|
|
|
+ data = {}
|
|
|
+ data["id"] = agent_task_detail.id
|
|
|
+ data["executorType"] = agent_task_detail.executor_type
|
|
|
+ data["status"] = get_agent_task_detail_status_desc(agent_task_detail.status)
|
|
|
+ data["inputData"] = agent_task_detail.input_data
|
|
|
+ data["executorName"] = agent_task_detail.executor_name
|
|
|
+ data["reasoning"] = agent_task_detail.reasoning
|
|
|
+ data["outputData"] = agent_task_detail.output_data
|
|
|
+ data["errorMessage"] = agent_task_detail.error_message
|
|
|
+ data["createTime"]: agent_task_detail.create_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ data["updateTime"]: agent_task_detail.update_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ agent_task_detail_datas.append(data)
|
|
|
+ return {
|
|
|
+ "input": agent_task.input,
|
|
|
+ "tools": agent_task.tools,
|
|
|
+ "agentTaskDetails": agent_task_detail_datas
|
|
|
+ }
|