task_scheduler.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. import asyncio
  2. import json
  3. import time
  4. from datetime import datetime, timedelta
  5. from typing import Optional, Dict, Any, List
  6. from app.infra.external import feishu_robot
  7. from app.infra.shared import task_schedule_response
  8. from app.jobs.task_handler import TaskHandler
  9. from app.jobs.task_config import (
  10. TaskStatus,
  11. TaskConstants,
  12. get_task_config,
  13. )
  14. from app.jobs.task_utils import (
  15. TaskError,
  16. TaskValidationError,
  17. TaskConcurrencyError,
  18. TaskUtils,
  19. )
  20. from app.core.config import GlobalConfigSettings
  21. from app.core.database import DatabaseManager
  22. from app.core.observability import LogService
  23. class TaskScheduler(TaskHandler):
  24. """
  25. 统一任务调度器
  26. 使用方法:
  27. scheduler = TaskScheduler(data, log_service, db_client, trace_id)
  28. result = await scheduler.deal()
  29. """
  30. def __init__(
  31. self,
  32. data: dict,
  33. log_service: LogService,
  34. db_client: DatabaseManager,
  35. trace_id: str,
  36. config: GlobalConfigSettings,
  37. ):
  38. super().__init__(data, log_service, db_client, trace_id, config)
  39. self.table = TaskUtils.validate_table_name(TaskConstants.TASK_TABLE)
  40. # ==================== 数据库操作 ====================
  41. async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None:
  42. """新建任务记录(若同键已存在则忽略)"""
  43. query = f"""
  44. INSERT IGNORE INTO {self.table}
  45. (date_string, task_name, start_timestamp, task_status, trace_id, data)
  46. VALUES (%s, %s, %s, %s, %s, %s)
  47. """
  48. await self.db_client.async_save(
  49. query=query,
  50. params=(
  51. date_str,
  52. task_name,
  53. int(time.time()),
  54. TaskStatus.INIT,
  55. self.trace_id,
  56. json.dumps(self.data, ensure_ascii=False),
  57. ),
  58. )
  59. async def _try_lock_task(self) -> bool:
  60. """
  61. 尝试获取任务锁(CAS 操作)
  62. 返回 True 表示成功获取锁
  63. """
  64. query = f"""
  65. UPDATE {self.table}
  66. SET task_status = %s
  67. WHERE trace_id = %s AND task_status = %s
  68. """
  69. result = await self.db_client.async_save(
  70. query=query,
  71. params=(TaskStatus.PROCESSING, self.trace_id, TaskStatus.INIT),
  72. )
  73. return bool(result)
  74. async def _release_task(self, status: int) -> None:
  75. """释放任务锁并更新状态"""
  76. query = f"""
  77. UPDATE {self.table}
  78. SET task_status = %s, finish_timestamp = %s
  79. WHERE trace_id = %s AND task_status = %s
  80. """
  81. await self.db_client.async_save(
  82. query=query,
  83. params=(
  84. status,
  85. int(time.time()),
  86. self.trace_id,
  87. TaskStatus.PROCESSING,
  88. ),
  89. )
  90. async def _get_processing_tasks(self, task_name: str) -> List[Dict[str, Any]]:
  91. """获取正在处理中的任务列表"""
  92. query = f"""
  93. SELECT trace_id, start_timestamp, data
  94. FROM {self.table}
  95. WHERE task_status = %s AND task_name = %s
  96. """
  97. rows = await self.db_client.async_fetch(
  98. query=query,
  99. params=(TaskStatus.PROCESSING, task_name),
  100. )
  101. return rows or []
  102. # ==================== 任务检查 ====================
  103. async def _check_task_concurrency_and_timeout(self, task_name: str) -> None:
  104. """
  105. 检查任务并发数和超时情况
  106. 优化点:
  107. 1. 真正检查任务是否超时(基于时间)
  108. 2. 分别处理超时和并发限制
  109. 3. 可选择自动释放超时任务
  110. Raises:
  111. TaskTimeoutError: 发现超时任务
  112. TaskConcurrencyError: 超过并发限制
  113. """
  114. processing_tasks = await self._get_processing_tasks(task_name)
  115. if not processing_tasks:
  116. return
  117. config = get_task_config(task_name)
  118. current_time = int(time.time())
  119. # 检查超时任务
  120. timeout_tasks = [
  121. task
  122. for task in processing_tasks
  123. if current_time - task["start_timestamp"] > config.timeout
  124. ]
  125. if timeout_tasks:
  126. await self._log_task_event(
  127. "task_timeout_detected",
  128. task_name=task_name,
  129. timeout_count=len(timeout_tasks),
  130. timeout_tasks=[t["trace_id"] for t in timeout_tasks],
  131. )
  132. await feishu_robot.bot(
  133. title=f"Task Timeout Alert: {task_name}",
  134. detail={
  135. "task_name": task_name,
  136. "timeout_count": len(timeout_tasks),
  137. "timeout_threshold": config.timeout,
  138. "timeout_tasks": [
  139. {
  140. "trace_id": t["trace_id"],
  141. "running_time": current_time - t["start_timestamp"],
  142. }
  143. for t in timeout_tasks
  144. ],
  145. },
  146. )
  147. # 可选:自动释放超时任务(需要谨慎使用)
  148. for task in timeout_tasks:
  149. await self._force_release_task(task["trace_id"], TaskStatus.FAILED)
  150. # 检查并发限制(排除超时任务)
  151. active_tasks = [
  152. task
  153. for task in processing_tasks
  154. if current_time - task["start_timestamp"] <= config.timeout
  155. ]
  156. if len(active_tasks) >= config.max_concurrent:
  157. await self._log_task_event(
  158. "task_concurrency_limit",
  159. task_name=task_name,
  160. current_count=len(active_tasks),
  161. max_concurrent=config.max_concurrent,
  162. )
  163. await feishu_robot.bot(
  164. title=f"Task Concurrency Limit: {task_name}",
  165. detail={
  166. "task_name": task_name,
  167. "current_count": len(active_tasks),
  168. "max_concurrent": config.max_concurrent,
  169. "active_tasks": [t["trace_id"] for t in active_tasks],
  170. },
  171. )
  172. raise TaskConcurrencyError(
  173. f"Task {task_name} has reached max concurrency limit "
  174. f"({len(active_tasks)}/{config.max_concurrent})",
  175. task_name=task_name,
  176. )
  177. # ==================== 任务执行 ====================
  178. async def _run_with_guard(
  179. self,
  180. task_name: str,
  181. date_str: str,
  182. task_handler,
  183. ) -> dict:
  184. """
  185. 带保护的任务执行
  186. 优化点:
  187. 1. 更好的错误处理和重试机制
  188. 2. 统一的日志记录
  189. 3. 详细的错误信息
  190. """
  191. # 1. 检查并发和超时
  192. try:
  193. await self._check_task_concurrency_and_timeout(task_name)
  194. except TaskConcurrencyError as e:
  195. return await task_schedule_response.fail_response("5005", str(e))
  196. # 2. 创建任务记录并尝试获取锁
  197. await self._insert_or_ignore_task(task_name, date_str)
  198. if not await self._try_lock_task():
  199. return await task_schedule_response.fail_response(
  200. "5001", "Task is already processing"
  201. )
  202. # 3. 后台执行任务
  203. async def _task_wrapper():
  204. """任务执行包装器 - 处理错误和重试"""
  205. status = TaskStatus.FAILED
  206. retry_count = 0
  207. config = get_task_config(task_name)
  208. start_time = time.time()
  209. try:
  210. await self._log_task_event("task_started", task_name=task_name)
  211. # 执行任务
  212. status = await task_handler()
  213. duration = time.time() - start_time
  214. await self._log_task_event(
  215. "task_completed",
  216. task_name=task_name,
  217. status=status,
  218. duration=duration,
  219. )
  220. except TaskError as e:
  221. # 已知的任务错误
  222. duration = time.time() - start_time
  223. error_detail = TaskUtils.format_error_detail(e)
  224. await self._log_task_event(
  225. "task_failed",
  226. task_name=task_name,
  227. error=error_detail,
  228. duration=duration,
  229. retry_count=retry_count,
  230. )
  231. # 根据错误类型决定是否告警
  232. if config.alert_on_failure:
  233. await feishu_robot.bot(
  234. title=f"Task Failed: {task_name}",
  235. detail={
  236. "task_name": task_name,
  237. "trace_id": self.trace_id,
  238. "error": error_detail,
  239. "duration": duration,
  240. "retryable": e.retryable,
  241. },
  242. )
  243. # TODO: 实现重试逻辑
  244. # if e.retryable and retry_count < config.retry_times:
  245. # await self._schedule_retry(task_name, retry_count + 1)
  246. except Exception as e:
  247. # 未知错误
  248. duration = time.time() - start_time
  249. error_detail = TaskUtils.format_error_detail(e)
  250. await self._log_task_event(
  251. "task_error",
  252. task_name=task_name,
  253. error=error_detail,
  254. duration=duration,
  255. )
  256. await feishu_robot.bot(
  257. title=f"Task Error: {task_name}",
  258. detail={
  259. "task_name": task_name,
  260. "trace_id": self.trace_id,
  261. "error": error_detail,
  262. "duration": duration,
  263. },
  264. )
  265. finally:
  266. await self._release_task(status)
  267. # 创建后台任务
  268. asyncio.create_task(_task_wrapper(), name=f"{task_name}_{self.trace_id}")
  269. return await task_schedule_response.success_response(
  270. task_name=task_name,
  271. data={
  272. "code": 0,
  273. "message": "Task started successfully",
  274. "trace_id": self.trace_id,
  275. },
  276. )
  277. # ==================== 任务管理接口 ====================
  278. async def get_task_status(
  279. self, trace_id: Optional[str] = None
  280. ) -> Optional[Dict[str, Any]]:
  281. """
  282. 查询任务状态
  283. Args:
  284. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  285. Returns:
  286. 任务信息字典,如果不存在返回 None
  287. """
  288. trace_id = trace_id or self.trace_id
  289. query = f"SELECT * FROM {self.table} WHERE trace_id = %s"
  290. result = await self.db_client.async_fetch_one(query, (trace_id,))
  291. return result
  292. async def cancel_task(self, trace_id: Optional[str] = None) -> bool:
  293. """
  294. 取消任务(将状态设置为失败)
  295. Args:
  296. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  297. Returns:
  298. 是否成功取消
  299. """
  300. trace_id = trace_id or self.trace_id
  301. query = f"""
  302. UPDATE {self.table}
  303. SET task_status = %s, finish_timestamp = %s
  304. WHERE trace_id = %s AND task_status IN (%s, %s)
  305. """
  306. result = await self.db_client.async_save(
  307. query,
  308. (
  309. TaskStatus.FAILED,
  310. int(time.time()),
  311. trace_id,
  312. TaskStatus.INIT,
  313. TaskStatus.PROCESSING,
  314. ),
  315. )
  316. if result:
  317. await self._log_task_event("task_cancelled", trace_id=trace_id)
  318. return bool(result)
  319. async def retry_task(self, trace_id: Optional[str] = None) -> bool:
  320. """
  321. 重试任务(将状态重置为初始化)
  322. Args:
  323. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  324. Returns:
  325. 是否成功重置
  326. """
  327. trace_id = trace_id or self.trace_id
  328. query = f"""
  329. UPDATE {self.table}
  330. SET task_status = %s, start_timestamp = %s, finish_timestamp = NULL
  331. WHERE trace_id = %s
  332. """
  333. result = await self.db_client.async_save(
  334. query,
  335. (TaskStatus.INIT, int(time.time()), trace_id),
  336. )
  337. if result:
  338. await self._log_task_event("task_retried", trace_id=trace_id)
  339. return bool(result)
  340. async def _force_release_task(self, trace_id: str, status: int) -> None:
  341. """强制释放任务(用于超时任务清理)"""
  342. query = f"""
  343. UPDATE {self.table}
  344. SET task_status = %s, finish_timestamp = %s
  345. WHERE trace_id = %s
  346. """
  347. await self.db_client.async_save(
  348. query,
  349. (status, int(time.time()), trace_id),
  350. )
  351. await self._log_task_event(
  352. "task_force_released", trace_id=trace_id, status=status
  353. )
  354. # ==================== 主入口 ====================
  355. async def deal(self) -> dict:
  356. """
  357. 任务调度主入口
  358. Returns:
  359. 调度结果字典
  360. """
  361. # 验证任务名
  362. task_name = self.data.get("task_name")
  363. if not task_name:
  364. return await task_schedule_response.fail_response(
  365. "4003", "task_name is required"
  366. )
  367. try:
  368. task_name = TaskUtils.validate_task_name(task_name)
  369. except TaskValidationError as e:
  370. return await task_schedule_response.fail_response("4003", str(e))
  371. # 获取日期
  372. date_str = self.data.get("date_string") or (
  373. datetime.utcnow() + timedelta(hours=8)
  374. ).strftime("%Y-%m-%d")
  375. # 获取任务处理器
  376. handler = self.get_handler(task_name)
  377. if not handler:
  378. return await task_schedule_response.fail_response(
  379. "4001",
  380. f"Unknown task: {task_name}. "
  381. f"Available tasks: {', '.join(self.list_registered_tasks())}",
  382. )
  383. # 执行任务
  384. return await self._run_with_guard(
  385. task_name,
  386. date_str,
  387. lambda: handler(self),
  388. )
  389. __all__ = ["TaskScheduler"]