task_scheduler.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. import asyncio
  2. import json
  3. import time
  4. from datetime import datetime, timedelta
  5. from typing import Optional, Dict, Any, List
  6. from app.infra.external import feishu_robot
  7. from app.infra.shared import task_schedule_response
  8. from app.jobs.task_handler import TaskHandler
  9. from app.jobs.task_config import (
  10. TaskStatus,
  11. TaskConstants,
  12. get_task_config,
  13. )
  14. from app.jobs.task_utils import (
  15. TaskError,
  16. TaskValidationError,
  17. TaskConcurrencyError,
  18. TaskUtils,
  19. )
  20. from app.core.config import GlobalConfigSettings
  21. from app.core.database import DatabaseManager
  22. from app.core.observability import LogService
  23. from app.core.task_registry import TaskRegistry
  24. class TaskScheduler(TaskHandler):
  25. """
  26. 统一任务调度器
  27. 使用方法:
  28. scheduler = TaskScheduler(data, log_service, db_client, trace_id)
  29. result = await scheduler.deal()
  30. """
  31. def __init__(
  32. self,
  33. data: dict,
  34. log_service: LogService,
  35. db_client: DatabaseManager,
  36. trace_id: str,
  37. config: GlobalConfigSettings,
  38. task_registry: TaskRegistry,
  39. ):
  40. super().__init__(data, log_service, db_client, trace_id, config)
  41. self.table = TaskUtils.validate_table_name(TaskConstants.TASK_TABLE)
  42. self.task_registry = task_registry
  43. # ==================== 数据库操作 ====================
  44. async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None:
  45. """新建任务记录(若同键已存在则忽略)"""
  46. query = f"""
  47. INSERT IGNORE INTO {self.table}
  48. (date_string, task_name, start_timestamp, task_status, trace_id, data)
  49. VALUES (%s, %s, %s, %s, %s, %s)
  50. """
  51. await self.db_client.async_save(
  52. query=query,
  53. params=(
  54. date_str,
  55. task_name,
  56. int(time.time()),
  57. TaskStatus.INIT,
  58. self.trace_id,
  59. json.dumps(self.data, ensure_ascii=False),
  60. ),
  61. )
  62. async def _try_lock_task(self) -> bool:
  63. """
  64. 尝试获取任务锁(CAS 操作)
  65. 返回 True 表示成功获取锁
  66. """
  67. query = f"""
  68. UPDATE {self.table}
  69. SET task_status = %s
  70. WHERE trace_id = %s AND task_status = %s
  71. """
  72. result = await self.db_client.async_save(
  73. query=query,
  74. params=(TaskStatus.PROCESSING, self.trace_id, TaskStatus.INIT),
  75. )
  76. return bool(result)
  77. async def _release_task(self, status: int) -> None:
  78. """释放任务锁并更新状态"""
  79. query = f"""
  80. UPDATE {self.table}
  81. SET task_status = %s, finish_timestamp = %s
  82. WHERE trace_id = %s AND task_status = %s
  83. """
  84. await self.db_client.async_save(
  85. query=query,
  86. params=(
  87. status,
  88. int(time.time()),
  89. self.trace_id,
  90. TaskStatus.PROCESSING,
  91. ),
  92. )
  93. async def _get_processing_tasks(self, task_name: str) -> List[Dict[str, Any]]:
  94. """获取正在处理中的任务列表"""
  95. query = f"""
  96. SELECT trace_id, start_timestamp, data
  97. FROM {self.table}
  98. WHERE task_status = %s AND task_name = %s
  99. """
  100. rows = await self.db_client.async_fetch(
  101. query=query,
  102. params=(TaskStatus.PROCESSING, task_name),
  103. )
  104. return rows or []
  105. # ==================== 任务检查 ====================
  106. async def _check_task_concurrency_and_timeout(self, task_name: str) -> None:
  107. """
  108. 检查任务并发数和超时情况
  109. 优化点:
  110. 1. 真正检查任务是否超时(基于时间)
  111. 2. 分别处理超时和并发限制
  112. 3. 可选择自动释放超时任务
  113. Raises:
  114. TaskTimeoutError: 发现超时任务
  115. TaskConcurrencyError: 超过并发限制
  116. """
  117. processing_tasks = await self._get_processing_tasks(task_name)
  118. if not processing_tasks:
  119. return
  120. config = get_task_config(task_name)
  121. current_time = int(time.time())
  122. # 检查超时任务
  123. timeout_tasks = [
  124. task
  125. for task in processing_tasks
  126. if current_time - task["start_timestamp"] > config.timeout
  127. ]
  128. if timeout_tasks:
  129. await self._log_task_event(
  130. "task_timeout_detected",
  131. task_name=task_name,
  132. timeout_count=len(timeout_tasks),
  133. timeout_tasks=[t["trace_id"] for t in timeout_tasks],
  134. )
  135. await feishu_robot.bot(
  136. title=f"Task Timeout Alert: {task_name}",
  137. detail={
  138. "task_name": task_name,
  139. "timeout_count": len(timeout_tasks),
  140. "timeout_threshold": config.timeout,
  141. "timeout_tasks": [
  142. {
  143. "trace_id": t["trace_id"],
  144. "running_time": current_time - t["start_timestamp"],
  145. }
  146. for t in timeout_tasks
  147. ],
  148. },
  149. )
  150. # 可选:自动释放超时任务(需要谨慎使用)
  151. for task in timeout_tasks:
  152. await self._force_release_task(task["trace_id"], TaskStatus.FAILED)
  153. # 检查并发限制(排除超时任务)
  154. active_tasks = [
  155. task
  156. for task in processing_tasks
  157. if current_time - task["start_timestamp"] <= config.timeout
  158. ]
  159. if len(active_tasks) >= config.max_concurrent:
  160. await self._log_task_event(
  161. "task_concurrency_limit",
  162. task_name=task_name,
  163. current_count=len(active_tasks),
  164. max_concurrent=config.max_concurrent,
  165. )
  166. await feishu_robot.bot(
  167. title=f"Task Concurrency Limit: {task_name}",
  168. detail={
  169. "task_name": task_name,
  170. "current_count": len(active_tasks),
  171. "max_concurrent": config.max_concurrent,
  172. "active_tasks": [t["trace_id"] for t in active_tasks],
  173. },
  174. )
  175. raise TaskConcurrencyError(
  176. f"Task {task_name} has reached max concurrency limit "
  177. f"({len(active_tasks)}/{config.max_concurrent})",
  178. task_name=task_name,
  179. )
  180. # ==================== 任务执行 ====================
  181. async def _run_with_guard(
  182. self,
  183. task_name: str,
  184. date_str: str,
  185. task_handler,
  186. ) -> dict:
  187. """
  188. 带保护的任务执行
  189. 优化点:
  190. 1. 更好的错误处理和重试机制
  191. 2. 统一的日志记录
  192. 3. 详细的错误信息
  193. """
  194. # 1. 检查并发和超时
  195. try:
  196. await self._check_task_concurrency_and_timeout(task_name)
  197. except TaskConcurrencyError as e:
  198. return await task_schedule_response.fail_response("5005", str(e))
  199. # 2. 创建任务记录并尝试获取锁
  200. await self._insert_or_ignore_task(task_name, date_str)
  201. if not await self._try_lock_task():
  202. return await task_schedule_response.fail_response(
  203. "5001", "Task is already processing"
  204. )
  205. # 3. 后台执行任务
  206. async def _task_wrapper():
  207. """任务执行包装器 - 处理错误、重试和取消"""
  208. status = TaskStatus.FAILED
  209. retry_count = 0
  210. config = get_task_config(task_name)
  211. start_time = time.time()
  212. async def _cancel_watchdog(_task: asyncio.Task):
  213. """定期轮询 DB,检测任务是否被外部取消"""
  214. while True:
  215. await asyncio.sleep(config.cancel_check_interval)
  216. row = await self.db_client.async_fetch_one(
  217. f"SELECT task_status FROM {self.table} WHERE trace_id = %s",
  218. (self.trace_id,),
  219. )
  220. if row and row["task_status"] == TaskStatus.CANCELLING:
  221. _task.cancel()
  222. return
  223. main_task = asyncio.current_task()
  224. watchdog = asyncio.create_task(
  225. _cancel_watchdog(main_task),
  226. name=f"watchdog_{task_name}_{self.trace_id}",
  227. )
  228. try:
  229. await self._log_task_event("task_started", task_name=task_name)
  230. # 执行任务
  231. status = await task_handler()
  232. duration = time.time() - start_time
  233. await self._log_task_event(
  234. "task_completed",
  235. task_name=task_name,
  236. status=status,
  237. duration=duration,
  238. )
  239. except asyncio.CancelledError:
  240. duration = time.time() - start_time
  241. await self._log_task_event(
  242. "task_cancelled", task_name=task_name, duration=duration
  243. )
  244. status = TaskStatus.CANCELLED
  245. raise
  246. except TaskError as e:
  247. # 已知的任务错误
  248. duration = time.time() - start_time
  249. error_detail = TaskUtils.format_error_detail(e)
  250. await self._log_task_event(
  251. "task_failed",
  252. task_name=task_name,
  253. error=error_detail,
  254. duration=duration,
  255. retry_count=retry_count,
  256. )
  257. # 根据错误类型决定是否告警
  258. if config.alert_on_failure:
  259. await feishu_robot.bot(
  260. title=f"Task Failed: {task_name}",
  261. detail={
  262. "task_name": task_name,
  263. "trace_id": self.trace_id,
  264. "error": error_detail,
  265. "duration": duration,
  266. "retryable": e.retryable,
  267. },
  268. )
  269. # TODO: 实现重试逻辑
  270. # if e.retryable and retry_count < config.retry_times:
  271. # await self._schedule_retry(task_name, retry_count + 1)
  272. except Exception as e:
  273. # 未知错误
  274. duration = time.time() - start_time
  275. error_detail = TaskUtils.format_error_detail(e)
  276. await self._log_task_event(
  277. "task_error",
  278. task_name=task_name,
  279. error=error_detail,
  280. duration=duration,
  281. )
  282. await feishu_robot.bot(
  283. title=f"Task Error: {task_name}",
  284. detail={
  285. "task_name": task_name,
  286. "trace_id": self.trace_id,
  287. "error": error_detail,
  288. "duration": duration,
  289. },
  290. )
  291. finally:
  292. watchdog.cancel()
  293. try:
  294. await watchdog
  295. except asyncio.CancelledError:
  296. pass
  297. await self._release_task(status)
  298. await self.task_registry.unregister(self.trace_id)
  299. # 创建后台任务并注册
  300. task = asyncio.create_task(_task_wrapper(), name=f"{task_name}_{self.trace_id}")
  301. await self.task_registry.register(self.trace_id, task)
  302. return await task_schedule_response.success_response(
  303. task_name=task_name,
  304. data={
  305. "code": 0,
  306. "message": "Task started successfully",
  307. "trace_id": self.trace_id,
  308. },
  309. )
  310. # ==================== 任务管理接口 ====================
  311. async def get_task_status(
  312. self, trace_id: Optional[str] = None
  313. ) -> Optional[Dict[str, Any]]:
  314. """
  315. 查询任务状态
  316. Args:
  317. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  318. Returns:
  319. 任务信息字典,如果不存在返回 None
  320. """
  321. trace_id = trace_id or self.trace_id
  322. query = f"SELECT * FROM {self.table} WHERE trace_id = %s"
  323. result = await self.db_client.async_fetch_one(query, (trace_id,))
  324. return result
  325. async def cancel_task(self, trace_id: Optional[str] = None) -> bool:
  326. """
  327. 请求取消任务
  328. 流程: PROCESSING -> CANCELLING (由 watchdog 检测后取消协程) -> CANCELLED
  329. 对于 INIT 状态的任务直接标记为 CANCELLED
  330. Args:
  331. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  332. Returns:
  333. 是否成功发起取消
  334. """
  335. trace_id = trace_id or self.trace_id
  336. # INIT 状态的任务还没开始执行,直接标记为 CANCELLED
  337. init_query = f"""
  338. UPDATE {self.table}
  339. SET task_status = %s, finish_timestamp = %s
  340. WHERE trace_id = %s AND task_status = %s
  341. """
  342. init_result = await self.db_client.async_save(
  343. init_query,
  344. (TaskStatus.CANCELLED, int(time.time()), trace_id, TaskStatus.INIT),
  345. )
  346. if init_result:
  347. await self._log_task_event("task_cancelled", trace_id=trace_id)
  348. return True
  349. # PROCESSING 状态的任务标记为 CANCELLING,等待 watchdog 检测并取消协程
  350. processing_query = f"""
  351. UPDATE {self.table}
  352. SET task_status = %s
  353. WHERE trace_id = %s AND task_status = %s
  354. """
  355. processing_result = await self.db_client.async_save(
  356. processing_query,
  357. (TaskStatus.CANCELLING, trace_id, TaskStatus.PROCESSING),
  358. )
  359. if processing_result:
  360. # 同进程优化:直接取消协程,不用等 watchdog 轮询
  361. await self.task_registry.cancel_task(trace_id)
  362. await self._log_task_event(
  363. "task_cancelling", trace_id=trace_id
  364. )
  365. return bool(processing_result)
  366. async def retry_task(self, trace_id: Optional[str] = None) -> bool:
  367. """
  368. 重试任务(将状态重置为初始化)
  369. Args:
  370. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  371. Returns:
  372. 是否成功重置
  373. """
  374. trace_id = trace_id or self.trace_id
  375. query = f"""
  376. UPDATE {self.table}
  377. SET task_status = %s, start_timestamp = %s, finish_timestamp = NULL
  378. WHERE trace_id = %s
  379. """
  380. result = await self.db_client.async_save(
  381. query,
  382. (TaskStatus.INIT, int(time.time()), trace_id),
  383. )
  384. if result:
  385. await self._log_task_event("task_retried", trace_id=trace_id)
  386. return bool(result)
  387. async def _force_release_task(self, trace_id: str, status: int) -> None:
  388. """强制释放任务(用于超时任务清理)"""
  389. query = f"""
  390. UPDATE {self.table}
  391. SET task_status = %s, finish_timestamp = %s
  392. WHERE trace_id = %s
  393. """
  394. await self.db_client.async_save(
  395. query,
  396. (status, int(time.time()), trace_id),
  397. )
  398. await self._log_task_event(
  399. "task_force_released", trace_id=trace_id, status=status
  400. )
  401. # ==================== 主入口 ====================
  402. async def deal(self) -> dict:
  403. """
  404. 任务调度主入口
  405. Returns:
  406. 调度结果字典
  407. """
  408. # 验证任务名
  409. task_name = self.data.get("task_name")
  410. if not task_name:
  411. return await task_schedule_response.fail_response(
  412. "4003", "task_name is required"
  413. )
  414. try:
  415. task_name = TaskUtils.validate_task_name(task_name)
  416. except TaskValidationError as e:
  417. return await task_schedule_response.fail_response("4003", str(e))
  418. # 获取日期
  419. date_str = self.data.get("date_string") or (
  420. datetime.utcnow() + timedelta(hours=8)
  421. ).strftime("%Y-%m-%d")
  422. # 获取任务处理器
  423. handler = self.get_handler(task_name)
  424. if not handler:
  425. return await task_schedule_response.fail_response(
  426. "4001",
  427. f"Unknown task: {task_name}. "
  428. f"Available tasks: {', '.join(self.list_registered_tasks())}",
  429. )
  430. # 执行任务
  431. return await self._run_with_guard(
  432. task_name,
  433. date_str,
  434. lambda: handler(self),
  435. )
  436. __all__ = ["TaskScheduler"]