task_scheduler.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. import asyncio
  2. import json
  3. import time
  4. from datetime import datetime, timedelta
  5. from typing import Optional, Dict, Any, List
  6. from app.infra.external import feishu_robot
  7. from app.infra.utils import task_schedule_response
  8. from app.jobs.task_handler import TaskHandler
  9. from app.jobs.task_config import (
  10. TaskStatus,
  11. TaskConstants,
  12. get_task_config,
  13. )
  14. from app.jobs.task_utils import (
  15. TaskError,
  16. TaskValidationError,
  17. TaskConcurrencyError,
  18. TaskUtils,
  19. )
  20. class TaskScheduler(TaskHandler):
  21. """
  22. 统一任务调度器
  23. 使用方法:
  24. scheduler = TaskScheduler(data, log_service, db_client, trace_id)
  25. result = await scheduler.deal()
  26. """
  27. def __init__(self, data: dict, log_service, db_client, trace_id: str, config):
  28. super().__init__(data, log_service, db_client, trace_id, config)
  29. self.table = TaskUtils.validate_table_name(TaskConstants.TASK_TABLE)
  30. # ==================== 数据库操作 ====================
  31. async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None:
  32. """新建任务记录(若同键已存在则忽略)"""
  33. query = f"""
  34. INSERT IGNORE INTO {self.table}
  35. (date_string, task_name, start_timestamp, task_status, trace_id, data)
  36. VALUES (%s, %s, %s, %s, %s, %s)
  37. """
  38. await self.db_client.async_save(
  39. query=query,
  40. params=(
  41. date_str,
  42. task_name,
  43. int(time.time()),
  44. TaskStatus.INIT,
  45. self.trace_id,
  46. json.dumps(self.data, ensure_ascii=False),
  47. ),
  48. )
  49. async def _try_lock_task(self) -> bool:
  50. """
  51. 尝试获取任务锁(CAS 操作)
  52. 返回 True 表示成功获取锁
  53. """
  54. query = f"""
  55. UPDATE {self.table}
  56. SET task_status = %s
  57. WHERE trace_id = %s AND task_status = %s
  58. """
  59. result = await self.db_client.async_save(
  60. query=query,
  61. params=(TaskStatus.PROCESSING, self.trace_id, TaskStatus.INIT),
  62. )
  63. return bool(result)
  64. async def _release_task(self, status: int) -> None:
  65. """释放任务锁并更新状态"""
  66. query = f"""
  67. UPDATE {self.table}
  68. SET task_status = %s, finish_timestamp = %s
  69. WHERE trace_id = %s AND task_status = %s
  70. """
  71. await self.db_client.async_save(
  72. query=query,
  73. params=(
  74. status,
  75. int(time.time()),
  76. self.trace_id,
  77. TaskStatus.PROCESSING,
  78. ),
  79. )
  80. async def _get_processing_tasks(self, task_name: str) -> List[Dict[str, Any]]:
  81. """获取正在处理中的任务列表"""
  82. query = f"""
  83. SELECT trace_id, start_timestamp, data
  84. FROM {self.table}
  85. WHERE task_status = %s AND task_name = %s
  86. """
  87. rows = await self.db_client.async_fetch(
  88. query=query,
  89. params=(TaskStatus.PROCESSING, task_name),
  90. )
  91. return rows or []
  92. # ==================== 任务检查 ====================
  93. async def _check_task_concurrency_and_timeout(self, task_name: str) -> None:
  94. """
  95. 检查任务并发数和超时情况
  96. 优化点:
  97. 1. 真正检查任务是否超时(基于时间)
  98. 2. 分别处理超时和并发限制
  99. 3. 可选择自动释放超时任务
  100. Raises:
  101. TaskTimeoutError: 发现超时任务
  102. TaskConcurrencyError: 超过并发限制
  103. """
  104. processing_tasks = await self._get_processing_tasks(task_name)
  105. if not processing_tasks:
  106. return
  107. config = get_task_config(task_name)
  108. current_time = int(time.time())
  109. # 检查超时任务
  110. timeout_tasks = [
  111. task
  112. for task in processing_tasks
  113. if current_time - task["start_timestamp"] > config.timeout
  114. ]
  115. if timeout_tasks:
  116. await self._log_task_event(
  117. "task_timeout_detected",
  118. task_name=task_name,
  119. timeout_count=len(timeout_tasks),
  120. timeout_tasks=[t["trace_id"] for t in timeout_tasks],
  121. )
  122. await feishu_robot.bot(
  123. title=f"Task Timeout Alert: {task_name}",
  124. detail={
  125. "task_name": task_name,
  126. "timeout_count": len(timeout_tasks),
  127. "timeout_threshold": config.timeout,
  128. "timeout_tasks": [
  129. {
  130. "trace_id": t["trace_id"],
  131. "running_time": current_time - t["start_timestamp"],
  132. }
  133. for t in timeout_tasks
  134. ],
  135. },
  136. )
  137. # 可选:自动释放超时任务(需要谨慎使用)
  138. for task in timeout_tasks:
  139. await self._force_release_task(task["trace_id"], TaskStatus.FAILED)
  140. # 检查并发限制(排除超时任务)
  141. active_tasks = [
  142. task
  143. for task in processing_tasks
  144. if current_time - task["start_timestamp"] <= config.timeout
  145. ]
  146. if len(active_tasks) >= config.max_concurrent:
  147. await self._log_task_event(
  148. "task_concurrency_limit",
  149. task_name=task_name,
  150. current_count=len(active_tasks),
  151. max_concurrent=config.max_concurrent,
  152. )
  153. await feishu_robot.bot(
  154. title=f"Task Concurrency Limit: {task_name}",
  155. detail={
  156. "task_name": task_name,
  157. "current_count": len(active_tasks),
  158. "max_concurrent": config.max_concurrent,
  159. "active_tasks": [t["trace_id"] for t in active_tasks],
  160. },
  161. )
  162. raise TaskConcurrencyError(
  163. f"Task {task_name} has reached max concurrency limit "
  164. f"({len(active_tasks)}/{config.max_concurrent})",
  165. task_name=task_name,
  166. )
  167. # ==================== 任务执行 ====================
  168. async def _run_with_guard(
  169. self,
  170. task_name: str,
  171. date_str: str,
  172. task_handler,
  173. ) -> dict:
  174. """
  175. 带保护的任务执行
  176. 优化点:
  177. 1. 更好的错误处理和重试机制
  178. 2. 统一的日志记录
  179. 3. 详细的错误信息
  180. """
  181. # 1. 检查并发和超时
  182. try:
  183. await self._check_task_concurrency_and_timeout(task_name)
  184. except TaskConcurrencyError as e:
  185. return await task_schedule_response.fail_response("5005", str(e))
  186. # 2. 创建任务记录并尝试获取锁
  187. await self._insert_or_ignore_task(task_name, date_str)
  188. if not await self._try_lock_task():
  189. return await task_schedule_response.fail_response(
  190. "5001", "Task is already processing"
  191. )
  192. # 3. 后台执行任务
  193. async def _task_wrapper():
  194. """任务执行包装器 - 处理错误和重试"""
  195. status = TaskStatus.FAILED
  196. retry_count = 0
  197. config = get_task_config(task_name)
  198. start_time = time.time()
  199. try:
  200. await self._log_task_event("task_started", task_name=task_name)
  201. # 执行任务
  202. status = await task_handler()
  203. duration = time.time() - start_time
  204. await self._log_task_event(
  205. "task_completed",
  206. task_name=task_name,
  207. status=status,
  208. duration=duration,
  209. )
  210. except TaskError as e:
  211. # 已知的任务错误
  212. duration = time.time() - start_time
  213. error_detail = TaskUtils.format_error_detail(e)
  214. await self._log_task_event(
  215. "task_failed",
  216. task_name=task_name,
  217. error=error_detail,
  218. duration=duration,
  219. retry_count=retry_count,
  220. )
  221. # 根据错误类型决定是否告警
  222. if config.alert_on_failure:
  223. await feishu_robot.bot(
  224. title=f"Task Failed: {task_name}",
  225. detail={
  226. "task_name": task_name,
  227. "trace_id": self.trace_id,
  228. "error": error_detail,
  229. "duration": duration,
  230. "retryable": e.retryable,
  231. },
  232. )
  233. # TODO: 实现重试逻辑
  234. # if e.retryable and retry_count < config.retry_times:
  235. # await self._schedule_retry(task_name, retry_count + 1)
  236. except Exception as e:
  237. # 未知错误
  238. duration = time.time() - start_time
  239. error_detail = TaskUtils.format_error_detail(e)
  240. await self._log_task_event(
  241. "task_error",
  242. task_name=task_name,
  243. error=error_detail,
  244. duration=duration,
  245. )
  246. await feishu_robot.bot(
  247. title=f"Task Error: {task_name}",
  248. detail={
  249. "task_name": task_name,
  250. "trace_id": self.trace_id,
  251. "error": error_detail,
  252. "duration": duration,
  253. },
  254. )
  255. finally:
  256. await self._release_task(status)
  257. # 创建后台任务
  258. asyncio.create_task(_task_wrapper(), name=f"{task_name}_{self.trace_id}")
  259. return await task_schedule_response.success_response(
  260. task_name=task_name,
  261. data={
  262. "code": 0,
  263. "message": "Task started successfully",
  264. "trace_id": self.trace_id,
  265. },
  266. )
  267. # ==================== 任务管理接口 ====================
  268. async def get_task_status(
  269. self, trace_id: Optional[str] = None
  270. ) -> Optional[Dict[str, Any]]:
  271. """
  272. 查询任务状态
  273. Args:
  274. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  275. Returns:
  276. 任务信息字典,如果不存在返回 None
  277. """
  278. trace_id = trace_id or self.trace_id
  279. query = f"SELECT * FROM {self.table} WHERE trace_id = %s"
  280. result = await self.db_client.async_fetch_one(query, (trace_id,))
  281. return result
  282. async def cancel_task(self, trace_id: Optional[str] = None) -> bool:
  283. """
  284. 取消任务(将状态设置为失败)
  285. Args:
  286. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  287. Returns:
  288. 是否成功取消
  289. """
  290. trace_id = trace_id or self.trace_id
  291. query = f"""
  292. UPDATE {self.table}
  293. SET task_status = %s, finish_timestamp = %s
  294. WHERE trace_id = %s AND task_status IN (%s, %s)
  295. """
  296. result = await self.db_client.async_save(
  297. query,
  298. (
  299. TaskStatus.FAILED,
  300. int(time.time()),
  301. trace_id,
  302. TaskStatus.INIT,
  303. TaskStatus.PROCESSING,
  304. ),
  305. )
  306. if result:
  307. await self._log_task_event("task_cancelled", trace_id=trace_id)
  308. return bool(result)
  309. async def retry_task(self, trace_id: Optional[str] = None) -> bool:
  310. """
  311. 重试任务(将状态重置为初始化)
  312. Args:
  313. trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
  314. Returns:
  315. 是否成功重置
  316. """
  317. trace_id = trace_id or self.trace_id
  318. query = f"""
  319. UPDATE {self.table}
  320. SET task_status = %s, start_timestamp = %s, finish_timestamp = NULL
  321. WHERE trace_id = %s
  322. """
  323. result = await self.db_client.async_save(
  324. query,
  325. (TaskStatus.INIT, int(time.time()), trace_id),
  326. )
  327. if result:
  328. await self._log_task_event("task_retried", trace_id=trace_id)
  329. return bool(result)
  330. async def _force_release_task(self, trace_id: str, status: int) -> None:
  331. """强制释放任务(用于超时任务清理)"""
  332. query = f"""
  333. UPDATE {self.table}
  334. SET task_status = %s, finish_timestamp = %s
  335. WHERE trace_id = %s
  336. """
  337. await self.db_client.async_save(
  338. query,
  339. (status, int(time.time()), trace_id),
  340. )
  341. await self._log_task_event(
  342. "task_force_released", trace_id=trace_id, status=status
  343. )
  344. # ==================== 主入口 ====================
  345. async def deal(self) -> dict:
  346. """
  347. 任务调度主入口
  348. Returns:
  349. 调度结果字典
  350. """
  351. # 验证任务名
  352. task_name = self.data.get("task_name")
  353. if not task_name:
  354. return await task_schedule_response.fail_response(
  355. "4003", "task_name is required"
  356. )
  357. try:
  358. task_name = TaskUtils.validate_task_name(task_name)
  359. except TaskValidationError as e:
  360. return await task_schedule_response.fail_response("4003", str(e))
  361. # 获取日期
  362. date_str = self.data.get("date_string") or (
  363. datetime.utcnow() + timedelta(hours=8)
  364. ).strftime("%Y-%m-%d")
  365. # 获取任务处理器
  366. handler = self.get_handler(task_name)
  367. if not handler:
  368. return await task_schedule_response.fail_response(
  369. "4001",
  370. f"Unknown task: {task_name}. "
  371. f"Available tasks: {', '.join(self.list_registered_tasks())}",
  372. )
  373. # 执行任务
  374. return await self._run_with_guard(
  375. task_name,
  376. date_str,
  377. lambda: handler(self),
  378. )
  379. __all__ = ["TaskScheduler"]