task_scheduler.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. import asyncio
  2. import json
  3. import time
  4. from datetime import datetime, timedelta
  5. from typing import Optional, Dict, Any, List
  6. from applications.api import feishu_robot
  7. from applications.utils import task_schedule_response
  8. from applications.tasks.task_handler import TaskHandler
  9. from applications.tasks.task_config import (
  10. TaskStatus,
  11. TaskConstants,
  12. get_task_config,
  13. )
  14. from applications.tasks.task_utils import (
  15. TaskError,
  16. TaskValidationError,
  17. TaskTimeoutError,
  18. TaskConcurrencyError,
  19. TaskLockError,
  20. TaskUtils,
  21. )
  22. class TaskScheduler(TaskHandler):
  23. """
  24. 统一任务调度器
  25. 使用方法:
  26. scheduler = TaskScheduler(data, log_service, db_client, trace_id)
  27. result = await scheduler.deal()
  28. """
  29. def __init__(self, data: dict, log_service, db_client, trace_id: str):
  30. super().__init__(data, log_service, db_client, trace_id)
  31. self.table = TaskUtils.validate_table_name(TaskConstants.TASK_TABLE)
  32. # ==================== 数据库操作 ====================
  33. async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None:
  34. """新建任务记录(若同键已存在则忽略)"""
  35. query = f"""
  36. INSERT IGNORE INTO {self.table}
  37. (date_string, task_name, start_timestamp, task_status, trace_id, data)
  38. VALUES (%s, %s, %s, %s, %s, %s)
  39. """
  40. await self.db_client.async_save(
  41. query=query,
  42. params=(
  43. date_str,
  44. task_name,
  45. int(time.time()),
  46. TaskStatus.INIT,
  47. self.trace_id,
  48. json.dumps(self.data, ensure_ascii=False),
  49. ),
  50. )
  51. async def _try_lock_task(self) -> bool:
  52. """
  53. 尝试获取任务锁(CAS 操作)
  54. 返回 True 表示成功获取锁
  55. """
  56. query = f"""
  57. UPDATE {self.table}
  58. SET task_status = %s
  59. WHERE trace_id = %s AND task_status = %s
  60. """
  61. result = await self.db_client.async_save(
  62. query=query,
  63. params=(TaskStatus.PROCESSING, self.trace_id, TaskStatus.INIT),
  64. )
  65. return bool(result)
  66. async def _release_task(self, status: int) -> None:
  67. """释放任务锁并更新状态"""
  68. query = f"""
  69. UPDATE {self.table}
  70. SET task_status = %s, finish_timestamp = %s
  71. WHERE trace_id = %s AND task_status = %s
  72. """
  73. await self.db_client.async_save(
  74. query=query,
  75. params=(
  76. status,
  77. int(time.time()),
  78. self.trace_id,
  79. TaskStatus.PROCESSING,
  80. ),
  81. )
  82. async def _get_processing_tasks(self, task_name: str) -> List[Dict[str, Any]]:
  83. """获取正在处理中的任务列表"""
  84. query = f"""
  85. SELECT trace_id, start_timestamp, data
  86. FROM {self.table}
  87. WHERE task_status = %s AND task_name = %s
  88. """
  89. rows = await self.db_client.async_fetch(
  90. query=query,
  91. params=(TaskStatus.PROCESSING, task_name),
  92. )
  93. return rows or []
  94. # ==================== 任务检查 ====================
  95. async def _check_task_concurrency_and_timeout(self, task_name: str) -> None:
  96. """
  97. 检查任务并发数和超时情况
  98. 优化点:
  99. 1. 真正检查任务是否超时(基于时间)
  100. 2. 分别处理超时和并发限制
  101. 3. 可选择自动释放超时任务
  102. Raises:
  103. TaskTimeoutError: 发现超时任务
  104. TaskConcurrencyError: 超过并发限制
  105. """
  106. processing_tasks = await self._get_processing_tasks(task_name)
  107. if not processing_tasks:
  108. return
  109. config = get_task_config(task_name)
  110. current_time = int(time.time())
  111. # 检查超时任务
  112. timeout_tasks = [
  113. task
  114. for task in processing_tasks
  115. if current_time - task["start_timestamp"] > config.timeout
  116. ]
  117. if timeout_tasks:
  118. await self._log_task_event(
  119. "task_timeout_detected",
  120. task_name=task_name,
  121. timeout_count=len(timeout_tasks),
  122. timeout_tasks=[t["trace_id"] for t in timeout_tasks],
  123. )
  124. await feishu_robot.bot(
  125. title=f"Task Timeout Alert: {task_name}",
  126. detail={
  127. "task_name": task_name,
  128. "timeout_count": len(timeout_tasks),
  129. "timeout_threshold": config.timeout,
  130. "timeout_tasks": [
  131. {
  132. "trace_id": t["trace_id"],
  133. "running_time": current_time - t["start_timestamp"],
  134. }
  135. for t in timeout_tasks
  136. ],
  137. },
  138. )
  139. # 可选:自动释放超时任务(需要谨慎使用)
  140. for task in timeout_tasks:
  141. await self._force_release_task(task["trace_id"], TaskStatus.FAILED)
  142. # 检查并发限制(排除超时任务)
  143. active_tasks = [
  144. task
  145. for task in processing_tasks
  146. if current_time - task["start_timestamp"] <= config.timeout
  147. ]
  148. if len(active_tasks) >= config.max_concurrent:
  149. await self._log_task_event(
  150. "task_concurrency_limit",
  151. task_name=task_name,
  152. current_count=len(active_tasks),
  153. max_concurrent=config.max_concurrent,
  154. )
  155. await feishu_robot.bot(
  156. title=f"Task Concurrency Limit: {task_name}",
  157. detail={
  158. "task_name": task_name,
  159. "current_count": len(active_tasks),
  160. "max_concurrent": config.max_concurrent,
  161. "active_tasks": [t["trace_id"] for t in active_tasks],
  162. },
  163. )
  164. raise TaskConcurrencyError(
  165. f"Task {task_name} has reached max concurrency limit "
  166. f"({len(active_tasks)}/{config.max_concurrent})",
  167. task_name=task_name,
  168. )
  169. # ==================== 任务执行 ====================
  170. async def _run_with_guard(
  171. self,
  172. task_name: str,
  173. date_str: str,
  174. task_handler,
  175. ) -> dict:
  176. """
  177. 带保护的任务执行
  178. 优化点:
  179. 1. 更好的错误处理和重试机制
  180. 2. 统一的日志记录
  181. 3. 详细的错误信息
  182. """
  183. # 1. 检查并发和超时
  184. try:
  185. await self._check_task_concurrency_and_timeout(task_name)
  186. except TaskConcurrencyError as e:
  187. return await task_schedule_response.fail_response("5005", str(e))
  188. # 2. 创建任务记录并尝试获取锁
  189. await self._insert_or_ignore_task(task_name, date_str)
  190. if not await self._try_lock_task():
  191. return await task_schedule_response.fail_response(
  192. "5001", "Task is already processing"
  193. )
  194. # 3. 后台执行任务
  195. async def _task_wrapper():
  196. """任务执行包装器 - 处理错误和重试"""
  197. status = TaskStatus.FAILED
  198. retry_count = 0
  199. config = get_task_config(task_name)
  200. start_time = time.time()
  201. try:
  202. await self._log_task_event("task_started", task_name=task_name)
  203. # 执行任务
  204. status = await task_handler()
  205. duration = time.time() - start_time
  206. await self._log_task_event(
  207. "task_completed",
  208. task_name=task_name,
  209. status=status,
  210. duration=duration,
  211. )
  212. except TaskError as e:
  213. # 已知的任务错误
  214. duration = time.time() - start_time
  215. error_detail = TaskUtils.format_error_detail(e)
  216. await self._log_task_event(
  217. "task_failed",
  218. task_name=task_name,
  219. error=error_detail,
  220. duration=duration,
  221. retry_count=retry_count,
  222. )
  223. # 根据错误类型决定是否告警
  224. if config.alert_on_failure:
  225. await feishu_robot.bot(
  226. title=f"Task Failed: {task_name}",
  227. detail={
  228. "task_name": task_name,
  229. "trace_id": self.trace_id,
  230. "error": error_detail,
  231. "duration": duration,
  232. "retryable": e.retryable,
  233. },
  234. )
  235. # TODO: 实现重试逻辑
  236. # if e.retryable and retry_count < config.retry_times:
  237. # await self._schedule_retry(task_name, retry_count + 1)
  238. except Exception as e:
  239. # 未知错误
  240. duration = time.time() - start_time
  241. error_detail = TaskUtils.format_error_detail(e)
  242. await self._log_task_event(
  243. "task_error",
  244. task_name=task_name,
  245. error=error_detail,
  246. duration=duration,
  247. )
  248. await feishu_robot.bot(
  249. title=f"Task Error: {task_name}",
  250. detail={
  251. "task_name": task_name,
  252. "trace_id": self.trace_id,
  253. "error": error_detail,
  254. "duration": duration,
  255. },
  256. )
  257. finally:
  258. await self._release_task(status)
  259. # 创建后台任务
  260. asyncio.create_task(_task_wrapper(), name=f"{task_name}_{self.trace_id}")
  261. return await task_schedule_response.success_response(
  262. task_name=task_name,
  263. data={
  264. "code": 0,
  265. "message": "Task started successfully",
  266. "trace_id": self.trace_id,
  267. },
  268. )
  269. # ==================== 任务管理接口 ====================
  270. async def get_task_status(
  271. self, trace_id: Optional[str] = None
  272. ) -> Optional[Dict[str, Any]]:
  273. """
  274. 查询任务状态
  275. Args:
  276. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  277. Returns:
  278. 任务信息字典,如果不存在返回 None
  279. """
  280. trace_id = trace_id or self.trace_id
  281. query = f"SELECT * FROM {self.table} WHERE trace_id = %s"
  282. result = await self.db_client.async_fetch_one(query, (trace_id,))
  283. return result
  284. async def cancel_task(self, trace_id: Optional[str] = None) -> bool:
  285. """
  286. 取消任务(将状态设置为失败)
  287. Args:
  288. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  289. Returns:
  290. 是否成功取消
  291. """
  292. trace_id = trace_id or self.trace_id
  293. query = f"""
  294. UPDATE {self.table}
  295. SET task_status = %s, finish_timestamp = %s
  296. WHERE trace_id = %s AND task_status IN (%s, %s)
  297. """
  298. result = await self.db_client.async_save(
  299. query,
  300. (
  301. TaskStatus.FAILED,
  302. int(time.time()),
  303. trace_id,
  304. TaskStatus.INIT,
  305. TaskStatus.PROCESSING,
  306. ),
  307. )
  308. if result:
  309. await self._log_task_event("task_cancelled", trace_id=trace_id)
  310. return bool(result)
  311. async def retry_task(self, trace_id: Optional[str] = None) -> bool:
  312. """
  313. 重试任务(将状态重置为初始化)
  314. Args:
  315. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  316. Returns:
  317. 是否成功重置
  318. """
  319. trace_id = trace_id or self.trace_id
  320. query = f"""
  321. UPDATE {self.table}
  322. SET task_status = %s, start_timestamp = %s, finish_timestamp = NULL
  323. WHERE trace_id = %s
  324. """
  325. result = await self.db_client.async_save(
  326. query,
  327. (TaskStatus.INIT, int(time.time()), trace_id),
  328. )
  329. if result:
  330. await self._log_task_event("task_retried", trace_id=trace_id)
  331. return bool(result)
  332. async def _force_release_task(self, trace_id: str, status: int) -> None:
  333. """强制释放任务(用于超时任务清理)"""
  334. query = f"""
  335. UPDATE {self.table}
  336. SET task_status = %s, finish_timestamp = %s
  337. WHERE trace_id = %s
  338. """
  339. await self.db_client.async_save(
  340. query,
  341. (status, int(time.time()), trace_id),
  342. )
  343. await self._log_task_event(
  344. "task_force_released", trace_id=trace_id, status=status
  345. )
  346. # ==================== 主入口 ====================
  347. async def deal(self) -> dict:
  348. """
  349. 任务调度主入口
  350. Returns:
  351. 调度结果字典
  352. """
  353. # 验证任务名
  354. task_name = self.data.get("task_name")
  355. if not task_name:
  356. return await task_schedule_response.fail_response(
  357. "4003", "task_name is required"
  358. )
  359. try:
  360. task_name = TaskUtils.validate_task_name(task_name)
  361. except TaskValidationError as e:
  362. return await task_schedule_response.fail_response("4003", str(e))
  363. # 获取日期
  364. date_str = self.data.get("date_string") or (
  365. datetime.utcnow() + timedelta(hours=8)
  366. ).strftime("%Y-%m-%d")
  367. # 获取任务处理器
  368. handler = self.get_handler(task_name)
  369. if not handler:
  370. return await task_schedule_response.fail_response(
  371. "4001",
  372. f"Unknown task: {task_name}. "
  373. f"Available tasks: {', '.join(self.list_registered_tasks())}",
  374. )
  375. # 执行任务
  376. return await self._run_with_guard(
  377. task_name,
  378. date_str,
  379. lambda: handler(self),
  380. )
  381. __all__ = ["TaskScheduler"]