|
- import concurrent.futures
- import json
- import threading
- from concurrent.futures import ThreadPoolExecutor
- from datetime import datetime
- from typing import Dict
- from sqlalchemy import func
- from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
- from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
- from pqai_agent.data_models import agent_task_detail
- from pqai_agent.data_models.agent_configuration import AgentConfiguration
- from pqai_agent.data_models.agent_task import AgentTask
- from pqai_agent.data_models.agent_task_detail import AgentTaskDetail
- from pqai_agent.data_models.agent_test_task import AgentTestTask
- from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
- from pqai_agent.data_models.service_module import ServiceModule
- from pqai_agent.logging import logger
- from pqai_agent.toolkit import get_tools
- from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc, \
- AgentTaskStatus, get_agent_task_detail_status_desc, AgentTaskDetailStatus
- from scripts.evaluate_agent import evaluate_agent
- class AgentTaskManager:
- """任务管理器"""
- def __init__(self, session_maker):
- self.session_maker = session_maker
- self.task_events = {} # 任务ID -> Event (用于取消任务)
- self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
- self.running_tasks = set()
- self.executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='AgentTaskWorker')
- self.task_futures = {} # 任务ID -> Future
- def get_agent_task(self, agent_task_id: int):
- """获取任务信息"""
- with self.session_maker() as session:
- return session.query(AgentTask).filter(AgentTask.id == agent_task_id).one()
- def get_agent_config(self, agent_id: int):
- """获取任务信息"""
- with self.session_maker() as session:
- return session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).one()
- def get_in_progress_task(self):
- """获取执行中任务"""
- with self.session_maker() as session:
- return session.query(AgentTask).filter(AgentTask.status.in_([
- AgentTaskStatus.NOT_STARTED.value,
- AgentTaskStatus.IN_PROGRESS.value
- ])).all()
- def get_agent_task_details(self, task_id):
- """更新任务状态"""
- with self.session_maker() as session:
- return session.query(AgentTaskDetail).filter(AgentTaskDetail.agent_task_id == task_id).all()
- def save_agent_task_details_batch(self, agent_task_details: list, agent_task_id: int, message: str):
- """批量保存子任务到数据库"""
- try:
- with self.session_maker() as session:
- with session.begin():
- session.add_all(agent_task_details)
- session.query(AgentTask).filter(
- AgentTask.id == agent_task_id).update(
- {"status": AgentTaskStatus.COMPLETED.value, "output": message, "update_time": datetime.now()})
- session.commit()
- except Exception as e:
- logger.error(e)
- raise Exception(e)
- def update_task_failed(self, task_id, error_message: str):
- """更新任务状态"""
- with self.session_maker() as session:
- session.query(AgentTask).filter(AgentTask.id == task_id).update(
- {"status": AgentTaskStatus.FAILED, "error_message": error_message, "update_time": datetime.now()})
- session.commit()
- def update_task_status(self, task_id, status):
- """更新任务状态"""
- with self.session_maker() as session:
- session.query(AgentTask).filter(AgentTask.id == task_id).update(
- {"status": status, "update_time": datetime.now()})
- session.commit()
- def get_agent_task_list(self, page_num: int, page_size: int) -> Dict:
- with self.session_maker() as session:
- # 计算偏移量
- offset = (page_num - 1) * page_size
- # 查询分页数据
- result = (session.query(AgentTestTask, AgentConfiguration)
- .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
- .limit(page_size).offset(offset).all())
- # 查询总记录数
- total = session.query(func.count(AgentTestTask.id)).scalar()
- total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
- total_page = 1 if total_page <= 0 else total_page
- response_data = [
- {
- "id": agent_test_task.id,
- "agentName": agent_configuration.name,
- "createUser": agent_test_task.create_user,
- "updateUser": agent_test_task.update_user,
- "statusName": get_test_task_status_desc(agent_test_task.status),
- "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
- "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
- }
- for agent_test_task, agent_configuration in result
- ]
- return {
- "currentPage": page_num,
- "pageSize": page_size,
- "totalSize": total_page,
- "total": total,
- "list": response_data,
- }
- def create_task(self, agent_id: int, task_prompt: str):
- """创建新任务"""
- with self.session_maker() as session:
- agent_config = session.get(AgentConfiguration, agent_id)
- agent_task = AgentTask(agent_id=agent_id,
- status=AgentTaskStatus.NOT_STARTED.value,
- start_time=datetime.now(),
- input=task_prompt,
- tools=agent_config.tools)
- session.add(agent_task)
- session.commit() # 显式提交
- task_id = agent_task.id
- # 异步执行创建任务
- self.executor.submit(self._execute_task, task_id)
- def _process_task(self, task_id: int):
- try:
- self.update_task_status(task_id, AgentTaskStatus.IN_PROGRESS.value)
- agent_task = self.get_agent_task(task_id)
- agent_config = self.get_agent_config(agent_task.agent_id)
- tools = get_tools(json.loads(agent_config.tools))
- chat_agent = SimpleOpenAICompatibleChatAgent(model=agent_config.execution_model,
- system_prompt=agent_config.system_prompt,
- tools=tools)
- message = chat_agent.run(agent_task.input)
- agent_task_details = chat_agent.get_agent_task_details()
- for agent_task_detail in agent_task_details:
- agent_task_detail.agent_task_id = task_id
- agent_task_detail.status = AgentTaskDetailStatus.SUCCESS.value
- self.save_agent_task_details_batch(agent_task_details, task_id, message)
- except Exception as e:
- logger.error(e)
- self.update_task_failed(task_id, str(e))
- def recover_tasks(self):
- """服务启动时恢复未完成的任务"""
- in_progress_tasks = self.get_in_progress_task()
- for task in in_progress_tasks:
- task_id = task.id
- # 重新提交任务
- self._execute_task(task_id)
- def _execute_task(self, task_id: int):
- """提交任务到线程池执行"""
- # 确保任务状态一致性
- if task_id in self.running_tasks:
- return
- # 创建任务事件和锁
- if task_id not in self.task_events:
- self.task_events[task_id] = threading.Event()
- if task_id not in self.task_locks:
- self.task_locks[task_id] = threading.Lock()
- # 提交到线程池
- future = self.executor.submit(self._process_task, task_id)
- self.task_futures[task_id] = future
- # 标记任务为运行中
- with self.task_locks[task_id]:
- self.running_tasks.add(task_id)
- def get_agent_task_detail(self, agent_task_id):
- agent_task = self.get_agent_task(agent_task_id)
- agent_task_details = self.get_agent_task_details(agent_task_id)
- agent_task_detail_datas = []
- for agent_task_detail in agent_task_details:
- data = {}
- data["id"] = agent_task_detail.id
- data["executorType"] = agent_task_detail.executor_type
- data["status"] = get_agent_task_detail_status_desc(agent_task_detail.status)
- data["inputData"] = agent_task_detail.input_data
- data["executorName"] = agent_task_detail.executor_name
- data["reasoning"] = agent_task_detail.reasoning
- data["outputData"] = agent_task_detail.output_data
- data["errorMessage"] = agent_task_detail.error_message
- data["createTime"]: agent_task_detail.create_time.strftime("%Y-%m-%d %H:%M:%S")
- data["updateTime"]: agent_task_detail.update_time.strftime("%Y-%m-%d %H:%M:%S")
- agent_task_detail_datas.append(data)
- return {
- "input": agent_task.input,
- "tools": agent_task.tools,
- "agentTaskDetails": agent_task_detail_datas
- }
|