Jelajahi Sumber

修改数据库操作 改成ORM

xueyiming 2 minggu lalu
induk
melakukan
85975ba201

+ 0 - 6
pqai_agent/configs/dev.yaml

@@ -41,12 +41,6 @@ storage:
   push_record:
     database: ai_agent
     table: agent_push_record_dev
-  agent_configuration:
-    table: agent_configuration
-  test_task:
-    table: agent_test_task
-  test_task_conversations:
-    table: agent_test_task_conversations
 
 agent_behavior:
   message_aggregation_sec: 3

+ 19 - 0
pqai_agent/data_models/agent_test_task.py

@@ -0,0 +1,19 @@
+from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+class AgentTestTask(Base):
+    __tablename__ = "agent_test_task"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    agent_id = Column(BigInteger, nullable=False, comment="agent主键")
+    model_id = Column(BigInteger, nullable=False, comment="model主键")
+    create_user = Column(String(32), nullable=True, comment="创建用户")
+    update_user = Column(String(32), nullable=True, comment="更新用户")
+    dataset_ids = Column(Text, nullable=False, comment="数据集ids")
+    status = Column(Integer, default=0, nullable=False, comment="状态(0:未开始, 1:进行中, 2:已完成, 3:已取消)")
+    create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
+                         comment="更新时间")

+ 22 - 0
pqai_agent/data_models/agent_test_task_conversations.py

@@ -0,0 +1,22 @@
+from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+class AgentTestTaskConversations(Base):
+    __tablename__ = "agent_test_task_conversations"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    task_id = Column(BigInteger, nullable=False, comment="任务主键")
+    agent_id = Column(BigInteger, nullable=False, comment="agent主键")
+    dataset_id = Column(BigInteger, nullable=False, comment="数据集主键")
+    conversation_id = Column(BigInteger, nullable=False, comment="对话主键")
+    input = Column(Text, nullable=False, comment="输入内容")
+    output = Column(Text, nullable=False, comment="输出内容")
+    score = Column(Text, nullable=False, comment="得分")
+    status = Column(Integer, default=0, nullable=False,
+                    comment="状态(0:待执行, 1:执行中, 2:执行成功, 3:执行失败, 4:已取消)")
+    create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
+                         comment="更新时间")

+ 7 - 12
pqai_agent_server/api_server.py

@@ -492,6 +492,7 @@ def delete_native_agent_configuration():
         return wrap_response(200, msg='Agent configuration deleted successfully')
 
 
+
 @app.route("/api/getModuleList", methods=["GET"])
 def get_module_list():
     """
@@ -642,6 +643,7 @@ def get_test_task_list():
     response = app.task_manager.get_test_task_list(page_num, page_size)
     return wrap_response(200, data=response)
 
+
 @app.route("/api/getTestTaskConversations", methods=["GET"])
 def get_test_task_conversations():
     """
@@ -690,7 +692,7 @@ def stop_test_task():
     if not task_id:
         return wrap_response(400, msg='task id is required')
     task = app.task_manager.get_task(task_id)
-    if task['status'] not in (TestTaskStatus.NOT_STARTED.value, TestTaskStatus.IN_PROGRESS.value):
+    if task.status not in (TestTaskStatus.NOT_STARTED.value, TestTaskStatus.IN_PROGRESS.value):
         return wrap_response(400, msg='task status is invalid')
     app.task_manager.cancel_task(task_id)
     return wrap_response(200)
@@ -707,7 +709,7 @@ def resume_test_task():
     if not task_id:
         return wrap_response(400, msg='task id is required')
     task = app.task_manager.get_task(task_id)
-    if task['status'] != TestTaskStatus.CANCELLED.value:
+    if task.status != TestTaskStatus.CANCELLED.value:
         return wrap_response(400, msg='task status is invalid')
     app.task_manager.resume_task(task_id)
     return wrap_response(200)
@@ -738,21 +740,11 @@ if __name__ == '__main__':
     staff_db_config = config['storage']['staff']
     agent_state_db_config = config['storage']['agent_state']
     chat_history_db_config = config['storage']['chat_history']
