123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596 |
- import json
- import threading
- import concurrent.futures
- import time
- from concurrent.futures import ThreadPoolExecutor
- from datetime import datetime
- from typing import Dict
- from sqlalchemy import func
- from pqai_agent import logging_service
- from pqai_agent.agents.message_push_agent import MessagePushAgent
- from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
- from pqai_agent.data_models.agent_configuration import AgentConfiguration
- from pqai_agent.data_models.agent_test_task import AgentTestTask
- from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
- from pqai_agent.data_models.service_module import ServiceModule
- from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc
- from concurrent.futures import ThreadPoolExecutor
- logger = logging_service.logger
- class TaskManager:
- """任务管理器"""
- def __init__(self, session_maker, dataset_service):
- self.session_maker = session_maker
- self.dataset_service = dataset_service
- self.task_events = {} # 任务ID -> Event (用于取消任务)
- self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
- self.running_tasks = set()
- self.executor = ThreadPoolExecutor(max_workers=20, thread_name_prefix='TaskWorker')
- self.create_task_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='CreateTaskWorker')
- self.task_futures = {} # 任务ID -> Future
- def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
- with self.session_maker() as session:
- # 计算偏移量
- offset = (page_num - 1) * page_size
- # 查询分页数据
- result = (session.query(AgentTestTask, AgentConfiguration)
- .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
- .limit(page_size).offset(offset).all())
- # 查询总记录数
- total = session.query(func.count(AgentTestTask.id)).scalar()
- total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
- total_page = 1 if total_page <= 0 else total_page
- response_data = [
- {
- "id": agent_test_task.id,
- "agentName": agent_configuration.name,
- "createUser": agent_test_task.create_user,
- "updateUser": agent_test_task.update_user,
- "statusName": get_test_task_status_desc(agent_test_task.status),
- "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
- "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
- }
- for agent_test_task, agent_configuration in result
- ]
- return {
- "currentPage": page_num,
- "pageSize": page_size,
- "totalSize": total_page,
- "total": total,
- "list": response_data,
- }
- def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict:
- with self.session_maker() as session:
- # 计算偏移量
- offset = (page_num - 1) * page_size
- # 查询分页数据
- result = (session.query(AgentTestTaskConversations, AgentConfiguration)
- .outerjoin(AgentConfiguration, AgentTestTaskConversations.agent_id == AgentConfiguration.id)
- .filter(AgentTestTaskConversations.task_id == task_id)
- .limit(page_size).offset(offset).all())
- # 查询总记录数
- total = session.query(func.count(AgentTestTaskConversations.id)).scalar()
- total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
- total_page = 1 if total_page <= 0 else total_page
- response_data = [
- {
- "id": agent_test_task_conversation.id,
- "agentName": agent_configuration.name,
- "input":MultiModalChatAgent.compose_dialogue(json.loads(agent_test_task_conversation.input)),
- "output": agent_test_task_conversation.output,
- "score": agent_test_task_conversation.score,
- "statusName": get_test_task_status_desc(agent_test_task_conversation.status),
- "createTime": agent_test_task_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S"),
- "updateTime": agent_test_task_conversation.update_time.strftime("%Y-%m-%d %H:%M:%S")
- }
- for agent_test_task_conversation, agent_configuration in result
- ]
- return {
- "currentPage": page_num,
- "pageSize": page_size,
- "totalSize": total_page,
- "total": total,
- "list": response_data,
- }
- def create_task(self, agent_id: int, module_id: int) -> Dict:
- """创建新任务"""
- with self.session_maker() as session:
- agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id,
- status=TestTaskStatus.CREATING.value)
- session.add(agent_test_task)
- session.commit() # 显式提交
- task_id = agent_test_task.id
- # 异步执行创建任务
- self.create_task_executor.submit(self._generate_agent_test_task_conversation_batch, task_id, agent_id,
- module_id)
- return self.get_task(task_id)
- def _generate_agent_test_task_conversation_batch(self, task_id: int, agent_id: int, module_id: int):
- """异步生成子任务"""
- try:
- # 获取数据集列表
- dataset_module_list = self.dataset_service.get_dataset_module_list_by_module(module_id)
- # 批量处理数据集 - 减少数据库交互
- batch_size = 100 # 每批处理100个子任务
- agent_test_task_conversation_batch = []
- for dataset_module in dataset_module_list:
- # 获取对话数据列表
- conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(dataset_module.dataset_id)
- for conversation_data in conversation_datas:
- # 创建子任务对象
- agent_test_task_conversation = AgentTestTaskConversations(
- task_id=task_id,
- agent_id=agent_id,
- dataset_id=dataset_module.dataset_id,
- conversation_id=conversation_data.id,
- status=TestTaskConversationsStatus.PENDING.value
- )
- agent_test_task_conversation_batch.append(agent_test_task_conversation)
- # 批量提交
- if len(agent_test_task_conversation_batch) >= batch_size:
- self.save_agent_test_task_conversation_batch(agent_test_task_conversation_batch)
- agent_test_task_conversation_batch = []
- # 提交剩余的子任务
- if agent_test_task_conversation_batch:
- self.save_agent_test_task_conversation_batch(agent_test_task_conversation_batch)
- # 更新主任务状态为未开始
- self.update_task_status(task_id, TestTaskStatus.NOT_STARTED.value)
- # 自动提交任务执行
- self._execute_task(task_id)
- except Exception as e:
- logger.error(f"生成子任务失败: {str(e)}")
- # 更新任务状态为失败
- self.update_task_status(task_id, TestTaskStatus.CREATED_FAIL.value)
- def save_agent_test_task_conversation_batch(self, agent_test_task_conversation_batch: list):
- """批量保存子任务到数据库"""
- try:
- with self.session_maker() as session:
- with session.begin():
- session.add_all(agent_test_task_conversation_batch)
- except Exception as e:
- logger.error(e)
- def get_agent_configuration_by_task_id(self, task_id: int):
- """获取指定任务ID对应的Agent配置信息"""
- with self.session_maker() as session:
- return session.query(AgentConfiguration) \
- .join(AgentTestTask, AgentTestTask.agent_id == AgentConfiguration.id) \
- .filter(AgentTestTask.id == task_id) \
- .one_or_none() # 返回单个对象或None(如果未找到)
- def get_service_module_by_task_id(self, task_id: int):
- """获取指定任务ID对应的Agent配置信息"""
- with self.session_maker() as session:
- return session.query(ServiceModule) \
- .join(AgentTestTask, AgentTestTask.module_id == ServiceModule.id) \
- .filter(AgentTestTask.id == task_id) \
- .one_or_none() # 返回单个对象或None(如果未找到)
- def get_task(self, task_id: int):
- """获取任务信息"""
- with self.session_maker() as session:
- return session.query(AgentTestTask).filter(AgentTestTask.id == task_id).one()
- def get_in_progress_task(self):
- """获取执行中任务"""
- with self.session_maker() as session:
- return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.IN_PROGRESS.value).all()
- def get_creating_task(self):
- """获取执行中任务"""
- with self.session_maker() as session:
- return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.CREATING.value).all()
- def get_task_conversations(self, task_id: int):
- """获取任务的所有子任务"""
- with self.session_maker() as session:
- return session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).all()
- def del_task_conversations(self, task_id: int):
- with self.session_maker() as session:
- session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).delete()
- # 提交事务生效
- session.commit()
- def get_pending_task_conversations(self, task_id: int):
- """获取待处理的子任务"""
- with self.session_maker() as session:
- return session.query(AgentTestTaskConversations).filter(
- AgentTestTaskConversations.task_id == task_id).filter(
- AgentTestTaskConversations.status.in_([
- TestTaskConversationsStatus.PENDING.value,
- TestTaskConversationsStatus.RUNNING.value
- ])).all()
- def update_task_status(self, task_id: int, status: int):
- """更新任务状态"""
- with self.session_maker() as session:
- session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
- {"status": status, "update_time": datetime.now()})
- session.commit()
- def update_task_conversations_status(self, task_conversations_id: int, status: int):
- """更新子任务状态"""
- with self.session_maker() as session:
- session.query(AgentTestTaskConversations).filter(
- AgentTestTaskConversations.id == task_conversations_id).update(
- {"status": status, "update_time": datetime.now()})
- session.commit()
- def update_task_conversations_res(self, task_conversations_id: int, status: int, input: str, output: str,
- score: str):
- """更新子任务结果"""
- with self.session_maker() as session:
- session.query(AgentTestTaskConversations).filter(
- AgentTestTaskConversations.id == task_conversations_id).update(
- {"status": status, "input": input, "output": output, "score": score, "update_time": datetime.now()})
- session.commit()
- def cancel_task(self, task_id: int):
- """取消任务(带事务支持)"""
- # 设置取消事件
- if task_id in self.task_events:
- self.task_events[task_id].set()
- # 如果任务正在执行,尝试取消Future
- if task_id in self.task_futures:
- self.task_futures[task_id].cancel()
- with self.session_maker() as session:
- with session.begin():
- session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
- {"status": TestTaskStatus.CANCELLED.value})
- session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
- AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).update(
- {"status": TestTaskConversationsStatus.CANCELLED.value})
- session.commit()
- def resume_task(self, task_id: int) -> bool:
- """恢复已取消的任务"""
- task = self.get_task(task_id)
- if not task or task.status != TestTaskStatus.CANCELLED.value:
- return False
- with self.session_maker() as session:
- with session.begin():
- session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
- {"status": TestTaskStatus.NOT_STARTED.value})
- session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
- AgentTestTaskConversations.status == TestTaskConversationsStatus.CANCELLED.value).update(
- {"status": TestTaskConversationsStatus.PENDING.value})
- session.commit()
- # 重新执行任务
- self._execute_task(task_id)
- logger.info(f"Resumed task {task_id}")
- return True
- def recover_tasks(self):
- """服务启动时恢复未完成的任务"""
- creating = self.get_creating_task()
- for task in creating:
- task_id = task.id
- agent_id = task.agent_id
- module_id = task.module_id
- self.del_task_conversations(task_id)
- # 重新提交任务
- # 异步执行创建任务
- self.create_task_executor.submit(self._generate_agent_test_task_conversation_batch, task_id, agent_id,
- module_id)
- # 获取所有进行中的任务ID(根据实际状态定义查询)
- in_progress_tasks = self.get_in_progress_task()
- for task in in_progress_tasks:
- task_id = task.id
- # 重新提交任务
- self._execute_task(task_id)
- def _execute_task(self, task_id: int):
- """提交任务到线程池执行"""
- # 确保任务状态一致性
- if task_id in self.running_tasks:
- return
- # 创建任务事件和锁
- if task_id not in self.task_events:
- self.task_events[task_id] = threading.Event()
- if task_id not in self.task_locks:
- self.task_locks[task_id] = threading.Lock()
- # 提交到线程池
- future = self.executor.submit(self._process_task, task_id)
- self.task_futures[task_id] = future
- # 标记任务为运行中
- with self.task_locks[task_id]:
- self.running_tasks.add(task_id)
- def _process_task(self, task_id: int):
- """处理任务的所有子任务(并发执行)"""
- try:
- self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
- task_conversations = self.get_pending_task_conversations(task_id)
- if not task_conversations:
- self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
- return
- agent_configuration = self.get_agent_configuration_by_task_id(task_id)
- query_prompt_template = agent_configuration.task_prompt
- # 使用线程池执行子任务
- with ThreadPoolExecutor(max_workers=20) as executor: # 可根据需要调整并发数
- futures = {}
- for task_conversation in task_conversations:
- if self.task_events[task_id].is_set():
- break # 检查任务取消事件
- # 提交子任务到线程池
- future = executor.submit(
- self._process_single_conversation,
- task_id,
- task_conversation,
- query_prompt_template,
- agent_configuration
- )
- futures[future] = task_conversation.id
- # 等待所有子任务完成或取消
- for future in concurrent.futures.as_completed(futures):
- conv_id = futures[future]
- try:
- # 设置单个任务超时时间(秒),可根据任务复杂度调整
- future.result() # 获取结果(如有异常会在此抛出)
- except Exception as e:
- logger.error(f"Subtask {conv_id} failed: {str(e)}")
- self.update_task_conversations_status(
- conv_id,
- TestTaskConversationsStatus.FAILED.value
- )
- # 检查最终任务状态
- self._update_final_task_status(task_id)
- except Exception as e:
- logger.error(f"Error processing task {task_id}: {str(e)}")
- self.update_task_status(task_id, TestTaskStatus.FAILED.value)
- finally:
- self._cleanup_task_resources(task_id)
- def _process_single_conversation(self, task_id, task_conversation, query_prompt_template, agent_configuration):
- """处理单个对话子任务(线程安全)"""
- # 检查任务是否被取消
- if self.task_events[task_id].is_set():
- return
- # 更新子任务状态
- if task_conversation.status == TestTaskConversationsStatus.PENDING.value:
- self.update_task_conversations_status(
- task_conversation.id,
- TestTaskConversationsStatus.RUNNING.value
- )
- try:
- # 创建独立的agent实例(确保线程安全)
- agent = MultiModalChatAgent(
- model=agent_configuration.execution_model,
- system_prompt=agent_configuration.system_prompt,
- tools=json.loads(agent_configuration.tools)
- )
- # 获取对话数据(与原始代码相同)
- conversation_data = self.dataset_service.get_conversation_data_by_id(
- task_conversation.conversation_id)
- user_profile_data = self.dataset_service.get_user_profile_data(
- conversation_data.user_id,
- conversation_data.version_date.replace("-", ""))
- user_profile = json.loads(user_profile_data['profile_data_v1'])
- avatar = user_profile_data['iconurl']
- staff_profile_data = self.dataset_service.get_staff_profile_data(
- conversation_data.staff_id).agent_profile
- conversations = self.dataset_service.get_chat_conversation_list_by_ids(
- json.loads(conversation_data.conversation),
- conversation_data.staff_id
- )
- conversations = sorted(conversations, key=lambda i: i['timestamp'], reverse=False)
- # 生成推送消息(与原始代码相同)
- last_timestamp = int(conversations[-1]["timestamp"])
- push_time = int(last_timestamp / 1000) + 24 * 3600
- push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
- push_message = agent._generate_message(
- context={
- "formatted_staff_profile": staff_profile_data,
- "nickname": user_profile['nickname'],
- "name": user_profile['name'],
- "avatar": avatar,
- "preferred_nickname": user_profile['preferred_nickname'],
- "gender": user_profile['gender'],
- "age": user_profile['age'],
- "region": user_profile['region'],
- "health_conditions": user_profile['health_conditions'],
- "medications": user_profile['medications'],
- "interests": user_profile['interests'],
- "current_datetime": push_dt
- },
- dialogue_history=conversations,
- query_prompt_template=query_prompt_template
- )
- # 获取打分(TODO: 实际实现)
- score = '{"score":0.05}'
- # 更新子任务结果
- self.update_task_conversations_res(
- task_conversation.id,
- TestTaskConversationsStatus.SUCCESS.value,
- json.dumps(conversations, ensure_ascii=False),
- json.dumps(push_message, ensure_ascii=False),
- score
- )
- except Exception as e:
- logger.error(f"Subtask {task_conversation.id} failed: {str(e)}")
- self.update_task_conversations_status(
- task_conversation.id,
- TestTaskConversationsStatus.FAILED.value
- )
- raise # 重新抛出异常以便主线程捕获
- def _update_final_task_status(self, task_id):
- """更新任务的最终状态"""
- task_conversations = self.get_task_conversations(task_id)
- all_completed = all(
- conv.status in (TestTaskConversationsStatus.SUCCESS.value,
- TestTaskConversationsStatus.FAILED.value)
- for conv in task_conversations
- )
- if all_completed:
- self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
- logger.info(f"Task {task_id} completed")
- elif not any(
- conv.status in (TestTaskConversationsStatus.PENDING.value,
- TestTaskConversationsStatus.RUNNING.value)
- for conv in task_conversations
- ):
- current_status = self.get_task(task_id).status
- if current_status != TestTaskStatus.CANCELLED.value:
- new_status = TestTaskStatus.COMPLETED.value if all_completed else TestTaskStatus.CANCELLED.value
- self.update_task_status(task_id, new_status)
- def _cleanup_task_resources(self, task_id):
- """清理任务资源(线程安全)"""
- with self.task_locks[task_id]:
- if task_id in self.running_tasks:
- self.running_tasks.remove(task_id)
- if task_id in self.task_events:
- del self.task_events[task_id]
- if task_id in self.task_futures:
- del self.task_futures[task_id]
- # def _process_task(self, task_id: int):
- # """处理任务的所有子任务"""
- # try:
- # # 更新任务状态为运行中
- # self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
- #
- # # 获取所有待处理的子任务
- # task_conversations = self.get_pending_task_conversations(task_id)
- #
- # agent_configuration = self.get_agent_configuration_by_task_id(task_id)
- # query_prompt_template = agent_configuration.task_prompt
- # agent = MultiModalChatAgent(model=agent_configuration.execution_model,
- # system_prompt=agent_configuration.system_prompt,
- # tools=json.loads(agent_configuration.tools))
- # # 执行每个子任务
- # for task_conversation in task_conversations:
- # # 检查任务是否被取消
- # if self.task_events[task_id].is_set():
- # break
- #
- # # 更新子任务状态为运行中
- # if task_conversation.status == TestTaskConversationsStatus.PENDING.value:
- # self.update_task_conversations_status(task_conversation.id,
- # TestTaskConversationsStatus.RUNNING.value)
- # try:
- # conversation_data = self.dataset_service.get_conversation_data_by_id(
- # task_conversation.conversation_id)
- # user_profile_data = self.dataset_service.get_user_profile_data(conversation_data.user_id,
- # conversation_data.version_date.replace(
- # "-", ""))
- # user_profile = json.loads(user_profile_data['profile_data_v1'])
- # avatar = user_profile_data['iconurl']
- # staff_profile_data = self.dataset_service.get_staff_profile_data(
- # conversation_data.staff_id).agent_profile
- # conversations = self.dataset_service.get_chat_conversation_list_by_ids(
- # json.loads(conversation_data.conversation), conversation_data.staff_id)
- # conversations = sorted(conversations, key=lambda i: i['timestamp'], reverse=False)
- #
- # last_timestamp = int(conversations[-1]["timestamp"])
- # push_time = int(last_timestamp / 1000) + 24 * 3600
- # push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
- # push_message = agent._generate_message(
- # context={
- # "formatted_staff_profile": staff_profile_data,
- # "nickname": user_profile['nickname'],
- # "name": user_profile['name'],
- # "avatar": avatar,
- # "preferred_nickname": user_profile['preferred_nickname'],
- # "gender": user_profile['gender'],
- # "age": user_profile['age'],
- # "region": user_profile['region'],
- # "health_conditions": user_profile['health_conditions'],
- # "medications": user_profile['medications'],
- # "interests": user_profile['interests'],
- # "current_datetime": push_dt
- # },
- # dialogue_history=conversations,
- # query_prompt_template=query_prompt_template
- # )
- # # TODO 获取打分
- # score = '{"score":0.05}'
- # # 更新子任务状态为已完成
- # self.update_task_conversations_res(task_conversation.id,
- # TestTaskConversationsStatus.SUCCESS.value,
- # json.dumps(conversations, ensure_ascii=False),
- # json.dumps(push_message, ensure_ascii=False),
- # score)
- # except Exception as e:
- # logger.error(f"Error executing task {task_id}: {str(e)}")
- # self.update_task_conversations_status(task_conversation.id,
- # TestTaskConversationsStatus.FAILED.value)
- #
- # # 检查任务是否完成
- # task_conversations = self.get_task_conversations(task_id)
- # all_completed = all(task_conversation.status in
- # (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value)
- # for task_conversation in task_conversations)
- # any_pending = any(task_conversation.status in
- # (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value)
- # for task_conversation in task_conversations)
- #
- # if all_completed:
- # self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
- # logger.info(f"Task {task_id} completed")
- # elif not any_pending:
- # # 没有待处理子任务但未全部完成(可能是取消了)
- # current_status = self.get_task(task_id).status
- # if current_status != TestTaskStatus.CANCELLED.value:
- # self.update_task_status(task_id, TestTaskStatus.COMPLETED.value
- # if all_completed else TestTaskStatus.CANCELLED.value)
- # except Exception as e:
- # logger.error(f"Error executing task {task_id}: {str(e)}")
- # self.update_task_status(task_id, TestTaskStatus.FAILED.value)
- # finally:
- # # 清理资源
- # with self.task_locks[task_id]:
- # if task_id in self.running_tasks:
- # self.running_tasks.remove(task_id)
- # if task_id in self.task_events:
- # del self.task_events[task_id]
- # if task_id in self.task_futures:
- # del self.task_futures[task_id]
- def shutdown(self):
- """关闭执行器"""
- self.executor.shutdown(wait=False)
- logger.info("Task executor shutdown")
|