|
@@ -0,0 +1,433 @@
|
|
|
|
+import random
|
|
|
|
+import threading
|
|
|
|
+import time
|
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
|
+from queue import Queue
|
|
|
|
+from typing import List, Dict, Optional
|
|
|
|
+
|
|
|
|
+import pymysql
|
|
|
|
+from pymysql import Connection
|
|
|
|
+from pymysql.cursors import DictCursor
|
|
|
|
+
|
|
|
|
+from pqai_agent import logging_service
|
|
|
|
+from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc, \
|
|
|
|
+ get_test_task_conversations_status_desc
|
|
|
|
+
|
|
|
|
+logger = logging_service.logger
|
|
|
|
+class Database:
|
|
|
|
+ """数据库操作类"""
|
|
|
|
+
|
|
|
|
+ def __init__(self, db_config):
|
|
|
|
+ self.db_config = db_config
|
|
|
|
+ self.connection_pool = Queue(maxsize=10)
|
|
|
|
+ self._initialize_pool()
|
|
|
|
+
|
|
|
|
+ def _initialize_pool(self):
|
|
|
|
+ """初始化数据库连接池"""
|
|
|
|
+ for _ in range(5):
|
|
|
|
+ conn = pymysql.connect(**self.db_config)
|
|
|
|
+ self.connection_pool.put(conn)
|
|
|
|
+ logger.info("Database connection pool initialized with 5 connections")
|
|
|
|
+
|
|
|
|
+ def get_connection(self) -> Connection:
|
|
|
|
+ """从连接池获取数据库连接"""
|
|
|
|
+ return self.connection_pool.get()
|
|
|
|
+
|
|
|
|
+ def release_connection(self, conn: Connection):
|
|
|
|
+ """释放数据库连接回连接池"""
|
|
|
|
+ self.connection_pool.put(conn)
|
|
|
|
+
|
|
|
|
+ def execute(self, query: str, args: tuple = (), many: bool = False) -> int:
|
|
|
|
+ """执行SQL语句并返回影响的行数"""
|
|
|
|
+ conn = self.get_connection()
|
|
|
|
+ try:
|
|
|
|
+ with conn.cursor() as cursor:
|
|
|
|
+ if many:
|
|
|
|
+ cursor.executemany(query, args)
|
|
|
|
+ else:
|
|
|
|
+ cursor.execute(query, args)
|
|
|
|
+ conn.commit()
|
|
|
|
+ return cursor.rowcount
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Database error: {str(e)}")
|
|
|
|
+ conn.rollback()
|
|
|
|
+ raise
|
|
|
|
+ finally:
|
|
|
|
+ self.release_connection(conn)
|
|
|
|
+
|
|
|
|
+ def insert(self, insert: str, args: tuple = (), many: bool = False) -> int:
|
|
|
|
+ """执行插入SQL语句并主键"""
|
|
|
|
+ conn = self.get_connection()
|
|
|
|
+ try:
|
|
|
|
+ with conn.cursor() as cursor:
|
|
|
|
+ if many:
|
|
|
|
+ cursor.executemany(insert, args)
|
|
|
|
+ else:
|
|
|
|
+ cursor.execute(insert, args)
|
|
|
|
+ conn.commit()
|
|
|
|
+ return cursor.lastrowid
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Database error: {str(e)}")
|
|
|
|
+ conn.rollback()
|
|
|
|
+ raise
|
|
|
|
+ finally:
|
|
|
|
+ self.release_connection(conn)
|
|
|
|
+
|
|
|
|
+ def fetch(self, query: str, args: tuple = ()) -> List[Dict]:
|
|
|
|
+ """执行SQL查询并返回结果列表"""
|
|
|
|
+ conn = self.get_connection()
|
|
|
|
+ try:
|
|
|
|
+ with conn.cursor(DictCursor) as cursor:
|
|
|
|
+ cursor.execute(query, args)
|
|
|
|
+ return cursor.fetchall()
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Database error: {str(e)}")
|
|
|
|
+ raise
|
|
|
|
+ finally:
|
|
|
|
+ self.release_connection(conn)
|
|
|
|
+
|
|
|
|
+ def fetch_one(self, query: str, args: tuple = ()) -> Optional[Dict]:
|
|
|
|
+ """执行SQL查询并返回单行结果"""
|
|
|
|
+ conn = self.get_connection()
|
|
|
|
+ try:
|
|
|
|
+ with conn.cursor(DictCursor) as cursor:
|
|
|
|
+ cursor.execute(query, args)
|
|
|
|
+ return cursor.fetchone()
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Database error: {str(e)}")
|
|
|
|
+ raise
|
|
|
|
+ finally:
|
|
|
|
+ self.release_connection(conn)
|
|
|
|
+
|
|
|
|
+ def close_all(self):
|
|
|
|
+ """关闭所有数据库连接"""
|
|
|
|
+ while not self.connection_pool.empty():
|
|
|
|
+ conn = self.connection_pool.get()
|
|
|
|
+ conn.close()
|
|
|
|
+ logger.info("All database connections closed")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class TaskManager:
|
|
|
|
+ """任务管理器"""
|
|
|
|
+
|
|
|
|
+ def __init__(self, db_config, agent_configuration_table, test_task_table, test_task_conversations_table,
|
|
|
|
+ max_workers: int = 10):
|
|
|
|
+ self.db = Database(db_config)
|
|
|
|
+ self.agent_configuration_table = agent_configuration_table
|
|
|
|
+ self.test_task_table = test_task_table
|
|
|
|
+ self.test_task_conversations_table = test_task_conversations_table
|
|
|
|
+ self.task_events = {} # 任务ID -> Event (用于取消任务)
|
|
|
|
+ self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
|
|
|
|
+ self.running_tasks = set()
|
|
|
|
+ self.executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='TaskWorker')
|
|
|
|
+ self.task_futures = {} # 任务ID -> Future
|
|
|
|
+
|
|
|
|
+ def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
|
|
|
|
+ fetch_query = f"""
|
|
|
|
+ select t1.id, t2.name, t1.create_user, t1.update_user, t1.status, t1.create_time, t1.update_time
|
|
|
|
+ from {self.test_task_table} t1
|
|
|
|
+ left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
|
|
|
|
+ order by create_time desc
|
|
|
|
+ limit %s, %s;
|
|
|
|
+ """
|
|
|
|
+ messages = self.db.fetch(fetch_query, ((page_num - 1) * page_size, page_size))
|
|
|
|
+ total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_table}""")
|
|
|
|
+
|
|
|
|
+ total = total_size["count"]
|
|
|
|
+ total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
|
|
|
|
+ total_page = 1 if total_page <= 0 else total_page
|
|
|
|
+ response_data = [
|
|
|
|
+ {
|
|
|
|
+ "id": message["id"],
|
|
|
|
+ "agentName": message["name"],
|
|
|
|
+ "createUser": message["create_user"],
|
|
|
|
+ "updateUser": message["update_user"],
|
|
|
|
+ "statusName": get_test_task_status_desc(message["status"]),
|
|
|
|
+ "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
|
+ "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
+ }
|
|
|
|
+ for message in messages
|
|
|
|
+ ]
|
|
|
|
+ return {
|
|
|
|
+ "currentPage": page_num,
|
|
|
|
+ "pageSize": page_size,
|
|
|
|
+ "totalSize": total_page,
|
|
|
|
+ "total": total,
|
|
|
|
+ "list": response_data,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict:
|
|
|
|
+ fetch_query = f"""
|
|
|
|
+ select t1.id, t2.name, t3.create_user, t1.input, t1.output, t1.score, t1.status, t1.create_time, t1.update_time
|
|
|
|
+ from {self.test_task_conversations_table} t1
|
|
|
|
+ left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
|
|
|
|
+ left join {self.test_task_table} t3 on t1.task_id = t3.id
|
|
|
|
+ where t1.task_id = %s
|
|
|
|
+ order by create_time desc
|
|
|
|
+ limit %s, %s;
|
|
|
|
+ """
|
|
|
|
+ messages = self.db.fetch(fetch_query, (task_id, (page_num - 1) * page_size, page_size))
|
|
|
|
+ total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_conversations_table} where task_id = %s""",
|
|
|
|
+ (task_id,))
|
|
|
|
+
|
|
|
|
+ total = total_size["count"]
|
|
|
|
+ total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
|
|
|
|
+ total_page = 1 if total_page <= 0 else total_page
|
|
|
|
+ response_data = [
|
|
|
|
+ {
|
|
|
|
+ "id": message["id"],
|
|
|
|
+ "agentName": message["name"],
|
|
|
|
+ "createUser": message["create_user"],
|
|
|
|
+ "input": message["input"],
|
|
|
|
+ "output": message["output"],
|
|
|
|
+ "score": message["score"],
|
|
|
|
+ "statusName": get_test_task_conversations_status_desc(message["status"]),
|
|
|
|
+ "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
|
+ "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
+ }
|
|
|
|
+ for message in messages
|
|
|
|
+ ]
|
|
|
|
+ return {
|
|
|
|
+ "currentPage": page_num,
|
|
|
|
+ "pageSize": page_size,
|
|
|
|
+ "totalSize": total_page,
|
|
|
|
+ "total": total,
|
|
|
|
+ "list": response_data,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ def create_task(self, agent_id: int) -> Dict:
|
|
|
|
+ """创建新任务并添加100个子任务"""
|
|
|
|
+
|
|
|
|
+ conn = self.db.get_connection()
|
|
|
|
+ try:
|
|
|
|
+ conn.begin()
|
|
|
|
+ # 插入任务
|
|
|
|
+ with conn.cursor() as cursor:
|
|
|
|
+ cursor.execute(
|
|
|
|
+ f"""INSERT INTO {self.test_task_table} (agent_id, status, create_user, update_user) VALUES (%s, %s, %s, %s)""",
|
|
|
|
+ (agent_id, 0, 'xueyiming', 'xueyiming')
|
|
|
|
+ )
|
|
|
|
+ task_id = cursor.lastrowid
|
|
|
|
+ task_conversations = []
|
|
|
|
+ i = 0
|
|
|
|
+ for _ in range(30):
|
|
|
|
+ i = i + 1
|
|
|
|
+ task_conversations.append((
|
|
|
|
+ task_id, agent_id, i, i, "输入", "输出", 0
|
|
|
|
+ ))
|
|
|
|
+
|
|
|
|
+ with conn.cursor() as cursor:
|
|
|
|
+ cursor.executemany(
|
|
|
|
+ f"""INSERT INTO {self.test_task_conversations_table} (task_id, agent_id, dataset_id, conversation_id, input, output, status)
|
|
|
|
+ VALUES (%s, %s, %s, %s, %s, %s, %s)""",
|
|
|
|
+ task_conversations
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ conn.commit()
|
|
|
|
+ except Exception as e:
|
|
|
|
+ conn.rollback()
|
|
|
|
+ logger.error(f"Failed to create task agent_id {agent_id}: {str(e)}")
|
|
|
|
+ raise
|
|
|
|
+ finally:
|
|
|
|
+ self.db.release_connection(conn)
|
|
|
|
+ logger.info(f"Created task {task_id} with 100 task_conversations")
|
|
|
|
+ # 异步执行任务
|
|
|
|
+ self._execute_task(task_id)
|
|
|
|
+ return self.get_task(task_id)
|
|
|
|
+
|
|
|
|
+ def get_task(self, task_id: int) -> Optional[Dict]:
|
|
|
|
+ """获取任务信息"""
|
|
|
|
+ return self.db.fetch_one(f"""SELECT * FROM {self.test_task_table} WHERE id = %s""", (task_id,))
|
|
|
|
+
|
|
|
|
+ def get_task_conversations(self, task_id: int) -> List[Dict]:
|
|
|
|
+ """获取任务的所有子任务"""
|
|
|
|
+ return self.db.fetch(f"""SELECT * FROM {self.test_task_conversations_table} WHERE task_id = %s""", (task_id,))
|
|
|
|
+
|
|
|
|
+ def get_pending_task_conversations(self, task_id: int) -> List[Dict]:
|
|
|
|
+ """获取待处理的子任务"""
|
|
|
|
+ return self.db.fetch(
|
|
|
|
+ f"""SELECT * FROM {self.test_task_conversations_table} WHERE task_id = %s AND status = %s""",
|
|
|
|
+ (task_id, TestTaskConversationsStatus.PENDING.value)
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def update_task_status(self, task_id: int, status: int):
|
|
|
|
+ """更新任务状态"""
|
|
|
|
+ self.db.execute(
|
|
|
|
+ f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
|
|
|
|
+ (status, task_id)
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def update_task_conversations_status(self, task_conversations_id: int, status: int):
|
|
|
|
+ """更新子任务状态"""
|
|
|
|
+ self.db.execute(
|
|
|
|
+ f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE id = %s""",
|
|
|
|
+ (status, task_conversations_id)
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def update_task_conversations_res(self, task_conversations_id: int, status: int, score: float):
|
|
|
|
+ """更新子任务状态"""
|
|
|
|
+ self.db.execute(
|
|
|
|
+ f"""UPDATE {self.test_task_conversations_table} SET status = %s, score = %s WHERE id = %s""",
|
|
|
|
+ (status, score, task_conversations_id)
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ def cancel_task(self, task_id: int):
|
|
|
|
+ """取消任务(带事务支持)"""
|
|
|
|
+ # 设置取消事件
|
|
|
|
+ if task_id in self.task_events:
|
|
|
|
+ self.task_events[task_id].set()
|
|
|
|
+ # 如果任务正在执行,尝试取消Future
|
|
|
|
+ if task_id in self.task_futures:
|
|
|
|
+ self.task_futures[task_id].cancel()
|
|
|
|
+
|
|
|
|
+ conn = self.db.get_connection()
|
|
|
|
+ try:
|
|
|
|
+ conn.begin()
|
|
|
|
+ # 更新任务状态为取消
|
|
|
|
+ with conn.cursor() as cursor:
|
|
|
|
+ cursor.execute(
|
|
|
|
+ f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
|
|
|
|
+ (TestTaskStatus.CANCELLED.value, task_id)
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # 取消所有待处理的子任务
|
|
|
|
+ with conn.cursor() as cursor:
|
|
|
|
+ cursor.execute(
|
|
|
|
+ f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE task_id = %s AND status = %s""",
|
|
|
|
+ (TestTaskConversationsStatus.CANCELLED.value, task_id, TestTaskConversationsStatus.PENDING.value)
|
|
|
|
+ )
|
|
|
|
+ conn.commit()
|
|
|
|
+ logger.info(f"Cancelled task {task_id} and its pending {self.test_task_conversations_table}")
|
|
|
|
+ except Exception as e:
|
|
|
|
+ conn.rollback()
|
|
|
|
+ logger.error(f"Failed to cancel task {task_id}: {str(e)}")
|
|
|
|
+ finally:
|
|
|
|
+ self.db.release_connection(conn)
|
|
|
|
+
|
|
|
|
+ def resume_task(self, task_id: int) -> bool:
|
|
|
|
+ """恢复已取消的任务"""
|
|
|
|
+ task = self.get_task(task_id)
|
|
|
|
+ if not task or task['status'] != TestTaskStatus.CANCELLED.value:
|
|
|
|
+ return False
|
|
|
|
+
|
|
|
|
+ conn = self.db.get_connection()
|
|
|
|
+ try:
|
|
|
|
+ conn.begin()
|
|
|
|
+ # 更新任务状态为待开始
|
|
|
|
+ with conn.cursor() as cursor:
|
|
|
|
+ cursor.execute(
|
|
|
|
+ f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
|
|
|
|
+ (TestTaskStatus.NOT_STARTED.value, task_id)
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # 恢复所有已取消的子任务
|
|
|
|
+ with conn.cursor() as cursor:
|
|
|
|
+ cursor.execute(
|
|
|
|
+ f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE task_id = %s AND status = %s""",
|
|
|
|
+ (TestTaskConversationsStatus.PENDING.value, task_id, TestTaskConversationsStatus.CANCELLED.value)
|
|
|
|
+ )
|
|
|
|
+ conn.commit()
|
|
|
|
+ logger.info(f"Cancelled task {task_id} and its pending {self.test_task_conversations_table}")
|
|
|
|
+ except Exception as e:
|
|
|
|
+ conn.rollback()
|
|
|
|
+ logger.error(f"Failed to cancel task {task_id}: {str(e)}")
|
|
|
|
+ finally:
|
|
|
|
+ self.db.release_connection(conn)
|
|
|
|
+
|
|
|
|
+ # 重新执行任务
|
|
|
|
+ self._execute_task(task_id)
|
|
|
|
+
|
|
|
|
+ logger.info(f"Resumed task {task_id}")
|
|
|
|
+ return True
|
|
|
|
+
|
|
|
|
+ def _execute_task(self, task_id: int):
|
|
|
|
+ """提交任务到线程池执行"""
|
|
|
|
+ # 确保任务状态一致性
|
|
|
|
+ if task_id in self.running_tasks:
|
|
|
|
+ return
|
|
|
|
+
|
|
|
|
+ # 创建任务事件和锁
|
|
|
|
+ if task_id not in self.task_events:
|
|
|
|
+ self.task_events[task_id] = threading.Event()
|
|
|
|
+ if task_id not in self.task_locks:
|
|
|
|
+ self.task_locks[task_id] = threading.Lock()
|
|
|
|
+
|
|
|
|
+ # 提交到线程池
|
|
|
|
+ future = self.executor.submit(self._process_task, task_id)
|
|
|
|
+ self.task_futures[task_id] = future
|
|
|
|
+
|
|
|
|
+ # 标记任务为运行中
|
|
|
|
+ with self.task_locks[task_id]:
|
|
|
|
+ self.running_tasks.add(task_id)
|
|
|
|
+
|
|
|
|
+ def _process_task(self, task_id: int):
|
|
|
|
+ """处理任务的所有子任务"""
|
|
|
|
+ try:
|
|
|
|
+ # 更新任务状态为运行中
|
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
|
|
|
|
+
|
|
|
|
+ # 获取所有待处理的子任务
|
|
|
|
+ task_conversations = self.get_pending_task_conversations(task_id)
|
|
|
|
+
|
|
|
|
+ # 执行每个子任务
|
|
|
|
+ for task_conversation in task_conversations:
|
|
|
|
+ # 检查任务是否被取消
|
|
|
|
+ if self.task_events[task_id].is_set():
|
|
|
|
+ break
|
|
|
|
+
|
|
|
|
+ # 更新子任务状态为运行中
|
|
|
|
+ self.update_task_conversations_status(task_conversation['id'],
|
|
|
|
+ TestTaskConversationsStatus.RUNNING.value)
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ # 模拟任务执行 - 在实际应用中替换为实际业务逻辑
|
|
|
|
+
|
|
|
|
+ # 再次检查任务是否被取消
|
|
|
|
+ # if self.task_events[task_id].is_set():
|
|
|
|
+ # self.update_task_conversations_status(subtask['id'], 'cancelled')
|
|
|
|
+ # break
|
|
|
|
+ # TODO 实际任务执行
|
|
|
|
+ time.sleep(1)
|
|
|
|
+ score = random.random()
|
|
|
|
+ # 更新子任务状态为已完成
|
|
|
|
+ self.update_task_conversations_res(task_conversation['id'],
|
|
|
|
+ TestTaskConversationsStatus.SUCCESS.value, score)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Error executing task {task_id}: {str(e)}")
|
|
|
|
+ self.update_task_conversations_status(task_conversation['id'],
|
|
|
|
+ TestTaskConversationsStatus.FAILED.value)
|
|
|
|
+ # 检查任务是否完成
|
|
|
|
+ task_conversations = self.get_task_conversations(task_id)
|
|
|
|
+ all_completed = all(task_conversation['status'] in
|
|
|
|
+ (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value)
|
|
|
|
+ for task_conversation in task_conversations)
|
|
|
|
+ any_pending = any(task_conversation['status'] in
|
|
|
|
+ (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value)
|
|
|
|
+ for task_conversation in task_conversations)
|
|
|
|
+
|
|
|
|
+ if all_completed:
|
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
|
|
|
|
+ logger.info(f"Task {task_id} completed")
|
|
|
|
+ elif not any_pending:
|
|
|
|
+ # 没有待处理子任务但未全部完成(可能是取消了)
|
|
|
|
+ current_status = self.get_task(task_id)['status']
|
|
|
|
+ if current_status != TestTaskStatus.CANCELLED.value:
|
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.COMPLETED.value
|
|
|
|
+ if all_completed else TestTaskStatus.CANCELLED.value)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"Error executing task {task_id}: {str(e)}")
|
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
|
|
|
|
+ finally:
|
|
|
|
+ # 清理资源
|
|
|
|
+ with self.task_locks[task_id]:
|
|
|
|
+ if task_id in self.running_tasks:
|
|
|
|
+ self.running_tasks.remove(task_id)
|
|
|
|
+ if task_id in self.task_events:
|
|
|
|
+ del self.task_events[task_id]
|
|
|
|
+ if task_id in self.task_futures:
|
|
|
|
+ del self.task_futures[task_id]
|
|
|
|
+
|
|
|
|
+ def shutdown(self):
|
|
|
|
+ """关闭执行器"""
|
|
|
|
+ self.executor.shutdown(wait=False)
|
|
|
|
+ logger.info("Task executor shutdown")
|