|
@@ -19,13 +19,14 @@ logger = logging_service.logger
|
|
class TaskManager:
|
|
class TaskManager:
|
|
"""任务管理器"""
|
|
"""任务管理器"""
|
|
|
|
|
|
- def __init__(self, session_maker, dataset_service, max_workers: int = 10):
|
|
|
|
|
|
+ def __init__(self, session_maker, dataset_service):
|
|
self.session_maker = session_maker
|
|
self.session_maker = session_maker
|
|
self.dataset_service = dataset_service
|
|
self.dataset_service = dataset_service
|
|
self.task_events = {} # 任务ID -> Event (用于取消任务)
|
|
self.task_events = {} # 任务ID -> Event (用于取消任务)
|
|
self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
|
|
self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
|
|
self.running_tasks = set()
|
|
self.running_tasks = set()
|
|
- self.executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='TaskWorker')
|
|
|
|
|
|
+ self.executor = ThreadPoolExecutor(max_workers=20, thread_name_prefix='TaskWorker')
|
|
|
|
+ self.create_task_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='CreateTaskWorker')
|
|
self.task_futures = {} # 任务ID -> Future
|
|
self.task_futures = {} # 任务ID -> Future
|
|
|
|
|
|
def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
|
|
def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
|
|
@@ -98,26 +99,71 @@ class TaskManager:
|
|
|
|
|
|
def create_task(self, agent_id: int, module_id: int) -> Dict:
|
|
def create_task(self, agent_id: int, module_id: int) -> Dict:
|
|
"""创建新任务"""
|
|
"""创建新任务"""
|
|
- with (self.session_maker() as session):
|
|
|
|
- with session.begin():
|
|
|
|
- agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id)
|
|
|
|
- session.add(agent_test_task)
|
|
|
|
- session.flush() # 强制SQL执行,但不提交事务
|
|
|
|
- task_id = agent_test_task.id
|
|
|
|
- agent_test_task_conversations = []
|
|
|
|
- datasets_list = self.dataset_service.get_dataset_list_by_module(module_id)
|
|
|
|
- for datasets in datasets_list:
|
|
|
|
- conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(datasets.id)
|
|
|
|
- for conversation_data in conversation_datas:
|
|
|
|
- agent_test_task_conversation = AgentTestTaskConversations(task_id=task_id, agent_id=agent_id,
|
|
|
|
- dataset_id=datasets.id,
|
|
|
|
- conversation_id=conversation_data.id)
|
|
|
|
- agent_test_task_conversations.append(agent_test_task_conversation)
|
|
|
|
- session.add_all(agent_test_task_conversations)
|
|
|
|
- # 异步执行任务
|
|
|
|
- self._execute_task(task_id)
|
|
|
|
|
|
+ with self.session_maker() as session:
|
|
|
|
+ agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id,
|
|
|
|
+ status=TestTaskStatus.CREATING.value)
|
|
|
|
+ session.add(agent_test_task)
|
|
|
|
+ session.commit() # 显式提交
|
|
|
|
+ task_id = agent_test_task.id
|
|
|
|
+ # 异步执行创建任务
|
|
|
|
+ self.create_task_executor.submit(self._generate_agent_test_task_conversation_batch, task_id, agent_id,
|
|
|
|
+ module_id)
|
|
return self.get_task(task_id)
|
|
return self.get_task(task_id)
|
|
|
|
|
|
|
|
+ def _generate_agent_test_task_conversation_batch(self, task_id: int, agent_id: int, module_id: int):
|
|
|
|
+ """异步生成子任务"""
|
|
|
|
+ try:
|
|
|
|
+ # 获取数据集列表
|
|
|
|
+ datasets_list = self.dataset_service.get_dataset_list_by_module(module_id)
|
|
|
|
+
|
|
|
|
+ # 批量处理数据集 - 减少数据库交互
|
|
|
|
+ batch_size = 100 # 每批处理100个子任务
|
|
|
|
+ agent_test_task_conversation_batch = []
|
|
|
|
+
|
|
|
|
+ for dataset in datasets_list:
|
|
|
|
+ # 获取对话数据列表
|
|
|
|
+ conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(dataset.id)
|
|
|
|
+
|
|
|
|
+ for conversation_data in conversation_datas:
|
|
|
|
+ # 创建子任务对象
|
|
|
|
+ agent_test_task_conversation = AgentTestTaskConversations(
|
|
|
|
+ task_id=task_id,
|
|
|
|
+ agent_id=agent_id,
|
|
|
|
+ dataset_id=dataset.id,
|
|
|
|
+ conversation_id=conversation_data.id,
|
|
|
|
+ status=TestTaskConversationsStatus.PENDING.value
|
|
|
|
+ )
|
|
|
|
+ agent_test_task_conversation_batch.append(agent_test_task_conversation)
|
|
|
|
+
|
|
|
|
+ # 批量提交
|
|
|
|
+ if len(agent_test_task_conversation_batch) >= batch_size:
|
|
|
|
+ self.save_agent_test_task_conversation_batch(agent_test_task_conversation_batch)
|
|
|
|
+ agent_test_task_conversation_batch = []
|
|
|
|
+
|
|
|
|
+ # 提交剩余的子任务
|
|
|
|
+ if agent_test_task_conversation_batch:
|
|
|
|
+ self.save_agent_test_task_conversation_batch(agent_test_task_conversation_batch)
|
|
|
|
+
|
|
|
|
+ # 更新主任务状态为未开始
|
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.NOT_STARTED.value)
|
|
|
|
+
|
|
|
|
+ # 自动提交任务执行
|
|
|
|
+ self._execute_task(task_id)
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(f"生成子任务失败: {str(e)}")
|
|
|
|
+ # 更新任务状态为失败
|
|
|
|
+ self.update_task_status(task_id, TestTaskStatus.CREATED_FAIL.value)
|
|
|
|
+
|
|
|
|
+ def save_agent_test_task_conversation_batch(self, agent_test_task_conversation_batch: list):
|
|
|
|
+ """批量保存子任务到数据库"""
|
|
|
|
+ try:
|
|
|
|
+ with self.session_maker() as session:
|
|
|
|
+ with session.begin():
|
|
|
|
+ session.add_all(agent_test_task_conversation_batch)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ logger.error(e)
|
|
|
|
+
|
|
def get_task(self, task_id: int):
|
|
def get_task(self, task_id: int):
|
|
"""获取任务信息"""
|
|
"""获取任务信息"""
|
|
with self.session_maker() as session:
|
|
with self.session_maker() as session:
|
|
@@ -128,11 +174,22 @@ class TaskManager:
|
|
with self.session_maker() as session:
|
|
with self.session_maker() as session:
|
|
return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.IN_PROGRESS.value).all()
|
|
return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.IN_PROGRESS.value).all()
|
|
|
|
|
|
|
|
+ def get_creating_task(self):
|
|
|
|
+ """获取执行中任务"""
|
|
|
|
+ with self.session_maker() as session:
|
|
|
|
+ return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.CREATING.value).all()
|
|
|
|
+
|
|
def get_task_conversations(self, task_id: int):
|
|
def get_task_conversations(self, task_id: int):
|
|
"""获取任务的所有子任务"""
|
|
"""获取任务的所有子任务"""
|
|
with self.session_maker() as session:
|
|
with self.session_maker() as session:
|
|
return session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).all()
|
|
return session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).all()
|
|
|
|
|
|
|
|
+ def del_task_conversations(self, task_id: int):
|
|
|
|
+ with self.session_maker() as session:
|
|
|
|
+ session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).delete()
|
|
|
|
+ # 提交事务生效
|
|
|
|
+ session.commit()
|
|
|
|
+
|
|
def get_pending_task_conversations(self, task_id: int):
|
|
def get_pending_task_conversations(self, task_id: int):
|
|
"""获取待处理的子任务"""
|
|
"""获取待处理的子任务"""
|
|
with self.session_maker() as session:
|
|
with self.session_maker() as session:
|
|
@@ -206,7 +263,19 @@ class TaskManager:
|
|
|
|
|
|
def recover_tasks(self):
|
|
def recover_tasks(self):
|
|
"""服务启动时恢复未完成的任务"""
|
|
"""服务启动时恢复未完成的任务"""
|
|
- # 获取所有未完成的任务ID(根据实际状态定义查询)
|
|
|
|
|
|
+
|
|
|
|
+ creating = self.get_creating_task()
|
|
|
|
+ for task in creating:
|
|
|
|
+ task_id = task.id
|
|
|
|
+ agent_id = task.agent_id
|
|
|
|
+ module_id = task.module_id
|
|
|
|
+ self.del_task_conversations(task_id)
|
|
|
|
+ # 重新提交任务
|
|
|
|
+ # 异步执行创建任务
|
|
|
|
+ self.create_task_executor.submit(self._generate_agent_test_task_conversation_batch, task_id, agent_id,
|
|
|
|
+ module_id)
|
|
|
|
+
|
|
|
|
+ # 获取所有进行中的任务ID(根据实际状态定义查询)
|
|
in_progress_tasks = self.get_in_progress_task()
|
|
in_progress_tasks = self.get_in_progress_task()
|
|
|
|
|
|
for task in in_progress_tasks:
|
|
for task in in_progress_tasks:
|
|
@@ -256,9 +325,13 @@ class TaskManager:
|
|
try:
|
|
try:
|
|
conversation_data = self.dataset_service.get_conversation_data_by_id(
|
|
conversation_data = self.dataset_service.get_conversation_data_by_id(
|
|
task_conversation.conversation_id)
|
|
task_conversation.conversation_id)
|
|
- user_profile_data = self.dataset_service.get_user_profile_data(conversation_data.user_id)['profile_data_v1']
|
|
|
|
- staff_profile_data = self.dataset_service.get_staff_profile_data(conversation_data.staff_id).agent_profile
|
|
|
|
- conversations = self.dataset_service.get_conversation_list_by_ids(json.loads(conversation_data.conversation))
|
|
|
|
|
|
+ user_profile_data = self.dataset_service.get_user_profile_data(conversation_data.user_id,
|
|
|
|
+ conversation_data.version_date.replace(
|
|
|
|
+ "-", ""))['profile_data_v1']
|
|
|
|
+ staff_profile_data = self.dataset_service.get_staff_profile_data(
|
|
|
|
+ conversation_data.staff_id).agent_profile
|
|
|
|
+ conversations = self.dataset_service.get_conversation_list_by_ids(
|
|
|
|
+ json.loads(conversation_data.conversation))
|
|
|
|
|
|
# 模拟任务执行 - 在实际应用中替换为实际业务逻辑
|
|
# 模拟任务执行 - 在实际应用中替换为实际业务逻辑
|
|
# TODO 后续改成实际任务执行
|
|
# TODO 后续改成实际任务执行
|