task_scheduler.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import asyncio
  2. import json
  3. import time
  4. import traceback
  5. from datetime import datetime
  6. from typing import Awaitable, Callable, Dict
  7. from applications.api import feishu_robot
  8. from applications.utils import task_schedule_response
  9. from applications.tasks.task_handler import TaskHandler
  10. class TaskScheduler(TaskHandler):
  11. """统一调度入口:外部只需调用 `await TaskScheduler(data, log_cli, db_cli).deal()`"""
  12. # ---------- 初始化 ----------
  13. def __init__(self, task_manager, data, log_service, db_client, trace_id):
  14. super().__init__(data, log_service, db_client, trace_id)
  15. self.data = data
  16. self.log_client = log_service
  17. self.db_client = db_client
  18. self.table = "long_articles_task_manager"
  19. self.trace_id = trace_id
  20. self.task_manager = task_manager
  21. # ---------- 公共数据库工具 ----------
  22. async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None:
  23. """新建记录(若同键已存在则忽略)"""
  24. query = (
  25. f"insert ignore into {self.table} "
  26. "(date_string, task_name, start_timestamp, task_status, trace_id, data) "
  27. "values (%s, %s, %s, %s, %s, %s);"
  28. )
  29. await self.db_client.async_save(
  30. query=query,
  31. params=(
  32. date_str,
  33. task_name,
  34. int(time.time()),
  35. self.TASK_INIT_STATUS,
  36. self.trace_id,
  37. json.dumps(self.data, ensure_ascii=False),
  38. ),
  39. )
  40. async def _try_lock_task(self) -> bool:
  41. """一次 UPDATE 抢锁;返回 True 表示成功上锁"""
  42. query = (
  43. f"update {self.table} "
  44. "set task_status = %s "
  45. "where trace_id = %s and task_status = %s;"
  46. )
  47. res = await self.db_client.async_save(
  48. query=query,
  49. params=(
  50. self.TASK_PROCESSING_STATUS,
  51. self.trace_id,
  52. self.TASK_INIT_STATUS,
  53. ),
  54. )
  55. return True if res else False
  56. async def _release_task(self, status: int) -> None:
  57. query = (
  58. f"update {self.table} set task_status=%s, finish_timestamp=%s "
  59. "where trace_id=%s and task_status=%s;"
  60. )
  61. await self.db_client.async_save(
  62. query=query,
  63. params=(
  64. status,
  65. int(time.time()),
  66. self.trace_id,
  67. self.TASK_PROCESSING_STATUS,
  68. ),
  69. )
  70. async def _is_processing_overtime(self, task_name) -> bool:
  71. """检测在处理任务是否超时,或者超过最大并行数,若超时会发飞书告警"""
  72. query = f"select trace_id from {self.table} where task_status = %s and task_name = %s;"
  73. rows = await self.db_client.async_fetch(
  74. query=query, params=(self.TASK_PROCESSING_STATUS, task_name)
  75. )
  76. if not rows:
  77. return False
  78. processing_task_num = len(rows)
  79. if processing_task_num >= self.get_task_config(task_name).get(
  80. "task_max_num", self.TASK_MAX_NUM
  81. ):
  82. await feishu_robot.bot(
  83. title=f"multi {task_name} is processing ",
  84. detail={"detail": rows},
  85. )
  86. return True
  87. return False
  88. async def _run_with_guard(
  89. self, task_name: str, date_str: str, task_coro: Callable[[], Awaitable[int]]
  90. ):
  91. """公共:检查、建记录、抢锁、后台运行"""
  92. # 1. 超时检测
  93. if await self._is_processing_overtime(task_name):
  94. return await task_schedule_response.fail_response(
  95. "5005", "muti tasks with same task_name is processing"
  96. )
  97. # 2. 记录并尝试抢锁
  98. await self._insert_or_ignore_task(task_name, date_str)
  99. if not await self._try_lock_task():
  100. return await task_schedule_response.fail_response(
  101. "5001", "task is processing"
  102. )
  103. # 3. 真正执行任务 —— 使用后台协程保证不阻塞调度入口
  104. async def _wrapper():
  105. status = self.TASK_FAILED_STATUS
  106. try:
  107. status = await task_coro()
  108. except Exception as e:
  109. await self.log_client.log(
  110. contents={
  111. "trace_id": self.trace_id,
  112. "function": "cor_wrapper",
  113. "task": task_name,
  114. "error": str(e),
  115. }
  116. )
  117. await feishu_robot.bot(
  118. title=f"{task_name} is failed",
  119. detail={
  120. "task": task_name,
  121. "err": str(e),
  122. "traceback": traceback.format_exc(),
  123. },
  124. )
  125. finally:
  126. await self._release_task(status)
  127. task: asyncio.Task = asyncio.create_task(_wrapper(), name=task_name)
  128. self.task_manager.register(task_id=self.trace_id, task=task, task_name=task_name)
  129. return await task_schedule_response.success_response(
  130. task_name=task_name,
  131. data={"code": 0, "message": "task started", "trace_id": self.trace_id},
  132. )
  133. # ---------- 主入口 ----------
  134. async def deal(self):
  135. task_name: str | None = self.data.get("task_name")
  136. if not task_name:
  137. return await task_schedule_response.fail_response(
  138. "4003", "task_name must be input"
  139. )
  140. date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d")
  141. # === 所有任务在此注册:映射到一个返回 int 状态码的异步函数 ===
  142. handlers: Dict[str, Callable[[], Awaitable[int]]] = {
  143. # 校验kimi余额
  144. "check_kimi_balance": self._check_kimi_balance_handler,
  145. # 长文视频发布之后,三天后下架
  146. "get_off_videos": self._get_off_videos_task_handler,
  147. # 长文视频发布之后,三天内保持视频可见状态
  148. "check_publish_video_audit_status": self._check_video_audit_status_handler,
  149. # 外部服务号发文监测
  150. "outside_article_monitor": self._outside_monitor_handler,
  151. # 站内发文监测
  152. "inner_article_monitor": self._inner_gzh_articles_monitor_handler,
  153. # 标题重写(代测试)
  154. "title_rewrite": self._title_rewrite_handler,
  155. # 每日发文数据回收
  156. "daily_publish_articles_recycle": self._recycle_article_data_handler,
  157. # 每日发文更新root_source_id
  158. "update_root_source_id": self._update_root_source_id_handler,
  159. # 头条文章,视频抓取
  160. "crawler_toutiao": self._crawler_toutiao_handler,
  161. # 文章池冷启动发布
  162. "article_pool_cold_start": self._article_pool_cold_start_handler,
  163. # 任务超时监控
  164. "task_processing_monitor": self._task_processing_monitor_handler,
  165. # 候选账号质量分析
  166. "candidate_account_quality_analysis": self._candidate_account_quality_score_handler,
  167. }
  168. if task_name not in handlers:
  169. return await task_schedule_response.fail_response(
  170. "4001", "wrong task name input"
  171. )
  172. return await self._run_with_guard(task_name, date_str, handlers[task_name])