agent_task_server.py 9.3 KB


  1. import concurrent.futures
  2. import json
  3. import threading
  4. from concurrent.futures import ThreadPoolExecutor
  5. from datetime import datetime
  6. from typing import Dict
  7. from sqlalchemy import func
  8. from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
  9. from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
  10. from pqai_agent.data_models import agent_task_detail
  11. from pqai_agent.data_models.agent_configuration import AgentConfiguration
  12. from pqai_agent.data_models.agent_task import AgentTask
  13. from pqai_agent.data_models.agent_task_detail import AgentTaskDetail
  14. from pqai_agent.data_models.agent_test_task import AgentTestTask
  15. from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
  16. from pqai_agent.data_models.service_module import ServiceModule
  17. from pqai_agent.logging import logger
  18. from pqai_agent.toolkit import get_tools
  19. from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc, \
  20. AgentTaskStatus, get_agent_task_detail_status_desc, AgentTaskDetailStatus
  21. from scripts.evaluate_agent import evaluate_agent
  22. class AgentTaskManager:
  23. """任务管理器"""
  24. def __init__(self, session_maker):
  25. self.session_maker = session_maker
  26. self.task_events = {} # 任务ID -> Event (用于取消任务)
  27. self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
  28. self.running_tasks = set()
  29. self.executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='AgentTaskWorker')
  30. self.task_futures = {} # 任务ID -> Future
  31. def get_agent_task(self, agent_task_id: int):
  32. """获取任务信息"""
  33. with self.session_maker() as session:
  34. return session.query(AgentTask).filter(AgentTask.id == agent_task_id).one()
  35. def get_agent_config(self, agent_id: int):
  36. """获取任务信息"""
  37. with self.session_maker() as session:
  38. return session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).one()
  39. def get_in_progress_task(self):
  40. """获取执行中任务"""
  41. with self.session_maker() as session:
  42. return session.query(AgentTask).filter(AgentTask.status.in_([
  43. AgentTaskStatus.NOT_STARTED.value,
  44. AgentTaskStatus.IN_PROGRESS.value
  45. ])).all()
  46. def get_agent_task_details(self, task_id):
  47. """更新任务状态"""
  48. with self.session_maker() as session:
  49. return session.query(AgentTaskDetail).filter(AgentTaskDetail.agent_task_id == task_id).all()
  50. def save_agent_task_details_batch(self, agent_task_details: list, agent_task_id: int, message: str):
  51. """批量保存子任务到数据库"""
  52. try:
  53. with self.session_maker() as session:
  54. with session.begin():
  55. session.add_all(agent_task_details)
  56. session.query(AgentTask).filter(
  57. AgentTask.id == agent_task_id).update(
  58. {"status": AgentTaskStatus.COMPLETED.value, "output": message, "update_time": datetime.now()})
  59. session.commit()
  60. except Exception as e:
  61. logger.error(e)
  62. raise Exception(e)
  63. def update_task_failed(self, task_id, error_message: str):
  64. """更新任务状态"""
  65. with self.session_maker() as session:
  66. session.query(AgentTask).filter(AgentTask.id == task_id).update(
  67. {"status": AgentTaskStatus.FAILED, "error_message": error_message, "update_time": datetime.now()})
  68. session.commit()
  69. def update_task_status(self, task_id, status):
  70. """更新任务状态"""
  71. with self.session_maker() as session:
  72. session.query(AgentTask).filter(AgentTask.id == task_id).update(
  73. {"status": status, "update_time": datetime.now()})
  74. session.commit()
  75. def get_agent_task_list(self, page_num: int, page_size: int) -> Dict:
  76. with self.session_maker() as session:
  77. # 计算偏移量
  78. offset = (page_num - 1) * page_size
  79. # 查询分页数据
  80. result = (session.query(AgentTestTask, AgentConfiguration)
  81. .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
  82. .limit(page_size).offset(offset).all())
  83. # 查询总记录数
  84. total = session.query(func.count(AgentTestTask.id)).scalar()
  85. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  86. total_page = 1 if total_page <= 0 else total_page
  87. response_data = [
  88. {
  89. "id": agent_test_task.id,
  90. "agentName": agent_configuration.name,
  91. "createUser": agent_test_task.create_user,
  92. "updateUser": agent_test_task.update_user,
  93. "statusName": get_test_task_status_desc(agent_test_task.status),
  94. "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
  95. "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
  96. }
  97. for agent_test_task, agent_configuration in result
  98. ]
  99. return {
  100. "currentPage": page_num,
  101. "pageSize": page_size,
  102. "totalSize": total_page,
  103. "total": total,
  104. "list": response_data,
  105. }
  106. def create_task(self, agent_id: int, task_prompt: str):
  107. """创建新任务"""
  108. with self.session_maker() as session:
  109. agent_config = session.get(AgentConfiguration, agent_id)
  110. agent_task = AgentTask(agent_id=agent_id,
  111. status=AgentTaskStatus.NOT_STARTED.value,
  112. start_time=datetime.now(),
  113. input=task_prompt,
  114. tools=agent_config.tools)
  115. session.add(agent_task)
  116. session.commit() # 显式提交
  117. task_id = agent_task.id
  118. # 异步执行创建任务
  119. self.executor.submit(self._execute_task, task_id)
  120. def _process_task(self, task_id: int):
  121. try:
  122. self.update_task_status(task_id, AgentTaskStatus.IN_PROGRESS.value)
  123. agent_task = self.get_agent_task(task_id)
  124. agent_config = self.get_agent_config(agent_task.agent_id)
  125. tools = get_tools(json.loads(agent_config.tools))
  126. chat_agent = SimpleOpenAICompatibleChatAgent(model=agent_config.execution_model,
  127. system_prompt=agent_config.system_prompt,
  128. tools=tools)
  129. message = chat_agent.run(agent_task.input)
  130. agent_task_details = chat_agent.get_agent_task_details()
  131. for agent_task_detail in agent_task_details:
  132. agent_task_detail.agent_task_id = task_id
  133. agent_task_detail.status = AgentTaskDetailStatus.SUCCESS.value
  134. self.save_agent_task_details_batch(agent_task_details, task_id, message)
  135. except Exception as e:
  136. logger.error(e)
  137. self.update_task_failed(task_id, str(e))
  138. def recover_tasks(self):
  139. """服务启动时恢复未完成的任务"""
  140. in_progress_tasks = self.get_in_progress_task()
  141. for task in in_progress_tasks:
  142. task_id = task.id
  143. # 重新提交任务
  144. self._execute_task(task_id)
  145. def _execute_task(self, task_id: int):
  146. """提交任务到线程池执行"""
  147. # 确保任务状态一致性
  148. if task_id in self.running_tasks:
  149. return
  150. # 创建任务事件和锁
  151. if task_id not in self.task_events:
  152. self.task_events[task_id] = threading.Event()
  153. if task_id not in self.task_locks:
  154. self.task_locks[task_id] = threading.Lock()
  155. # 提交到线程池
  156. future = self.executor.submit(self._process_task, task_id)
  157. self.task_futures[task_id] = future
  158. # 标记任务为运行中
  159. with self.task_locks[task_id]:
  160. self.running_tasks.add(task_id)
  161. def get_agent_task_detail(self, agent_task_id):
  162. agent_task = self.get_agent_task(agent_task_id)
  163. agent_task_details = self.get_agent_task_details(agent_task_id)
  164. agent_task_detail_datas = []
  165. for agent_task_detail in agent_task_details:
  166. data = {}
  167. data["id"] = agent_task_detail.id
  168. data["executorType"] = agent_task_detail.executor_type
  169. data["status"] = get_agent_task_detail_status_desc(agent_task_detail.status)
  170. data["inputData"] = agent_task_detail.input_data
  171. data["executorName"] = agent_task_detail.executor_name
  172. data["reasoning"] = agent_task_detail.reasoning
  173. data["outputData"] = agent_task_detail.output_data
  174. data["errorMessage"] = agent_task_detail.error_message
  175. data["createTime"]: agent_task_detail.create_time.strftime("%Y-%m-%d %H:%M:%S")
  176. data["updateTime"]: agent_task_detail.update_time.strftime("%Y-%m-%d %H:%M:%S")
  177. agent_task_detail_datas.append(data)
  178. return {
  179. "input": agent_task.input,
  180. "tools": agent_task.tools,
  181. "agentTaskDetails": agent_task_detail_datas
  182. }