xueyiming преди 1 седмица
родител
ревизия
6af2c4a51d
променени са 5 файла, в които са добавени 607 реда и са изтрити 1 реда
  1. 6 0
      pqai_agent/configs/dev.yaml
  2. 6 0
      pqai_agent/configs/prod.yaml
  3. 106 1
      pqai_agent_server/api_server.py
  4. 56 0
      pqai_agent_server/const/status_enum.py
  5. 433 0
      pqai_agent_server/task_server.py

+ 6 - 0
pqai_agent/configs/dev.yaml

@@ -36,6 +36,12 @@ storage:
     table: qywx_chat_history
   push_record:
     table: agent_push_record_dev
+  agent_configuration:
+    table: agent_configuration
+  test_task:
+    table: agent_test_task
+  test_task_conversations:
+    table: agent_test_task_conversations
 
 agent_behavior:
   message_aggregation_sec: 3

+ 6 - 0
pqai_agent/configs/prod.yaml

@@ -36,6 +36,12 @@ storage:
     table: qywx_chat_history
   push_record:
     table: agent_push_record_dev
+  agent_configuration:
+    table: agent_configuration
+  test_task:
+    table: agent_test_task
+  test_task_conversations:
+    table: agent_test_task_conversations
 
 chat_api:
   coze:

+ 106 - 1
pqai_agent_server/api_server.py

@@ -20,7 +20,9 @@ from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
 from pqai_agent.utils.db_utils import create_sql_engine
 from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
 from pqai_agent_server.const import AgentApiConst
+from pqai_agent_server.const.status_enum import TestTaskStatus
 from pqai_agent_server.models import MySQLSessionManager
+from pqai_agent_server.task_server import TaskManager
 from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
 from pqai_agent_server.utils import (
     run_extractor_prompt,
@@ -32,6 +34,7 @@ app = Flask('agent_api_server')
 logger = logging_service.logger
 const = AgentApiConst()
 
+
 @app.route('/api/listStaffs', methods=['GET'])
 def list_staffs():
     staff_data = app.user_relation_manager.list_staffs()
@@ -178,6 +181,7 @@ def run_prompt():
         logger.error(e)
         return wrap_response(500, msg='Error: {}'.format(e))
 
+
 @app.route('/api/formatForPrompt', methods=['POST'])
 def format_data_for_prompt():
     try:
@@ -312,6 +316,7 @@ def quit_human_interventions_status():
 
     return wrap_response(200, data=response)
 
+
 ## Agent管理接口
 @app.route("/api/getNativeAgentList", methods=["GET"])
 def get_native_agent_list():
@@ -348,6 +353,7 @@ def get_native_agent_list():
     ]
     return wrap_response(200, data=ret_data)
 
+
 @app.route("/api/getNativeAgentConfiguration", methods=["GET"])
 def get_native_agent_configuration():
     """
@@ -379,6 +385,7 @@ def get_native_agent_configuration():
         }
         return wrap_response(200, data=data)
 
+
 @app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
 def save_native_agent_configuration():
     """
@@ -432,6 +439,7 @@ def save_native_agent_configuration():
         session.commit()
         return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
 
+
 @app.route("/api/getModuleList", methods=["GET"])
 def get_module_list():
     """
@@ -456,6 +464,7 @@ def get_module_list():
     ]
     return wrap_response(200, data=ret_data)
 
+
 @app.route("/api/getModuleConfiguration", methods=["GET"])
 def get_module_configuration():
     """
@@ -482,6 +491,7 @@ def get_module_configuration():
         }
         return wrap_response(200, data=data)
 