-    agent_configuration_db_config = config['storage']['agent_configuration']
-    test_task_db_config = config['storage']['test_task']
-    test_task_conversations_db_config = config['storage']['test_task_conversations']
 
     # init user manager
     user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
     app.user_manager = user_manager
 
-    task_manager = TaskManager(
-        db_config=user_db_config['mysql'],
-        agent_configuration_table=agent_configuration_db_config['table'],
-        test_task_table=test_task_db_config['table'],
-        test_task_conversations_table=test_task_conversations_db_config['table'])
-    app.task_manager = task_manager
-
     # init session manager
     session_manager = MySQLSessionManager(
         db_config=agent_db_config,
@@ -765,6 +757,9 @@ if __name__ == '__main__':
     agent_db_engine = create_ai_agent_db_engine()
     app.session_maker = sessionmaker(bind=agent_db_engine)
 
+    task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine))
+    app.task_manager = task_manager
+
     wecom_db_config = config['storage']['user_relation']
     user_relation_manager = MySQLUserRelationManager(
         agent_db_config, growth_db_config,

+ 134 - 278
pqai_agent_server/task_server.py

@@ -8,114 +8,23 @@ from typing import List, Dict, Optional
 import pymysql
 from pymysql import Connection
 from pymysql.cursors import DictCursor
+from sqlalchemy import func
 
 from pqai_agent import logging_service
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
+from pqai_agent.data_models.agent_test_task import AgentTestTask
+from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
 from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc, \
     get_test_task_conversations_status_desc
 
 logger = logging_service.logger
-class Database:
-    """数据库操作类"""
-
-    def __init__(self, db_config):
-        self.db_config = db_config
-        self.connection_pool = Queue(maxsize=10)
-        self._initialize_pool()
-
-    def _initialize_pool(self):
-        """初始化数据库连接池"""
-        for _ in range(5):
-            conn = pymysql.connect(**self.db_config)
-            self.connection_pool.put(conn)
-        logger.info("Database connection pool initialized with 5 connections")
-
-    def get_connection(self) -> Connection:
-        """从连接池获取数据库连接"""
-        return self.connection_pool.get()
-
-    def release_connection(self, conn: Connection):
-        """释放数据库连接回连接池"""
-        self.connection_pool.put(conn)
-
-    def execute(self, query: str, args: tuple = (), many: bool = False) -> int:
-        """执行SQL语句并返回影响的行数"""
-        conn = self.get_connection()
-        try:
-            with conn.cursor() as cursor:
-                if many:
-                    cursor.executemany(query, args)
-                else:
-                    cursor.execute(query, args)
-                conn.commit()
-                return cursor.rowcount
-        except Exception as e:
-            logger.error(f"Database error: {str(e)}")
-            conn.rollback()
-            raise
-        finally:
-            self.release_connection(conn)
-
-    def insert(self, insert: str, args: tuple = (), many: bool = False) -> int:
-        """执行插入SQL语句并主键"""
-        conn = self.get_connection()
-        try:
-            with conn.cursor() as cursor:
-                if many:
-                    cursor.executemany(insert, args)
-                else:
-                    cursor.execute(insert, args)
-                conn.commit()
-                return cursor.lastrowid
-        except Exception as e:
-            logger.error(f"Database error: {str(e)}")
-            conn.rollback()
-            raise
-        finally:
-            self.release_connection(conn)
-
-    def fetch(self, query: str, args: tuple = ()) -> List[Dict]:
-        """执行SQL查询并返回结果列表"""
-        conn = self.get_connection()
-        try:
-            with conn.cursor(DictCursor) as cursor:
-                cursor.execute(query, args)
-                return cursor.fetchall()
-        except Exception as e:
-            logger.error(f"Database error: {str(e)}")
-            raise
-        finally:
-            self.release_connection(conn)
-
-    def fetch_one(self, query: str, args: tuple = ()) -> Optional[Dict]:
-        """执行SQL查询并返回单行结果"""
-        conn = self.get_connection()
-        try:
-            with conn.cursor(DictCursor) as cursor:
-                cursor.execute(query, args)
-                return cursor.fetchone()
-        except Exception as e:
-            logger.error(f"Database error: {str(e)}")
-            raise
-        finally:
-            self.release_connection(conn)
-
-    def close_all(self):
-        """关闭所有数据库连接"""
-        while not self.connection_pool.empty():
-            conn = self.connection_pool.get()
-            conn.close()
-        logger.info("All database connections closed")
 
 
 class TaskManager:
     """任务管理器"""
 
-    def __init__(self, db_config, agent_configuration_table, test_task_table, test_task_conversations_table,
-                 max_workers: int = 10):
-        self.db = Database(db_config)
-        self.agent_configuration_table = agent_configuration_table
-        self.test_task_table = test_task_table
-        self.test_task_conversations_table = test_task_conversations_table
+    def __init__(self, session_maker, max_workers: int = 10):
+        self.session_maker = session_maker
         self.task_events = {}  # 任务ID -> Event (用于取消任务)
         self.task_locks = {}  # 任务ID -> Lock (用于任务状态同步)
         self.running_tasks = set()
@@ -123,154 +32,131 @@ class TaskManager:
         self.task_futures = {}  # 任务ID -> Future
 
     def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
-        fetch_query = f"""
-            select t1.id, t2.name, t1.create_user, t1.update_user, t1.status, t1.create_time, t1.update_time
-            from {self.test_task_table} t1
-            left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
-            order by create_time desc
-            limit %s, %s;
-        """
-        messages = self.db.fetch(fetch_query, ((page_num - 1) * page_size, page_size))
-        total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_table}""")
-
-        total = total_size["count"]
-        total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
-        total_page = 1 if total_page <= 0 else total_page
-        response_data = [
-            {
-                "id": message["id"],
-                "agentName": message["name"],
-                "createUser": message["create_user"],
-                "updateUser": message["update_user"],
-                "statusName": get_test_task_status_desc(message["status"]),
-                "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
-                "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
+        with self.session_maker() as session:
+            # 计算偏移量
+            offset = (page_num - 1) * page_size
+            # 查询分页数据
+            result = (session.query(AgentTestTask, AgentConfiguration)
+                      .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
+                      .limit(page_size).offset(offset).all())
+            # 查询总记录数
+            total = session.query(func.count(AgentTestTask.id)).scalar()
+
+            total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
+            total_page = 1 if total_page <= 0 else total_page
+            response_data = [
+                {
+                    "id": agent_test_task.id,
+                    "agentName": agent_configuration.name,
+                    "createUser": agent_test_task.create_user,
+                    "updateUser": agent_test_task.update_user,
+                    "statusName": get_test_task_status_desc(agent_test_task.status),
+                    "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
+                    "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
+                }
+                for agent_test_task, agent_configuration in result
+            ]
+            return {
+                "currentPage": page_num,
+                "pageSize": page_size,
+                "totalSize": total_page,
+                "total": total,
+                "list": response_data,
             }
-            for message in messages
-        ]
-        return {
-            "currentPage": page_num,
-            "pageSize": page_size,
-            "totalSize": total_page,
-            "total": total,
-            "list": response_data,
-        }
 
     def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict:
-        fetch_query = f"""
-            select t1.id, t2.name, t3.create_user, t1.input, t1.output, t1.score, t1.status, t1.create_time, t1.update_time
-            from {self.test_task_conversations_table} t1
-            left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
-            left join {self.test_task_table} t3 on t1.task_id = t3.id
-            where t1.task_id = %s
-            order by create_time desc
-            limit %s, %s;
-        """
-        messages = self.db.fetch(fetch_query, (task_id, (page_num - 1) * page_size, page_size))
-        total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_conversations_table} where task_id = %s""",
-                                       (task_id,))
-
-        total = total_size["count"]
-        total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
-        total_page = 1 if total_page <= 0 else total_page
-        response_data = [
-            {
-                "id": message["id"],
-                "agentName": message["name"],
-                "createUser": message["create_user"],
-                "input": message["input"],
-                "output": message["output"],
-                "score": message["score"],
-                "statusName": get_test_task_conversations_status_desc(message["status"]),
-                "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
-                "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
+        with self.session_maker() as session:
+            # 计算偏移量
+            offset = (page_num - 1) * page_size
+            # 查询分页数据
+            result = (session.query(AgentTestTaskConversations, AgentConfiguration)
+                      .outerjoin(AgentConfiguration, AgentTestTaskConversations.agent_id == AgentConfiguration.id)
+                      .filter(AgentTestTaskConversations.task_id == task_id)
+                      .limit(page_size).offset(offset).all())
+            # 查询总记录数
+            total = session.query(func.count(AgentTestTaskConversations.id)).scalar()
+
+            total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
+            total_page = 1 if total_page <= 0 else total_page
+            response_data = [
+                {
+                    "id": agent_test_task_conversation.id,
+                    "agentName": agent_configuration.name,
+                    "input": agent_test_task_conversation.input,
+                    "output": agent_test_task_conversation.output,
+                    "score": agent_test_task_conversation.score,
+                    "statusName": get_test_task_status_desc(agent_test_task_conversation.status),
+                    "createTime": agent_test_task_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S"),
+                    "updateTime": agent_test_task_conversation.update_time.strftime("%Y-%m-%d %H:%M:%S")
+                }
+                for agent_test_task_conversation, agent_configuration in result
+            ]
+            return {
+                "currentPage": page_num,
+                "pageSize": page_size,
+                "totalSize": total_page,
+                "total": total,
+                "list": response_data,
             }
-            for message in messages
-        ]
-        return {
-            "currentPage": page_num,
-            "pageSize": page_size,
-            "totalSize": total_page,
-            "total": total,
-            "list": response_data,
-        }
-
-    def create_task(self, agent_id: int, model_id: int ) -> Dict:
-        """创建新任务"""
 
-        conn = self.db.get_connection()
-        try:
-            conn.begin()
-            # TODO 插入任务 当前测试模拟
-            with conn.cursor() as cursor:
-                cursor.execute(
-                    f"""INSERT INTO {self.test_task_table} (agent_id, status, create_user, update_user) VALUES (%s, %s, %s, %s)""",
-                    (agent_id, 0, 'xueyiming', 'xueyiming')
-                )
-            task_id = cursor.lastrowid
-            task_conversations = []
-            # TODO 查询具体的数据集信息后插入
-            i = 0
-            for _ in range(30):
-                i = i + 1
-                task_conversations.append((
-                    task_id, agent_id, i, i, "输入", "输出", 0
-                ))
-
-            with conn.cursor() as cursor:
-                cursor.executemany(
-                    f"""INSERT INTO {self.test_task_conversations_table} (task_id, agent_id, dataset_id, conversation_id, input, output, status) 
-                    VALUES (%s, %s, %s, %s, %s, %s, %s)""",
-                    task_conversations
-                )
-
-            conn.commit()
-        except Exception as e:
-            conn.rollback()
-            logger.error(f"Failed to create task agent_id {agent_id}: {str(e)}")
-            raise
-        finally:
-            self.db.release_connection(conn)
-        logger.info(f"Created task {task_id} with 100 task_conversations")
-        # 异步执行任务
+    def create_task(self, agent_id: int, model_id: int) -> Dict:
+        """创建新任务"""
+        with self.session_maker() as session:
+            with session.begin():
+                agent_test_task = AgentTestTask(agent_id=agent_id, model_id=model_id)
+                session.add(agent_test_task)
+                session.flush()  # 强制SQL执行,但不提交事务
+                task_id = agent_test_task.id
+                agent_test_task_conversations = []
+                # TODO 查询具体的数据集信息后插入
+                i = 0
+                for _ in range(30):
+                    i = i + 1
+                    agent_test_task_conversation = AgentTestTaskConversations(task_id=task_id, agent_id=agent_id,
+                                                                              input='输入', output='输出',
+                                                                              dataset_id=i, conversation_id=i)
+                    agent_test_task_conversations.append(agent_test_task_conversation)
+                session.add_all(agent_test_task_conversations)
+                # 异步执行任务
         self._execute_task(task_id)
         return self.get_task(task_id)
 
-    def get_task(self, task_id: int) -> Optional[Dict]:
+    def get_task(self, task_id: int):
         """获取任务信息"""
-        return self.db.fetch_one(f"""SELECT * FROM {self.test_task_table} WHERE id = %s""", (task_id,))
+        with self.session_maker() as session:
+            return session.query(AgentTestTask).filter(AgentTestTask.id == task_id).one()
 
-    def get_task_conversations(self, task_id: int) -> List[Dict]:
+    def get_task_conversations(self, task_id: int):
         """获取任务的所有子任务"""
-        return self.db.fetch(f"""SELECT * FROM {self.test_task_conversations_table} WHERE task_id = %s""", (task_id,))
+        with self.session_maker() as session:
+            return session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).all()
 
