agent_task_server.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import json
  2. import threading
  3. from concurrent.futures import ThreadPoolExecutor
  4. from datetime import datetime
  5. from typing import Dict
  6. from sqlalchemy import func, select
  7. from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
  8. from pqai_agent.data_models.agent_configuration import AgentConfiguration
  9. from pqai_agent.data_models.agent_task import AgentTask
  10. from pqai_agent.data_models.agent_task_detail import AgentTaskDetail
  11. from pqai_agent.logging import logger
  12. from pqai_agent.toolkit import get_tools
  13. from pqai_agent_server.const.status_enum import AgentTaskStatus, get_agent_task_detail_status_desc, \
  14. AgentTaskDetailStatus, get_agent_task_status_desc
  15. class AgentTaskManager:
  16. """任务管理器"""
  17. def __init__(self, session_maker):
  18. self.session_maker = session_maker
  19. self.task_events = {} # 任务ID -> Event (用于取消任务)
  20. self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
  21. self.running_tasks = set()
  22. self.executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='AgentTaskWorker')
  23. self.task_futures = {} # 任务ID -> Future
  24. def get_agent_task(self, agent_task_id: int):
  25. """获取任务信息"""
  26. with self.session_maker() as session:
  27. return session.query(AgentTask).filter(AgentTask.id == agent_task_id).one()
  28. def get_agent_config(self, agent_id: int):
  29. """获取任务信息"""
  30. with self.session_maker() as session:
  31. return session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).one()
  32. def get_in_progress_task(self):
  33. """获取执行中任务"""
  34. with self.session_maker() as session:
  35. return session.query(AgentTask).filter(AgentTask.status.in_([
  36. AgentTaskStatus.NOT_STARTED.value,
  37. AgentTaskStatus.IN_PROGRESS.value
  38. ])).all()
  39. def get_agent_task_details(self, task_id, parent_execution_id):
  40. """获取任务详情"""
  41. with self.session_maker() as session:
  42. # 创建子查询:统计每个父节点的子节点数量
  43. subquery = (
  44. select(
  45. AgentTaskDetail.parent_execution_id.label('parent_id'),
  46. func.count('*').label('child_count')
  47. )
  48. .where(AgentTaskDetail.parent_execution_id.isnot(None))
  49. .group_by(AgentTaskDetail.parent_execution_id)
  50. .subquery()
  51. )
  52. # 主查询:关联子查询,判断是否有子节点
  53. # 修正连接条件:使用parent_execution_id关联
  54. query = (
  55. select(
  56. AgentTaskDetail,
  57. (func.coalesce(subquery.c.child_count, 0) > 0).label('has_children')
  58. )
  59. .outerjoin(
  60. subquery,
  61. AgentTaskDetail.id == subquery.c.parent_id # 使用当前记录的id匹配子查询的parent_id
  62. )
  63. ).where(AgentTaskDetail.agent_task_id == task_id).where(
  64. AgentTaskDetail.parent_execution_id == parent_execution_id)
  65. # 执行查询
  66. return session.execute(query).all()
  67. def save_agent_task_details_batch(self, agent_task_details: list, agent_task_id: int, message: str):
  68. """批量保存子任务到数据库"""
  69. try:
  70. with self.session_maker() as session:
  71. with session.begin():
  72. session.add_all(agent_task_details)
  73. session.query(AgentTask).filter(
  74. AgentTask.id == agent_task_id).update(
  75. {"status": AgentTaskStatus.COMPLETED.value, "output": message, "update_time": datetime.now()})
  76. session.commit()
  77. except Exception as e:
  78. logger.error(e)
  79. raise Exception(e)
  80. def update_task_failed(self, task_id, error_message: str):
  81. """更新任务状态"""
  82. with self.session_maker() as session:
  83. session.query(AgentTask).filter(AgentTask.id == task_id).update(
  84. {"status": AgentTaskStatus.FAILED, "error_message": error_message, "update_time": datetime.now()})
  85. session.commit()
  86. def update_task_status(self, task_id, status):
  87. """更新任务状态"""
  88. with self.session_maker() as session:
  89. session.query(AgentTask).filter(AgentTask.id == task_id).update(
  90. {"status": status, "update_time": datetime.now()})
  91. session.commit()
  92. def get_agent_task_list(self, page_num: int, page_size: int) -> Dict:
  93. with self.session_maker() as session:
  94. # 计算偏移量
  95. offset = (page_num - 1) * page_size
  96. # 查询分页数据
  97. result = (session.query(AgentTask, AgentConfiguration)
  98. .outerjoin(AgentConfiguration, AgentTask.agent_id == AgentConfiguration.id)
  99. .limit(page_size).offset(offset).all())
  100. # 查询总记录数
  101. total = session.query(func.count(AgentTask.id)).scalar()
  102. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  103. total_page = 1 if total_page <= 0 else total_page
  104. response_data = [
  105. {
  106. "id": agent_task.id,
  107. "agentName": agent_configuration.name,
  108. "statusName": get_agent_task_status_desc(agent_task.status),
  109. "startTime": agent_task.start_time.strftime("%Y-%m-%d %H:%M:%S"),
  110. "endTime": agent_task.start_time.strftime(
  111. "%Y-%m-%d %H:%M:%S") if agent_task.start_time is not None else None,
  112. "createUser": agent_task.create_user,
  113. "input": agent_task.input,
  114. "output": agent_task.output,
  115. "createTime": agent_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
  116. "updateTime": agent_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
  117. }
  118. for agent_task, agent_configuration in result
  119. ]
  120. return {
  121. "currentPage": page_num,
  122. "pageSize": page_size,
  123. "totalSize": total_page,
  124. "total": total,
  125. "list": response_data,
  126. }
  127. def create_task(self, agent_id: int, task_prompt: str):
  128. """创建新任务"""
  129. with self.session_maker() as session:
  130. agent_config = session.get(AgentConfiguration, agent_id)
  131. agent_task = AgentTask(agent_id=agent_id,
  132. status=AgentTaskStatus.NOT_STARTED.value,
  133. start_time=datetime.now(),
  134. input=task_prompt,
  135. tools=agent_config.tools)
  136. session.add(agent_task)
  137. session.commit() # 显式提交
  138. task_id = agent_task.id
  139. # 异步执行创建任务
  140. self.executor.submit(self._execute_task, task_id)
  141. def _process_task(self, task_id: int):
  142. try:
  143. self.update_task_status(task_id, AgentTaskStatus.IN_PROGRESS.value)
  144. agent_task = self.get_agent_task(task_id)
  145. agent_config = self.get_agent_config(agent_task.agent_id)
  146. tools = get_tools(json.loads(agent_config.tools))
  147. chat_agent = SimpleOpenAICompatibleChatAgent(model=agent_config.execution_model,
  148. system_prompt=agent_config.system_prompt,
  149. tools=tools)
  150. message = chat_agent.run(agent_task.input)
  151. agent_task_details = chat_agent.get_agent_task_details()
  152. for agent_task_detail in agent_task_details:
  153. agent_task_detail.agent_task_id = task_id
  154. agent_task_detail.status = AgentTaskDetailStatus.SUCCESS.value
  155. self.save_agent_task_details_batch(agent_task_details, task_id, message)
  156. except Exception as e:
  157. logger.error(e)
  158. self.update_task_failed(task_id, str(e))
  159. def recover_tasks(self):
  160. """服务启动时恢复未完成的任务"""
  161. in_progress_tasks = self.get_in_progress_task()
  162. for task in in_progress_tasks:
  163. task_id = task.id
  164. # 重新提交任务
  165. self._execute_task(task_id)
  166. def _execute_task(self, task_id: int):
  167. """提交任务到线程池执行"""
  168. # 确保任务状态一致性
  169. if task_id in self.running_tasks:
  170. return
  171. # 创建任务事件和锁
  172. if task_id not in self.task_events:
  173. self.task_events[task_id] = threading.Event()
  174. if task_id not in self.task_locks:
  175. self.task_locks[task_id] = threading.Lock()
  176. # 提交到线程池
  177. future = self.executor.submit(self._process_task, task_id)
  178. self.task_futures[task_id] = future
  179. # 标记任务为运行中
  180. with self.task_locks[task_id]:
  181. self.running_tasks.add(task_id)
  182. def get_agent_task_detail(self, agent_task_id, parent_execution_id):
  183. agent_task = self.get_agent_task(agent_task_id)
  184. agent_task_details = self.get_agent_task_details(agent_task_id, parent_execution_id)
  185. agent_task_detail_datas = [
  186. {
  187. "id": agent_task_detail.id,
  188. "agentTaskId": agent_task_detail.agent_task_id,
  189. "executorType": agent_task_detail.executor_type,
  190. "statusName": get_agent_task_detail_status_desc(agent_task_detail.status),
  191. "inputData": agent_task_detail.input_data,
  192. "executorName": agent_task_detail.executor_name,
  193. "reasoning": agent_task_detail.reasoning,
  194. "outputData": agent_task_detail.output_data,
  195. "errorMessage": agent_task_detail.error_message,
  196. "hasChildren": has_children,
  197. "createTime": agent_task_detail.create_time.strftime("%Y-%m-%d %H:%M:%S"),
  198. "updateTime": agent_task_detail.update_time.strftime("%Y-%m-%d %H:%M:%S")
  199. }
  200. for agent_task_detail, has_children in agent_task_details
  201. ]
  202. return {
  203. "input": agent_task.input,
  204. "tools": agent_task.tools,
  205. "agentTaskDetails": agent_task_detail_datas
  206. }