task_server.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. import json
  2. import threading
  3. import concurrent.futures
  4. import time
  5. from concurrent.futures import ThreadPoolExecutor
  6. from datetime import datetime
  7. from typing import Dict
  8. from sqlalchemy import func
  9. from pqai_agent import logging_service
  10. from pqai_agent.agents.message_push_agent import MessagePushAgent
  11. from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
  12. from pqai_agent.data_models.agent_configuration import AgentConfiguration
  13. from pqai_agent.data_models.agent_test_task import AgentTestTask
  14. from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
  15. from pqai_agent.data_models.service_module import ServiceModule
  16. from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc
  17. from concurrent.futures import ThreadPoolExecutor
  18. logger = logging_service.logger
  19. class TaskManager:
  20. """任务管理器"""
  21. def __init__(self, session_maker, dataset_service):
  22. self.session_maker = session_maker
  23. self.dataset_service = dataset_service
  24. self.task_events = {} # 任务ID -> Event (用于取消任务)
  25. self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
  26. self.running_tasks = set()
  27. self.executor = ThreadPoolExecutor(max_workers=20, thread_name_prefix='TaskWorker')
  28. self.create_task_executor = ThreadPoolExecutor(max_workers=10, thread_name_prefix='CreateTaskWorker')
  29. self.task_futures = {} # 任务ID -> Future
  30. def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
  31. with self.session_maker() as session:
  32. # 计算偏移量
  33. offset = (page_num - 1) * page_size
  34. # 查询分页数据
  35. result = (session.query(AgentTestTask, AgentConfiguration)
  36. .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
  37. .limit(page_size).offset(offset).all())
  38. # 查询总记录数
  39. total = session.query(func.count(AgentTestTask.id)).scalar()
  40. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  41. total_page = 1 if total_page <= 0 else total_page
  42. response_data = [
  43. {
  44. "id": agent_test_task.id,
  45. "agentName": agent_configuration.name,
  46. "createUser": agent_test_task.create_user,
  47. "updateUser": agent_test_task.update_user,
  48. "statusName": get_test_task_status_desc(agent_test_task.status),
  49. "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
  50. "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
  51. }
  52. for agent_test_task, agent_configuration in result
  53. ]
  54. return {
  55. "currentPage": page_num,
  56. "pageSize": page_size,
  57. "totalSize": total_page,
  58. "total": total,
  59. "list": response_data,
  60. }
  61. def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict:
  62. with self.session_maker() as session:
  63. # 计算偏移量
  64. offset = (page_num - 1) * page_size
  65. # 查询分页数据
  66. result = (session.query(AgentTestTaskConversations, AgentConfiguration)
  67. .outerjoin(AgentConfiguration, AgentTestTaskConversations.agent_id == AgentConfiguration.id)
  68. .filter(AgentTestTaskConversations.task_id == task_id)
  69. .limit(page_size).offset(offset).all())
  70. # 查询总记录数
  71. total = session.query(func.count(AgentTestTaskConversations.id)).scalar()
  72. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  73. total_page = 1 if total_page <= 0 else total_page
  74. response_data = [
  75. {
  76. "id": agent_test_task_conversation.id,
  77. "agentName": agent_configuration.name,
  78. "input":MultiModalChatAgent.compose_dialogue(json.loads(agent_test_task_conversation.input)),
  79. "output": agent_test_task_conversation.output,
  80. "score": agent_test_task_conversation.score,
  81. "statusName": get_test_task_status_desc(agent_test_task_conversation.status),
  82. "createTime": agent_test_task_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S"),
  83. "updateTime": agent_test_task_conversation.update_time.strftime("%Y-%m-%d %H:%M:%S")
  84. }
  85. for agent_test_task_conversation, agent_configuration in result
  86. ]
  87. return {
  88. "currentPage": page_num,
  89. "pageSize": page_size,
  90. "totalSize": total_page,
  91. "total": total,
  92. "list": response_data,
  93. }
  94. def create_task(self, agent_id: int, module_id: int) -> Dict:
  95. """创建新任务"""
  96. with self.session_maker() as session:
  97. agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id,
  98. status=TestTaskStatus.CREATING.value)
  99. session.add(agent_test_task)
  100. session.commit() # 显式提交
  101. task_id = agent_test_task.id
  102. # 异步执行创建任务
  103. self.create_task_executor.submit(self._generate_agent_test_task_conversation_batch, task_id, agent_id,
  104. module_id)
  105. return self.get_task(task_id)
  106. def _generate_agent_test_task_conversation_batch(self, task_id: int, agent_id: int, module_id: int):
  107. """异步生成子任务"""
  108. try:
  109. # 获取数据集列表
  110. dataset_module_list = self.dataset_service.get_dataset_module_list_by_module(module_id)
  111. # 批量处理数据集 - 减少数据库交互
  112. batch_size = 100 # 每批处理100个子任务
  113. agent_test_task_conversation_batch = []
  114. for dataset_module in dataset_module_list:
  115. # 获取对话数据列表
  116. conversation_datas = self.dataset_service.get_conversation_data_list_by_dataset(dataset_module.dataset_id)
  117. for conversation_data in conversation_datas:
  118. # 创建子任务对象
  119. agent_test_task_conversation = AgentTestTaskConversations(
  120. task_id=task_id,
  121. agent_id=agent_id,
  122. dataset_id=dataset_module.dataset_id,
  123. conversation_id=conversation_data.id,
  124. status=TestTaskConversationsStatus.PENDING.value
  125. )
  126. agent_test_task_conversation_batch.append(agent_test_task_conversation)
  127. # 批量提交
  128. if len(agent_test_task_conversation_batch) >= batch_size:
  129. self.save_agent_test_task_conversation_batch(agent_test_task_conversation_batch)
  130. agent_test_task_conversation_batch = []
  131. # 提交剩余的子任务
  132. if agent_test_task_conversation_batch:
  133. self.save_agent_test_task_conversation_batch(agent_test_task_conversation_batch)
  134. # 更新主任务状态为未开始
  135. self.update_task_status(task_id, TestTaskStatus.NOT_STARTED.value)
  136. # 自动提交任务执行
  137. self._execute_task(task_id)
  138. except Exception as e:
  139. logger.error(f"生成子任务失败: {str(e)}")
  140. # 更新任务状态为失败
  141. self.update_task_status(task_id, TestTaskStatus.CREATED_FAIL.value)
  142. def save_agent_test_task_conversation_batch(self, agent_test_task_conversation_batch: list):
  143. """批量保存子任务到数据库"""
  144. try:
  145. with self.session_maker() as session:
  146. with session.begin():
  147. session.add_all(agent_test_task_conversation_batch)
  148. except Exception as e:
  149. logger.error(e)
  150. def get_agent_configuration_by_task_id(self, task_id: int):
  151. """获取指定任务ID对应的Agent配置信息"""
  152. with self.session_maker() as session:
  153. return session.query(AgentConfiguration) \
  154. .join(AgentTestTask, AgentTestTask.agent_id == AgentConfiguration.id) \
  155. .filter(AgentTestTask.id == task_id) \
  156. .one_or_none() # 返回单个对象或None(如果未找到)
  157. def get_service_module_by_task_id(self, task_id: int):
  158. """获取指定任务ID对应的Agent配置信息"""
  159. with self.session_maker() as session:
  160. return session.query(ServiceModule) \
  161. .join(AgentTestTask, AgentTestTask.module_id == ServiceModule.id) \
  162. .filter(AgentTestTask.id == task_id) \
  163. .one_or_none() # 返回单个对象或None(如果未找到)
  164. def get_task(self, task_id: int):
  165. """获取任务信息"""
  166. with self.session_maker() as session:
  167. return session.query(AgentTestTask).filter(AgentTestTask.id == task_id).one()
  168. def get_in_progress_task(self):
  169. """获取执行中任务"""
  170. with self.session_maker() as session:
  171. return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.IN_PROGRESS.value).all()
  172. def get_creating_task(self):
  173. """获取执行中任务"""
  174. with self.session_maker() as session:
  175. return session.query(AgentTestTask).filter(AgentTestTask.status == TestTaskStatus.CREATING.value).all()
  176. def get_task_conversations(self, task_id: int):
  177. """获取任务的所有子任务"""
  178. with self.session_maker() as session:
  179. return session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).all()
  180. def del_task_conversations(self, task_id: int):
  181. with self.session_maker() as session:
  182. session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).delete()
  183. # 提交事务生效
  184. session.commit()
  185. def get_pending_task_conversations(self, task_id: int):
  186. """获取待处理的子任务"""
  187. with self.session_maker() as session:
  188. return session.query(AgentTestTaskConversations).filter(
  189. AgentTestTaskConversations.task_id == task_id).filter(
  190. AgentTestTaskConversations.status.in_([
  191. TestTaskConversationsStatus.PENDING.value,
  192. TestTaskConversationsStatus.RUNNING.value
  193. ])).all()
  194. def update_task_status(self, task_id: int, status: int):
  195. """更新任务状态"""
  196. with self.session_maker() as session:
  197. session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
  198. {"status": status, "update_time": datetime.now()})
  199. session.commit()
  200. def update_task_conversations_status(self, task_conversations_id: int, status: int):
  201. """更新子任务状态"""
  202. with self.session_maker() as session:
  203. session.query(AgentTestTaskConversations).filter(
  204. AgentTestTaskConversations.id == task_conversations_id).update(
  205. {"status": status, "update_time": datetime.now()})
  206. session.commit()
  207. def update_task_conversations_res(self, task_conversations_id: int, status: int, input: str, output: str,
  208. score: str):
  209. """更新子任务结果"""
  210. with self.session_maker() as session:
  211. session.query(AgentTestTaskConversations).filter(
  212. AgentTestTaskConversations.id == task_conversations_id).update(
  213. {"status": status, "input": input, "output": output, "score": score, "update_time": datetime.now()})
  214. session.commit()
  215. def cancel_task(self, task_id: int):
  216. """取消任务(带事务支持)"""
  217. # 设置取消事件
  218. if task_id in self.task_events:
  219. self.task_events[task_id].set()
  220. # 如果任务正在执行,尝试取消Future
  221. if task_id in self.task_futures:
  222. self.task_futures[task_id].cancel()
  223. with self.session_maker() as session:
  224. with session.begin():
  225. session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
  226. {"status": TestTaskStatus.CANCELLED.value})
  227. session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
  228. AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).update(
  229. {"status": TestTaskConversationsStatus.CANCELLED.value})
  230. session.commit()
  231. def resume_task(self, task_id: int) -> bool:
  232. """恢复已取消的任务"""
  233. task = self.get_task(task_id)
  234. if not task or task.status != TestTaskStatus.CANCELLED.value:
  235. return False
  236. with self.session_maker() as session:
  237. with session.begin():
  238. session.query(AgentTestTask).filter(AgentTestTask.id == task_id).update(
  239. {"status": TestTaskStatus.NOT_STARTED.value})
  240. session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
  241. AgentTestTaskConversations.status == TestTaskConversationsStatus.CANCELLED.value).update(
  242. {"status": TestTaskConversationsStatus.PENDING.value})
  243. session.commit()
  244. # 重新执行任务
  245. self._execute_task(task_id)
  246. logger.info(f"Resumed task {task_id}")
  247. return True
  248. def recover_tasks(self):
  249. """服务启动时恢复未完成的任务"""
  250. creating = self.get_creating_task()
  251. for task in creating:
  252. task_id = task.id
  253. agent_id = task.agent_id
  254. module_id = task.module_id
  255. self.del_task_conversations(task_id)
  256. # 重新提交任务
  257. # 异步执行创建任务
  258. self.create_task_executor.submit(self._generate_agent_test_task_conversation_batch, task_id, agent_id,
  259. module_id)
  260. # 获取所有进行中的任务ID(根据实际状态定义查询)
  261. in_progress_tasks = self.get_in_progress_task()
  262. for task in in_progress_tasks:
  263. task_id = task.id
  264. # 重新提交任务
  265. self._execute_task(task_id)
  266. def _execute_task(self, task_id: int):
  267. """提交任务到线程池执行"""
  268. # 确保任务状态一致性
  269. if task_id in self.running_tasks:
  270. return
  271. # 创建任务事件和锁
  272. if task_id not in self.task_events:
  273. self.task_events[task_id] = threading.Event()
  274. if task_id not in self.task_locks:
  275. self.task_locks[task_id] = threading.Lock()
  276. # 提交到线程池
  277. future = self.executor.submit(self._process_task, task_id)
  278. self.task_futures[task_id] = future
  279. # 标记任务为运行中
  280. with self.task_locks[task_id]:
  281. self.running_tasks.add(task_id)
  282. def _process_task(self, task_id: int):
  283. """处理任务的所有子任务(并发执行)"""
  284. try:
  285. self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
  286. task_conversations = self.get_pending_task_conversations(task_id)
  287. if not task_conversations:
  288. self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
  289. return
  290. agent_configuration = self.get_agent_configuration_by_task_id(task_id)
  291. query_prompt_template = agent_configuration.task_prompt
  292. # 使用线程池执行子任务
  293. with ThreadPoolExecutor(max_workers=20) as executor: # 可根据需要调整并发数
  294. futures = {}
  295. for task_conversation in task_conversations:
  296. if self.task_events[task_id].is_set():
  297. break # 检查任务取消事件
  298. # 提交子任务到线程池
  299. future = executor.submit(
  300. self._process_single_conversation,
  301. task_id,
  302. task_conversation,
  303. query_prompt_template,
  304. agent_configuration
  305. )
  306. futures[future] = task_conversation.id
  307. # 等待所有子任务完成或取消
  308. for future in concurrent.futures.as_completed(futures):
  309. conv_id = futures[future]
  310. try:
  311. # 设置单个任务超时时间(秒),可根据任务复杂度调整
  312. future.result() # 获取结果(如有异常会在此抛出)
  313. except Exception as e:
  314. logger.error(f"Subtask {conv_id} failed: {str(e)}")
  315. self.update_task_conversations_status(
  316. conv_id,
  317. TestTaskConversationsStatus.FAILED.value
  318. )
  319. # 检查最终任务状态
  320. self._update_final_task_status(task_id)
  321. except Exception as e:
  322. logger.error(f"Error processing task {task_id}: {str(e)}")
  323. self.update_task_status(task_id, TestTaskStatus.FAILED.value)
  324. finally:
  325. self._cleanup_task_resources(task_id)
  326. def _process_single_conversation(self, task_id, task_conversation, query_prompt_template, agent_configuration):
  327. """处理单个对话子任务(线程安全)"""
  328. # 检查任务是否被取消
  329. if self.task_events[task_id].is_set():
  330. return
  331. # 更新子任务状态
  332. if task_conversation.status == TestTaskConversationsStatus.PENDING.value:
  333. self.update_task_conversations_status(
  334. task_conversation.id,
  335. TestTaskConversationsStatus.RUNNING.value
  336. )
  337. try:
  338. # 创建独立的agent实例(确保线程安全)
  339. agent = MultiModalChatAgent(
  340. model=agent_configuration.execution_model,
  341. system_prompt=agent_configuration.system_prompt,
  342. tools=json.loads(agent_configuration.tools)
  343. )
  344. # 获取对话数据(与原始代码相同)
  345. conversation_data = self.dataset_service.get_conversation_data_by_id(
  346. task_conversation.conversation_id)
  347. user_profile_data = self.dataset_service.get_user_profile_data(
  348. conversation_data.user_id,
  349. conversation_data.version_date.replace("-", ""))
  350. user_profile = json.loads(user_profile_data['profile_data_v1'])
  351. avatar = user_profile_data['iconurl']
  352. staff_profile_data = self.dataset_service.get_staff_profile_data(
  353. conversation_data.staff_id).agent_profile
  354. conversations = self.dataset_service.get_chat_conversation_list_by_ids(
  355. json.loads(conversation_data.conversation),
  356. conversation_data.staff_id
  357. )
  358. conversations = sorted(conversations, key=lambda i: i['timestamp'], reverse=False)
  359. # 生成推送消息(与原始代码相同)
  360. last_timestamp = int(conversations[-1]["timestamp"])
  361. push_time = int(last_timestamp / 1000) + 24 * 3600
  362. push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
  363. push_message = agent._generate_message(
  364. context={
  365. "formatted_staff_profile": staff_profile_data,
  366. "nickname": user_profile['nickname'],
  367. "name": user_profile['name'],
  368. "avatar": avatar,
  369. "preferred_nickname": user_profile['preferred_nickname'],
  370. "gender": user_profile['gender'],
  371. "age": user_profile['age'],
  372. "region": user_profile['region'],
  373. "health_conditions": user_profile['health_conditions'],
  374. "medications": user_profile['medications'],
  375. "interests": user_profile['interests'],
  376. "current_datetime": push_dt
  377. },
  378. dialogue_history=conversations,
  379. query_prompt_template=query_prompt_template
  380. )
  381. # 获取打分(TODO: 实际实现)
  382. score = '{"score":0.05}'
  383. # 更新子任务结果
  384. self.update_task_conversations_res(
  385. task_conversation.id,
  386. TestTaskConversationsStatus.SUCCESS.value,
  387. json.dumps(conversations, ensure_ascii=False),
  388. json.dumps(push_message, ensure_ascii=False),
  389. score
  390. )
  391. except Exception as e:
  392. logger.error(f"Subtask {task_conversation.id} failed: {str(e)}")
  393. self.update_task_conversations_status(
  394. task_conversation.id,
  395. TestTaskConversationsStatus.FAILED.value
  396. )
  397. raise # 重新抛出异常以便主线程捕获
  398. def _update_final_task_status(self, task_id):
  399. """更新任务的最终状态"""
  400. task_conversations = self.get_task_conversations(task_id)
  401. all_completed = all(
  402. conv.status in (TestTaskConversationsStatus.SUCCESS.value,
  403. TestTaskConversationsStatus.FAILED.value)
  404. for conv in task_conversations
  405. )
  406. if all_completed:
  407. self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
  408. logger.info(f"Task {task_id} completed")
  409. elif not any(
  410. conv.status in (TestTaskConversationsStatus.PENDING.value,
  411. TestTaskConversationsStatus.RUNNING.value)
  412. for conv in task_conversations
  413. ):
  414. current_status = self.get_task(task_id).status
  415. if current_status != TestTaskStatus.CANCELLED.value:
  416. new_status = TestTaskStatus.COMPLETED.value if all_completed else TestTaskStatus.CANCELLED.value
  417. self.update_task_status(task_id, new_status)
  418. def _cleanup_task_resources(self, task_id):
  419. """清理任务资源(线程安全)"""
  420. with self.task_locks[task_id]:
  421. if task_id in self.running_tasks:
  422. self.running_tasks.remove(task_id)
  423. if task_id in self.task_events:
  424. del self.task_events[task_id]
  425. if task_id in self.task_futures:
  426. del self.task_futures[task_id]
  427. # def _process_task(self, task_id: int):
  428. # """处理任务的所有子任务"""
  429. # try:
  430. # # 更新任务状态为运行中
  431. # self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
  432. #
  433. # # 获取所有待处理的子任务
  434. # task_conversations = self.get_pending_task_conversations(task_id)
  435. #
  436. # agent_configuration = self.get_agent_configuration_by_task_id(task_id)
  437. # query_prompt_template = agent_configuration.task_prompt
  438. # agent = MultiModalChatAgent(model=agent_configuration.execution_model,
  439. # system_prompt=agent_configuration.system_prompt,
  440. # tools=json.loads(agent_configuration.tools))
  441. # # 执行每个子任务
  442. # for task_conversation in task_conversations:
  443. # # 检查任务是否被取消
  444. # if self.task_events[task_id].is_set():
  445. # break
  446. #
  447. # # 更新子任务状态为运行中
  448. # if task_conversation.status == TestTaskConversationsStatus.PENDING.value:
  449. # self.update_task_conversations_status(task_conversation.id,
  450. # TestTaskConversationsStatus.RUNNING.value)
  451. # try:
  452. # conversation_data = self.dataset_service.get_conversation_data_by_id(
  453. # task_conversation.conversation_id)
  454. # user_profile_data = self.dataset_service.get_user_profile_data(conversation_data.user_id,
  455. # conversation_data.version_date.replace(
  456. # "-", ""))
  457. # user_profile = json.loads(user_profile_data['profile_data_v1'])
  458. # avatar = user_profile_data['iconurl']
  459. # staff_profile_data = self.dataset_service.get_staff_profile_data(
  460. # conversation_data.staff_id).agent_profile
  461. # conversations = self.dataset_service.get_chat_conversation_list_by_ids(
  462. # json.loads(conversation_data.conversation), conversation_data.staff_id)
  463. # conversations = sorted(conversations, key=lambda i: i['timestamp'], reverse=False)
  464. #
  465. # last_timestamp = int(conversations[-1]["timestamp"])
  466. # push_time = int(last_timestamp / 1000) + 24 * 3600
  467. # push_dt = datetime.fromtimestamp(push_time).strftime('%Y-%m-%d %H:%M:%S')
  468. # push_message = agent._generate_message(
  469. # context={
  470. # "formatted_staff_profile": staff_profile_data,
  471. # "nickname": user_profile['nickname'],
  472. # "name": user_profile['name'],
  473. # "avatar": avatar,
  474. # "preferred_nickname": user_profile['preferred_nickname'],
  475. # "gender": user_profile['gender'],
  476. # "age": user_profile['age'],
  477. # "region": user_profile['region'],
  478. # "health_conditions": user_profile['health_conditions'],
  479. # "medications": user_profile['medications'],
  480. # "interests": user_profile['interests'],
  481. # "current_datetime": push_dt
  482. # },
  483. # dialogue_history=conversations,
  484. # query_prompt_template=query_prompt_template
  485. # )
  486. # # TODO 获取打分
  487. # score = '{"score":0.05}'
  488. # # 更新子任务状态为已完成
  489. # self.update_task_conversations_res(task_conversation.id,
  490. # TestTaskConversationsStatus.SUCCESS.value,
  491. # json.dumps(conversations, ensure_ascii=False),
  492. # json.dumps(push_message, ensure_ascii=False),
  493. # score)
  494. # except Exception as e:
  495. # logger.error(f"Error executing task {task_id}: {str(e)}")
  496. # self.update_task_conversations_status(task_conversation.id,
  497. # TestTaskConversationsStatus.FAILED.value)
  498. #
  499. # # 检查任务是否完成
  500. # task_conversations = self.get_task_conversations(task_id)
  501. # all_completed = all(task_conversation.status in
  502. # (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value)
  503. # for task_conversation in task_conversations)
  504. # any_pending = any(task_conversation.status in
  505. # (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value)
  506. # for task_conversation in task_conversations)
  507. #
  508. # if all_completed:
  509. # self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
  510. # logger.info(f"Task {task_id} completed")
  511. # elif not any_pending:
  512. # # 没有待处理子任务但未全部完成(可能是取消了)
  513. # current_status = self.get_task(task_id).status
  514. # if current_status != TestTaskStatus.CANCELLED.value:
  515. # self.update_task_status(task_id, TestTaskStatus.COMPLETED.value
  516. # if all_completed else TestTaskStatus.CANCELLED.value)
  517. # except Exception as e:
  518. # logger.error(f"Error executing task {task_id}: {str(e)}")
  519. # self.update_task_status(task_id, TestTaskStatus.FAILED.value)
  520. # finally:
  521. # # 清理资源
  522. # with self.task_locks[task_id]:
  523. # if task_id in self.running_tasks:
  524. # self.running_tasks.remove(task_id)
  525. # if task_id in self.task_events:
  526. # del self.task_events[task_id]
  527. # if task_id in self.task_futures:
  528. # del self.task_futures[task_id]
  529. def shutdown(self):
  530. """关闭执行器"""
  531. self.executor.shutdown(wait=False)
  532. logger.info("Task executor shutdown")