|
@@ -1,10 +1,10 @@
|
|
|
-import threading
|
|
|
+import json
|
|
|
import threading
|
|
|
import time
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
+from datetime import datetime
|
|
|
from typing import Dict
|
|
|
|
|
|
-from pyarrow.dataset import dataset
|
|
|
from sqlalchemy import func
|
|
|
|
|
|
from pqai_agent import logging_service
|
|
@@ -19,9 +19,9 @@ logger = logging_service.logger
|
|
|
class TaskManager:
|
|
|
"""任务管理器"""
|
|
|
|
|
|
- def __init__(self, session_maker, dataset_server, max_workers: int = 10):
|
|
|
+ def __init__(self, session_maker, dataset_service, max_workers: int = 10):
|
|
|
self.session_maker = session_maker
|
|
|
- self.dataset_server = dataset_server
|
|
|
+ self.dataset_service = dataset_service
|
|
|
self.task_events = {} # 任务ID -> Event (用于取消任务)
|
|
|
self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
|
|
|
self.running_tasks = set()
|
|
@@ -105,9 +105,9 @@ class TaskManager:
|
|
|
session.flush() # 强制SQL执行,但不提交事务
|
|
|
task_id = agent_test_task.id
|
|
|
agent_test_task_conversations = []
|
|
|
- datasets_list = self.dataset_server.get_dataset_list_by_module(module_id)
|
|
|
+ datasets_list = self.dataset_service.get_dataset_list_by_module(module_id)
|
|
|
for datasets in datasets_list:
|
|
|
- conversation_datas = self.dataset_server.get_conversation_data_list_by_dataset(datasets.id)
|
|
|
+ conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(datasets.id)
|
|
|
for conversation_data in conversation_datas:
|
|
|
agent_test_task_conversation = AgentTestTaskConversations(task_id=task_id, agent_id=agent_id,
|
|
|
dataset_id=datasets.id,
|
|
@@ -123,6 +123,11 @@ class TaskManager:
|
|
|
with self.session_maker() as session:
|
|
|
return session.query(AgentTestTask).filter(AgentTestTask.id == task_id).one()
|
|
|
|
|
|
+ def get_in_progress_task(self):
|
|
|
+ """获取执行中任务"""
|
|
|
+ with self.session_maker() as session:
|
|
|
+ return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.IN_PROGRESS.value).all()
|
|
|
+
|
|
|
def get_task_conversations(self, task_id: int):
|
|
|
"""获取任务的所有子任务"""
|
|
|
with self.session_maker() as session:
|
|
@@ -133,26 +138,32 @@ class TaskManager:
|
|
|
with self.session_maker() as session:
|
|
|
return session.query(AgentTestTaskConversations).filter(
|
|
|
AgentTestTaskConversations.task_id == task_id).filter(
|
|
|
- AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).all()
|
|
|
+ AgentTestTaskConversations.status.in_([
|
|
|
+ TestTaskConversationsStatus.PENDING.value,
|
|
|
+ TestTaskConversationsStatus.RUNNING.value
|
|
|
+ ])).all()
|
|
|
|
|
|
def update_task_status(self, task_id: int, status: int):
|
|
|
"""更新任务状态"""
|
|
|
with self.session_maker() as session:
|
|
|
- session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update({"status": status})
|
|
|
+ session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
|
|
|
+ {"status": status, "update_time": datetime.now()})
|
|
|
session.commit()
|
|
|
|
|
|
def update_task_conversations_status(self, task_conversations_id: int, status: int):
|
|
|
"""更新子任务状态"""
|
|
|
with self.session_maker() as session:
|
|
|
session.query(AgentTestTaskConversations).filter(
|
|
|
- AgentTestTaskConversations.id == task_conversations_id).update({"status": status})
|
|
|
+ AgentTestTaskConversations.id == task_conversations_id).update(
|
|
|
+ {"status": status, "update_time": datetime.now()})
|
|
|
session.commit()
|
|
|
|
|
|
def update_task_conversations_res(self, task_conversations_id: int, status: int, score: str):
|
|
|
"""更新子任务结果"""
|
|
|
with self.session_maker() as session:
|
|
|
session.query(AgentTestTaskConversations).filter(
|
|
|
- AgentTestTaskConversations.id == task_conversations_id).update({"status": status, "score": score})
|
|
|
+ AgentTestTaskConversations.id == task_conversations_id).update(
|
|
|
+ {"status": status, "score": score, "update_time": datetime.now()})
|
|
|
session.commit()
|
|
|
|
|
|
def cancel_task(self, task_id: int):
|
|
@@ -193,6 +204,16 @@ class TaskManager:
|
|
|
logger.info(f"Resumed task {task_id}")
|
|
|
return True
|
|
|
|
|
|
+ def recover_tasks(self):
|
|
|
+ """服务启动时恢复未完成的任务"""
|
|
|
+ # 获取所有未完成的任务ID(根据实际状态定义查询)
|
|
|
+ in_progress_tasks = self.get_in_progress_task()
|
|
|
+
|
|
|
+ for task in in_progress_tasks:
|
|
|
+ task_id = task.id
|
|
|
+ # 重新提交任务
|
|
|
+ self._execute_task(task_id)
|
|
|
+
|
|
|
def _execute_task(self, task_id: int):
|
|
|
"""提交任务到线程池执行"""
|
|
|
# 确保任务状态一致性
|
|
@@ -229,14 +250,15 @@ class TaskManager:
|
|
|
break
|
|
|
|
|
|
# 更新子任务状态为运行中
|
|
|
- self.update_task_conversations_status(task_conversation.id,
|
|
|
- TestTaskConversationsStatus.RUNNING.value)
|
|
|
+ if task_conversation.status == TestTaskConversationsStatus.PENDING.value:
|
|
|
+ self.update_task_conversations_status(task_conversation.id,
|
|
|
+ TestTaskConversationsStatus.RUNNING.value)
|
|
|
try:
|
|
|
- conversation_data = self.dataset_server.get_conversation_data_by_id(
|
|
|
+ conversation_data = self.dataset_service.get_conversation_data_by_id(
|
|
|
task_conversation.conversation_id)
|
|
|
- user_profile_data = self.dataset_server.get_user_profile_data(conversation_data.user_id)
|
|
|
- staff_profile_data = self.dataset_server.get_staff_profile_data(conversation_data.staff_id)
|
|
|
-
|
|
|
+ user_profile_data = self.dataset_service.get_user_profile_data(conversation_data.user_id)['profile_data_v1']
|
|
|
+ staff_profile_data = self.dataset_service.get_staff_profile_data(conversation_data.staff_id).agent_profile
|
|
|
+ conversations = self.dataset_service.get_conversation_list_by_ids(json.loads(conversation_data.conversation))
|
|
|
|
|
|
# 模拟任务执行 - 在实际应用中替换为实际业务逻辑
|
|
|
# TODO 后续改成实际任务执行
|
|
@@ -270,7 +292,7 @@ class TaskManager:
|
|
|
if all_completed else TestTaskStatus.CANCELLED.value)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error executing task {task_id}: {str(e)}")
|
|
|
- self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.FAILED.value)
|
|
|
finally:
|
|
|
# 清理资源
|
|
|
with self.task_locks[task_id]:
|