|
- import threading
- import threading
- import time
- from concurrent.futures import ThreadPoolExecutor
- from typing import Dict
- from pyarrow.dataset import dataset
- from sqlalchemy import func
- from pqai_agent import logging_service
- from pqai_agent.data_models.agent_configuration import AgentConfiguration
- from pqai_agent.data_models.agent_test_task import AgentTestTask
- from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
- from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc
- logger = logging_service.logger
- class TaskManager:
- """任务管理器"""
- def __init__(self, session_maker, dataset_server, max_workers: int = 10):
- self.session_maker = session_maker
- self.dataset_server = dataset_server
- self.task_events = {} # 任务ID -> Event (用于取消任务)
- self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
- self.running_tasks = set()
- self.executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='TaskWorker')
- self.task_futures = {} # 任务ID -> Future
- def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
- with self.session_maker() as session:
- # 计算偏移量
- offset = (page_num - 1) * page_size
- # 查询分页数据
- result = (session.query(AgentTestTask, AgentConfiguration)
- .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
- .limit(page_size).offset(offset).all())
- # 查询总记录数
- total = session.query(func.count(AgentTestTask.id)).scalar()
- total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
- total_page = 1 if total_page <= 0 else total_page
- response_data = [
- {
- "id": agent_test_task.id,
- "agentName": agent_configuration.name,
- "createUser": agent_test_task.create_user,
- "updateUser": agent_test_task.update_user,
- "statusName": get_test_task_status_desc(agent_test_task.status),
- "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
- "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
- }
- for agent_test_task, agent_configuration in result
- ]
- return {
- "currentPage": page_num,
- "pageSize": page_size,
- "totalSize": total_page,
- "total": total,
- "list": response_data,
- }
- def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict:
- with self.session_maker() as session:
- # 计算偏移量
- offset = (page_num - 1) * page_size
- # 查询分页数据
- result = (session.query(AgentTestTaskConversations, AgentConfiguration)
- .outerjoin(AgentConfiguration, AgentTestTaskConversations.agent_id == AgentConfiguration.id)
- .filter(AgentTestTaskConversations.task_id == task_id)
- .limit(page_size).offset(offset).all())
- # 查询总记录数
- total = session.query(func.count(AgentTestTaskConversations.id)).scalar()
- total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
- total_page = 1 if total_page <= 0 else total_page
- response_data = [
- {
- "id": agent_test_task_conversation.id,
- "agentName": agent_configuration.name,
- "input": agent_test_task_conversation.input,
- "output": agent_test_task_conversation.output,
- "score": agent_test_task_conversation.score,
- "statusName": get_test_task_status_desc(agent_test_task_conversation.status),
- "createTime": agent_test_task_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S"),
- "updateTime": agent_test_task_conversation.update_time.strftime("%Y-%m-%d %H:%M:%S")
- }
- for agent_test_task_conversation, agent_configuration in result
- ]
- return {
- "currentPage": page_num,
- "pageSize": page_size,
- "totalSize": total_page,
- "total": total,
- "list": response_data,
- }
- def create_task(self, agent_id: int, module_id: int) -> Dict:
- """创建新任务"""
- with (self.session_maker() as session):
- with session.begin():
- agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id)
- session.add(agent_test_task)
- session.flush() # 强制SQL执行,但不提交事务
- task_id = agent_test_task.id
- agent_test_task_conversations = []
- datasets_list = self.dataset_server.get_dataset_list_by_module(module_id)
- for datasets in datasets_list:
- conversation_datas = self.dataset_server.get_conversation_data_list_by_dataset(datasets.id)
- for conversation_data in conversation_datas:
- agent_test_task_conversation = AgentTestTaskConversations(task_id=task_id, agent_id=agent_id,
- dataset_id=datasets.id,
- conversation_id=conversation_data.id)
- agent_test_task_conversations.append(agent_test_task_conversation)
- session.add_all(agent_test_task_conversations)
- # 异步执行任务
- self._execute_task(task_id)
- return self.get_task(task_id)
- def get_task(self, task_id: int):
- """获取任务信息"""
- with self.session_maker() as session:
- return session.query(AgentTestTask).filter(AgentTestTask.id == task_id).one()
- def get_task_conversations(self, task_id: int):
- """获取任务的所有子任务"""
- with self.session_maker() as session:
- return session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).all()
- def get_pending_task_conversations(self, task_id: int):
- """获取待处理的子任务"""
- with self.session_maker() as session:
- return session.query(AgentTestTaskConversations).filter(
- AgentTestTaskConversations.task_id == task_id).filter(
- AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).all()
- def update_task_status(self, task_id: int, status: int):
- """更新任务状态"""
- with self.session_maker() as session:
- session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update({"status": status})
- session.commit()
- def update_task_conversations_status(self, task_conversations_id: int, status: int):
- """更新子任务状态"""
- with self.session_maker() as session:
- session.query(AgentTestTaskConversations).filter(
- AgentTestTaskConversations.id == task_conversations_id).update({"status": status})
- session.commit()
- def update_task_conversations_res(self, task_conversations_id: int, status: int, score: str):
- """更新子任务结果"""
- with self.session_maker() as session:
- session.query(AgentTestTaskConversations).filter(
- AgentTestTaskConversations.id == task_conversations_id).update({"status": status, "score": score})
- session.commit()
- def cancel_task(self, task_id: int):
- """取消任务(带事务支持)"""
- # 设置取消事件
- if task_id in self.task_events:
- self.task_events[task_id].set()
- # 如果任务正在执行,尝试取消Future
- if task_id in self.task_futures:
- self.task_futures[task_id].cancel()
- with self.session_maker() as session:
- with session.begin():
- session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
- {"status": TestTaskStatus.CANCELLED.value})
- session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
- AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).update(
- {"status": TestTaskConversationsStatus.CANCELLED.value})
- session.commit()
- def resume_task(self, task_id: int) -> bool:
- """恢复已取消的任务"""
- task = self.get_task(task_id)
- if not task or task.status != TestTaskStatus.CANCELLED.value:
- return False
- with self.session_maker() as session:
- with session.begin():
- session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
- {"status": TestTaskStatus.NOT_STARTED.value})
- session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
- AgentTestTaskConversations.status == TestTaskConversationsStatus.CANCELLED.value).update(
- {"status": TestTaskConversationsStatus.PENDING.value})
- session.commit()
- # 重新执行任务
- self._execute_task(task_id)
- logger.info(f"Resumed task {task_id}")
- return True
- def _execute_task(self, task_id: int):
- """提交任务到线程池执行"""
- # 确保任务状态一致性
- if task_id in self.running_tasks:
- return
- # 创建任务事件和锁
- if task_id not in self.task_events:
- self.task_events[task_id] = threading.Event()
- if task_id not in self.task_locks:
- self.task_locks[task_id] = threading.Lock()
- # 提交到线程池
- future = self.executor.submit(self._process_task, task_id)
- self.task_futures[task_id] = future
- # 标记任务为运行中
- with self.task_locks[task_id]:
- self.running_tasks.add(task_id)
- def _process_task(self, task_id: int):
- """处理任务的所有子任务"""
- try:
- # 更新任务状态为运行中
- self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
- # 获取所有待处理的子任务
- task_conversations = self.get_pending_task_conversations(task_id)
- # 执行每个子任务
- for task_conversation in task_conversations:
- # 检查任务是否被取消
- if self.task_events[task_id].is_set():
- break
- # 更新子任务状态为运行中
- self.update_task_conversations_status(task_conversation.id,
- TestTaskConversationsStatus.RUNNING.value)
- try:
- conversation_data = self.dataset_server.get_conversation_data_by_id(
- task_conversation.conversation_id)
- user_profile_data = self.dataset_server.get_user_profile_data(conversation_data.user_id)
- staff_profile_data = self.dataset_server.get_staff_profile_data(conversation_data.staff_id)
- # 模拟任务执行 - 在实际应用中替换为实际业务逻辑
- # TODO 后续改成实际任务执行
- time.sleep(1)
- score = '{"score":0.05}'
- # 更新子任务状态为已完成
- self.update_task_conversations_res(task_conversation.id,
- TestTaskConversationsStatus.SUCCESS.value, score)
- except Exception as e:
- logger.error(f"Error executing task {task_id}: {str(e)}")
- self.update_task_conversations_status(task_conversation.id,
- TestTaskConversationsStatus.FAILED.value)
- # 检查任务是否完成
- task_conversations = self.get_task_conversations(task_id)
- all_completed = all(task_conversation.status in
- (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value)
- for task_conversation in task_conversations)
- any_pending = any(task_conversation.status in
- (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value)
- for task_conversation in task_conversations)
- if all_completed:
- self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
- logger.info(f"Task {task_id} completed")
- elif not any_pending:
- # 没有待处理子任务但未全部完成(可能是取消了)
- current_status = self.get_task(task_id).status
- if current_status != TestTaskStatus.CANCELLED.value:
- self.update_task_status(task_id, TestTaskStatus.COMPLETED.value
- if all_completed else TestTaskStatus.CANCELLED.value)
- except Exception as e:
- logger.error(f"Error executing task {task_id}: {str(e)}")
- self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
- finally:
- # 清理资源
- with self.task_locks[task_id]:
- if task_id in self.running_tasks:
- self.running_tasks.remove(task_id)
- if task_id in self.task_events:
- del self.task_events[task_id]
- if task_id in self.task_futures:
- del self.task_futures[task_id]
- def shutdown(self):
- """关闭执行器"""
- self.executor.shutdown(wait=False)
- logger.info("Task executor shutdown")
|