-    def get_pending_task_conversations(self, task_id: int) -> List[Dict]:
+    def get_pending_task_conversations(self, task_id: int):
         """获取待处理的子任务"""
-        return self.db.fetch(
-            f"""SELECT * FROM {self.test_task_conversations_table} WHERE task_id = %s AND status = %s""",
-            (task_id, TestTaskConversationsStatus.PENDING.value)
-        )
+        with self.session_maker() as session:
+            return session.query(AgentTestTaskConversations).filter(
+                AgentTestTaskConversations.task_id == task_id).filter(
+                AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).all()
 
     def update_task_status(self, task_id: int, status: int):
         """更新任务状态"""
-        self.db.execute(
-            f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
-            (status, task_id)
-        )
+        with self.session_maker() as session:
+            session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update({"status": status})
+            session.commit()
 
     def update_task_conversations_status(self, task_conversations_id: int, status: int):
         """更新子任务状态"""
-        self.db.execute(
-            f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE id = %s""",
-            (status, task_conversations_id)
-        )
+        with self.session_maker() as session:
+            session.query(AgentTestTaskConversations).filter(
+                AgentTestTaskConversations.id == task_conversations_id).update({"status": status})
+            session.commit()
 
     def update_task_conversations_res(self, task_conversations_id: int, status: int, score: str):
-        """更新子任务状态"""
-        self.db.execute(
-            f"""UPDATE {self.test_task_conversations_table} SET status = %s, score = %s WHERE id = %s""",
-            (status, score, task_conversations_id)
-        )
+        """更新子任务结果"""
+        with self.session_maker() as session:
+            session.query(AgentTestTaskConversations).filter(
+                AgentTestTaskConversations.id == task_conversations_id).update({"status": status, "score": score})
+            session.commit()
 
     def cancel_task(self, task_id: int):
         """取消任务(带事务支持)"""
@@ -281,63 +167,32 @@ class TaskManager:
         if task_id in self.task_futures:
             self.task_futures[task_id].cancel()
 
-        conn = self.db.get_connection()
-        try:
-            conn.begin()
-            # 更新任务状态为取消
-            with conn.cursor() as cursor:
-                cursor.execute(
-                    f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
-                    (TestTaskStatus.CANCELLED.value, task_id)
-                )
-
-            # 取消所有待处理的子任务
-            with conn.cursor() as cursor:
-                cursor.execute(
-                    f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE task_id = %s AND status = %s""",
-                    (TestTaskConversationsStatus.CANCELLED.value, task_id, TestTaskConversationsStatus.PENDING.value)
-                )
-            conn.commit()
-            logger.info(f"Cancelled task {task_id} and its pending {self.test_task_conversations_table}")
-        except Exception as e:
-            conn.rollback()
-            logger.error(f"Failed to cancel task {task_id}: {str(e)}")
-        finally:
-            self.db.release_connection(conn)
+        with self.session_maker() as session:
+            with session.begin():
+                session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
+                    {"status": TestTaskStatus.CANCELLED.value})
+                session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
+                    AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).update(
+                    {"status": TestTaskConversationsStatus.CANCELLED.value})
+                session.commit()
 
     def resume_task(self, task_id: int) -> bool:
         """恢复已取消的任务"""
         task = self.get_task(task_id)
