task_scheduler_v2.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. import asyncio
  2. import json
  3. import time
  4. import traceback
  5. from datetime import datetime
  6. from typing import Awaitable, Callable, Dict
  7. from applications.api import feishu_robot
  8. from applications.utils import task_schedule_response, generate_task_trace_id
  9. from applications.tasks.cold_start_tasks import ArticlePoolColdStart
  10. from applications.tasks.crawler_tasks import CrawlerToutiao
  11. from applications.tasks.data_recycle_tasks import CheckDailyPublishArticlesTask
  12. from applications.tasks.data_recycle_tasks import RecycleDailyPublishArticlesTask
  13. from applications.tasks.data_recycle_tasks import UpdateRootSourceIdAndUpdateTimeTask
  14. from applications.tasks.llm_tasks import TitleRewrite
  15. from applications.tasks.monitor_tasks import check_kimi_balance
  16. from applications.tasks.monitor_tasks import GetOffVideos
  17. from applications.tasks.monitor_tasks import CheckVideoAuditStatus
  18. from applications.tasks.monitor_tasks import InnerGzhArticlesMonitor
  19. from applications.tasks.monitor_tasks import OutsideGzhArticlesMonitor
  20. from applications.tasks.monitor_tasks import OutsideGzhArticlesCollector
  21. from applications.tasks.monitor_tasks import TaskProcessingMonitor
  22. from applications.tasks.task_mapper import TaskMapper
  23. class TaskScheduler(TaskMapper):
  24. """统一调度入口:外部只需调用 `await TaskScheduler(data, log_cli, db_cli).deal()`"""
  25. # ---------- 初始化 ----------
  26. def __init__(self, data, log_service, db_client):
  27. self.data = data
  28. self.log_client = log_service
  29. self.db_client = db_client
  30. self.table = "long_articles_task_manager"
  31. self.trace_id = generate_task_trace_id()
  32. # ---------- 公共数据库工具 ----------
  33. async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None:
  34. """新建记录(若同键已存在则忽略)"""
  35. query = (
  36. f"insert ignore into {self.table} "
  37. "(date_string, task_name, start_timestamp, task_status, trace_id, data) "
  38. "values (%s, %s, %s, %s, %s, %s);"
  39. )
  40. await self.db_client.async_save(
  41. query=query,
  42. params=(
  43. date_str,
  44. task_name,
  45. int(time.time()),
  46. self.TASK_INIT_STATUS,
  47. self.trace_id,
  48. json.dumps(self.data, ensure_ascii=False),
  49. ),
  50. )
  51. async def _try_lock_task(self) -> bool:
  52. """一次 UPDATE 抢锁;返回 True 表示成功上锁"""
  53. query = (
  54. f"update {self.table} "
  55. "set task_status = %s "
  56. "where trace_id = %s and task_status = %s;"
  57. )
  58. res = await self.db_client.async_save(
  59. query=query,
  60. params=(
  61. self.TASK_PROCESSING_STATUS,
  62. self.trace_id,
  63. self.TASK_INIT_STATUS,
  64. ),
  65. )
  66. return True if res else False
  67. async def _release_task(self, status: int) -> None:
  68. query = (
  69. f"update {self.table} set task_status=%s, finish_timestamp=%s "
  70. "where trace_d=%s and task_status=%s;"
  71. )
  72. await self.db_client.async_save(
  73. query=query,
  74. params=(
  75. status,
  76. int(time.time()),
  77. self.trace_id,
  78. self.TASK_PROCESSING_STATUS,
  79. ),
  80. )
  81. async def _is_processing_overtime(self, task_name) -> bool:
  82. """检测在处理任务是否超时,或者超过最大并行数,若超时会发飞书告警"""
  83. query = f"select trace_id from {self.table} where task_status = %s and task_name = %s;"
  84. rows = await self.db_client.async_fetch(
  85. query=query, params=(self.TASK_PROCESSING_STATUS, task_name)
  86. )
  87. if not rows:
  88. return False
  89. processing_task_num = len(rows)
  90. if processing_task_num >= self.get_task_config(task_name).get(
  91. "task_max_num", self.TASK_MAX_NUM
  92. ):
  93. await feishu_robot.bot(
  94. title=f"multi {task_name} is processing ",
  95. detail={"detail": rows},
  96. )
  97. return True
  98. return False
  99. async def _run_with_guard(
  100. self, task_name: str, date_str: str, task_coro: Callable[[], Awaitable[int]]
  101. ):
  102. """公共:检查、建记录、抢锁、后台运行"""
  103. # 1. 超时检测
  104. if await self._is_processing_overtime(task_name):
  105. return await task_schedule_response.fail_response(
  106. "5005", "muti tasks with same task_name is processing"
  107. )
  108. # 2. 记录并尝试抢锁
  109. await self._insert_or_ignore_task(task_name, date_str)
  110. if not await self._try_lock_task():
  111. return await task_schedule_response.fail_response(
  112. "5001", "task is processing"
  113. )
  114. # 3. 真正执行任务 —— 使用后台协程保证不阻塞调度入口
  115. async def _wrapper():
  116. status = self.TASK_FAILED_STATUS
  117. try:
  118. status = await task_coro()
  119. except Exception as e:
  120. await self.log_client.log(
  121. contents={
  122. "trace_id": self.trace_id,
  123. "function": "cor_wrapper",
  124. "task": task_name,
  125. "error": str(e),
  126. }
  127. )
  128. await feishu_robot.bot(
  129. title=f"{task_name} is failed",
  130. detail={
  131. "task": task_name,
  132. "err": str(e),
  133. "traceback": traceback.format_exc(),
  134. },
  135. )
  136. finally:
  137. await self._release_task(status)
  138. asyncio.create_task(_wrapper(), name=task_name)
  139. return await task_schedule_response.success_response(
  140. task_name=task_name,
  141. data={"code": 0, "message": "task started", "trace_id": self.trace_id},
  142. )
  143. # ---------- 主入口 ----------
  144. async def deal(self):
  145. task_name: str | None = self.data.get("task_name")
  146. if not task_name:
  147. return await task_schedule_response.fail_response(
  148. "4003", "task_name must be input"
  149. )
  150. date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d")
  151. # === 所有任务在此注册:映射到一个返回 int 状态码的异步函数 ===
  152. handlers: Dict[str, Callable[[], Awaitable[int]]] = {
  153. # 校验kimi余额
  154. "check_kimi_balance": self._check_kimi_balance_handler,
  155. # 长文视频发布之后,三天后下架
  156. "get_off_videos": self._get_off_videos_task_handler,
  157. # 长文视频发布之后,三天内保持视频可见状态
  158. "check_publish_video_audit_status": self._check_video_audit_status_handler,
  159. # 外部服务号发文监测
  160. "outside_article_monitor": self._outside_monitor_handler,
  161. # 站内发文监测
  162. "inner_article_monitor": self._inner_gzh_articles_monitor_handler,
  163. # 标题重写(代测试)
  164. "title_rewrite": self._title_rewrite_handler,
  165. # 每日发文数据回收
  166. "daily_publish_articles_recycle": self._recycle_article_data_handler,
  167. # 每日发文更新root_source_id
  168. "update_root_source_id": self._update_root_source_id_handler,
  169. # 头条文章,视频抓取
  170. "crawler_toutiao": self._crawler_toutiao_handler,
  171. # 文章池冷启动发布
  172. "article_pool_cold_start": self._article_pool_cold_start_handler,
  173. # 任务超时监控
  174. "task_processing_monitor": self._task_processing_monitor_handler,
  175. }
  176. if task_name not in handlers:
  177. return await task_schedule_response.fail_response(
  178. "4001", "wrong task name input"
  179. )
  180. return await self._run_with_guard(task_name, date_str, handlers[task_name])
  181. # ---------- 下面是若干复合任务的局部实现 ----------
  182. async def _check_kimi_balance_handler(self) -> int:
  183. response = await check_kimi_balance()
  184. await self.log_client.log(
  185. contents={
  186. "trace_id": self.trace_id,
  187. "task": "check_kimi_balance",
  188. "data": response,
  189. }
  190. )
  191. return self.TASK_SUCCESS_STATUS
  192. async def _get_off_videos_task_handler(self) -> int:
  193. sub_task = GetOffVideos(self.db_client, self.log_client, self.trace_id)
  194. return await sub_task.deal()
  195. async def _check_video_audit_status_handler(self) -> int:
  196. sub_task = CheckVideoAuditStatus(self.db_client, self.log_client, self.trace_id)
  197. return await sub_task.deal()
  198. async def _task_processing_monitor_handler(self) -> int:
  199. sub_task = TaskProcessingMonitor(self.db_client)
  200. await sub_task.deal()
  201. return self.TASK_SUCCESS_STATUS
  202. async def _inner_gzh_articles_monitor_handler(self) -> int:
  203. sub_task = InnerGzhArticlesMonitor(self.db_client)
  204. return await sub_task.deal()
  205. async def _title_rewrite_handler(self):
  206. sub_task = TitleRewrite(self.db_client, self.log_client)
  207. return await sub_task.deal()
  208. async def _update_root_source_id_handler(self) -> int:
  209. sub_task = UpdateRootSourceIdAndUpdateTimeTask(self.db_client, self.log_client)
  210. await sub_task.deal()
  211. return self.TASK_SUCCESS_STATUS
  212. async def _outside_monitor_handler(self) -> int:
  213. collector = OutsideGzhArticlesCollector(self.db_client)
  214. await collector.deal()
  215. monitor = OutsideGzhArticlesMonitor(self.db_client)
  216. return await monitor.deal() # 应返回 SUCCESS / FAILED
  217. async def _recycle_article_data_handler(self) -> int:
  218. date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d")
  219. recycle = RecycleDailyPublishArticlesTask(
  220. self.db_client, self.log_client, date_str
  221. )
  222. await recycle.deal()
  223. check = CheckDailyPublishArticlesTask(self.db_client, self.log_client, date_str)
  224. await check.deal()
  225. return self.TASK_SUCCESS_STATUS
  226. async def _crawler_toutiao_handler(self) -> int:
  227. sub_task = CrawlerToutiao(self.db_client, self.log_client, self.trace_id)
  228. method = self.data.get("method", "account")
  229. media_type = self.data.get("media_type", "article")
  230. category_list = self.data.get("category_list", [])
  231. match method:
  232. case "account":
  233. await sub_task.crawler_task(media_type=media_type)
  234. case "recommend":
  235. await sub_task.crawl_toutiao_recommend_task(category_list)
  236. case "search":
  237. await sub_task.search_candidate_accounts()
  238. case _:
  239. raise ValueError(f"Unsupported method {method}")
  240. return self.TASK_SUCCESS_STATUS
  241. async def _article_pool_cold_start_handler(self) -> int:
  242. cold_start = ArticlePoolColdStart(
  243. self.db_client, self.log_client, self.trace_id
  244. )
  245. platform = self.data.get("platform", "weixin")
  246. crawler_methods = self.data.get("crawler_methods", [])
  247. await cold_start.deal(platform=platform, crawl_methods=crawler_methods)
  248. return self.TASK_SUCCESS_STATUS