|
@@ -0,0 +1,528 @@
|
|
|
+import json
|
|
|
+import threading
|
|
|
+import concurrent.futures
|
|
|
+import time
|
|
|
+import traceback
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
+from datetime import datetime
|
|
|
+from typing import Dict
|
|
|
+
|
|
|
+from sqlalchemy import func
|
|
|
+
|
|
|
+from pqai_agent import logging_service
|
|
|
+from pqai_agent.agents.message_push_agent import MessagePushAgent
|
|
|
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
|
|
|
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
|
|
|
+from pqai_agent.data_models.agent_test_task import AgentTestTask
|
|
|
+from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
|
|
|
+from pqai_agent.data_models.service_module import ServiceModule
|
|
|
+from pqai_agent.utils.prompt_utils import format_agent_profile
|
|
|
+from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc
|
|
|
+from concurrent.futures import ThreadPoolExecutor
|
|
|
+
|
|
|
+from scripts.evaluate_agent import evaluate_agent
|
|
|
+
|
|
|
+logger = logging_service.logger
|
|
|
+
|
|
|
+
|
|
|
+class TaskManager:
|
|
|
+ """任务管理器"""
|
|
|
+
|
|
|
+ def __init__(self, session_maker, dataset_service):
|
|
|
+ self.session_maker = session_maker
|
|
|
+ self.dataset_service = dataset_service
|
|
|
+ self.task_events = {} # 任务ID -> Event (用于取消任务)
|
|
|
+ self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
|
|
|
+ self.running_tasks = set()
|
|
|
+ self.executor = ThreadPoolExecutor(max_workers=20, thread_name_prefix='TaskWorker')
|
|
|
+ self.create_task_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='CreateTaskWorker')
|
|
|
+ self.task_futures = {} # 任务ID -> Future
|
|
|
+
|
|
|
+ def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
|
|
|
+ with self.session_maker() as session:
|
|
|
+ # 计算偏移量
|
|
|
+ offset = (page_num - 1) * page_size
|
|
|
+ # 查询分页数据
|
|
|
+ result = (session.query(AgentTestTask, AgentConfiguration)
|
|
|
+ .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
|
|
|
+ .limit(page_size).offset(offset).all())
|
|
|
+ # 查询总记录数
|
|
|
+ total = session.query(func.count(AgentTestTask.id)).scalar()
|
|
|
+
|
|
|
+ total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
|
|
|
+ total_page = 1 if total_page <= 0 else total_page
|
|
|
+ response_data = [
|
|
|
+ {
|
|
|
+ "id": agent_test_task.id,
|
|
|
+ "agentName": agent_configuration.name,
|
|
|
+ "createUser": agent_test_task.create_user,
|
|
|
+ "updateUser": agent_test_task.update_user,
|
|
|
+ "statusName": get_test_task_status_desc(agent_test_task.status),
|
|
|
+ "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
+ "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ }
|
|
|
+ for agent_test_task, agent_configuration in result
|
|
|
+ ]
|
|
|
+ return {
|
|
|
+ "currentPage": page_num,
|
|
|
+ "pageSize": page_size,
|
|
|
+ "totalSize": total_page,
|
|
|
+ "total": total,
|
|
|
+ "list": response_data,
|
|
|
+ }
|
|
|
+
|
|
|
+ def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict:
|
|
|
+ with self.session_maker() as session:
|
|
|
+ # 计算偏移量
|
|
|
+ offset = (page_num - 1) * page_size
|
|
|
+ # 查询分页数据
|
|
|
+ result = (session.query(AgentTestTaskConversations, AgentConfiguration)
|
|
|
+ .outerjoin(AgentConfiguration, AgentTestTaskConversations.agent_id == AgentConfiguration.id)
|
|
|
+ .filter(AgentTestTaskConversations.task_id == task_id)
|
|
|
+ .limit(page_size).offset(offset).all())
|
|
|
+ # 查询总记录数
|
|
|
+ total = session.query(func.count(AgentTestTaskConversations.id)).scalar()
|
|
|
+
|
|
|
+ total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
|
|
|
+ total_page = 1 if total_page <= 0 else total_page
|
|
|
+ response_data = [
|
|
|
+ {
|
|
|
+ "id": agent_test_task_conversation.id,
|
|
|
+ "agentName": agent_configuration.name,
|
|
|
+ "input": MultiModalChatAgent.compose_dialogue(json.loads(agent_test_task_conversation.input)),
|
|
|
+ "output": agent_test_task_conversation.output,
|
|
|
+ "score": agent_test_task_conversation.score,
|
|
|
+ "statusName": get_test_task_status_desc(agent_test_task_conversation.status),
|
|
|
+ "createTime": agent_test_task_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
+ "updateTime": agent_test_task_conversation.update_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ }
|
|
|
+ for agent_test_task_conversation, agent_configuration in result
|
|
|
+ ]
|
|
|
+ return {
|
|
|
+ "currentPage": page_num,
|
|
|
+ "pageSize": page_size,
|
|
|
+ "totalSize": total_page,
|
|
|
+ "total": total,
|
|
|
+ "list": response_data,
|
|
|
+ }
|
|
|
+
|
|
|
+ def create_task(self, agent_id: int, module_id: int, evaluate_type: int) -> Dict:
|
|
|
+ """创建新任务"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id, evaluate_type=evaluate_type,
|
|
|
+ status=TestTaskStatus.CREATING.value)
|
|
|
+ session.add(agent_test_task)
|
|
|
+ session.commit() # 显式提交
|
|
|
+ task_id = agent_test_task.id
|
|
|
+ # 异步执行创建任务
|
|
|
+ self.create_task_executor.submit(self._generate_agent_test_task_conversation_batch, task_id, agent_id,
|
|
|
+ module_id)
|
|
|
+ return self.get_task(task_id)
|
|
|
+
|
|
|
+ def _generate_agent_test_task_conversation_batch(self, task_id: int, agent_id: int, module_id: int):
|
|
|
+ """异步生成子任务"""
|
|
|
+ try:
|
|
|
+ # 获取数据集列表
|
|
|
+ dataset_module_list = self.dataset_service.get_dataset_module_list_by_module(module_id)
|
|
|
+
|
|
|
+ # 批量处理数据集 - 减少数据库交互
|
|
|
+ batch_size = 100 # 每批处理100个子任务
|
|
|
+ agent_test_task_conversation_batch = []
|
|
|
+
|
|
|
+ for dataset_module in dataset_module_list:
|
|
|
+ # 获取对话数据列表
|
|
|
+ conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(
|
|
|
+ dataset_module.dataset_id)
|
|
|
+
|
|
|
+ for conversation_data in conversation_datas:
|
|
|
+ # 创建子任务对象
|
|
|
+ agent_test_task_conversation = AgentTestTaskConversations(
|
|
|
+ task_id=task_id,
|
|
|
+ agent_id=agent_id,
|
|
|
+ dataset_id=dataset_module.dataset_id,
|
|
|
+ conversation_id=conversation_data.id,
|
|
|
+ status=TestTaskConversationsStatus.PENDING.value
|
|
|
+ )
|
|
|
+ agent_test_task_conversation_batch.append(agent_test_task_conversation)
|
|
|
+
|
|
|
+ # 批量提交
|
|
|
+ if len(agent_test_task_conversation_batch) >= batch_size:
|
|
|
+ self.save_agent_test_task_conversation_batch(agent_test_task_conversation_batch)
|
|
|
+ agent_test_task_conversation_batch = []
|
|
|
+
|
|
|
+ # 提交剩余的子任务
|
|
|
+ if agent_test_task_conversation_batch:
|
|
|
+ self.save_agent_test_task_conversation_batch(agent_test_task_conversation_batch)
|
|
|
+
|
|
|
+ # 更新主任务状态为未开始
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.NOT_STARTED.value)
|
|
|
+
|
|
|
+ # 自动提交任务执行
|
|
|
+ self._execute_task(task_id)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"生成子任务失败: {str(e)}")
|
|
|
+ # 更新任务状态为失败
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.CREATED_FAIL.value)
|
|
|
+
|
|
|
+ def save_agent_test_task_conversation_batch(self, agent_test_task_conversation_batch: list):
|
|
|
+ """批量保存子任务到数据库"""
|
|
|
+ try:
|
|
|
+ with self.session_maker() as session:
|
|
|
+ with session.begin():
|
|
|
+ session.add_all(agent_test_task_conversation_batch)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(e)
|
|
|
+
|
|
|
+ def get_agent_configuration_by_task_id(self, task_id: int):
|
|
|
+ """获取指定任务ID对应的Agent配置信息"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(AgentConfiguration) \
|
|
|
+ .join(AgentTestTask, AgentTestTask.agent_id == AgentConfiguration.id) \
|
|
|
+ .filter(AgentTestTask.id == task_id) \
|
|
|
+ .one_or_none() # 返回单个对象或None(如果未找到)
|
|
|
+
|
|
|
+ def get_service_module_by_task_id(self, task_id: int):
|
|
|
+ """获取指定任务ID对应的Agent配置信息"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(ServiceModule) \
|
|
|
+ .join(AgentTestTask, AgentTestTask.module_id == ServiceModule.id) \
|
|
|
+ .filter(AgentTestTask.id == task_id) \
|
|
|
+ .one_or_none() # 返回单个对象或None(如果未找到)
|
|
|
+
|
|
|
+ def get_task(self, task_id: int):
|
|
|
+ """获取任务信息"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(AgentTestTask).filter(AgentTestTask.id == task_id).one()
|
|
|
+
|
|
|
+ def get_in_progress_task(self):
|
|
|
+ """获取执行中任务"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.IN_PROGRESS.value).all()
|
|
|
+
|
|
|
+ def get_creating_task(self):
|
|
|
+ """获取执行中任务"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.CREATING.value).all()
|
|
|
+
|
|
|
+ def get_task_conversations(self, task_id: int):
|
|
|
+ """获取任务的所有子任务"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).all()
|
|
|
+
|
|
|
+ def del_task_conversations(self, task_id: int):
|
|
|
+ with self.session_maker() as session:
|
|
|
+ session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).delete()
|
|
|
+ # 提交事务生效
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ def get_pending_task_conversations(self, task_id: int):
|
|
|
+ """获取待处理的子任务"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(AgentTestTaskConversations).filter(
|
|
|
+ AgentTestTaskConversations.task_id == task_id).filter(
|
|
|
+ AgentTestTaskConversations.status.in_([
|
|
|
+ TestTaskConversationsStatus.PENDING.value,
|
|
|
+ TestTaskConversationsStatus.RUNNING.value
|
|
|
+ ])).all()
|
|
|
+
|
|
|
+ def update_task_status(self, task_id: int, status: int):
|
|
|
+ """更新任务状态"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
|
|
|
+ {"status": status, "update_time": datetime.now()})
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ def update_task_conversations_status(self, task_conversations_id: int, status: int):
|
|
|
+ """更新子任务状态"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ session.query(AgentTestTaskConversations).filter(
|
|
|
+ AgentTestTaskConversations.id == task_conversations_id).update(
|
|
|
+ {"status": status, "update_time": datetime.now()})
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ def update_task_conversations_res(self, task_conversations_id: int, status: int, input: str, output: str,
|
|
|
+ score: str):
|
|
|
+ """更新子任务结果"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ session.query(AgentTestTaskConversations).filter(
|
|
|
+ AgentTestTaskConversations.id == task_conversations_id).update(
|
|
|
+ {"status": status, "input": input, "output": output, "score": score, "update_time": datetime.now()})
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ def cancel_task(self, task_id: int):
|
|
|
+ """取消任务(带事务支持)"""
|
|
|
+ # 设置取消事件
|
|
|
+ if task_id in self.task_events:
|
|
|
+ self.task_events[task_id].set()
|
|
|
+ # 如果任务正在执行,尝试取消Future
|
|
|
+ if task_id in self.task_futures:
|
|
|
+ self.task_futures[task_id].cancel()
|
|
|
+
|
|
|
+ with self.session_maker() as session:
|
|
|
+ with session.begin():
|
|
|
+ session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
|
|
|
+ {"status": TestTaskStatus.CANCELLED.value})
|
|
|
+ session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
|
|
|
+ AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).update(
|
|
|
+ {"status": TestTaskConversationsStatus.CANCELLED.value})
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ def resume_task(self, task_id: int) -> bool:
|
|
|
+ """恢复已取消的任务"""
|
|
|
+ task = self.get_task(task_id)
|
|
|
+ if not task or task.status != TestTaskStatus.CANCELLED.value:
|
|
|
+ return False
|
|
|
+
|
|
|
+ with self.session_maker() as session:
|
|
|
+ with session.begin():
|
|
|
+ session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
|
|
|
+ {"status": TestTaskStatus.NOT_STARTED.value})
|
|
|
+ session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
|
|
|
+ AgentTestTaskConversations.status == TestTaskConversationsStatus.CANCELLED.value).update(
|
|
|
+ {"status": TestTaskConversationsStatus.PENDING.value})
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ # 重新执行任务
|
|
|
+ self._execute_task(task_id)
|
|
|
+ logger.info(f"Resumed task {task_id}")
|
|
|
+ return True
|
|
|
+
|
|
|
+ def recover_tasks(self):
|
|
|
+ """服务启动时恢复未完成的任务"""
|
|
|
+
|
|
|
+ creating = self.get_creating_task()
|
|
|
+ for task in creating:
|
|
|
+ task_id = task.id
|
|
|
+ agent_id = task.agent_id
|
|
|
+ module_id = task.module_id
|
|
|
+ self.del_task_conversations(task_id)
|
|
|
+ # 重新提交任务
|
|
|
+ # 异步执行创建任务
|
|
|
+ self.create_task_executor.submit(self._generate_agent_test_task_conversation_batch, task_id, agent_id,
|
|
|
+ module_id)
|
|
|
+
|
|
|
+ # 获取所有进行中的任务ID(根据实际状态定义查询)
|
|
|
+ in_progress_tasks = self.get_in_progress_task()
|
|
|
+
|
|
|
+ for task in in_progress_tasks:
|
|
|
+ task_id = task.id
|
|
|
+ # 重新提交任务
|
|
|
+ self._execute_task(task_id)
|
|
|
+
|
|
|
+ def _execute_task(self, task_id: int):
|
|
|
+ """提交任务到线程池执行"""
|
|
|
+ # 确保任务状态一致性
|
|
|
+ if task_id in self.running_tasks:
|
|
|
+ return
|
|
|
+
|
|
|
+ # 创建任务事件和锁
|
|
|
+ if task_id not in self.task_events:
|
|
|
+ self.task_events[task_id] = threading.Event()
|
|
|
+ if task_id not in self.task_locks:
|
|
|
+ self.task_locks[task_id] = threading.Lock()
|
|
|
+
|
|
|
+ # 提交到线程池
|
|
|
+ future = self.executor.submit(self._process_task, task_id)
|
|
|
+ self.task_futures[task_id] = future
|
|
|
+
|
|
|
+ # 标记任务为运行中
|
|
|
+ with self.task_locks[task_id]:
|
|
|
+ self.running_tasks.add(task_id)
|
|
|
+
|
|
|
+ def _process_task(self, task_id: int):
|
|
|
+ """处理任务的所有子任务(并发执行)"""
|
|
|
+ try:
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
|
|
|
+ task_conversations = self.get_pending_task_conversations(task_id)
|
|
|
+
|
|
|
+ if not task_conversations:
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
|
|
|
+ return
|
|
|
+
|
|
|
+ agent_configuration = self.get_agent_configuration_by_task_id(task_id)
|
|
|
+ query_prompt_template = agent_configuration.task_prompt
|
|
|
+
|
|
|
+ task = self.get_task(task_id)
|
|
|
+
|
|
|
+ # 使用线程池执行子任务
|
|
|
+ with ThreadPoolExecutor(max_workers=8) as executor: # 可根据需要调整并发数
|
|
|
+ futures = {}
|
|
|
+ for task_conversation in task_conversations:
|
|
|
+ if self.task_events[task_id].is_set():
|
|
|
+ break # 检查任务取消事件
|
|
|
+
|
|
|
+ # 提交子任务到线程池
|
|
|
+ future = executor.submit(
|
|
|
+ self._process_single_conversation,
|
|
|
+ task_id,
|
|
|
+ task,
|
|
|
+ task_conversation,
|
|
|
+ query_prompt_template,
|
|
|
+ agent_configuration
|
|
|
+ )
|
|
|
+ futures[future] = task_conversation.id
|
|
|
+
|
|
|
+ # 等待所有子任务完成或取消
|
|
|
+ for future in concurrent.futures.as_completed(futures):
|
|
|
+ conv_id = futures[future]
|
|
|
+ try:
|
|
|
+ future.result() # 获取结果(如有异常会在此抛出)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Subtask {conv_id} failed: {str(e)}")
|
|
|
+ self.update_task_conversations_status(
|
|
|
+ conv_id,
|
|
|
+ TestTaskConversationsStatus.FAILED.value
|
|
|
+ )
|
|
|
+
|
|
|
+ # 检查最终任务状态
|
|
|
+ self._update_final_task_status(task_id)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error processing task {task_id}: {str(e)}")
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.FAILED.value)
|
|
|
+ finally:
|
|
|
+ self._cleanup_task_resources(task_id)
|
|
|
+
|
|
|
+ def _process_single_conversation(self, task_id, task, task_conversation, query_prompt_template,
|
|
|
+ agent_configuration):
|
|
|
+ """处理单个对话子任务(线程安全)"""
|
|
|
+ # 检查任务是否被取消
|
|
|
+ if self.task_events[task_id].is_set():
|
|
|
+ return
|
|
|
+
|
|
|
+ # 更新子任务状态
|
|
|
+ if task_conversation.status == TestTaskConversationsStatus.PENDING.value:
|
|
|
+ self.update_task_conversations_status(
|
|
|
+ task_conversation.id,
|
|
|
+ TestTaskConversationsStatus.RUNNING.value
|
|
|
+ )
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 创建独立的agent实例(确保线程安全)
|
|
|
+ agent = MultiModalChatAgent(
|
|
|
+ model=agent_configuration.execution_model,
|
|
|
+ system_prompt=agent_configuration.system_prompt,
|
|
|
+ tools=json.loads(agent_configuration.tools)
|
|
|
+ )
|
|
|
+
|
|
|
+ # 获取对话数据
|
|
|
+ conversation_data = self.dataset_service.get_conversation_data_by_id(
|
|
|
+ task_conversation.conversation_id)
|
|
|
+ user_profile_data = self.dataset_service.get_user_profile_data(
|
|
|
+ conversation_data.user_id,
|
|
|
+ conversation_data.version_date.replace("-", ""))
|
|
|
+ user_profile = json.loads(user_profile_data['profile_data_v1'])
|
|
|
+ avatar = user_profile_data['iconurl']
|
|
|
+ staff_profile = self.dataset_service.get_staff_profile_data(
|
|
|
+ conversation_data.staff_id).agent_profile
|
|
|
+ conversations = self.dataset_service.get_chat_conversation_list_by_ids(
|
|
|
+ json.loads(conversation_data.conversation),
|
|
|
+ conversation_data.staff_id
|
|
|
+ )
|
|
|
+ conversations = sorted(conversations, key=lambda i: i['timestamp'], reverse=False)
|
|
|
+
|
|
|
+ # 生成推送消息
|
|
|
+ last_timestamp = int(conversations[-1]["timestamp"])
|
|
|
+ match task.evaluate_type:
|
|
|
+ case 0:
|
|
|
+ send_timestamp = int(last_timestamp / 1000) + 10
|
|
|
+ case 1:
|
|
|
+ send_timestamp = int(last_timestamp / 1000) + 24 * 3600
|
|
|
+ case _:
|
|
|
+ raise ValueError("evaluate_type must be 0 or 1")
|
|
|
+ send_time = datetime.fromtimestamp(send_timestamp).strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ message = agent._generate_message(
|
|
|
+ context={
|
|
|
+ "formatted_staff_profile": staff_profile,
|
|
|
+ "nickname": user_profile['nickname'],
|
|
|
+ "name": user_profile['name'],
|
|
|
+ "avatar": avatar,
|
|
|
+ "preferred_nickname": user_profile['preferred_nickname'],
|
|
|
+ "gender": user_profile['gender'],
|
|
|
+ "age": user_profile['age'],
|
|
|
+ "region": user_profile['region'],
|
|
|
+ "health_conditions": user_profile['health_conditions'],
|
|
|
+ "medications": user_profile['medications'],
|
|
|
+ "interests": user_profile['interests'],
|
|
|
+ "current_datetime": send_time
|
|
|
+ },
|
|
|
+ dialogue_history=conversations,
|
|
|
+ query_prompt_template=query_prompt_template
|
|
|
+ )
|
|
|
+
|
|
|
+ if not message:
|
|
|
+ self.update_task_conversations_status(
|
|
|
+ task_conversation.id,
|
|
|
+ TestTaskConversationsStatus.MESSAGE_FAILED.value
|
|
|
+ )
|
|
|
+ return
|
|
|
+
|
|
|
+ param = {}
|
|
|
+ param["dialogue_history"] = conversations
|
|
|
+ param["message"] = message
|
|
|
+ param["send_time"] = send_time
|
|
|
+ param["agent_profile"] = json.loads(staff_profile)
|
|
|
+ param["user_profile"] = user_profile
|
|
|
+ score = evaluate_agent(param, task.evaluate_type)
|
|
|
+
|
|
|
+ if not score:
|
|
|
+ self.update_task_conversations_status(
|
|
|
+ task_conversation.id,
|
|
|
+ TestTaskConversationsStatus.SCORE_FAILED.value
|
|
|
+ )
|
|
|
+ return
|
|
|
+
|
|
|
+ # 更新子任务结果
|
|
|
+ self.update_task_conversations_res(
|
|
|
+ task_conversation.id,
|
|
|
+ TestTaskConversationsStatus.SUCCESS.value,
|
|
|
+ json.dumps(conversations, ensure_ascii=False),
|
|
|
+ json.dumps(message, ensure_ascii=False),
|
|
|
+ json.dumps(score)
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Subtask {task_conversation.id} failed: {str(e)}")
|
|
|
+ self.update_task_conversations_status(
|
|
|
+ task_conversation.id,
|
|
|
+ TestTaskConversationsStatus.FAILED.value
|
|
|
+ )
|
|
|
+ raise # 重新抛出异常以便主线程捕获
|
|
|
+
|
|
|
+ def _update_final_task_status(self, task_id):
|
|
|
+ """更新任务的最终状态"""
|
|
|
+ task_conversations = self.get_task_conversations(task_id)
|
|
|
+ all_completed = all(
|
|
|
+ conv.status in (TestTaskConversationsStatus.SUCCESS.value,
|
|
|
+ TestTaskConversationsStatus.FAILED.value)
|
|
|
+ for conv in task_conversations
|
|
|
+ )
|
|
|
+
|
|
|
+ if all_completed:
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
|
|
|
+ logger.info(f"Task {task_id} completed")
|
|
|
+ elif not any(
|
|
|
+ conv.status in (TestTaskConversationsStatus.PENDING.value,
|
|
|
+ TestTaskConversationsStatus.RUNNING.value)
|
|
|
+ for conv in task_conversations
|
|
|
+ ):
|
|
|
+ current_status = self.get_task(task_id).status
|
|
|
+ if current_status != TestTaskStatus.CANCELLED.value:
|
|
|
+ new_status = TestTaskStatus.COMPLETED.value if all_completed else TestTaskStatus.CANCELLED.value
|
|
|
+ self.update_task_status(task_id, new_status)
|
|
|
+
|
|
|
+ def _cleanup_task_resources(self, task_id):
|
|
|
+ """清理任务资源(线程安全)"""
|
|
|
+ with self.task_locks[task_id]:
|
|
|
+ if task_id in self.running_tasks:
|
|
|
+ self.running_tasks.remove(task_id)
|
|
|
+ if task_id in self.task_events:
|
|
|
+ del self.task_events[task_id]
|
|
|
+ if task_id in self.task_futures:
|
|
|
+ del self.task_futures[task_id]
|
|
|
+
|
|
|
+ def shutdown(self):
|
|
|
+ """关闭执行器"""
|
|
|
+ self.executor.shutdown(wait=False)
|
|
|
+ logger.info("Task executor shutdown")
|