-        if not task or task['status'] != TestTaskStatus.CANCELLED.value:
+        if not task or task.status != TestTaskStatus.CANCELLED.value:
             return False
 
-        conn = self.db.get_connection()
-        try:
-            conn.begin()
-            # 更新任务状态为待开始
-            with conn.cursor() as cursor:
-                cursor.execute(
-                    f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
-                    (TestTaskStatus.NOT_STARTED.value, task_id)
-                )
-
-            # 恢复所有已取消的子任务
-            with conn.cursor() as cursor:
-                cursor.execute(
-                    f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE task_id = %s AND status = %s""",
-                    (TestTaskConversationsStatus.PENDING.value, task_id, TestTaskConversationsStatus.CANCELLED.value)
-                )
-            conn.commit()
-            logger.info(f"Cancelled task {task_id} and its pending {self.test_task_conversations_table}")
-        except Exception as e:
-            conn.rollback()
-            logger.error(f"Failed to cancel task {task_id}: {str(e)}")
-        finally:
-            self.db.release_connection(conn)
+        with self.session_maker() as session:
+            with session.begin():
+                session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
+                    {"status": TestTaskStatus.NOT_STARTED.value})
+                session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
+                    AgentTestTaskConversations.status == TestTaskConversationsStatus.CANCELLED.value).update(
+                    {"status": TestTaskConversationsStatus.PENDING.value})
+                session.commit()
 
         # 重新执行任务
         self._execute_task(task_id)
-
         logger.info(f"Resumed task {task_id}")
         return True
 
@@ -377,7 +232,7 @@ class TaskManager:
                     break
 
                 # 更新子任务状态为运行中
-                self.update_task_conversations_status(task_conversation['id'],
+                self.update_task_conversations_status(task_conversation.id,
                                                       TestTaskConversationsStatus.RUNNING.value)
                 try:
                     # 模拟任务执行 - 在实际应用中替换为实际业务逻辑
@@ -385,18 +240,19 @@ class TaskManager:
                     time.sleep(1)
                     score = '{"score":0.05}'
                     # 更新子任务状态为已完成
-                    self.update_task_conversations_res(task_conversation['id'],
+                    self.update_task_conversations_res(task_conversation.id,
                                                        TestTaskConversationsStatus.SUCCESS.value, score)
                 except Exception as e:
                     logger.error(f"Error executing task {task_id}: {str(e)}")
-                    self.update_task_conversations_status(task_conversation['id'],
+                    self.update_task_conversations_status(task_conversation.id,
                                                           TestTaskConversationsStatus.FAILED.value)
+
             # 检查任务是否完成
             task_conversations = self.get_task_conversations(task_id)
-            all_completed = all(task_conversation['status'] in
+            all_completed = all(task_conversation.status in
                                 (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value)
                                 for task_conversation in task_conversations)
-            any_pending = any(task_conversation['status'] in
+            any_pending = any(task_conversation.status in
                               (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value)
                               for task_conversation in task_conversations)
 
@@ -405,7 +261,7 @@ class TaskManager:
                 logger.info(f"Task {task_id} completed")
             elif not any_pending:
                 # 没有待处理子任务但未全部完成(可能是取消了)
-                current_status = self.get_task(task_id)['status']
+                current_status = self.get_task(task_id).status
                 if current_status != TestTaskStatus.CANCELLED.value:
                     self.update_task_status(task_id, TestTaskStatus.COMPLETED.value
                     if all_completed else TestTaskStatus.CANCELLED.value)