task_scheduler.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. import asyncio
  2. import json
  3. import logging
  4. import time
  5. from datetime import datetime, timedelta
  6. from typing import Optional, Dict, Any, List
  7. from app.infra.shared import task_schedule_response
  8. from app.jobs.task_handler import TaskHandler
  9. from app.jobs.task_config import (
  10. TaskStatus,
  11. TaskConstants,
  12. get_task_config,
  13. )
  14. from app.jobs.task_utils import (
  15. TaskError,
  16. TaskValidationError,
  17. TaskConcurrencyError,
  18. TaskUtils,
  19. )
  20. from app.jobs.task_lifecycle import TaskLifecycleManager
  21. from app.core.config import GlobalConfigSettings
  22. from app.core.database import DatabaseManager
  23. from app.core.observability import LogService, AlertService
  24. logger = logging.getLogger(__name__)
  25. class TaskScheduler(TaskHandler):
  26. """
  27. 统一任务调度器
  28. 使用方法:
  29. scheduler = TaskScheduler(data, log_service, db_client, trace_id)
  30. result = await scheduler.deal()
  31. """
  32. def __init__(
  33. self,
  34. data: dict,
  35. log_service: LogService,
  36. db_client: DatabaseManager,
  37. trace_id: str,
  38. config: GlobalConfigSettings,
  39. ):
  40. super().__init__(data, log_service, db_client, trace_id, config)
  41. self.table = TaskUtils.validate_table_name(TaskConstants.TASK_TABLE)
  42. async def _send_alert(self, title: str, detail: dict, dedup_key: str = None):
  43. """发送告警(异步解耦,不阻塞主链路)"""
  44. alert = AlertService.get_instance()
  45. if alert:
  46. await alert.send_alert(title=title, detail=detail, dedup_key=dedup_key)
  47. # ==================== 数据库操作 ====================
  48. async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None:
  49. """新建任务记录(若同键已存在则忽略)"""
  50. query = f"""
  51. INSERT IGNORE INTO {self.table}
  52. (date_string, task_name, start_timestamp, task_status, trace_id, data)
  53. VALUES (%s, %s, %s, %s, %s, %s)
  54. """
  55. await self.db_client.async_save(
  56. query=query,
  57. params=(
  58. date_str,
  59. task_name,
  60. int(time.time()),
  61. TaskStatus.INIT,
  62. self.trace_id,
  63. json.dumps(self.data, ensure_ascii=False),
  64. ),
  65. )
  66. async def _try_lock_task(self) -> bool:
  67. """
  68. 尝试获取任务锁(CAS 操作)
  69. 返回 True 表示成功获取锁
  70. """
  71. query = f"""
  72. UPDATE {self.table}
  73. SET task_status = %s
  74. WHERE trace_id = %s AND task_status = %s
  75. """
  76. result = await self.db_client.async_save(
  77. query=query,
  78. params=(TaskStatus.PROCESSING, self.trace_id, TaskStatus.INIT),
  79. )
  80. return bool(result)
  81. async def _release_task(self, status: int) -> None:
  82. """释放任务锁并更新状态"""
  83. query = f"""
  84. UPDATE {self.table}
  85. SET task_status = %s, finish_timestamp = %s
  86. WHERE trace_id = %s AND task_status IN (%s, %s)
  87. """
  88. await self.db_client.async_save(
  89. query=query,
  90. params=(
  91. status,
  92. int(time.time()),
  93. self.trace_id,
  94. TaskStatus.PROCESSING,
  95. TaskStatus.CANCEL_REQUESTED,
  96. ),
  97. )
  98. async def _get_processing_tasks(self, task_name: str) -> List[Dict[str, Any]]:
  99. """获取正在处理中的任务列表"""
  100. query = f"""
  101. SELECT trace_id, start_timestamp, data
  102. FROM {self.table}
  103. WHERE task_status = %s AND task_name = %s
  104. """
  105. rows = await self.db_client.async_fetch(
  106. query=query,
  107. params=(TaskStatus.PROCESSING, task_name),
  108. )
  109. return rows or []
  110. # ==================== 任务检查 ====================
  111. async def _check_task_concurrency_and_timeout(self, task_name: str) -> None:
  112. """
  113. 检查任务并发数和超时情况
  114. 优化点:
  115. 1. 真正检查任务是否超时(基于时间)
  116. 2. 分别处理超时和并发限制
  117. 3. 可选择自动释放超时任务
  118. Raises:
  119. TaskTimeoutError: 发现超时任务
  120. TaskConcurrencyError: 超过并发限制
  121. """
  122. processing_tasks = await self._get_processing_tasks(task_name)
  123. if not processing_tasks:
  124. return
  125. config = get_task_config(task_name)
  126. current_time = int(time.time())
  127. # 检查超时任务
  128. timeout_tasks = [
  129. task
  130. for task in processing_tasks
  131. if current_time - task["start_timestamp"] > config.timeout
  132. ]
  133. if timeout_tasks:
  134. await self._log_task_event(
  135. "task_timeout_detected",
  136. task_name=task_name,
  137. timeout_count=len(timeout_tasks),
  138. timeout_tasks=[t["trace_id"] for t in timeout_tasks],
  139. )
  140. await self._send_alert(
  141. title=f"Task Timeout Alert: {task_name}",
  142. detail={
  143. "task_name": task_name,
  144. "timeout_count": len(timeout_tasks),
  145. "timeout_threshold": config.timeout,
  146. "timeout_tasks": [
  147. {
  148. "trace_id": t["trace_id"],
  149. "running_time": current_time - t["start_timestamp"],
  150. }
  151. for t in timeout_tasks
  152. ],
  153. },
  154. dedup_key=f"timeout_{task_name}",
  155. )
  156. # 可选:自动释放超时任务(需要谨慎使用)
  157. for task in timeout_tasks:
  158. await self._force_release_task(task["trace_id"], TaskStatus.FAILED)
  159. # 检查并发限制(排除超时任务)
  160. active_tasks = [
  161. task
  162. for task in processing_tasks
  163. if current_time - task["start_timestamp"] <= config.timeout
  164. ]
  165. if len(active_tasks) >= config.max_concurrent:
  166. await self._log_task_event(
  167. "task_concurrency_limit",
  168. task_name=task_name,
  169. current_count=len(active_tasks),
  170. max_concurrent=config.max_concurrent,
  171. )
  172. await self._send_alert(
  173. title=f"Task Concurrency Limit: {task_name}",
  174. detail={
  175. "task_name": task_name,
  176. "current_count": len(active_tasks),
  177. "max_concurrent": config.max_concurrent,
  178. "active_tasks": [t["trace_id"] for t in active_tasks],
  179. },
  180. dedup_key=f"concurrency_{task_name}",
  181. )
  182. raise TaskConcurrencyError(
  183. f"Task {task_name} has reached max concurrency limit "
  184. f"({len(active_tasks)}/{config.max_concurrent})",
  185. task_name=task_name,
  186. )
  187. # ==================== 任务执行 ====================
  188. async def _run_with_guard(
  189. self,
  190. task_name: str,
  191. date_str: str,
  192. task_handler,
  193. ) -> dict:
  194. """
  195. 带保护的任务执行
  196. 优化点:
  197. 1. 更好的错误处理和重试机制
  198. 2. 统一的日志记录
  199. 3. 详细的错误信息
  200. """
  201. # 1. 检查并发和超时
  202. try:
  203. await self._check_task_concurrency_and_timeout(task_name)
  204. except TaskConcurrencyError as e:
  205. return await task_schedule_response.fail_response("5005", str(e))
  206. # 2. 创建任务记录并尝试获取锁
  207. await self._insert_or_ignore_task(task_name, date_str)
  208. if not await self._try_lock_task():
  209. return await task_schedule_response.fail_response(
  210. "5001", "Task is already processing"
  211. )
  212. # 3. 后台执行任务
  213. async def _task_wrapper():
  214. """任务执行包装器 - 处理错误和重试"""
  215. status = TaskStatus.FAILED
  216. retry_count = 0
  217. config = get_task_config(task_name)
  218. start_time = time.time()
  219. try:
  220. await self._log_task_event("task_started", task_name=task_name)
  221. # 执行任务
  222. status = await task_handler()
  223. duration = time.time() - start_time
  224. await self._log_task_event(
  225. "task_completed",
  226. task_name=task_name,
  227. status=status,
  228. duration=duration,
  229. )
  230. except TaskError as e:
  231. # 已知的任务错误
  232. duration = time.time() - start_time
  233. error_detail = TaskUtils.format_error_detail(e)
  234. await self._log_task_event(
  235. "task_failed",
  236. task_name=task_name,
  237. error=error_detail,
  238. duration=duration,
  239. retry_count=retry_count,
  240. )
  241. # 根据错误类型决定是否告警
  242. if config.alert_on_failure:
  243. await self._send_alert(
  244. title=f"Task Failed: {task_name}",
  245. detail={
  246. "task_name": task_name,
  247. "trace_id": self.trace_id,
  248. "error": error_detail,
  249. "duration": duration,
  250. "retryable": e.retryable,
  251. },
  252. dedup_key=f"task_failed_{task_name}_{self.trace_id}",
  253. )
  254. # TODO: 实现重试逻辑
  255. # if e.retryable and retry_count < config.retry_times:
  256. # await self._schedule_retry(task_name, retry_count + 1)
  257. except asyncio.CancelledError:
  258. # 任务被取消
  259. status = TaskStatus.CANCELLED
  260. duration = time.time() - start_time
  261. await self._log_task_event(
  262. "task_cancelled",
  263. task_name=task_name,
  264. duration=duration,
  265. )
  266. raise
  267. except Exception as e:
  268. # 未知错误
  269. duration = time.time() - start_time
  270. error_detail = TaskUtils.format_error_detail(e)
  271. await self._log_task_event(
  272. "task_error",
  273. task_name=task_name,
  274. error=error_detail,
  275. duration=duration,
  276. )
  277. await self._send_alert(
  278. title=f"Task Error: {task_name}",
  279. detail={
  280. "task_name": task_name,
  281. "trace_id": self.trace_id,
  282. "error": error_detail,
  283. "duration": duration,
  284. },
  285. dedup_key=f"task_error_{task_name}_{self.trace_id}",
  286. )
  287. finally:
  288. await self._release_task(status)
  289. lifecycle = TaskLifecycleManager.get_instance()
  290. if lifecycle:
  291. await lifecycle.unregister(self.trace_id)
  292. # 创建后台任务
  293. task = asyncio.create_task(
  294. _task_wrapper(), name=f"{task_name}_{self.trace_id}"
  295. )
  296. lifecycle = TaskLifecycleManager.get_instance()
  297. if lifecycle:
  298. await lifecycle.register(self.trace_id, task)
  299. return await task_schedule_response.success_response(
  300. task_name=task_name,
  301. data={
  302. "code": 0,
  303. "message": "Task started successfully",
  304. "trace_id": self.trace_id,
  305. },
  306. )
  307. # ==================== 任务管理接口 ====================
  308. async def get_task_status(
  309. self, trace_id: Optional[str] = None
  310. ) -> Optional[Dict[str, Any]]:
  311. """
  312. 查询任务状态
  313. Args:
  314. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  315. Returns:
  316. 任务信息字典,如果不存在返回 None
  317. """
  318. trace_id = trace_id or self.trace_id
  319. query = f"SELECT * FROM {self.table} WHERE trace_id = %s"
  320. result = await self.db_client.async_fetch_one(query, params=(trace_id,))
  321. return result
  322. async def cancel_task(self, trace_id: Optional[str] = None) -> bool:
  323. """
  324. 取消任务
  325. INIT 状态直接设为 CANCELLED,PROCESSING 状态设为 CANCEL_REQUESTED
  326. 等待轮询器检测到信号后取消本地协程
  327. Args:
  328. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  329. Returns:
  330. 是否成功取消
  331. """
  332. trace_id = trace_id or self.trace_id
  333. query = f"""
  334. UPDATE {self.table}
  335. SET task_status = CASE
  336. WHEN task_status = %s THEN %s
  337. WHEN task_status = %s THEN %s
  338. END,
  339. finish_timestamp = CASE
  340. WHEN task_status = %s THEN %s
  341. ELSE finish_timestamp
  342. END
  343. WHERE trace_id = %s AND task_status IN (%s, %s)
  344. """
  345. result = await self.db_client.async_save(
  346. query,
  347. (
  348. TaskStatus.INIT,
  349. TaskStatus.CANCELLED,
  350. TaskStatus.PROCESSING,
  351. TaskStatus.CANCEL_REQUESTED,
  352. TaskStatus.INIT,
  353. int(time.time()),
  354. trace_id,
  355. TaskStatus.INIT,
  356. TaskStatus.PROCESSING,
  357. ),
  358. )
  359. if result:
  360. await self._log_task_event("task_cancel_requested", trace_id=trace_id)
  361. return bool(result)
  362. async def retry_task(self, trace_id: Optional[str] = None) -> bool:
  363. """
  364. 重试任务(将状态重置为初始化)
  365. Args:
  366. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  367. Returns:
  368. 是否成功重置
  369. """
  370. trace_id = trace_id or self.trace_id
  371. query = f"""
  372. UPDATE {self.table}
  373. SET task_status = %s, start_timestamp = %s, finish_timestamp = NULL
  374. WHERE trace_id = %s
  375. """
  376. result = await self.db_client.async_save(
  377. query,
  378. (TaskStatus.INIT, int(time.time()), trace_id),
  379. )
  380. if result:
  381. await self._log_task_event("task_retried", trace_id=trace_id)
  382. return bool(result)
  383. async def _force_release_task(self, trace_id: str, status: int) -> None:
  384. """强制释放任务(用于超时任务清理)"""
  385. query = f"""
  386. UPDATE {self.table}
  387. SET task_status = %s, finish_timestamp = %s
  388. WHERE trace_id = %s
  389. """
  390. await self.db_client.async_save(
  391. query,
  392. (status, int(time.time()), trace_id),
  393. )
  394. await self._log_task_event(
  395. "task_force_released", trace_id=trace_id, status=status
  396. )
  397. # ==================== 主入口 ====================
  398. async def deal(self) -> dict:
  399. """
  400. 任务调度主入口
  401. Returns:
  402. 调度结果字典
  403. """
  404. # 验证任务名
  405. task_name = self.data.get("task_name")
  406. if not task_name:
  407. return await task_schedule_response.fail_response(
  408. "4003", "task_name is required"
  409. )
  410. try:
  411. task_name = TaskUtils.validate_task_name(task_name)
  412. except TaskValidationError as e:
  413. return await task_schedule_response.fail_response("4003", str(e))
  414. # 获取日期
  415. date_str = self.data.get("date_string") or (
  416. datetime.utcnow() + timedelta(hours=8)
  417. ).strftime("%Y-%m-%d")
  418. # 获取任务处理器
  419. handler = self.get_handler(task_name)
  420. if not handler:
  421. return await task_schedule_response.fail_response(
  422. "4001",
  423. f"Unknown task: {task_name}. "
  424. f"Available tasks: {', '.join(self.list_registered_tasks())}",
  425. )
  426. # 执行任务
  427. return await self._run_with_guard(
  428. task_name,
  429. date_str,
  430. lambda: handler(self),
  431. )
  432. __all__ = ["TaskScheduler"]