import json import threading import concurrent.futures import time import traceback from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Dict from sqlalchemy import func from pqai_agent import logging_service from pqai_agent.agents.message_push_agent import MessagePushAgent from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent from pqai_agent.data_models.agent_configuration import AgentConfiguration from pqai_agent.data_models.agent_test_task import AgentTestTask from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations from pqai_agent.data_models.service_module import ServiceModule from pqai_agent.utils.prompt_utils import format_agent_profile from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc from concurrent.futures import ThreadPoolExecutor from scripts.evaluate_agent import evaluate_agent logger = logging_service.logger class TaskManager: """任务管理器""" def __init__(self, session_maker, dataset_service): self.session_maker = session_maker self.dataset_service = dataset_service self.task_events = {} # 任务ID -> Event (用于取消任务) self.task_locks = {} # 任务ID -> Lock (用于任务状态同步) self.running_tasks = set() self.executor = ThreadPoolExecutor(max_workers=20, thread_name_prefix='TaskWorker') self.create_task_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='CreateTaskWorker') self.task_futures = {} # 任务ID -> Future def get_test_task_list(self, page_num: int, page_size: int) -> Dict: with self.session_maker() as session: # 计算偏移量 offset = (page_num - 1) * page_size # 查询分页数据 result = (session.query(AgentTestTask, AgentConfiguration) .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id) .limit(page_size).offset(offset).all()) # 查询总记录数 total = session.query(func.count(AgentTestTask.id)).scalar() total_page = total // page_size + 1 if total % page_size > 0 else total // page_size total_page = 1 if total_page <= 0 else total_page response_data = [ { "id": agent_test_task.id, "agentName": agent_configuration.name, "createUser": agent_test_task.create_user, "updateUser": agent_test_task.update_user, "statusName": get_test_task_status_desc(agent_test_task.status), "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"), "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S") } for agent_test_task, agent_configuration in result ] return { "currentPage": page_num, "pageSize": page_size, "totalSize": total_page, "total": total, "list": response_data, } def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict: with self.session_maker() as session: # 计算偏移量 offset = (page_num - 1) * page_size # 查询分页数据 result = (session.query(AgentTestTaskConversations, AgentConfiguration) .outerjoin(AgentConfiguration, AgentTestTaskConversations.agent_id == AgentConfiguration.id) .filter(AgentTestTaskConversations.task_id == task_id) .limit(page_size).offset(offset).all()) # 查询总记录数 total = session.query(func.count(AgentTestTaskConversations.id)).scalar() total_page = total // page_size + 1 if total % page_size > 0 else total // page_size total_page = 1 if total_page <= 0 else total_page response_data = [ { "id": agent_test_task_conversation.id, "agentName": agent_configuration.name, "input": MultiModalChatAgent.compose_dialogue(json.loads(agent_test_task_conversation.input)), "output": agent_test_task_conversation.output, "score": agent_test_task_conversation.score, "statusName": get_test_task_status_desc(agent_test_task_conversation.status), "createTime": agent_test_task_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S"), "updateTime": agent_test_task_conversation.update_time.strftime("%Y-%m-%d %H:%M:%S") } for agent_test_task_conversation, agent_configuration in result ] return { "currentPage": page_num, "pageSize": page_size, "totalSize": total_page, "total": total, "list": response_data, } def create_task(self, agent_id: int, module_id: int, evaluate_type: int) -> Dict: """创建新任务""" with self.session_maker() as session: agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id, evaluate_type=evaluate_type, status=TestTaskStatus.CREATING.value) session.add(agent_test_task) session.commit() # 显式提交 task_id = agent_test_task.id # 异步执行创建任务 self.create_task_executor.submit(self._generate_agent_test_task_conversation_batch, task_id, agent_id, module_id) return self.get_task(task_id) def _generate_agent_test_task_conversation_batch(self, task_id: int, agent_id: int, module_id: int): """异步生成子任务""" try: # 获取数据集列表 dataset_module_list = self.dataset_service.get_dataset_module_list_by_module(module_id) # 批量处理数据集 - 减少数据库交互 batch_size = 100 # 每批处理100个子任务 agent_test_task_conversation_batch = [] for dataset_module in dataset_module_list: # 获取对话数据列表 conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset( dataset_module.dataset_id) for conversation_data in conversation_datas: # 创建子任务对象 agent_test_task_conversation = AgentTestTaskConversations( task_id=task_id, agent_id=agent_id, dataset_id=dataset_module.dataset_id, conversation_id=conversation_data.id, status=TestTaskConversationsStatus.PENDING.value ) agent_test_task_conversation_batch.append(agent_test_task_conversation) # 批量提交 if len(agent_test_task_conversation_batch) >= batch_size: self.save_agent_test_task_conversation_batch(agent_test_task_conversation_batch) agent_test_task_conversation_batch = [] # 提交剩余的子任务 if agent_test_task_conversation_batch: self.save_agent_test_task_conversation_batch(agent_test_task_conversation_batch) # 更新主任务状态为未开始 self.update_task_status(task_id, TestTaskStatus.NOT_STARTED.value) # 自动提交任务执行 self._execute_task(task_id) except Exception as e: logger.error(f"生成子任务失败: {str(e)}") # 更新任务状态为失败 self.update_task_status(task_id, TestTaskStatus.CREATED_FAIL.value) def save_agent_test_task_conversation_batch(self, agent_test_task_conversation_batch: list): """批量保存子任务到数据库""" try: with self.session_maker() as session: with session.begin(): session.add_all(agent_test_task_conversation_batch) except Exception as e: logger.error(e) def get_agent_configuration_by_task_id(self, task_id: int): """获取指定任务ID对应的Agent配置信息""" with self.session_maker() as session: return session.query(AgentConfiguration) \ .join(AgentTestTask, AgentTestTask.agent_id == AgentConfiguration.id) \ .filter(AgentTestTask.id == task_id) \ .one_or_none() # 返回单个对象或None(如果未找到) def get_service_module_by_task_id(self, task_id: int): """获取指定任务ID对应的Agent配置信息""" with self.session_maker() as session: return session.query(ServiceModule) \ .join(AgentTestTask, AgentTestTask.module_id == ServiceModule.id) \ .filter(AgentTestTask.id == task_id) \ .one_or_none() # 返回单个对象或None(如果未找到) def get_task(self, task_id: int): """获取任务信息""" with self.session_maker() as session: return session.query(AgentTestTask).filter(AgentTestTask.id == task_id).one() def get_in_progress_task(self): """获取执行中任务""" with self.session_maker() as session: return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.IN_PROGRESS.value).all() def get_creating_task(self): """获取执行中任务""" with self.session_maker() as session: return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.CREATING.value).all() def get_task_conversations(self, task_id: int): """获取任务的所有子任务""" with self.session_maker() as session: return session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).all() def del_task_conversations(self, task_id: int): with self.session_maker() as session: session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).delete() # 提交事务生效 session.commit() def get_pending_task_conversations(self, task_id: int): """获取待处理的子任务""" with self.session_maker() as session: return session.query(AgentTestTaskConversations).filter( AgentTestTaskConversations.task_id == task_id).filter( AgentTestTaskConversations.status.in_([ TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value ])).all() def update_task_status(self, task_id: int, status: int): """更新任务状态""" with self.session_maker() as session: session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update( {"status": status, "update_time": datetime.now()}) session.commit() def update_task_conversations_status(self, task_conversations_id: int, status: int): """更新子任务状态""" with self.session_maker() as session: session.query(AgentTestTaskConversations).filter( AgentTestTaskConversations.id == task_conversations_id).update( {"status": status, "update_time": datetime.now()}) session.commit() def update_task_conversations_res(self, task_conversations_id: int, status: int, input: str, output: str, score: str): """更新子任务结果""" with self.session_maker() as session: session.query(AgentTestTaskConversations).filter( AgentTestTaskConversations.id == task_conversations_id).update( {"status": status, "input": input, "output": output, "score": score, "update_time": datetime.now()}) session.commit() def cancel_task(self, task_id: int): """取消任务(带事务支持)""" # 设置取消事件 if task_id in self.task_events: self.task_events[task_id].set() # 如果任务正在执行,尝试取消Future if task_id in self.task_futures: self.task_futures[task_id].cancel() with self.session_maker() as session: with session.begin(): session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update( {"status": TestTaskStatus.CANCELLED.value}) session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter( AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).update( {"status": TestTaskConversationsStatus.CANCELLED.value}) session.commit() def resume_task(self, task_id: int) -> bool: """恢复已取消的任务""" task = self.get_task(task_id) if not task or task.status != TestTaskStatus.CANCELLED.value: return False with self.session_maker() as session: with session.begin(): session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update( {"status": TestTaskStatus.NOT_STARTED.value}) session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter( AgentTestTaskConversations.status == TestTaskConversationsStatus.CANCELLED.value).update( {"status": TestTaskConversationsStatus.PENDING.value}) session.commit() # 重新执行任务 self._execute_task(task_id) logger.info(f"Resumed task {task_id}") return True def recover_tasks(self): """服务启动时恢复未完成的任务""" creating = self.get_creating_task() for task in creating: task_id = task.id agent_id = task.agent_id module_id = task.module_id self.del_task_conversations(task_id) # 重新提交任务 # 异步执行创建任务 self.create_task_executor.submit(self._generate_agent_test_task_conversation_batch, task_id, agent_id, module_id) # 获取所有进行中的任务ID(根据实际状态定义查询) in_progress_tasks = self.get_in_progress_task() for task in in_progress_tasks: task_id = task.id # 重新提交任务 self._execute_task(task_id) def _execute_task(self, task_id: int): """提交任务到线程池执行""" # 确保任务状态一致性 if task_id in self.running_tasks: return # 创建任务事件和锁 if task_id not in self.task_events: self.task_events[task_id] = threading.Event() if task_id not in self.task_locks: self.task_locks[task_id] = threading.Lock() # 提交到线程池 future = self.executor.submit(self._process_task, task_id) self.task_futures[task_id] = future # 标记任务为运行中 with self.task_locks[task_id]: self.running_tasks.add(task_id) def _process_task(self, task_id: int): """处理任务的所有子任务(并发执行)""" try: self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value) task_conversations = self.get_pending_task_conversations(task_id) if not task_conversations: self.update_task_status(task_id, TestTaskStatus.COMPLETED.value) return agent_configuration = self.get_agent_configuration_by_task_id(task_id) query_prompt_template = agent_configuration.task_prompt task = self.get_task(task_id) # 使用线程池执行子任务 with ThreadPoolExecutor(max_workers=8) as executor: # 可根据需要调整并发数 futures = {} for task_conversation in task_conversations: if self.task_events[task_id].is_set(): break # 检查任务取消事件 # 提交子任务到线程池 future = executor.submit( self._process_single_conversation, task_id, task, task_conversation, query_prompt_template, agent_configuration ) futures[future] = task_conversation.id # 等待所有子任务完成或取消 for future in concurrent.futures.as_completed(futures): conv_id = futures[future] try: future.result() # 获取结果(如有异常会在此抛出) except Exception as e: logger.error(f"Subtask {conv_id} failed: {str(e)}") self.update_task_conversations_status( conv_id, TestTaskConversationsStatus.FAILED.value ) # 检查最终任务状态 self._update_final_task_status(task_id) except Exception as e: logger.error(f"Error processing task {task_id}: {str(e)}") self.update_task_status(task_id, TestTaskStatus.FAILED.value) finally: self._cleanup_task_resources(task_id) def _process_single_conversation(self, task_id, task, task_conversation, query_prompt_template, agent_configuration): """处理单个对话子任务(线程安全)""" # 检查任务是否被取消 if self.task_events[task_id].is_set(): return # 更新子任务状态 if task_conversation.status == TestTaskConversationsStatus.PENDING.value: self.update_task_conversations_status( task_conversation.id, TestTaskConversationsStatus.RUNNING.value ) try: # 创建独立的agent实例(确保线程安全) agent = MultiModalChatAgent( model=agent_configuration.execution_model, system_prompt=agent_configuration.system_prompt, tools=json.loads(agent_configuration.tools) ) # 获取对话数据 conversation_data = self.dataset_service.get_conversation_data_by_id( task_conversation.conversation_id) user_profile_data = self.dataset_service.get_user_profile_data( conversation_data.user_id, conversation_data.version_date.replace("-", "")) user_profile = json.loads(user_profile_data['profile_data_v1']) avatar = user_profile_data['iconurl'] staff_profile = self.dataset_service.get_staff_profile_data( conversation_data.staff_id).agent_profile conversations = self.dataset_service.get_chat_conversation_list_by_ids( json.loads(conversation_data.conversation), conversation_data.staff_id ) conversations = sorted(conversations, key=lambda i: i['timestamp'], reverse=False) # 生成推送消息 last_timestamp = int(conversations[-1]["timestamp"]) match task.evaluate_type: case 0: send_timestamp = int(last_timestamp / 1000) + 10 case 1: send_timestamp = int(last_timestamp / 1000) + 24 * 3600 case _: raise ValueError("evaluate_type must be 0 or 1") send_time = datetime.fromtimestamp(send_timestamp).strftime('%Y-%m-%d %H:%M:%S') message = agent._generate_message( context={ "formatted_staff_profile": staff_profile, "nickname": user_profile['nickname'], "name": user_profile['name'], "avatar": avatar, "preferred_nickname": user_profile['preferred_nickname'], "gender": user_profile['gender'], "age": user_profile['age'], "region": user_profile['region'], "health_conditions": user_profile['health_conditions'], "medications": user_profile['medications'], "interests": user_profile['interests'], "current_datetime": send_time }, dialogue_history=conversations, query_prompt_template=query_prompt_template ) if not message: self.update_task_conversations_status( task_conversation.id, TestTaskConversationsStatus.MESSAGE_FAILED.value ) return param = {} param["dialogue_history"] = conversations param["message"] = message param["send_time"] = send_time param["agent_profile"] = json.loads(staff_profile) param["user_profile"] = user_profile score = evaluate_agent(param, task.evaluate_type) if not score: self.update_task_conversations_status( task_conversation.id, TestTaskConversationsStatus.SCORE_FAILED.value ) return # 更新子任务结果 self.update_task_conversations_res( task_conversation.id, TestTaskConversationsStatus.SUCCESS.value, json.dumps(conversations, ensure_ascii=False), json.dumps(message, ensure_ascii=False), json.dumps(score) ) except Exception as e: logger.error(f"Subtask {task_conversation.id} failed: {str(e)}") self.update_task_conversations_status( task_conversation.id, TestTaskConversationsStatus.FAILED.value ) raise # 重新抛出异常以便主线程捕获 def _update_final_task_status(self, task_id): """更新任务的最终状态""" task_conversations = self.get_task_conversations(task_id) all_completed = all( conv.status in (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value) for conv in task_conversations ) if all_completed: self.update_task_status(task_id, TestTaskStatus.COMPLETED.value) logger.info(f"Task {task_id} completed") elif not any( conv.status in (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value) for conv in task_conversations ): current_status = self.get_task(task_id).status if current_status != TestTaskStatus.CANCELLED.value: new_status = TestTaskStatus.COMPLETED.value if all_completed else TestTaskStatus.CANCELLED.value self.update_task_status(task_id, new_status) def _cleanup_task_resources(self, task_id): """清理任务资源(线程安全)""" with self.task_locks[task_id]: if task_id in self.running_tasks: self.running_tasks.remove(task_id) if task_id in self.task_events: del self.task_events[task_id] if task_id in self.task_futures: del self.task_futures[task_id] def shutdown(self): """关闭执行器""" self.executor.shutdown(wait=False) logger.info("Task executor shutdown")