Browse Source

修改取消恢复任务

xueyiming 3 days ago
parent
commit
4992ae4ab6
1 changed files with 7 additions and 34 deletions
  1. 7 34
      pqai_agent_server/task_server.py

+ 7 - 34
pqai_agent_server/task_server.py

@@ -24,12 +24,10 @@ class TaskManager:
     def __init__(self, session_maker, dataset_service):
         self.session_maker = session_maker
         self.dataset_service = dataset_service
-        self.task_events = {}  # 任务ID -> Event (用于取消任务)
         self.task_locks = {}  # 任务ID -> Lock (用于任务状态同步)
         self.running_tasks = set()
         self.executor = ThreadPoolExecutor(max_workers=20, thread_name_prefix='TaskWorker')
         self.create_task_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='CreateTaskWorker')
-        self.task_futures = {}  # 任务ID -> Future
 
     def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
         with self.session_maker() as session:
@@ -229,6 +227,11 @@ class TaskManager:
                     TestTaskConversationsStatus.PENDING.value,
                     TestTaskConversationsStatus.RUNNING.value
                 ])).all()
+    def get_task_conversation(self, task_conversation_id: int):
+        """获取待处理的子任务"""
+        with self.session_maker() as session:
+            return session.query(AgentTestTaskConversations).filter(
+                AgentTestTaskConversations.id == task_conversation_id).one()
 
     def update_task_status(self, task_id: int, status: int):
         """更新任务状态"""
@@ -255,15 +258,6 @@ class TaskManager:
             session.commit()
 
     def cancel_task(self, task_id: int, user: str):
-        """取消任务(带事务支持)"""
-        # 设置取消事件
-        # 1. 设置取消事件(通知任务内部)
-        if task_id in self.task_events:
-            self.task_events[task_id].set()
-        # 如果任务正在执行,尝试取消Future
-        if task_id in self.task_futures:
-            self.task_futures[task_id].cancel()
-
         with self.session_maker() as session:
             with session.begin():
                 session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
@@ -276,7 +270,6 @@ class TaskManager:
         self._cleanup_task_resources(task_id)
 
     def resume_task(self, task_id: int, user: str) -> bool:
-        """恢复已取消的任务"""
         task = self.get_task(task_id)
         if not task or task.status != TestTaskStatus.CANCELLED.value:
             return False
@@ -289,7 +282,6 @@ class TaskManager:
                     AgentTestTaskConversations.status == TestTaskConversationsStatus.CANCELLED.value).update(
                     {"status": TestTaskConversationsStatus.PENDING.value, "update_time": datetime.now()})
                 session.commit()
-
         # 重新执行任务
         self._execute_task(task_id)
         logger.info(f"Resumed task {task_id}")
@@ -324,14 +316,11 @@ class TaskManager:
             return
 
         # 创建任务事件和锁
-        if task_id not in self.task_events:
-            self.task_events[task_id] = threading.Event()
         if task_id not in self.task_locks:
             self.task_locks[task_id] = threading.Lock()
 
         # 提交到线程池
-        future = self.executor.submit(self._process_task, task_id)
-        self.task_futures[task_id] = future
+        self.executor.submit(self._process_task, task_id)
 
         # 标记任务为运行中
         with self.task_locks[task_id]:
@@ -390,19 +379,7 @@ class TaskManager:
 
     def _process_single_conversation(self, task_id, task, task_conversation, query_prompt_template,
                                      agent_configuration):
-        """处理单个对话子任务(线程安全)"""
-        # 获取锁(避免竞态条件)
-        task_lock = self.task_locks.get(task_id, threading.Lock())
-        with task_lock:
-            # 检查任务是否被取消或不存在
-            if task_id not in self.task_events:
-                logger.warning(f"Task {task_id} not found in task_events")
-                return
-
-            if self.task_events[task_id].is_set():
-                logger.info(f"Task {task_id} already cancelled")
-                return
-
+        task_conversation = self.get_task_conversation(task_conversation.id)
         # 更新子任务状态
         if task_conversation.status == TestTaskConversationsStatus.PENDING.value:
             self.update_task_conversations_status(
@@ -548,10 +525,6 @@ class TaskManager:
         with self.task_locks[task_id]:
             if task_id in self.running_tasks:
                 self.running_tasks.remove(task_id)
-            if task_id in self.task_events:
-                del self.task_events[task_id]
-            if task_id in self.task_futures:
-                del self.task_futures[task_id]
 
     def shutdown(self):
         """关闭执行器"""