task_scheduler_v2.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import asyncio
  2. import time
  3. from datetime import datetime
  4. from typing import Awaitable, Callable, Dict
  5. from applications.api import feishu_robot
  6. from applications.utils import task_schedule_response, generate_task_trace_id
  7. from applications.tasks.cold_start_tasks import ArticlePoolColdStart
  8. from applications.tasks.crawler_tasks import CrawlerToutiao
  9. from applications.tasks.data_recycle_tasks import CheckDailyPublishArticlesTask
  10. from applications.tasks.data_recycle_tasks import RecycleDailyPublishArticlesTask
  11. from applications.tasks.data_recycle_tasks import UpdateRootSourceIdAndUpdateTimeTask
  12. from applications.tasks.llm_tasks import TitleRewrite
  13. from applications.tasks.monitor_tasks import check_kimi_balance
  14. from applications.tasks.monitor_tasks import GetOffVideos
  15. from applications.tasks.monitor_tasks import CheckVideoAuditStatus
  16. from applications.tasks.monitor_tasks import InnerGzhArticlesMonitor
  17. from applications.tasks.monitor_tasks import OutsideGzhArticlesMonitor
  18. from applications.tasks.monitor_tasks import OutsideGzhArticlesCollector
  19. from applications.tasks.monitor_tasks import TaskProcessingMonitor
  20. from applications.tasks.task_mapper import TaskMapper
  21. class TaskScheduler(TaskMapper):
  22. """统一调度入口:外部只需调用 `await TaskScheduler(data, log_cli, db_cli).deal()`"""
  23. # ---------- 初始化 ----------
  24. def __init__(self, data, log_service, db_client):
  25. self.data = data
  26. self.log_client = log_service
  27. self.db_client = db_client
  28. self.table = "long_articles_task_manager"
  29. self.trace_id = generate_task_trace_id()
  30. # ---------- 公共数据库工具 ----------
  31. async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None:
  32. """新建记录(若同键已存在则忽略)"""
  33. query = (
  34. f"insert ignore into {self.table} "
  35. "(date_string, task_name, start_timestamp, task_status, trace_id) "
  36. "values (%s, %s, %s, %s, %s);"
  37. )
  38. await self.db_client.async_save(
  39. query=query,
  40. params=(
  41. date_str,
  42. task_name,
  43. int(time.time()),
  44. self.TASK_INIT_STATUS,
  45. self.trace_id,
  46. ),
  47. )
  48. async def _try_lock_task(self, task_name: str, date_str: str) -> bool:
  49. """一次 UPDATE 抢锁;返回 True 表示成功上锁"""
  50. query = (
  51. f"update {self.table} "
  52. "set task_status = %s "
  53. "where task_name = %s and date_string = %s and task_status = %s;"
  54. )
  55. res = await self.db_client.async_save(
  56. query=query,
  57. params=(
  58. self.TASK_PROCESSING_STATUS,
  59. task_name,
  60. date_str,
  61. self.TASK_INIT_STATUS,
  62. ),
  63. )
  64. return True if res else False
  65. async def _release_task(self, task_name: str, date_str: str, status: int) -> None:
  66. query = (
  67. f"update {self.table} set task_status=%s, finish_timestamp=%s "
  68. "where task_name=%s and date_string=%s and task_status=%s;"
  69. )
  70. await self.db_client.async_save(
  71. query=query,
  72. params=(
  73. status,
  74. int(time.time()),
  75. task_name,
  76. date_str,
  77. self.TASK_PROCESSING_STATUS,
  78. ),
  79. )
  80. async def _is_processing_overtime(self, task_name: str) -> bool:
  81. """检测是否已有同名任务在执行且超时。若超时会发飞书告警"""
  82. query = f"select start_timestamp from {self.table} where task_name=%s and task_status=%s"
  83. rows = await self.db_client.async_fetch(
  84. query=query, params=(task_name, self.TASK_PROCESSING_STATUS)
  85. )
  86. if not rows:
  87. return False
  88. start_ts = rows[0]["start_timestamp"]
  89. if int(time.time()) - start_ts >= self.get_task_config(task_name).get(
  90. "expire_duration", self.DEFAULT_TIMEOUT
  91. ):
  92. await feishu_robot.bot(
  93. title=f"{task_name} is overtime",
  94. detail={"start_ts": start_ts},
  95. )
  96. return True
  97. async def _run_with_guard(
  98. self, task_name: str, date_str: str, task_coro: Callable[[], Awaitable[int]]
  99. ):
  100. """公共:检查、建记录、抢锁、后台运行"""
  101. # 1. 超时检测(若有正在执行的同名任务则拒绝)
  102. if await self._is_processing_overtime(task_name):
  103. return await task_schedule_response.fail_response(
  104. "5001", "task is processing"
  105. )
  106. # 2. 记录并尝试抢锁
  107. await self._insert_or_ignore_task(task_name, date_str)
  108. if not await self._try_lock_task(task_name, date_str):
  109. return await task_schedule_response.fail_response(
  110. "5001", "task is processing"
  111. )
  112. # 3. 真正执行任务 —— 使用后台协程保证不阻塞调度入口
  113. async def _wrapper():
  114. status = self.TASK_FAILED_STATUS
  115. try:
  116. status = (
  117. await task_coro()
  118. ) # 你的任务函数需返回 TASK_SUCCESS_STATUS / FAILED_STATUS
  119. except Exception as e:
  120. await self.log_client.log(
  121. contents={
  122. "trace_id": self.trace_id,
  123. "task": task_name,
  124. "err": str(e),
  125. }
  126. )
  127. await feishu_robot.bot(
  128. title=f"{task_name} is failed",
  129. detail={"task": task_name, "err": str(e)},
  130. )
  131. finally:
  132. await self._release_task(task_name, date_str, status)
  133. asyncio.create_task(_wrapper(), name=task_name)
  134. return await task_schedule_response.success_response(
  135. task_name=task_name, data={"code": 0, "message": "task started"}
  136. )
  137. # ---------- 主入口 ----------
  138. async def deal(self):
  139. task_name: str | None = self.data.get("task_name")
  140. if not task_name:
  141. return await task_schedule_response.fail_response(
  142. "4002", "task_name must be input"
  143. )
  144. date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d")
  145. # === 所有任务在此注册:映射到一个返回 int 状态码的异步函数 ===
  146. handlers: Dict[str, Callable[[], Awaitable[int]]] = {
  147. "check_kimi_balance": lambda: check_kimi_balance(),
  148. "get_off_videos": self._get_off_videos_task,
  149. "check_publish_video_audit_status": self._check_video_audit_status,
  150. "task_processing_monitor": self._task_processing_monitor,
  151. "outside_article_monitor": self._outside_monitor_handler,
  152. "inner_article_monitor": self._inner_gzh_articles_monitor,
  153. "title_rewrite": self._title_rewrite,
  154. "daily_publish_articles_recycle": self._recycle_handler,
  155. "update_root_source_id": self._update_root_source_id,
  156. "crawler_toutiao_articles": self._crawler_toutiao_handler,
  157. "article_pool_pool_cold_start": self._article_pool_cold_start_handler,
  158. }
  159. if task_name not in handlers:
  160. return await task_schedule_response.fail_response(
  161. "4001", "wrong task name input"
  162. )
  163. return await self._run_with_guard(task_name, date_str, handlers[task_name])
  164. # ---------- 下面是若干复合任务的局部实现 ----------
  165. # 写成独立方法保持清爽
  166. async def _get_off_videos_task(self):
  167. sub_task = GetOffVideos(self.db_client, self.log_client, self.trace_id)
  168. return await sub_task.deal()
  169. async def _check_video_audit_status(self):
  170. sub_task = CheckVideoAuditStatus(self.db_client, self.log_client, self.trace_id)
  171. return await sub_task.deal()
  172. async def _task_processing_monitor(self):
  173. sub_task = TaskProcessingMonitor(self.db_client)
  174. return await sub_task.deal()
  175. async def _inner_gzh_articles_monitor(self):
  176. sub_task = InnerGzhArticlesMonitor(self.db_client)
  177. return await sub_task.deal()
  178. async def _title_rewrite(self):
  179. sub_task = TitleRewrite(self.db_client, self.log_client)
  180. return await sub_task.deal()
  181. async def _update_root_source_id(self) -> int:
  182. sub_task = UpdateRootSourceIdAndUpdateTimeTask(self.db_client, self.log_client)
  183. return await sub_task.deal()
  184. async def _outside_monitor_handler(self) -> int:
  185. collector = OutsideGzhArticlesCollector(self.db_client)
  186. await collector.deal()
  187. monitor = OutsideGzhArticlesMonitor(self.db_client)
  188. return await monitor.deal() # 应返回 SUCCESS / FAILED
  189. async def _recycle_handler(self) -> int:
  190. date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d")
  191. recycle = RecycleDailyPublishArticlesTask(
  192. self.db_client, self.log_client, date_str
  193. )
  194. await recycle.deal()
  195. check = CheckDailyPublishArticlesTask(self.db_client, self.log_client, date_str)
  196. return await check.deal()
  197. async def _crawler_toutiao_handler(self) -> int:
  198. sub_task = CrawlerToutiao(self.db_client, self.log_client, self.trace_id)
  199. media_type = self.data.get("media_type", "article")
  200. method = self.data.get("method", "account")
  201. category_list = self.data.get("category_list", [])
  202. if method == "account":
  203. await sub_task.crawler_task(media_type=media_type)
  204. elif method == "recommend":
  205. await sub_task.crawl_toutiao_recommend_task(category_list)
  206. else:
  207. raise ValueError(f"Unsupported method {method}")
  208. return self.TASK_SUCCESS_STATUS
  209. async def _article_pool_cold_start_handler(self) -> int:
  210. cold_start = ArticlePoolColdStart(
  211. self.db_client, self.log_client, self.trace_id
  212. )
  213. platform = self.data.get("platform", "weixin")
  214. crawler_methods = self.data.get("crawler_methods", [])
  215. await cold_start.deal(platform=platform, crawl_methods=crawler_methods)
  216. return self.TASK_SUCCESS_STATUS