import threading import threading import time from concurrent.futures import ThreadPoolExecutor from typing import Dict from pyarrow.dataset import dataset from sqlalchemy import func from pqai_agent import logging_service from pqai_agent.data_models.agent_configuration import AgentConfiguration from pqai_agent.data_models.agent_test_task import AgentTestTask from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc logger = logging_service.logger class TaskManager: """任务管理器""" def __init__(self, session_maker, dataset_server, max_workers: int = 10): self.session_maker = session_maker self.dataset_server = dataset_server self.task_events = {} # 任务ID -> Event (用于取消任务) self.task_locks = {} # 任务ID -> Lock (用于任务状态同步) self.running_tasks = set() self.executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='TaskWorker') self.task_futures = {} # 任务ID -> Future def get_test_task_list(self, page_num: int, page_size: int) -> Dict: with self.session_maker() as session: # 计算偏移量 offset = (page_num - 1) * page_size # 查询分页数据 result = (session.query(AgentTestTask, AgentConfiguration) .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id) .limit(page_size).offset(offset).all()) # 查询总记录数 total = session.query(func.count(AgentTestTask.id)).scalar() total_page = total // page_size + 1 if total % page_size > 0 else total // page_size total_page = 1 if total_page <= 0 else total_page response_data = [ { "id": agent_test_task.id, "agentName": agent_configuration.name, "createUser": agent_test_task.create_user, "updateUser": agent_test_task.update_user, "statusName": get_test_task_status_desc(agent_test_task.status), "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"), "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S") } for agent_test_task, agent_configuration in result ] return { "currentPage": page_num, "pageSize": page_size, "totalSize": total_page, "total": total, "list": response_data, } def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict: with self.session_maker() as session: # 计算偏移量 offset = (page_num - 1) * page_size # 查询分页数据 result = (session.query(AgentTestTaskConversations, AgentConfiguration) .outerjoin(AgentConfiguration, AgentTestTaskConversations.agent_id == AgentConfiguration.id) .filter(AgentTestTaskConversations.task_id == task_id) .limit(page_size).offset(offset).all()) # 查询总记录数 total = session.query(func.count(AgentTestTaskConversations.id)).scalar() total_page = total // page_size + 1 if total % page_size > 0 else total // page_size total_page = 1 if total_page <= 0 else total_page response_data = [ { "id": agent_test_task_conversation.id, "agentName": agent_configuration.name, "input": agent_test_task_conversation.input, "output": agent_test_task_conversation.output, "score": agent_test_task_conversation.score, "statusName": get_test_task_status_desc(agent_test_task_conversation.status), "createTime": agent_test_task_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S"), "updateTime": agent_test_task_conversation.update_time.strftime("%Y-%m-%d %H:%M:%S") } for agent_test_task_conversation, agent_configuration in result ] return { "currentPage": page_num, "pageSize": page_size, "totalSize": total_page, "total": total, "list": response_data, } def create_task(self, agent_id: int, module_id: int) -> Dict: """创建新任务""" with (self.session_maker() as session): with session.begin(): agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id) session.add(agent_test_task) session.flush() # 强制SQL执行,但不提交事务 task_id = agent_test_task.id agent_test_task_conversations = [] datasets_list = self.dataset_server.get_dataset_list_by_module(module_id) for datasets in datasets_list: conversation_datas = self.dataset_server.get_conversation_data_list_by_dataset(datasets.id) for conversation_data in conversation_datas: agent_test_task_conversation = AgentTestTaskConversations(task_id=task_id, agent_id=agent_id, dataset_id=datasets.id, conversation_id=conversation_data.id) agent_test_task_conversations.append(agent_test_task_conversation) session.add_all(agent_test_task_conversations) # 异步执行任务 self._execute_task(task_id) return self.get_task(task_id) def get_task(self, task_id: int): """获取任务信息""" with self.session_maker() as session: return session.query(AgentTestTask).filter(AgentTestTask.id == task_id).one() def get_task_conversations(self, task_id: int): """获取任务的所有子任务""" with self.session_maker() as session: return session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).all() def get_pending_task_conversations(self, task_id: int): """获取待处理的子任务""" with self.session_maker() as session: return session.query(AgentTestTaskConversations).filter( AgentTestTaskConversations.task_id == task_id).filter( AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).all() def update_task_status(self, task_id: int, status: int): """更新任务状态""" with self.session_maker() as session: session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update({"status": status}) session.commit() def update_task_conversations_status(self, task_conversations_id: int, status: int): """更新子任务状态""" with self.session_maker() as session: session.query(AgentTestTaskConversations).filter( AgentTestTaskConversations.id == task_conversations_id).update({"status": status}) session.commit() def update_task_conversations_res(self, task_conversations_id: int, status: int, score: str): """更新子任务结果""" with self.session_maker() as session: session.query(AgentTestTaskConversations).filter( AgentTestTaskConversations.id == task_conversations_id).update({"status": status, "score": score}) session.commit() def cancel_task(self, task_id: int): """取消任务(带事务支持)""" # 设置取消事件 if task_id in self.task_events: self.task_events[task_id].set() # 如果任务正在执行,尝试取消Future if task_id in self.task_futures: self.task_futures[task_id].cancel() with self.session_maker() as session: with session.begin(): session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update( {"status": TestTaskStatus.CANCELLED.value}) session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter( AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).update( {"status": TestTaskConversationsStatus.CANCELLED.value}) session.commit() def resume_task(self, task_id: int) -> bool: """恢复已取消的任务""" task = self.get_task(task_id) if not task or task.status != TestTaskStatus.CANCELLED.value: return False with self.session_maker() as session: with session.begin(): session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update( {"status": TestTaskStatus.NOT_STARTED.value}) session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter( AgentTestTaskConversations.status == TestTaskConversationsStatus.CANCELLED.value).update( {"status": TestTaskConversationsStatus.PENDING.value}) session.commit() # 重新执行任务 self._execute_task(task_id) logger.info(f"Resumed task {task_id}") return True def _execute_task(self, task_id: int): """提交任务到线程池执行""" # 确保任务状态一致性 if task_id in self.running_tasks: return # 创建任务事件和锁 if task_id not in self.task_events: self.task_events[task_id] = threading.Event() if task_id not in self.task_locks: self.task_locks[task_id] = threading.Lock() # 提交到线程池 future = self.executor.submit(self._process_task, task_id) self.task_futures[task_id] = future # 标记任务为运行中 with self.task_locks[task_id]: self.running_tasks.add(task_id) def _process_task(self, task_id: int): """处理任务的所有子任务""" try: # 更新任务状态为运行中 self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value) # 获取所有待处理的子任务 task_conversations = self.get_pending_task_conversations(task_id) # 执行每个子任务 for task_conversation in task_conversations: # 检查任务是否被取消 if self.task_events[task_id].is_set(): break # 更新子任务状态为运行中 self.update_task_conversations_status(task_conversation.id, TestTaskConversationsStatus.RUNNING.value) try: conversation_data = self.dataset_server.get_conversation_data_by_id( task_conversation.conversation_id) user_profile_data = self.dataset_server.get_user_profile_data(conversation_data.user_id) staff_profile_data = self.dataset_server.get_staff_profile_data(conversation_data.staff_id) # 模拟任务执行 - 在实际应用中替换为实际业务逻辑 # TODO 后续改成实际任务执行 time.sleep(1) score = '{"score":0.05}' # 更新子任务状态为已完成 self.update_task_conversations_res(task_conversation.id, TestTaskConversationsStatus.SUCCESS.value, score) except Exception as e: logger.error(f"Error executing task {task_id}: {str(e)}") self.update_task_conversations_status(task_conversation.id, TestTaskConversationsStatus.FAILED.value) # 检查任务是否完成 task_conversations = self.get_task_conversations(task_id) all_completed = all(task_conversation.status in (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value) for task_conversation in task_conversations) any_pending = any(task_conversation.status in (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value) for task_conversation in task_conversations) if all_completed: self.update_task_status(task_id, TestTaskStatus.COMPLETED.value) logger.info(f"Task {task_id} completed") elif not any_pending: # 没有待处理子任务但未全部完成(可能是取消了) current_status = self.get_task(task_id).status if current_status != TestTaskStatus.CANCELLED.value: self.update_task_status(task_id, TestTaskStatus.COMPLETED.value if all_completed else TestTaskStatus.CANCELLED.value) except Exception as e: logger.error(f"Error executing task {task_id}: {str(e)}") self.update_task_status(task_id, TestTaskStatus.COMPLETED.value) finally: # 清理资源 with self.task_locks[task_id]: if task_id in self.running_tasks: self.running_tasks.remove(task_id) if task_id in self.task_events: del self.task_events[task_id] if task_id in self.task_futures: del self.task_futures[task_id] def shutdown(self): """关闭执行器""" self.executor.shutdown(wait=False) logger.info("Task executor shutdown")