task_server.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. import random
  2. import threading
  3. import time
  4. from concurrent.futures import ThreadPoolExecutor
  5. from queue import Queue
  6. from typing import List, Dict, Optional
  7. import pymysql
  8. from pymysql import Connection
  9. from pymysql.cursors import DictCursor
  10. from pqai_agent import logging_service
  11. from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc, \
  12. get_test_task_conversations_status_desc
  13. logger = logging_service.logger
  14. class Database:
  15. """数据库操作类"""
  16. def __init__(self, db_config):
  17. self.db_config = db_config
  18. self.connection_pool = Queue(maxsize=10)
  19. self._initialize_pool()
  20. def _initialize_pool(self):
  21. """初始化数据库连接池"""
  22. for _ in range(5):
  23. conn = pymysql.connect(**self.db_config)
  24. self.connection_pool.put(conn)
  25. logger.info("Database connection pool initialized with 5 connections")
  26. def get_connection(self) -> Connection:
  27. """从连接池获取数据库连接"""
  28. return self.connection_pool.get()
  29. def release_connection(self, conn: Connection):
  30. """释放数据库连接回连接池"""
  31. self.connection_pool.put(conn)
  32. def execute(self, query: str, args: tuple = (), many: bool = False) -> int:
  33. """执行SQL语句并返回影响的行数"""
  34. conn = self.get_connection()
  35. try:
  36. with conn.cursor() as cursor:
  37. if many:
  38. cursor.executemany(query, args)
  39. else:
  40. cursor.execute(query, args)
  41. conn.commit()
  42. return cursor.rowcount
  43. except Exception as e:
  44. logger.error(f"Database error: {str(e)}")
  45. conn.rollback()
  46. raise
  47. finally:
  48. self.release_connection(conn)
  49. def insert(self, insert: str, args: tuple = (), many: bool = False) -> int:
  50. """执行插入SQL语句并主键"""
  51. conn = self.get_connection()
  52. try:
  53. with conn.cursor() as cursor:
  54. if many:
  55. cursor.executemany(insert, args)
  56. else:
  57. cursor.execute(insert, args)
  58. conn.commit()
  59. return cursor.lastrowid
  60. except Exception as e:
  61. logger.error(f"Database error: {str(e)}")
  62. conn.rollback()
  63. raise
  64. finally:
  65. self.release_connection(conn)
  66. def fetch(self, query: str, args: tuple = ()) -> List[Dict]:
  67. """执行SQL查询并返回结果列表"""
  68. conn = self.get_connection()
  69. try:
  70. with conn.cursor(DictCursor) as cursor:
  71. cursor.execute(query, args)
  72. return cursor.fetchall()
  73. except Exception as e:
  74. logger.error(f"Database error: {str(e)}")
  75. raise
  76. finally:
  77. self.release_connection(conn)
  78. def fetch_one(self, query: str, args: tuple = ()) -> Optional[Dict]:
  79. """执行SQL查询并返回单行结果"""
  80. conn = self.get_connection()
  81. try:
  82. with conn.cursor(DictCursor) as cursor:
  83. cursor.execute(query, args)
  84. return cursor.fetchone()
  85. except Exception as e:
  86. logger.error(f"Database error: {str(e)}")
  87. raise
  88. finally:
  89. self.release_connection(conn)
  90. def close_all(self):
  91. """关闭所有数据库连接"""
  92. while not self.connection_pool.empty():
  93. conn = self.connection_pool.get()
  94. conn.close()
  95. logger.info("All database connections closed")
  96. class TaskManager:
  97. """任务管理器"""
  98. def __init__(self, db_config, agent_configuration_table, test_task_table, test_task_conversations_table,
  99. max_workers: int = 10):
  100. self.db = Database(db_config)
  101. self.agent_configuration_table = agent_configuration_table
  102. self.test_task_table = test_task_table
  103. self.test_task_conversations_table = test_task_conversations_table
  104. self.task_events = {} # 任务ID -> Event (用于取消任务)
  105. self.task_locks = {} # 任务ID -> Lock (用于任务状态同步)
  106. self.running_tasks = set()
  107. self.executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='TaskWorker')
  108. self.task_futures = {} # 任务ID -> Future
  109. def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
  110. fetch_query = f"""
  111. select t1.id, t2.name, t1.create_user, t1.update_user, t1.status, t1.create_time, t1.update_time
  112. from {self.test_task_table} t1
  113. left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
  114. order by create_time desc
  115. limit %s, %s;
  116. """
  117. messages = self.db.fetch(fetch_query, ((page_num - 1) * page_size, page_size))
  118. total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_table}""")
  119. total = total_size["count"]
  120. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  121. total_page = 1 if total_page <= 0 else total_page
  122. response_data = [
  123. {
  124. "id": message["id"],
  125. "agentName": message["name"],
  126. "createUser": message["create_user"],
  127. "updateUser": message["update_user"],
  128. "statusName": get_test_task_status_desc(message["status"]),
  129. "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
  130. "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
  131. }
  132. for message in messages
  133. ]
  134. return {
  135. "currentPage": page_num,
  136. "pageSize": page_size,
  137. "totalSize": total_page,
  138. "total": total,
  139. "list": response_data,
  140. }
  141. def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict:
  142. fetch_query = f"""
  143. select t1.id, t2.name, t3.create_user, t1.input, t1.output, t1.score, t1.status, t1.create_time, t1.update_time
  144. from {self.test_task_conversations_table} t1
  145. left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
  146. left join {self.test_task_table} t3 on t1.task_id = t3.id
  147. where t1.task_id = %s
  148. order by create_time desc
  149. limit %s, %s;
  150. """
  151. messages = self.db.fetch(fetch_query, (task_id, (page_num - 1) * page_size, page_size))
  152. total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_conversations_table} where task_id = %s""",
  153. (task_id,))
  154. total = total_size["count"]
  155. total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
  156. total_page = 1 if total_page <= 0 else total_page
  157. response_data = [
  158. {
  159. "id": message["id"],
  160. "agentName": message["name"],
  161. "createUser": message["create_user"],
  162. "input": message["input"],
  163. "output": message["output"],
  164. "score": message["score"],
  165. "statusName": get_test_task_conversations_status_desc(message["status"]),
  166. "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
  167. "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
  168. }
  169. for message in messages
  170. ]
  171. return {
  172. "currentPage": page_num,
  173. "pageSize": page_size,
  174. "totalSize": total_page,
  175. "total": total,
  176. "list": response_data,
  177. }
  178. def create_task(self, agent_id: int) -> Dict:
  179. """创建新任务并添加100个子任务"""
  180. conn = self.db.get_connection()
  181. try:
  182. conn.begin()
  183. # TODO 插入任务
  184. with conn.cursor() as cursor:
  185. cursor.execute(
  186. f"""INSERT INTO {self.test_task_table} (agent_id, status, create_user, update_user) VALUES (%s, %s, %s, %s)""",
  187. (agent_id, 0, 'xueyiming', 'xueyiming')
  188. )
  189. task_id = cursor.lastrowid
  190. task_conversations = []
  191. # TODO 具体的数据集信息
  192. i = 0
  193. for _ in range(30):
  194. i = i + 1
  195. task_conversations.append((
  196. task_id, agent_id, i, i, "输入", "输出", 0
  197. ))
  198. with conn.cursor() as cursor:
  199. cursor.executemany(
  200. f"""INSERT INTO {self.test_task_conversations_table} (task_id, agent_id, dataset_id, conversation_id, input, output, status)
  201. VALUES (%s, %s, %s, %s, %s, %s, %s)""",
  202. task_conversations
  203. )
  204. conn.commit()
  205. except Exception as e:
  206. conn.rollback()
  207. logger.error(f"Failed to create task agent_id {agent_id}: {str(e)}")
  208. raise
  209. finally:
  210. self.db.release_connection(conn)
  211. logger.info(f"Created task {task_id} with 100 task_conversations")
  212. # 异步执行任务
  213. self._execute_task(task_id)
  214. return self.get_task(task_id)
  215. def get_task(self, task_id: int) -> Optional[Dict]:
  216. """获取任务信息"""
  217. return self.db.fetch_one(f"""SELECT * FROM {self.test_task_table} WHERE id = %s""", (task_id,))
  218. def get_task_conversations(self, task_id: int) -> List[Dict]:
  219. """获取任务的所有子任务"""
  220. return self.db.fetch(f"""SELECT * FROM {self.test_task_conversations_table} WHERE task_id = %s""", (task_id,))
  221. def get_pending_task_conversations(self, task_id: int) -> List[Dict]:
  222. """获取待处理的子任务"""
  223. return self.db.fetch(
  224. f"""SELECT * FROM {self.test_task_conversations_table} WHERE task_id = %s AND status = %s""",
  225. (task_id, TestTaskConversationsStatus.PENDING.value)
  226. )
  227. def update_task_status(self, task_id: int, status: int):
  228. """更新任务状态"""
  229. self.db.execute(
  230. f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
  231. (status, task_id)
  232. )
  233. def update_task_conversations_status(self, task_conversations_id: int, status: int):
  234. """更新子任务状态"""
  235. self.db.execute(
  236. f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE id = %s""",
  237. (status, task_conversations_id)
  238. )
  239. def update_task_conversations_res(self, task_conversations_id: int, status: int, score: float):
  240. """更新子任务状态"""
  241. self.db.execute(
  242. f"""UPDATE {self.test_task_conversations_table} SET status = %s, score = %s WHERE id = %s""",
  243. (status, score, task_conversations_id)
  244. )
  245. def cancel_task(self, task_id: int):
  246. """取消任务(带事务支持)"""
  247. # 设置取消事件
  248. if task_id in self.task_events:
  249. self.task_events[task_id].set()
  250. # 如果任务正在执行,尝试取消Future
  251. if task_id in self.task_futures:
  252. self.task_futures[task_id].cancel()
  253. conn = self.db.get_connection()
  254. try:
  255. conn.begin()
  256. # 更新任务状态为取消
  257. with conn.cursor() as cursor:
  258. cursor.execute(
  259. f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
  260. (TestTaskStatus.CANCELLED.value, task_id)
  261. )
  262. # 取消所有待处理的子任务
  263. with conn.cursor() as cursor:
  264. cursor.execute(
  265. f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE task_id = %s AND status = %s""",
  266. (TestTaskConversationsStatus.CANCELLED.value, task_id, TestTaskConversationsStatus.PENDING.value)
  267. )
  268. conn.commit()
  269. logger.info(f"Cancelled task {task_id} and its pending {self.test_task_conversations_table}")
  270. except Exception as e:
  271. conn.rollback()
  272. logger.error(f"Failed to cancel task {task_id}: {str(e)}")
  273. finally:
  274. self.db.release_connection(conn)
  275. def resume_task(self, task_id: int) -> bool:
  276. """恢复已取消的任务"""
  277. task = self.get_task(task_id)
  278. if not task or task['status'] != TestTaskStatus.CANCELLED.value:
  279. return False
  280. conn = self.db.get_connection()
  281. try:
  282. conn.begin()
  283. # 更新任务状态为待开始
  284. with conn.cursor() as cursor:
  285. cursor.execute(
  286. f"""UPDATE {self.test_task_table} SET status = %s WHERE id = %s""",
  287. (TestTaskStatus.NOT_STARTED.value, task_id)
  288. )
  289. # 恢复所有已取消的子任务
  290. with conn.cursor() as cursor:
  291. cursor.execute(
  292. f"""UPDATE {self.test_task_conversations_table} SET status = %s WHERE task_id = %s AND status = %s""",
  293. (TestTaskConversationsStatus.PENDING.value, task_id, TestTaskConversationsStatus.CANCELLED.value)
  294. )
  295. conn.commit()
  296. logger.info(f"Cancelled task {task_id} and its pending {self.test_task_conversations_table}")
  297. except Exception as e:
  298. conn.rollback()
  299. logger.error(f"Failed to cancel task {task_id}: {str(e)}")
  300. finally:
  301. self.db.release_connection(conn)
  302. # 重新执行任务
  303. self._execute_task(task_id)
  304. logger.info(f"Resumed task {task_id}")
  305. return True
  306. def _execute_task(self, task_id: int):
  307. """提交任务到线程池执行"""
  308. # 确保任务状态一致性
  309. if task_id in self.running_tasks:
  310. return
  311. # 创建任务事件和锁
  312. if task_id not in self.task_events:
  313. self.task_events[task_id] = threading.Event()
  314. if task_id not in self.task_locks:
  315. self.task_locks[task_id] = threading.Lock()
  316. # 提交到线程池
  317. future = self.executor.submit(self._process_task, task_id)
  318. self.task_futures[task_id] = future
  319. # 标记任务为运行中
  320. with self.task_locks[task_id]:
  321. self.running_tasks.add(task_id)
  322. def _process_task(self, task_id: int):
  323. """处理任务的所有子任务"""
  324. try:
  325. # 更新任务状态为运行中
  326. self.update_task_status(task_id, TestTaskStatus.IN_PROGRESS.value)
  327. # 获取所有待处理的子任务
  328. task_conversations = self.get_pending_task_conversations(task_id)
  329. # 执行每个子任务
  330. for task_conversation in task_conversations:
  331. # 检查任务是否被取消
  332. if self.task_events[task_id].is_set():
  333. break
  334. # 更新子任务状态为运行中
  335. self.update_task_conversations_status(task_conversation['id'],
  336. TestTaskConversationsStatus.RUNNING.value)
  337. try:
  338. # 模拟任务执行 - 在实际应用中替换为实际业务逻辑
  339. # 再次检查任务是否被取消
  340. # if self.task_events[task_id].is_set():
  341. # self.update_task_conversations_status(subtask['id'], 'cancelled')
  342. # break
  343. # TODO 实际任务执行
  344. time.sleep(1)
  345. score = random.random()
  346. # 更新子任务状态为已完成
  347. self.update_task_conversations_res(task_conversation['id'],
  348. TestTaskConversationsStatus.SUCCESS.value, score)
  349. except Exception as e:
  350. logger.error(f"Error executing task {task_id}: {str(e)}")
  351. self.update_task_conversations_status(task_conversation['id'],
  352. TestTaskConversationsStatus.FAILED.value)
  353. # 检查任务是否完成
  354. task_conversations = self.get_task_conversations(task_id)
  355. all_completed = all(task_conversation['status'] in
  356. (TestTaskConversationsStatus.SUCCESS.value, TestTaskConversationsStatus.FAILED.value)
  357. for task_conversation in task_conversations)
  358. any_pending = any(task_conversation['status'] in
  359. (TestTaskConversationsStatus.PENDING.value, TestTaskConversationsStatus.RUNNING.value)
  360. for task_conversation in task_conversations)
  361. if all_completed:
  362. self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
  363. logger.info(f"Task {task_id} completed")
  364. elif not any_pending:
  365. # 没有待处理子任务但未全部完成(可能是取消了)
  366. current_status = self.get_task(task_id)['status']
  367. if current_status != TestTaskStatus.CANCELLED.value:
  368. self.update_task_status(task_id, TestTaskStatus.COMPLETED.value
  369. if all_completed else TestTaskStatus.CANCELLED.value)
  370. except Exception as e:
  371. logger.error(f"Error executing task {task_id}: {str(e)}")
  372. self.update_task_status(task_id, TestTaskStatus.COMPLETED.value)
  373. finally:
  374. # 清理资源
  375. with self.task_locks[task_id]:
  376. if task_id in self.running_tasks:
  377. self.running_tasks.remove(task_id)
  378. if task_id in self.task_events:
  379. del self.task_events[task_id]
  380. if task_id in self.task_futures:
  381. del self.task_futures[task_id]
  382. def shutdown(self):
  383. """关闭执行器"""
  384. self.executor.shutdown(wait=False)
  385. logger.info("Task executor shutdown")