task_server.py 14 KB


  1. import threading
  2. import threading
  3. import time
  4. from concurrent.futures import ThreadPoolExecutor
  5. from typing import Dict
  6. from pyarrow.dataset import dataset
  7. from sqlalchemy import func
  8. from pqai_agent import logging_service
  9. from pqai_agent.data_models.agent_configuration import AgentConfiguration
  10. from pqai_agent.data_models.agent_test_task import AgentTestTask
  11. from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
  12. from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc
  13. logger = logging_service.logger
  14. class TaskManager:
  15. """任务管理器"""
  16. def __init__(self, session_maker, dataset_server, max_workers: int = 10):
  17. self.session_maker = session_maker
  18. self.dataset_server = dataset_server
  19. self.task_events = {} # 任务ID -> Event (用于取消任务)
  20. self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
  21. self.running_tasks = set()
  22. self.executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='TaskWorker')
  23. self.task_futures = {} # 任务ID -> Future
  24. def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
  25. with self.session_maker() as session:
  26. # 计算偏移量
  27. offset = (page_num - 1) * page_size
  28. # 查询分页数据
  29. result = (session.query(AgentTestTask, AgentConfiguration)
  30. .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
  31. .limit(page_size).offset(offset).all())
  32. # 查询总记录数
  33. total = session.query(func.count(AgentTestTask.id)).scalar()
  34. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  35. total_page = 1 if total_page <= 0 else total_page
  36. response_data = [
  37. {
  38. "id": agent_test_task.id,
  39. "agentName": agent_configuration.name,
  40. "createUser": agent_test_task.create_user,
  41. "updateUser": agent_test_task.update_user,
  42. "statusName": get_test_task_status_desc(agent_test_task.status),
  43. "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
  44. "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
  45. }
  46. for agent_test_task, agent_configuration in result
  47. ]
  48. return {
  49. "currentPage": page_num,
  50. "pageSize": page_size,
  51. "totalSize": total_page,
  52. "total": total,
  53. "list": response_data,
  54. }
  55. def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict:
  56. with self.session_maker() as session:
  57. # 计算偏移量
  58. offset = (page_num - 1) * page_size
  59. # 查询分页数据
  60. result = (session.query(AgentTestTaskConversations, AgentConfiguration)
  61. .outerjoin(AgentConfiguration, AgentTestTaskConversations.agent_id == AgentConfiguration.id)
  62. .filter(AgentTestTaskConversations.task_id == task_id)
  63. .limit(page_size).offset(offset).all())
  64. # 查询总记录数
  65. total = session.query(func.count(AgentTestTaskConversations.id)).scalar()
  66. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  67. total_page = 1 if total_page <= 0 else total_page
  68. response_data = [
  69. {
  70. "id": agent_test_task_conversation.id,
  71. "agentName": agent_configuration.name,
  72. "input": agent_test_task_conversation.input,
  73. "output": agent_test_task_conversation.output,
  74. "score": agent_test_task_conversation.score,
  75. "statusName": get_test_task_status_desc(agent_test_task_conversation.status),
  76. "createTime": agent_test_task_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S"),
  77. "updateTime": agent_test_task_conversation.update_time.strftime("%Y-%m-%d %H:%M:%S")
  78. }
  79. for agent_test_task_conversation, agent_configuration in result
  80. ]
  81. return {
  82. "currentPage": page_num,
  83. "pageSize": page_size,
  84. "totalSize": total_page,
  85. "total": total,
  86. "list": response_data,
  87. }
  88. def create_task(self, agent_id: int, module_id: int) -> Dict:
  89. """创建新任务"""
  90. with (self.session_maker() as session):
  91. with session.begin():
  92. agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id)
  93. session.add(agent_test_task)
  94. session.flush() # 强制SQL执行,但不提交事务
  95. task_id = agent_test_task.id
  96. agent_test_task_conversations = []
  97. datasets_list = self.dataset_server.get_dataset_list_by_module(module_id)
  98. for datasets in datasets_list:
  99. conversation_datas = self.dataset_server.get_conversation_data_list_by_dataset(datasets.id)
  100. for conversation_data in conversation_datas:
  101. agent_test_task_conversation = AgentTestTaskConversations(task_id=task_id, agent_id=agent_id,
  102. dataset_id=datasets.id,
  103. conversation_id=conversation_data.id)
  104. agent_test_task_conversations.append(agent_test_task_conversation)
  105. session.add_all(agent_test_task_conversations)
  106. # 异步执行任务
  107. self._execute_task(task_id)
  108. return self.get_task(task_id)
  109. def get_task(self, task_id: int):
  110. """获取任务信息"""
  111. with self.session_maker() as session:
  112. return session.query(AgentTestTask).filter(AgentTestTask.id == task_id).one()
  113. def get_task_conversations(self, task_id: int):
  114. """获取任务的所有子任务"""
  115. with self.session_maker() as session:
  116. return session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).all()
  117. def get_pending_task_conversations(self, task_id: int):
  118. """获取待处理的子任务"""
  119. with self.session_maker() as session:
  120. return session.query(AgentTestTaskConversations).filter(
  121. AgentTestTaskConversations.task_id == task_id).filter(
  122. AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).all()
  123. def update_task_status(self, task_id: int, status: int):
  124. """更新任务状态"""
  125. with self.session_maker() as session:
  126. session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update({"status": status})
  127. session.commit()
  128. def update_task_conversations_status(self, task_conversations_id: int, status: int):
  129. """更新子任务状态"""
  130. with self.session_maker() as session:
  131. session.query(AgentTestTaskConversations).filter(
  132. AgentTestTaskConversations.id == task_conversations_id).update({"status": status})
  133. session.commit()
  134. def update_task_conversations_res(self, task_conversations_id: int, status: int, score: str):
  135. """更新子任务结果"""
  136. with self.session_maker() as session:
  137. session.query(AgentTestTaskConversations).filter(
  138. AgentTestTaskConversations.id == task_conversations_id).update({"status": status, "score": score})
  139. session.commit()
  140. def cancel_task(self, task_id: int):
  141. """取消任务(带事务支持)"""
  142. # 设置取消事件
  143. if task_id in self.task_events:
  144. self.task_events[task_id].set()
  145. # 如果任务正在执行,尝试取消Future
  146. if task_id in self.task_futures:
  147. self.task_futures[task_id].cancel()
  148. with self.session_maker() as session:
  149. with session.begin():
  150. session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
  151. {"status": TestTaskStatus.CANCELLED.value})
  152. session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
  153. AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).update(
  154. {"status": TestTaskConversationsStatus.CANCELLED.value})
  155. session.commit()
  156. def resume_task(self, task_id: int) -> bool:
  157. """恢复已取消的任务"""
  158. task = self.get_task(task_id)
  159. if not task or task.status != TestTaskStatus.CANCELLED.value:
  160. return False
  161. with self.session_maker() as session:
  162. with session.begin():
  163. session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
  164. {"status": TestTaskStatus.NOT_STARTED.value})
  165. session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
  166. AgentTestTaskConversations.status == TestTaskConversationsStatus.CANCELLED.value).update(
  167. {"status": TestTaskConversationsStatus.PENDING.value})
  168. session.commit()
  169. # 重新执行任务
  170. self._execute_task(task_id)
  171. logger.info(f"Resumed task {task_id}")
  172. return True
  173. def _execute_task(self, task_id: int):
  174. """提交任务到线程池执行"""
  175. # 确保任务状态一致性
  176. if task_id in self.running_tasks:
  177. return
  178. # 创建任务事件和锁
  179. if task_id not in self.task_events:
  180. self.task_events[task_id] = threading.Event()
  181. if task_id not in self.task_locks:
  182. self.task_locks[task_id] = threading.Lock()
  183. # 提交到线程池
  184. future = self.executor.submit(self._process_task, task_id)
  185. self.task_futures[task_id] = future
  186. # 标记任务为运行中
  187. with self.task_locks[task_id]:
  188. self.running_tasks.add(task_id)
  189. def _process_task(self, task_id: int):
  190. """处理任务的所有子任务"""
  191. try:
  192. # 更新任务状态为运行中
  193. self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
  194. # 获取所有待处理的子任务
  195. task_conversations = self.get_pending_task_conversations(task_id)
  196. # 执行每个子任务
  197. for task_conversation in task_conversations:
  198. # 检查任务是否被取消
  199. if self.task_events[task_id].is_set():
  200. break
  201. # 更新子任务状态为运行中
  202. self.update_task_conversations_status(task_conversation.id,
  203. TestTaskConversationsStatus.RUNNING.value)
  204. try:
  205. conversation_data = self.dataset_server.get_conversation_data_by_id(
  206. task_conversation.conversation_id)
  207. user_profile_data = self.dataset_server.get_user_profile_data(conversation_data.user_id)
  208. staff_profile_data = self.dataset_server.get_staff_profile_data(conversation_data.staff_id)
  209. # 模拟任务执行 - 在实际应用中替换为实际业务逻辑
  210. # TODO 后续改成实际任务执行
  211. time.sleep(1)
  212. score = '{"score":0.05}'
  213. # 更新子任务状态为已完成
  214. self.update_task_conversations_res(task_conversation.id,
  215. TestTaskConversationsStatus.SUCCESS.value, score)
  216. except Exception as e:
  217. logger.error(f"Error executing task {task_id}: {str(e)}")
  218. self.update_task_conversations_status(task_conversation.id,
  219. TestTaskConversationsStatus.FAILED.value)
  220. # 检查任务是否完成
  221. task_conversations = self.get_task_conversations(task_id)
  222. all_completed = all(task_conversation.status in
  223. (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value)
  224. for task_conversation in task_conversations)
  225. any_pending = any(task_conversation.status in
  226. (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value)
  227. for task_conversation in task_conversations)
  228. if all_completed:
  229. self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
  230. logger.info(f"Task {task_id} completed")
  231. elif not any_pending:
  232. # 没有待处理子任务但未全部完成(可能是取消了)
  233. current_status = self.get_task(task_id).status
  234. if current_status != TestTaskStatus.CANCELLED.value:
  235. self.update_task_status(task_id, TestTaskStatus.COMPLETED.value
  236. if all_completed else TestTaskStatus.CANCELLED.value)
  237. except Exception as e:
  238. logger.error(f"Error executing task {task_id}: {str(e)}")
  239. self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
  240. finally:
  241. # 清理资源
  242. with self.task_locks[task_id]:
  243. if task_id in self.running_tasks:
  244. self.running_tasks.remove(task_id)
  245. if task_id in self.task_events:
  246. del self.task_events[task_id]
  247. if task_id in self.task_futures:
  248. del self.task_futures[task_id]
  249. def shutdown(self):
  250. """关闭执行器"""
  251. self.executor.shutdown(wait=False)
  252. logger.info("Task executor shutdown")