+
 @app.route("/api/saveModuleConfiguration", methods=["POST"])
 def save_module_configuration():
     """
@@ -520,6 +530,91 @@ def save_module_configuration():
         session.commit()
         return wrap_response(200, msg='Module configuration saved successfully', data={'id': module.id})
 
+
+@app.route("/api/getTestTaskList", methods=["GET"])
+def get_test_task_list():
+    """
+       获取单元测试任务列表
+       :return:
+    """
+    page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
+    page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
+    try:
+        page_num = int(page_num)
+        page_size = int(page_size)
+    except Exception as e:
+        return wrap_response(404, msg="Invalid parameter: {}".format(e))
+    response = app.task_manager.get_test_task_list(page_num, page_size)
+    return wrap_response(200, data=response)
+
+@app.route("/api/getTestTaskConversations", methods=["GET"])
+def get_test_task_conversations():
+    """
+       获取单元测试对话任务列表
+       :return:
+    """
+    task_id = request.args.get("taskId", None)
+    if not task_id:
+        return wrap_response(404, msg='task_id is required')
+    page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
+    page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
+    try:
+        page_num = int(page_num)
+        page_size = int(page_size)
+    except Exception as e:
+        return wrap_response(404, msg="Invalid parameter: {}".format(e))
+    response = app.task_manager.get_test_task_conversations(int(task_id), page_num, page_size)
+    return wrap_response(200, data=response)
+
+
+@app.route("/api/createTestTask", methods=["POST"])
+def create_test_task():
+    """
+       创建单元测试任务
+       :return:
+    """
+    req_data = request.json
+    agent_id = req_data.get('agentId', None)
+    if not agent_id:
+        return wrap_response(400, msg='agent id is required')
+    app.task_manager.create_task(agent_id)
+    return wrap_response(200)
+
+
+@app.route("/api/stopTestTask", methods=["POST"])
+def stop_test_task():
+    """
+       停止单元测试任务
+       :return:
+    """
+    req_data = request.json
+    task_id = req_data.get('taskId', None)
+    if not task_id:
+        return wrap_response(400, msg='task id is required')
+    task = app.task_manager.get_task(task_id)
+    if task['status'] not in (TestTaskStatus.NOT_STARTED.value, TestTaskStatus.IN_PROGRESS.value):
+        return wrap_response(400, msg='task status is invalid')
+    app.task_manager.cancel_task(task_id)
+    return wrap_response(200)
+
+
+@app.route("/api/resumeTestTask", methods=["POST"])
+def resume_test_task():
+    """
+       恢复停止的单元测试任务
+       :return:
+    """
+    req_data = request.json
+    task_id = req_data.get('taskId', None)
+    if not task_id:
+        return wrap_response(400, msg='task id is required')
+    task = app.task_manager.get_task(task_id)
+    if task['status'] != TestTaskStatus.CANCELLED.value:
+        return wrap_response(400, msg='task status is invalid')
+    app.task_manager.resume_task(task_id)
+    return wrap_response(200)
+
+
 @app.errorhandler(werkzeug.exceptions.BadRequest)
 def handle_bad_request(e):
     logger.error(e)
@@ -543,18 +638,28 @@ if __name__ == '__main__':
     staff_db_config = config['storage']['staff']
     agent_state_db_config = config['storage']['agent_state']
     chat_history_db_config = config['storage']['chat_history']
+    agent_configuration_db_config = config['storage']['agent_configuration']
+    test_task_db_config = config['storage']['test_task']
+    test_task_conversations_db_config = config['storage']['test_task_conversations']
 
     # init user manager
     user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
     app.user_manager = user_manager
 
+    task_manager = TaskManager(
+        db_config=user_db_config['mysql'],
+        agent_configuration_table=agent_configuration_db_config['table'],
+        test_task_table=test_task_db_config['table'],
+        test_task_conversations_table=test_task_conversations_db_config['table'])
+    app.task_manager = task_manager
+
     # init session manager
     session_manager = MySQLSessionManager(
         db_config=user_db_config['mysql'],
         staff_table=staff_db_config['table'],
         user_table=user_db_config['table'],
         agent_state_table=agent_state_db_config['table'],
-        chat_history_table=chat_history_db_config['table']
+        chat_history_table=agent_configuration_db_config['table']
     )
     app.session_manager = session_manager
     agent_db_engine = create_sql_engine(config['storage']['agent_state']['mysql'])

+ 56 - 0
pqai_agent_server/const/status_enum.py

@@ -0,0 +1,56 @@
+from enum import Enum
+
+
+class TestTaskStatus(Enum):
+    NOT_STARTED = 0
+    IN_PROGRESS = 1
+    COMPLETED = 2
+    CANCELLED = 3
+
+    @property
+    def description(self):
+        descriptions = {
+            self.NOT_STARTED: "未开始",
+            self.IN_PROGRESS: "进行中",
+            self.COMPLETED: "已完成",
+            self.CANCELLED: "已取消"
+        }
+        return descriptions.get(self)
+
+class TestTaskConversationsStatus(Enum):
+    """任务状态枚举类"""
+    PENDING = 0  # 待执行
+    RUNNING = 1  # 执行中
+    SUCCESS = 2  # 执行成功
+    FAILED = 3  # 执行失败
+    CANCELLED = 4  # 已取消
+
+    @property
+    def description(self):
+        descriptions = {
+            self.PENDING: "待执行",
+            self.RUNNING: "执行中",
+            self.SUCCESS: "执行成功",
+            self.FAILED: "执行失败",
+            self.CANCELLED: "已取消"
+        }
+        return descriptions.get(self)
+
+
+# 使用示例
+def get_test_task_status_desc(status_code):
+    try:
+        status = TestTaskStatus(status_code)
+        return status.description
+    except ValueError:
+        return f"未知状态: {status_code}"
+
+# 使用示例
+def get_test_task_conversations_status_desc(status_code):
+    try:
+        status = TestTaskConversationsStatus(status_code)
+        return status.description
+    except ValueError:
+        return f"未知状态: {status_code}"
+
+

+ 433 - 0
pqai_agent_server/task_server.py

@@ -0,0 +1,433 @@
+import random
+import threading
+import time
+from concurrent.futures import ThreadPoolExecutor
+from queue import Queue
+from typing import List, Dict, Optional
+
+import pymysql
+from pymysql import Connection
+from pymysql.cursors import DictCursor
+
+from pqai_agent import logging_service
+from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc, \
+    get_test_task_conversations_status_desc
+
+logger = logging_service.logger
+class Database:
+    """数据库操作类"""
+
+    def __init__(self, db_config):
+        self.db_config = db_config
+        self.connection_pool = Queue(maxsize=10)
+        self._initialize_pool()
+
+    def _initialize_pool(self):
+        """初始化数据库连接池"""
+        for _ in range(5):
+            conn = pymysql.connect(**self.db_config)
+            self.connection_pool.put(conn)
+        logger.info("Database connection pool initialized with 5 connections")
+
+    def get_connection(self) -> Connection:
+        """从连接池获取数据库连接"""
+        return self.connection_pool.get()
+
+    def release_connection(self, conn: Connection):
+        """释放数据库连接回连接池"""
+        self.connection_pool.put(conn)
+
+    def execute(self, query: str, args: tuple = (), many: bool = False) -> int:
+        """执行SQL语句并返回影响的行数"""
+        conn = self.get_connection()
+        try:
+            with conn.cursor() as cursor:
+                if many:
+                    cursor.executemany(query, args)
+                else:
+                    cursor.execute(query, args)
+                conn.commit()
+                return cursor.rowcount
+        except Exception as e:
+            logger.error(f"Database error: {str(e)}")
+            conn.rollback()
+            raise
+        finally:
+            self.release_connection(conn)
+
+    def insert(self, insert: str, args: tuple = (), many: bool = False) -> int:
+        """执行插入SQL语句并主键"""
+        conn = self.get_connection()
+        try:
+            with conn.cursor() as cursor:
+                if many:
+                    cursor.executemany(insert, args)
+                else:
+                    cursor.execute(insert, args)
+                conn.commit()
+                return cursor.lastrowid
+        except Exception as e:
+            logger.error(f"Database error: {str(e)}")
+            conn.rollback()
+            raise
+        finally:
+            self.release_connection(conn)
+
+    def fetch(self, query: str, args: tuple = ()) -> List[Dict]:
+        """执行SQL查询并返回结果列表"""
+        conn = self.get_connection()
+        try:
+            with conn.cursor(DictCursor) as cursor:
+                cursor.execute(query, args)
+                return cursor.fetchall()
+        except Exception as e:
+            logger.error(f"Database error: {str(e)}")
+            raise
+        finally:
+            self.release_connection(conn)
+
+    def fetch_one(self, query: str, args: tuple = ()) -> Optional[Dict]:
+        """执行SQL查询并返回单行结果"""
+        conn = self.get_connection()
+        try:
+            with conn.cursor(DictCursor) as cursor:
+                cursor.execute(query, args)
+                return cursor.fetchone()
+        except Exception as e:
+            logger.error(f"Database error: {str(e)}")
+            raise
+        finally:
+            self.release_connection(conn)
+
+    def close_all(self):
+        """关闭所有数据库连接"""
+        while not self.connection_pool.empty():
+            conn = self.connection_pool.get()
+            conn.close()
+        logger.info("All database connections closed")
+
+
+class TaskManager:
+    """任务管理器"""
+
+    def __init__(self, db_config, agent_configuration_table, test_task_table, test_task_conversations_table,
+                 max_workers: int = 10):
+        self.db = Database(db_config)
+        self.agent_configuration_table = agent_configuration_table
+        self.test_task_table = test_task_table
+        self.test_task_conversations_table = test_task_conversations_table
+        self.task_events = {}  # 任务ID -> Event (用于取消任务)
+        self.task_locks = {}  # 任务ID -> Lock (用于任务状态同步)
+        self.running_tasks = set()
+        self.executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='TaskWorker')
+        self.task_futures = {}  # 任务ID -> Future
+
+    def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
+        fetch_query = f"""
+            select t1.id, t2.name, t1.create_user, t1.update_user, t1.status, t1.create_time, t1.update_time
+            from {self.test_task_table} t1
+            left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
+            order by create_time desc
+            limit %s, %s;
+        """
+        messages = self.db.fetch(fetch_query, ((page_num - 1) * page_size, page_size))
+        total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_table}""")
+
+        total = total_size["count"]
+        total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
+        total_page = 1 if total_page <= 0 else total_page
+        response_data = [
+            {
+                "id": message["id"],
+                "agentName": message["name"],
+                "createUser": message["create_user"],
+                "updateUser": message["update_user"],
+                "statusName": get_test_task_status_desc(message["status"]),
+                "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
+                "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
+            }
+            for message in messages
+        ]
+        return {
+            "currentPage": page_num,
+            "pageSize": page_size,
+            "totalSize": total_page,
+            "total": total,
+            "list": response_data,
+        }
+
+    def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict:
+        fetch_query = f"""
+            select t1.id, t2.name, t3.create_user, t1.input, t1.output, t1.score, t1.status, t1.create_time, t1.update_time
+            from {self.test_task_conversations_table} t1
+            left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
+            left join {self.test_task_table} t3 on t1.task_id = t3.id
+            where t1.task_id = %s
+            order by create_time desc
+            limit %s, %s;
+        """
+        messages = self.db.fetch(fetch_query, (task_id, (page_num - 1) * page_size, page_size))
+        total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_conversations_table} where task_id = %s""",
+                                       (task_id,))
+
+        total = total_size["count"]
+        total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
+        total_page = 1 if total_page <= 0 else total_page
+        response_data = [
+            {
+                "id": message["id"],
+                "agentName": message["name"],
+                "createUser": message["create_user"],
+                "input": message["input"],
+                "output": message["output"],
+                "score": message["score"],
+                "statusName": get_test_task_conversations_status_desc(message["status"]),
+                "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
+                "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
+            }
+            for message in messages
+        ]
+        return {
+            "currentPage": page_num,
+            "pageSize": page_size,
+            "totalSize": total_page,
+            "total": total,
+            "list": response_data,
+        }
+
+    def create_task(self, agent_id: int) -> Dict:
+        """创建新任务并添加100个子任务"""
+
+        conn = self.db.get_connection()
+        try:
+            conn.begin()
+            # 插入任务
+            with conn.cursor() as cursor:
+                cursor.execute(
+                    f"""INSERT INTO {self.test_task_table} (agent_id, status, create_user, update_user) VALUES (%s, %s, %s, %s)""",
+                    (agent_id, 0, 'xueyiming', 'xueyiming')
+                )
+            task_id = cursor.lastrowid
+            task_conversations = []
+            i = 0
+            for _ in range(30):
+                i = i + 1
+                task_conversations.append((
+                    task_id, agent_id, i, i, "输入", "输出", 0
+                ))
+
+            with conn.cursor() as cursor:
+                cursor.executemany(
+                    f"""INSERT INTO {self.test_task_conversations_table} (task_id, agent_id, dataset_id, conversation_id, input, output, status) 
+                    VALUES (%s, %s, %s, %s, %s, %s, %s)""",
+                    task_conversations
+                )
+
+            conn.commit()
+        except Exception as e:
+            conn.rollback()
+            logger.error(f"Failed to create task agent_id {agent_id}: {str(e)}")
+            raise
+        finally:
+            self.db.release_connection(conn)
+        logger.info(f"Created task {task_id} with 100 task_conversations")
+        # 异步执行任务
+        self._execute_task(task_id)
+        return self.get_task(task_id)
+
+    def get_task(self, task_id: int) -> Optional[Dict]:
+        """获取任务信息"""
+        return self.db.fetch_one(f"""SELECT * FROM {self.test_task_table} WHERE id = %s""", (task_id,))
+
+    def get_task_conversations(self, task_id: int) -> List[Dict]:
+        """获取任务的所有子任务"""
+        return self.db.fetch(f"""SELECT * FROM {self.test_task_conversations_table} WHERE task_id = %s""", (task_id,))
+
+    def get_pending_task_conversations(self, task_id: int) -> List[Dict]:
+        """获取待处理的子任务"""
+        return self.db.fetch(
+            f"""SELECT * FROM {self.test_task_conversations_table} WHERE task_id = %s AND status = %s""",
+            (task_id, TestTaskConversationsStatus.PENDING.value)
+        )
+
+    def update_task_status(self, task_id: int, status: int):
+        """更新任务状态"""
+        self.db.execute(
+            f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
+            (status, task_id)
+        )
+
+    def update_task_conversations_status(self, task_conversations_id: int, status: int):
+        """更新子任务状态"""
+        self.db.execute(
+            f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE id = %s""",
+            (status, task_conversations_id)
+        )
+
+    def update_task_conversations_res(self, task_conversations_id: int, status: int, score: float):
+        """更新子任务状态"""
+        self.db.execute(
+            f"""UPDATE {self.test_task_conversations_table} SET status = %s, score = %s WHERE id = %s""",
+            (status, score, task_conversations_id)
+        )
+
+    def cancel_task(self, task_id: int):
+        """取消任务(带事务支持)"""
+        # 设置取消事件
+        if task_id in self.task_events:
+            self.task_events[task_id].set()
+        # 如果任务正在执行,尝试取消Future
+        if task_id in self.task_futures:
+            self.task_futures[task_id].cancel()
+
+        conn = self.db.get_connection()
+        try:
+            conn.begin()
+            # 更新任务状态为取消
+            with conn.cursor() as cursor:
+                cursor.execute(
+                    f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
+                    (TestTaskStatus.CANCELLED.value, task_id)
+                )
+
+            # 取消所有待处理的子任务
+            with conn.cursor() as cursor:
+                cursor.execute(
+                    f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE task_id = %s AND status = %s""",
+                    (TestTaskConversationsStatus.CANCELLED.value, task_id, TestTaskConversationsStatus.PENDING.value)
+                )
+            conn.commit()
+            logger.info(f"Cancelled task {task_id} and its pending {self.test_task_conversations_table}")
+        except Exception as e:
+            conn.rollback()
+            logger.error(f"Failed to cancel task {task_id}: {str(e)}")
+        finally:
+            self.db.release_connection(conn)
+
+    def resume_task(self, task_id: int) -> bool:
+        """恢复已取消的任务"""
+        task = self.get_task(task_id)
+        if not task or task['status'] != TestTaskStatus.CANCELLED.value:
+            return False
+
+        conn = self.db.get_connection()
+        try:
+            conn.begin()
+            # 更新任务状态为待开始
+            with conn.cursor() as cursor:
+                cursor.execute(
+                    f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
+                    (TestTaskStatus.NOT_STARTED.value, task_id)
+                )
+
+            # 恢复所有已取消的子任务
+            with conn.cursor() as cursor:
+                cursor.execute(
+                    f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE task_id = %s AND status = %s""",
+                    (TestTaskConversationsStatus.PENDING.value, task_id, TestTaskConversationsStatus.CANCELLED.value)
+                )
+            conn.commit()
+            logger.info(f"Cancelled task {task_id} and its pending {self.test_task_conversations_table}")
+        except Exception as e:
+            conn.rollback()
+            logger.error(f"Failed to cancel task {task_id}: {str(e)}")
+        finally:
+            self.db.release_connection(conn)
+
+        # 重新执行任务
+        self._execute_task(task_id)
+
+        logger.info(f"Resumed task {task_id}")
+        return True
+
+    def _execute_task(self, task_id: int):
+        """提交任务到线程池执行"""
+        # 确保任务状态一致性
+        if task_id in self.running_tasks:
+            return
+
+        # 创建任务事件和锁
+        if task_id not in self.task_events:
+            self.task_events[task_id] = threading.Event()
+        if task_id not in self.task_locks:
+            self.task_locks[task_id] = threading.Lock()
+
+        # 提交到线程池
+        future = self.executor.submit(self._process_task, task_id)
+        self.task_futures[task_id] = future
+
+        # 标记任务为运行中
+        with self.task_locks[task_id]:
+            self.running_tasks.add(task_id)
+
+    def _process_task(self, task_id: int):
+        """处理任务的所有子任务"""
+        try:
+            # 更新任务状态为运行中
+            self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
+
+            # 获取所有待处理的子任务
+            task_conversations = self.get_pending_task_conversations(task_id)
+
+            # 执行每个子任务
+            for task_conversation in task_conversations:
+                # 检查任务是否被取消
+                if self.task_events[task_id].is_set():
+                    break
+
+                # 更新子任务状态为运行中
+                self.update_task_conversations_status(task_conversation['id'],
+                                                      TestTaskConversationsStatus.RUNNING.value)
+
+                try:
+                    # 模拟任务执行 - 在实际应用中替换为实际业务逻辑
+
+                    # 再次检查任务是否被取消
+                    # if self.task_events[task_id].is_set():
+                    #     self.update_task_conversations_status(subtask['id'], 'cancelled')
+                    #     break
+                    # TODO 实际任务执行
+                    time.sleep(1)
+                    score = random.random()
+                    # 更新子任务状态为已完成
+                    self.update_task_conversations_res(task_conversation['id'],
+                                                       TestTaskConversationsStatus.SUCCESS.value, score)
+                except Exception as e:
+                    logger.error(f"Error executing task {task_id}: {str(e)}")
+                    self.update_task_conversations_status(task_conversation['id'],
+                                                          TestTaskConversationsStatus.FAILED.value)
+            # 检查任务是否完成
+            task_conversations = self.get_task_conversations(task_id)
+            all_completed = all(task_conversation['status'] in
+                                (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value)
+                                for task_conversation in task_conversations)
+            any_pending = any(task_conversation['status'] in
+                              (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value)
+                              for task_conversation in task_conversations)
+
+            if all_completed:
+                self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
+                logger.info(f"Task {task_id} completed")
+            elif not any_pending:
+                # 没有待处理子任务但未全部完成(可能是取消了)
+                current_status = self.get_task(task_id)['status']
+                if current_status != TestTaskStatus.CANCELLED.value:
+                    self.update_task_status(task_id, TestTaskStatus.COMPLETED.value
+                    if all_completed else TestTaskStatus.CANCELLED.value)
+        except Exception as e:
+            logger.error(f"Error executing task {task_id}: {str(e)}")
+            self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
+        finally:
+            # 清理资源
+            with self.task_locks[task_id]:
+                if task_id in self.running_tasks:
+                    self.running_tasks.remove(task_id)
+                if task_id in self.task_events:
+                    del self.task_events[task_id]
+                if task_id in self.task_futures:
+                    del self.task_futures[task_id]
+
+    def shutdown(self):
+        """关闭执行器"""
+        self.executor.shutdown(wait=False)
+        logger.info("Task executor shutdown")