123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469 |
- import random
- import threading
- import time
- from concurrent.futures import ThreadPoolExecutor
- from queue import Queue
- from typing import List, Dict, Optional
- import pymysql
- from pymysql import Connection
- from pymysql.cursors import DictCursor
- from sqlalchemy import func
- from pqai_agent import logging_service
- from pqai_agent.data_models.agent_configuration import AgentConfiguration
- from pqai_agent.data_models.agent_test_task import AgentTestTask
- from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc, \
- get_test_task_conversations_status_desc
- logger = logging_service.logger
- class Database:
- """数据库操作类"""
- def __init__(self, db_config):
- self.db_config = db_config
- self.connection_pool = Queue(maxsize=10)
- self._initialize_pool()
- def _initialize_pool(self):
- """初始化数据库连接池"""
- for _ in range(5):
- conn = pymysql.connect(**self.db_config)
- self.connection_pool.put(conn)
- logger.info("Database connection pool initialized with 5 connections")
- def get_connection(self) -> Connection:
- """从连接池获取数据库连接"""
- return self.connection_pool.get()
- def release_connection(self, conn: Connection):
- """释放数据库连接回连接池"""
- self.connection_pool.put(conn)
- def execute(self, query: str, args: tuple = (), many: bool = False) -> int:
- """执行SQL语句并返回影响的行数"""
- conn = self.get_connection()
- try:
- with conn.cursor() as cursor:
- if many:
- cursor.executemany(query, args)
- else:
- cursor.execute(query, args)
- conn.commit()
- return cursor.rowcount
- except Exception as e:
- logger.error(f"Database error: {str(e)}")
- conn.rollback()
- raise
- finally:
- self.release_connection(conn)
- def insert(self, insert: str, args: tuple = (), many: bool = False) -> int:
- """执行插入SQL语句并主键"""
- conn = self.get_connection()
- try:
- with conn.cursor() as cursor:
- if many:
- cursor.executemany(insert, args)
- else:
- cursor.execute(insert, args)
- conn.commit()
- return cursor.lastrowid
- except Exception as e:
- logger.error(f"Database error: {str(e)}")
- conn.rollback()
- raise
- finally:
- self.release_connection(conn)
- def fetch(self, query: str, args: tuple = ()) -> List[Dict]:
- """执行SQL查询并返回结果列表"""
- conn = self.get_connection()
- try:
- with conn.cursor(DictCursor) as cursor:
- cursor.execute(query, args)
- return cursor.fetchall()
- except Exception as e:
- logger.error(f"Database error: {str(e)}")
- raise
- finally:
- self.release_connection(conn)
- def fetch_one(self, query: str, args: tuple = ()) -> Optional[Dict]:
- """执行SQL查询并返回单行结果"""
- conn = self.get_connection()
- try:
- with conn.cursor(DictCursor) as cursor:
- cursor.execute(query, args)
- return cursor.fetchone()
- except Exception as e:
- logger.error(f"Database error: {str(e)}")
- raise
- finally:
- self.release_connection(conn)
- def close_all(self):
- """关闭所有数据库连接"""
- while not self.connection_pool.empty():
- conn = self.connection_pool.get()
- conn.close()
- logger.info("All database connections closed")
- class TaskManager:
- """任务管理器"""
- def __init__(self, session_maker, db_config, agent_configuration_table, test_task_table,
- test_task_conversations_table,
- max_workers: int = 10):
- self.session_maker = session_maker
- self.db = Database(db_config)
- self.agent_configuration_table = agent_configuration_table
- self.test_task_table = test_task_table
- self.test_task_conversations_table = test_task_conversations_table
- self.task_events = {} # 任务ID -> Event (用于取消任务)
- self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
- self.running_tasks = set()
- self.executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='TaskWorker')
- self.task_futures = {} # 任务ID -> Future
- # def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
- # fetch_query = f"""
- # select t1.id, t2.name, t1.create_user, t1.update_user, t1.status, t1.create_time, t1.update_time
- # from {self.test_task_table} t1
- # left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
- # order by create_time desc
- # limit %s, %s;
- # """
- # messages = self.db.fetch(fetch_query, ((page_num - 1) * page_size, page_size))
- # total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_table}""")
- #
- # total = total_size["count"]
- # total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
- # total_page = 1 if total_page <= 0 else total_page
- # response_data = [
- # {
- # "id": message["id"],
- # "agentName": message["name"],
- # "createUser": message["create_user"],
- # "updateUser": message["update_user"],
- # "statusName": get_test_task_status_desc(message["status"]),
- # "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
- # "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
- # }
- # for message in messages
- # ]
- # return {
- # "currentPage": page_num,
- # "pageSize": page_size,
- # "totalSize": total_page,
- # "total": total,
- # "list": response_data,
- # }
- def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
- with self.session_maker() as session:
- # 计算偏移量
- offset = (page_num - 1) * page_size
- # 查询分页数据
- result = (session.query(AgentTestTask, AgentConfiguration)
- .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
- .limit(page_size).offset(offset).all())
- # 查询总记录数
- total = session.query(func.count(AgentTestTask.id)).scalar()
- total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
- total_page = 1 if total_page <= 0 else total_page
- response_data = [
- {
- "id": agent_test_task.id,
- "agentName": agent_configuration.name,
- "createUser": agent_test_task.create_user,
- "updateUser": agent_test_task.update_user,
- "statusName": get_test_task_status_desc(agent_test_task.status),
- "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
- "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
- }
- for agent_test_task, agent_configuration in result
- ]
- return {
- "currentPage": page_num,
- "pageSize": page_size,
- "totalSize": total_page,
- "total": total,
- "list": response_data,
- }
- def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict:
- fetch_query = f"""
- select t1.id, t2.name, t3.create_user, t1.input, t1.output, t1.score, t1.status, t1.create_time, t1.update_time
- from {self.test_task_conversations_table} t1
- left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
- left join {self.test_task_table} t3 on t1.task_id = t3.id
- where t1.task_id = %s
- order by create_time desc
- limit %s, %s;
- """
- messages = self.db.fetch(fetch_query, (task_id, (page_num - 1) * page_size, page_size))
- total_size = self.db.fetch_one(
- f"""select count(*) as `count` from {self.test_task_conversations_table} where task_id = %s""",
- (task_id,))
- total = total_size["count"]
- total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
- total_page = 1 if total_page <= 0 else total_page
- response_data = [
- {
- "id": message["id"],
- "agentName": message["name"],
- "createUser": message["create_user"],
- "input": message["input"],
- "output": message["output"],
- "score": message["score"],
- "statusName": get_test_task_conversations_status_desc(message["status"]),
- "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
- "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
- }
- for message in messages
- ]
- return {
- "currentPage": page_num,
- "pageSize": page_size,
- "totalSize": total_page,
- "total": total,
- "list": response_data,
- }
- def create_task(self, agent_id: int, model_id: int) -> Dict:
- """创建新任务"""
- conn = self.db.get_connection()
- try:
- conn.begin()
- # TODO 插入任务 当前测试模拟
- with conn.cursor() as cursor:
- cursor.execute(
- f"""INSERT INTO {self.test_task_table} (agent_id, status, create_user, update_user) VALUES (%s, %s, %s, %s)""",
- (agent_id, 0, 'xueyiming', 'xueyiming')
- )
- task_id = cursor.lastrowid
- task_conversations = []
- # TODO 查询具体的数据集信息后插入
- i = 0
- for _ in range(30):
- i = i + 1
- task_conversations.append((
- task_id, agent_id, i, i, "输入", "输出", 0
- ))
- with conn.cursor() as cursor:
- cursor.executemany(
- f"""INSERT INTO {self.test_task_conversations_table} (task_id, agent_id, dataset_id, conversation_id, input, output, status)
- VALUES (%s, %s, %s, %s, %s, %s, %s)""",
- task_conversations
- )
- conn.commit()
- except Exception as e:
- conn.rollback()
- logger.error(f"Failed to create task agent_id {agent_id}: {str(e)}")
- raise
- finally:
- self.db.release_connection(conn)
- logger.info(f"Created task {task_id} with 100 task_conversations")
- # 异步执行任务
- self._execute_task(task_id)
- return self.get_task(task_id)
- def get_task(self, task_id: int) -> Optional[Dict]:
- """获取任务信息"""
- return self.db.fetch_one(f"""SELECT * FROM {self.test_task_table} WHERE id = %s""", (task_id,))
- def get_task_conversations(self, task_id: int) -> List[Dict]:
- """获取任务的所有子任务"""
- return self.db.fetch(f"""SELECT * FROM {self.test_task_conversations_table} WHERE task_id = %s""", (task_id,))
- def get_pending_task_conversations(self, task_id: int) -> List[Dict]:
- """获取待处理的子任务"""
- return self.db.fetch(
- f"""SELECT * FROM {self.test_task_conversations_table} WHERE task_id = %s AND status = %s""",
- (task_id, TestTaskConversationsStatus.PENDING.value)
- )
- def update_task_status(self, task_id: int, status: int):
- """更新任务状态"""
- self.db.execute(
- f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
- (status, task_id)
- )
- def update_task_conversations_status(self, task_conversations_id: int, status: int):
- """更新子任务状态"""
- self.db.execute(
- f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE id = %s""",
- (status, task_conversations_id)
- )
- def update_task_conversations_res(self, task_conversations_id: int, status: int, score: str):
- """更新子任务状态"""
- self.db.execute(
- f"""UPDATE {self.test_task_conversations_table} SET status = %s, score = %s WHERE id = %s""",
- (status, score, task_conversations_id)
- )
- def cancel_task(self, task_id: int):
- """取消任务(带事务支持)"""
- # 设置取消事件
- if task_id in self.task_events:
- self.task_events[task_id].set()
- # 如果任务正在执行,尝试取消Future
- if task_id in self.task_futures:
- self.task_futures[task_id].cancel()
- conn = self.db.get_connection()
- try:
- conn.begin()
- # 更新任务状态为取消
- with conn.cursor() as cursor:
- cursor.execute(
- f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
- (TestTaskStatus.CANCELLED.value, task_id)
- )
- # 取消所有待处理的子任务
- with conn.cursor() as cursor:
- cursor.execute(
- f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE task_id = %s AND status = %s""",
- (TestTaskConversationsStatus.CANCELLED.value, task_id, TestTaskConversationsStatus.PENDING.value)
- )
- conn.commit()
- logger.info(f"Cancelled task {task_id} and its pending {self.test_task_conversations_table}")
- except Exception as e:
- conn.rollback()
- logger.error(f"Failed to cancel task {task_id}: {str(e)}")
- finally:
- self.db.release_connection(conn)
- def resume_task(self, task_id: int) -> bool:
- """恢复已取消的任务"""
- task = self.get_task(task_id)
- if not task or task['status'] != TestTaskStatus.CANCELLED.value:
- return False
- conn = self.db.get_connection()
- try:
- conn.begin()
- # 更新任务状态为待开始
- with conn.cursor() as cursor:
- cursor.execute(
- f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
- (TestTaskStatus.NOT_STARTED.value, task_id)
- )
- # 恢复所有已取消的子任务
- with conn.cursor() as cursor:
- cursor.execute(
- f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE task_id = %s AND status = %s""",
- (TestTaskConversationsStatus.PENDING.value, task_id, TestTaskConversationsStatus.CANCELLED.value)
- )
- conn.commit()
- logger.info(f"Cancelled task {task_id} and its pending {self.test_task_conversations_table}")
- except Exception as e:
- conn.rollback()
- logger.error(f"Failed to cancel task {task_id}: {str(e)}")
- finally:
- self.db.release_connection(conn)
- # 重新执行任务
- self._execute_task(task_id)
- logger.info(f"Resumed task {task_id}")
- return True
- def _execute_task(self, task_id: int):
- """提交任务到线程池执行"""
- # 确保任务状态一致性
- if task_id in self.running_tasks:
- return
- # 创建任务事件和锁
- if task_id not in self.task_events:
- self.task_events[task_id] = threading.Event()
- if task_id not in self.task_locks:
- self.task_locks[task_id] = threading.Lock()
- # 提交到线程池
- future = self.executor.submit(self._process_task, task_id)
- self.task_futures[task_id] = future
- # 标记任务为运行中
- with self.task_locks[task_id]:
- self.running_tasks.add(task_id)
- def _process_task(self, task_id: int):
- """处理任务的所有子任务"""
- try:
- # 更新任务状态为运行中
- self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
- # 获取所有待处理的子任务
- task_conversations = self.get_pending_task_conversations(task_id)
- # 执行每个子任务
- for task_conversation in task_conversations:
- # 检查任务是否被取消
- if self.task_events[task_id].is_set():
- break
- # 更新子任务状态为运行中
- self.update_task_conversations_status(task_conversation['id'],
- TestTaskConversationsStatus.RUNNING.value)
- try:
- # 模拟任务执行 - 在实际应用中替换为实际业务逻辑
- # TODO 后续改成实际任务执行
- time.sleep(1)
- score = '{"score":0.05}'
- # 更新子任务状态为已完成
- self.update_task_conversations_res(task_conversation['id'],
- TestTaskConversationsStatus.SUCCESS.value, score)
- except Exception as e:
- logger.error(f"Error executing task {task_id}: {str(e)}")
- self.update_task_conversations_status(task_conversation['id'],
- TestTaskConversationsStatus.FAILED.value)
- # 检查任务是否完成
- task_conversations = self.get_task_conversations(task_id)
- all_completed = all(task_conversation['status'] in
- (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value)
- for task_conversation in task_conversations)
- any_pending = any(task_conversation['status'] in
- (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value)
- for task_conversation in task_conversations)
- if all_completed:
- self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
- logger.info(f"Task {task_id} completed")
- elif not any_pending:
- # 没有待处理子任务但未全部完成(可能是取消了)
- current_status = self.get_task(task_id)['status']
- if current_status != TestTaskStatus.CANCELLED.value:
- self.update_task_status(task_id, TestTaskStatus.COMPLETED.value
- if all_completed else TestTaskStatus.CANCELLED.value)
- except Exception as e:
- logger.error(f"Error executing task {task_id}: {str(e)}")
- self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
- finally:
- # 清理资源
- with self.task_locks[task_id]:
- if task_id in self.running_tasks:
- self.running_tasks.remove(task_id)
- if task_id in self.task_events:
- del self.task_events[task_id]
- if task_id in self.task_futures:
- del self.task_futures[task_id]
- def shutdown(self):
- """关闭执行器"""
- self.executor.shutdown(wait=False)
- logger.info("Task executor shutdown")
|