|
@@ -24,20 +24,19 @@ class TaskManager:
|
|
|
def __init__(self, session_maker, dataset_service):
|
|
|
self.session_maker = session_maker
|
|
|
self.dataset_service = dataset_service
|
|
|
- self.task_events = {} # 任务ID -> Event (用于取消任务)
|
|
|
self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
|
|
|
self.running_tasks = set()
|
|
|
self.executor = ThreadPoolExecutor(max_workers=20, thread_name_prefix='TaskWorker')
|
|
|
self.create_task_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='CreateTaskWorker')
|
|
|
- self.task_futures = {} # 任务ID -> Future
|
|
|
|
|
|
def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
|
|
|
with self.session_maker() as session:
|
|
|
# 计算偏移量
|
|
|
offset = (page_num - 1) * page_size
|
|
|
# 查询分页数据
|
|
|
- result = (session.query(AgentTestTask, AgentConfiguration)
|
|
|
+ result = (session.query(AgentTestTask, AgentConfiguration, ServiceModule)
|
|
|
.outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
|
|
|
+ .outerjoin(ServiceModule, AgentTestTask.module_id == ServiceModule.id)
|
|
|
.limit(page_size).offset(offset).all())
|
|
|
# 查询总记录数
|
|
|
total = session.query(func.count(AgentTestTask.id)).scalar()
|
|
@@ -47,14 +46,17 @@ class TaskManager:
|
|
|
response_data = [
|
|
|
{
|
|
|
"id": agent_test_task.id,
|
|
|
- "agentName": agent_configuration.name,
|
|
|
+ "agentId": agent_configuration.id,
|
|
|
+ "agentName": agent_configuration.display_name,
|
|
|
+ "moduleName": service_module.display_name,
|
|
|
"createUser": agent_test_task.create_user,
|
|
|
"updateUser": agent_test_task.update_user,
|
|
|
+ "status": agent_test_task.status,
|
|
|
"statusName": get_test_task_status_desc(agent_test_task.status),
|
|
|
"createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
"updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
}
|
|
|
- for agent_test_task, agent_configuration in result
|
|
|
+ for agent_test_task, agent_configuration, service_module in result
|
|
|
]
|
|
|
return {
|
|
|
"currentPage": page_num,
|
|
@@ -82,7 +84,8 @@ class TaskManager:
|
|
|
response_data = [
|
|
|
{
|
|
|
"id": agent_test_task_conversation.id,
|
|
|
- "agentName": agent_configuration.name,
|
|
|
+ "datasetId": agent_test_task_conversation.dataset_id,
|
|
|
+ "conversationId": agent_test_task_conversation.conversation_id,
|
|
|
"input": MultiModalChatAgent.compose_dialogue(json.loads(agent_test_task_conversation.input))
|
|
|
if agent_test_task_conversation.input and agent_test_task_conversation.input.strip()
|
|
|
else None,
|
|
@@ -102,11 +105,11 @@ class TaskManager:
|
|
|
"list": response_data,
|
|
|
}
|
|
|
|
|
|
- def create_task(self, agent_id: int, module_id: int, evaluate_type: int) -> Dict:
|
|
|
+ def create_task(self, agent_id: int, module_id: int, evaluate_type: int, user: str) -> Dict:
|
|
|
"""创建新任务"""
|
|
|
with self.session_maker() as session:
|
|
|
agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id, evaluate_type=evaluate_type,
|
|
|
- status=TestTaskStatus.CREATING.value)
|
|
|
+ status=TestTaskStatus.CREATING.value, create_user=user, update_user=user)
|
|
|
session.add(agent_test_task)
|
|
|
session.commit() # 显式提交
|
|
|
task_id = agent_test_task.id
|
|
@@ -224,6 +227,11 @@ class TaskManager:
|
|
|
TestTaskConversationsStatus.PENDING.value,
|
|
|
TestTaskConversationsStatus.RUNNING.value
|
|
|
])).all()
|
|
|
+ def get_task_conversation(self, task_conversation_id: int):
|
|
|
+ """获取待处理的子任务"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(AgentTestTaskConversations).filter(
|
|
|
+ AgentTestTaskConversations.id == task_conversation_id).one()
|
|
|
|
|
|
def update_task_status(self, task_id: int, status: int):
|
|
|
"""更新任务状态"""
|
|
@@ -249,26 +257,19 @@ class TaskManager:
|
|
|
{"status": status, "input": input, "output": output, "score": score, "update_time": datetime.now()})
|
|
|
session.commit()
|
|
|
|
|
|
- def cancel_task(self, task_id: int):
|
|
|
- """取消任务(带事务支持)"""
|
|
|
- # 设置取消事件
|
|
|
- if task_id in self.task_events:
|
|
|
- self.task_events[task_id].set()
|
|
|
- # 如果任务正在执行,尝试取消Future
|
|
|
- if task_id in self.task_futures:
|
|
|
- self.task_futures[task_id].cancel()
|
|
|
-
|
|
|
+ def cancel_task(self, task_id: int, user: str):
|
|
|
with self.session_maker() as session:
|
|
|
with session.begin():
|
|
|
session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
|
|
|
- {"status": TestTaskStatus.CANCELLED.value})
|
|
|
+ {"status": TestTaskStatus.CANCELLED.value, "update_user": user, "update_time": datetime.now()})
|
|
|
session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
|
|
|
AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).update(
|
|
|
- {"status": TestTaskConversationsStatus.CANCELLED.value})
|
|
|
+ {"status": TestTaskConversationsStatus.CANCELLED.value, "update_time": datetime.now()})
|
|
|
session.commit()
|
|
|
|
|
|
- def resume_task(self, task_id: int) -> bool:
|
|
|
- """恢复已取消的任务"""
|
|
|
+ self._cleanup_task_resources(task_id)
|
|
|
+
|
|
|
+ def resume_task(self, task_id: int, user: str) -> bool:
|
|
|
task = self.get_task(task_id)
|
|
|
if not task or task.status != TestTaskStatus.CANCELLED.value:
|
|
|
return False
|
|
@@ -276,12 +277,11 @@ class TaskManager:
|
|
|
with self.session_maker() as session:
|
|
|
with session.begin():
|
|
|
session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
|
|
|
- {"status": TestTaskStatus.NOT_STARTED.value})
|
|
|
+ {"status": TestTaskStatus.NOT_STARTED.value, "update_user": user, "update_time": datetime.now()})
|
|
|
session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
|
|
|
AgentTestTaskConversations.status == TestTaskConversationsStatus.CANCELLED.value).update(
|
|
|
- {"status": TestTaskConversationsStatus.PENDING.value})
|
|
|
+ {"status": TestTaskConversationsStatus.PENDING.value, "update_time": datetime.now()})
|
|
|
session.commit()
|
|
|
-
|
|
|
# 重新执行任务
|
|
|
self._execute_task(task_id)
|
|
|
logger.info(f"Resumed task {task_id}")
|
|
@@ -316,14 +316,11 @@ class TaskManager:
|
|
|
return
|
|
|
|
|
|
# 创建任务事件和锁
|
|
|
- if task_id not in self.task_events:
|
|
|
- self.task_events[task_id] = threading.Event()
|
|
|
if task_id not in self.task_locks:
|
|
|
self.task_locks[task_id] = threading.Lock()
|
|
|
|
|
|
# 提交到线程池
|
|
|
- future = self.executor.submit(self._process_task, task_id)
|
|
|
- self.task_futures[task_id] = future
|
|
|
+ self.executor.submit(self._process_task, task_id)
|
|
|
|
|
|
# 标记任务为运行中
|
|
|
with self.task_locks[task_id]:
|
|
@@ -348,9 +345,6 @@ class TaskManager:
|
|
|
with ThreadPoolExecutor(max_workers=8) as executor: # 可根据需要调整并发数
|
|
|
futures = {}
|
|
|
for task_conversation in task_conversations:
|
|
|
- if self.task_events[task_id].is_set():
|
|
|
- break # 检查任务取消事件
|
|
|
-
|
|
|
# 提交子任务到线程池
|
|
|
future = executor.submit(
|
|
|
self._process_single_conversation,
|
|
@@ -385,26 +379,17 @@ class TaskManager:
|
|
|
|
|
|
def _process_single_conversation(self, task_id, task, task_conversation, query_prompt_template,
|
|
|
agent_configuration):
|
|
|
- """处理单个对话子任务(线程安全)"""
|
|
|
- # 检查任务是否被取消
|
|
|
- if self.task_events[task_id].is_set():
|
|
|
- return
|
|
|
-
|
|
|
+ task_conversation = self.get_task_conversation(task_conversation.id)
|
|
|
# 更新子任务状态
|
|
|
if task_conversation.status == TestTaskConversationsStatus.PENDING.value:
|
|
|
self.update_task_conversations_status(
|
|
|
task_conversation.id,
|
|
|
TestTaskConversationsStatus.RUNNING.value
|
|
|
)
|
|
|
+ else:
|
|
|
+ return
|
|
|
|
|
|
try:
|
|
|
- # 创建独立的agent实例(确保线程安全)
|
|
|
- agent = MultiModalChatAgent(
|
|
|
- model=agent_configuration.execution_model,
|
|
|
- system_prompt=agent_configuration.system_prompt,
|
|
|
- tools=json.loads(agent_configuration.tools)
|
|
|
- )
|
|
|
-
|
|
|
# 获取对话数据
|
|
|
conversation_data = self.dataset_service.get_conversation_data_by_id(
|
|
|
task_conversation.conversation_id)
|
|
@@ -431,6 +416,21 @@ class TaskManager:
|
|
|
case _:
|
|
|
raise ValueError("evaluate_type must be 0 or 1")
|
|
|
send_time = datetime.fromtimestamp(send_timestamp).strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Subtask {task_conversation.id} failed: {str(e)}")
|
|
|
+ self.update_task_conversations_status(
|
|
|
+ task_conversation.id,
|
|
|
+ TestTaskConversationsStatus.FAILED.value
|
|
|
+ )
|
|
|
+ return
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 创建独立的agent实例(确保线程安全)
|
|
|
+ agent = MultiModalChatAgent(
|
|
|
+ model=agent_configuration.execution_model,
|
|
|
+ system_prompt=agent_configuration.system_prompt,
|
|
|
+ tools=json.loads(agent_configuration.tools)
|
|
|
+ )
|
|
|
message = agent._generate_message(
|
|
|
context={
|
|
|
"formatted_staff_profile": staff_profile,
|
|
@@ -456,7 +456,15 @@ class TaskManager:
|
|
|
TestTaskConversationsStatus.MESSAGE_FAILED.value
|
|
|
)
|
|
|
return
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Subtask {task_conversation.id} failed: {str(e)}")
|
|
|
+ self.update_task_conversations_status(
|
|
|
+ task_conversation.id,
|
|
|
+ TestTaskConversationsStatus.MESSAGE_FAILED.value
|
|
|
+ )
|
|
|
+ return
|
|
|
|
|
|
+ try:
|
|
|
param = {}
|
|
|
param["dialogue_history"] = conversations
|
|
|
param["message"] = message
|
|
@@ -478,23 +486,24 @@ class TaskManager:
|
|
|
TestTaskConversationsStatus.SUCCESS.value,
|
|
|
json.dumps(conversations, ensure_ascii=False),
|
|
|
json.dumps(message, ensure_ascii=False),
|
|
|
- json.dumps(score)
|
|
|
+ json.dumps(score, ensure_ascii=False)
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Subtask {task_conversation.id} failed: {str(e)}")
|
|
|
self.update_task_conversations_status(
|
|
|
task_conversation.id,
|
|
|
- TestTaskConversationsStatus.FAILED.value
|
|
|
+ TestTaskConversationsStatus.SCORE_FAILED.value
|
|
|
)
|
|
|
- raise # 重新抛出异常以便主线程捕获
|
|
|
|
|
|
def _update_final_task_status(self, task_id):
|
|
|
"""更新任务的最终状态"""
|
|
|
task_conversations = self.get_task_conversations(task_id)
|
|
|
all_completed = all(
|
|
|
conv.status in (TestTaskConversationsStatus.SUCCESS.value,
|
|
|
- TestTaskConversationsStatus.FAILED.value)
|
|
|
+ TestTaskConversationsStatus.FAILED.value,
|
|
|
+ TestTaskConversationsStatus.MESSAGE_FAILED.value,
|
|
|
+ TestTaskConversationsStatus.SCORE_FAILED.value)
|
|
|
for conv in task_conversations
|
|
|
)
|
|
|
|
|
@@ -516,10 +525,6 @@ class TaskManager:
|
|
|
with self.task_locks[task_id]:
|
|
|
if task_id in self.running_tasks:
|
|
|
self.running_tasks.remove(task_id)
|
|
|
- if task_id in self.task_events:
|
|
|
- del self.task_events[task_id]
|
|
|
- if task_id in self.task_futures:
|
|
|
- del self.task_futures[task_id]
|
|
|
|
|
|
def shutdown(self):
|
|
|
"""关闭执行器"""
|