agent_task_server.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import json
  2. import threading
  3. from concurrent.futures import ThreadPoolExecutor
  4. from datetime import datetime
  5. from typing import Dict
  6. from sqlalchemy import func
  7. from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
  8. from pqai_agent.data_models.agent_configuration import AgentConfiguration
  9. from pqai_agent.data_models.agent_task import AgentTask
  10. from pqai_agent.data_models.agent_task_detail import AgentTaskDetail
  11. from pqai_agent.logging import logger
  12. from pqai_agent.toolkit import get_tools
  13. from pqai_agent_server.const.status_enum import AgentTaskStatus, get_agent_task_detail_status_desc, \
  14. AgentTaskDetailStatus, get_agent_task_status_desc
  15. class AgentTaskManager:
  16. """任务管理器"""
  17. def __init__(self, session_maker):
  18. self.session_maker = session_maker
  19. self.task_events = {} # 任务ID -> Event (用于取消任务)
  20. self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
  21. self.running_tasks = set()
  22. self.executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='AgentTaskWorker')
  23. self.task_futures = {} # 任务ID -> Future
  24. def get_agent_task(self, agent_task_id: int):
  25. """获取任务信息"""
  26. with self.session_maker() as session:
  27. return session.query(AgentTask).filter(AgentTask.id == agent_task_id).one()
  28. def get_agent_config(self, agent_id: int):
  29. """获取任务信息"""
  30. with self.session_maker() as session:
  31. return session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).one()
  32. def get_in_progress_task(self):
  33. """获取执行中任务"""
  34. with self.session_maker() as session:
  35. return session.query(AgentTask).filter(AgentTask.status.in_([
  36. AgentTaskStatus.NOT_STARTED.value,
  37. AgentTaskStatus.IN_PROGRESS.value
  38. ])).all()
  39. def get_agent_task_details(self, task_id):
  40. """获取任务详情"""
  41. with self.session_maker() as session:
  42. return session.query(AgentTaskDetail).filter(AgentTaskDetail.agent_task_id == task_id).filter(
  43. AgentTaskDetail.parent_execution_id == None).all()
  44. def save_agent_task_details_batch(self, agent_task_details: list, agent_task_id: int, message: str):
  45. """批量保存子任务到数据库"""
  46. try:
  47. with self.session_maker() as session:
  48. with session.begin():
  49. session.add_all(agent_task_details)
  50. session.query(AgentTask).filter(
  51. AgentTask.id == agent_task_id).update(
  52. {"status": AgentTaskStatus.COMPLETED.value, "output": message, "update_time": datetime.now()})
  53. session.commit()
  54. except Exception as e:
  55. logger.error(e)
  56. raise Exception(e)
  57. def update_task_failed(self, task_id, error_message: str):
  58. """更新任务状态"""
  59. with self.session_maker() as session:
  60. session.query(AgentTask).filter(AgentTask.id == task_id).update(
  61. {"status": AgentTaskStatus.FAILED, "error_message": error_message, "update_time": datetime.now()})
  62. session.commit()
  63. def update_task_status(self, task_id, status):
  64. """更新任务状态"""
  65. with self.session_maker() as session:
  66. session.query(AgentTask).filter(AgentTask.id == task_id).update(
  67. {"status": status, "update_time": datetime.now()})
  68. session.commit()
  69. def get_agent_task_list(self, page_num: int, page_size: int) -> Dict:
  70. with self.session_maker() as session:
  71. # 计算偏移量
  72. offset = (page_num - 1) * page_size
  73. # 查询分页数据
  74. result = (session.query(AgentTask, AgentConfiguration)
  75. .outerjoin(AgentConfiguration, AgentTask.agent_id == AgentConfiguration.id)
  76. .limit(page_size).offset(offset).all())
  77. # 查询总记录数
  78. total = session.query(func.count(AgentTask.id)).scalar()
  79. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  80. total_page = 1 if total_page <= 0 else total_page
  81. response_data = [
  82. {
  83. "id": agent_task.id,
  84. "agentName": agent_configuration.name,
  85. "createUser": agent_task.create_user,
  86. "updateUser": agent_task.update_user,
  87. "statusName": get_agent_task_status_desc(agent_task.status),
  88. "createTime": agent_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
  89. "updateTime": agent_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
  90. }
  91. for agent_task, agent_configuration in result
  92. ]
  93. return {
  94. "currentPage": page_num,
  95. "pageSize": page_size,
  96. "totalSize": total_page,
  97. "total": total,
  98. "list": response_data,
  99. }
  100. def create_task(self, agent_id: int, task_prompt: str):
  101. """创建新任务"""
  102. with self.session_maker() as session:
  103. agent_config = session.get(AgentConfiguration, agent_id)
  104. agent_task = AgentTask(agent_id=agent_id,
  105. status=AgentTaskStatus.NOT_STARTED.value,
  106. start_time=datetime.now(),
  107. input=task_prompt,
  108. tools=agent_config.tools)
  109. session.add(agent_task)
  110. session.commit() # 显式提交
  111. task_id = agent_task.id
  112. # 异步执行创建任务
  113. self.executor.submit(self._execute_task, task_id)
  114. def _process_task(self, task_id: int):
  115. try:
  116. self.update_task_status(task_id, AgentTaskStatus.IN_PROGRESS.value)
  117. agent_task = self.get_agent_task(task_id)
  118. agent_config = self.get_agent_config(agent_task.agent_id)
  119. tools = get_tools(json.loads(agent_config.tools))
  120. chat_agent = SimpleOpenAICompatibleChatAgent(model=agent_config.execution_model,
  121. system_prompt=agent_config.system_prompt,
  122. tools=tools)
  123. message = chat_agent.run(agent_task.input)
  124. agent_task_details = chat_agent.get_agent_task_details()
  125. for agent_task_detail in agent_task_details:
  126. agent_task_detail.agent_task_id = task_id
  127. agent_task_detail.status = AgentTaskDetailStatus.SUCCESS.value
  128. self.save_agent_task_details_batch(agent_task_details, task_id, message)
  129. except Exception as e:
  130. logger.error(e)
  131. self.update_task_failed(task_id, str(e))
  132. def recover_tasks(self):
  133. """服务启动时恢复未完成的任务"""
  134. in_progress_tasks = self.get_in_progress_task()
  135. for task in in_progress_tasks:
  136. task_id = task.id
  137. # 重新提交任务
  138. self._execute_task(task_id)
  139. def _execute_task(self, task_id: int):
  140. """提交任务到线程池执行"""
  141. # 确保任务状态一致性
  142. if task_id in self.running_tasks:
  143. return
  144. # 创建任务事件和锁
  145. if task_id not in self.task_events:
  146. self.task_events[task_id] = threading.Event()
  147. if task_id not in self.task_locks:
  148. self.task_locks[task_id] = threading.Lock()
  149. # 提交到线程池
  150. future = self.executor.submit(self._process_task, task_id)
  151. self.task_futures[task_id] = future
  152. # 标记任务为运行中
  153. with self.task_locks[task_id]:
  154. self.running_tasks.add(task_id)
  155. def get_agent_task_detail(self, agent_task_id):
  156. agent_task = self.get_agent_task(agent_task_id)
  157. agent_task_details = self.get_agent_task_details(agent_task_id)
  158. agent_task_detail_datas = []
  159. for agent_task_detail in agent_task_details:
  160. data = {}
  161. data["id"] = agent_task_detail.id
  162. data["executorType"] = agent_task_detail.executor_type
  163. data["status"] = get_agent_task_detail_status_desc(agent_task_detail.status)
  164. data["inputData"] = agent_task_detail.input_data
  165. data["executorName"] = agent_task_detail.executor_name
  166. data["reasoning"] = agent_task_detail.reasoning
  167. data["outputData"] = agent_task_detail.output_data
  168. data["errorMessage"] = agent_task_detail.error_message
  169. data["createTime"]: agent_task_detail.create_time.strftime("%Y-%m-%d %H:%M:%S")
  170. data["updateTime"]: agent_task_detail.update_time.strftime("%Y-%m-%d %H:%M:%S")
  171. agent_task_detail_datas.append(data)
  172. return {
  173. "input": agent_task.input,
  174. "tools": agent_task.tools,
  175. "agentTaskDetails": agent_task_detail_datas
  176. }