task_scheduler_v2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import asyncio
  2. import json
  3. import time
  4. import traceback
  5. from datetime import datetime
  6. from typing import Awaitable, Callable, Dict
  7. from applications.api import feishu_robot
  8. from applications.utils import task_schedule_response, generate_task_trace_id
  9. from applications.tasks.cold_start_tasks import ArticlePoolColdStart
  10. from applications.tasks.crawler_tasks import CrawlerToutiao
  11. from applications.tasks.data_recycle_tasks import RecycleDailyPublishArticlesTask
  12. from applications.tasks.data_recycle_tasks import CheckDailyPublishArticlesTask
  13. from applications.tasks.data_recycle_tasks import UpdateRootSourceIdAndUpdateTimeTask
  14. from applications.tasks.llm_tasks import TitleRewrite
  15. from applications.tasks.llm_tasks import CandidateAccountQualityScoreRecognizer
  16. from applications.tasks.monitor_tasks import check_kimi_balance
  17. from applications.tasks.monitor_tasks import GetOffVideos
  18. from applications.tasks.monitor_tasks import CheckVideoAuditStatus
  19. from applications.tasks.monitor_tasks import InnerGzhArticlesMonitor
  20. from applications.tasks.monitor_tasks import OutsideGzhArticlesMonitor
  21. from applications.tasks.monitor_tasks import OutsideGzhArticlesCollector
  22. from applications.tasks.monitor_tasks import TaskProcessingMonitor
  23. from applications.tasks.task_mapper import TaskMapper
  24. class TaskScheduler(TaskMapper):
  25. """统一调度入口:外部只需调用 `await TaskScheduler(data, log_cli, db_cli).deal()`"""
  26. # ---------- 初始化 ----------
  27. def __init__(self, data, log_service, db_client):
  28. self.data = data
  29. self.log_client = log_service
  30. self.db_client = db_client
  31. self.table = "long_articles_task_manager"
  32. self.trace_id = generate_task_trace_id()
  33. # ---------- 公共数据库工具 ----------
  34. async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None:
  35. """新建记录(若同键已存在则忽略)"""
  36. query = (
  37. f"insert ignore into {self.table} "
  38. "(date_string, task_name, start_timestamp, task_status, trace_id, data) "
  39. "values (%s, %s, %s, %s, %s, %s);"
  40. )
  41. await self.db_client.async_save(
  42. query=query,
  43. params=(
  44. date_str,
  45. task_name,
  46. int(time.time()),
  47. self.TASK_INIT_STATUS,
  48. self.trace_id,
  49. json.dumps(self.data, ensure_ascii=False),
  50. ),
  51. )
  52. async def _try_lock_task(self) -> bool:
  53. """一次 UPDATE 抢锁;返回 True 表示成功上锁"""
  54. query = (
  55. f"update {self.table} "
  56. "set task_status = %s "
  57. "where trace_id = %s and task_status = %s;"
  58. )
  59. res = await self.db_client.async_save(
  60. query=query,
  61. params=(
  62. self.TASK_PROCESSING_STATUS,
  63. self.trace_id,
  64. self.TASK_INIT_STATUS,
  65. ),
  66. )
  67. return True if res else False
  68. async def _release_task(self, status: int) -> None:
  69. query = (
  70. f"update {self.table} set task_status=%s, finish_timestamp=%s "
  71. "where trace_id=%s and task_status=%s;"
  72. )
  73. await self.db_client.async_save(
  74. query=query,
  75. params=(
  76. status,
  77. int(time.time()),
  78. self.trace_id,
  79. self.TASK_PROCESSING_STATUS,
  80. ),
  81. )
  82. async def _is_processing_overtime(self, task_name) -> bool:
  83. """检测在处理任务是否超时,或者超过最大并行数,若超时会发飞书告警"""
  84. query = f"select trace_id from {self.table} where task_status = %s and task_name = %s;"
  85. rows = await self.db_client.async_fetch(
  86. query=query, params=(self.TASK_PROCESSING_STATUS, task_name)
  87. )
  88. if not rows:
  89. return False
  90. processing_task_num = len(rows)
  91. if processing_task_num >= self.get_task_config(task_name).get(
  92. "task_max_num", self.TASK_MAX_NUM
  93. ):
  94. await feishu_robot.bot(
  95. title=f"multi {task_name} is processing ",
  96. detail={"detail": rows},
  97. )
  98. return True
  99. return False
  100. async def _run_with_guard(
  101. self, task_name: str, date_str: str, task_coro: Callable[[], Awaitable[int]]
  102. ):
  103. """公共:检查、建记录、抢锁、后台运行"""
  104. # 1. 超时检测
  105. if await self._is_processing_overtime(task_name):
  106. return await task_schedule_response.fail_response(
  107. "5005", "muti tasks with same task_name is processing"
  108. )
  109. # 2. 记录并尝试抢锁
  110. await self._insert_or_ignore_task(task_name, date_str)
  111. if not await self._try_lock_task():
  112. return await task_schedule_response.fail_response(
  113. "5001", "task is processing"
  114. )
  115. # 3. 真正执行任务 —— 使用后台协程保证不阻塞调度入口
  116. async def _wrapper():
  117. status = self.TASK_FAILED_STATUS
  118. try:
  119. status = await task_coro()
  120. except Exception as e:
  121. await self.log_client.log(
  122. contents={
  123. "trace_id": self.trace_id,
  124. "function": "cor_wrapper",
  125. "task": task_name,
  126. "error": str(e),
  127. }
  128. )
  129. await feishu_robot.bot(
  130. title=f"{task_name} is failed",
  131. detail={
  132. "task": task_name,
  133. "err": str(e),
  134. "traceback": traceback.format_exc(),
  135. },
  136. )
  137. finally:
  138. await self._release_task(status)
  139. asyncio.create_task(_wrapper(), name=task_name)
  140. return await task_schedule_response.success_response(
  141. task_name=task_name,
  142. data={"code": 0, "message": "task started", "trace_id": self.trace_id},
  143. )
  144. # ---------- 主入口 ----------
  145. async def deal(self):
  146. task_name: str | None = self.data.get("task_name")
  147. if not task_name:
  148. return await task_schedule_response.fail_response(
  149. "4003", "task_name must be input"
  150. )
  151. date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d")
  152. # === 所有任务在此注册:映射到一个返回 int 状态码的异步函数 ===
  153. handlers: Dict[str, Callable[[], Awaitable[int]]] = {
  154. # 校验kimi余额
  155. "check_kimi_balance": self._check_kimi_balance_handler,
  156. # 长文视频发布之后,三天后下架
  157. "get_off_videos": self._get_off_videos_task_handler,
  158. # 长文视频发布之后,三天内保持视频可见状态
  159. "check_publish_video_audit_status": self._check_video_audit_status_handler,
  160. # 外部服务号发文监测
  161. "outside_article_monitor": self._outside_monitor_handler,
  162. # 站内发文监测
  163. "inner_article_monitor": self._inner_gzh_articles_monitor_handler,
  164. # 标题重写(代测试)
  165. "title_rewrite": self._title_rewrite_handler,
  166. # 每日发文数据回收
  167. "daily_publish_articles_recycle": self._recycle_article_data_handler,
  168. # 每日发文更新root_source_id
  169. "update_root_source_id": self._update_root_source_id_handler,
  170. # 头条文章,视频抓取
  171. "crawler_toutiao": self._crawler_toutiao_handler,
  172. # 文章池冷启动发布
  173. "article_pool_cold_start": self._article_pool_cold_start_handler,
  174. # 任务超时监控
  175. "task_processing_monitor": self._task_processing_monitor_handler,
  176. # 候选账号质量分析
  177. "candidate_account_quality_analysis": self._candidate_account_quality_score_handler,
  178. }
  179. if task_name not in handlers:
  180. return await task_schedule_response.fail_response(
  181. "4001", "wrong task name input"
  182. )
  183. return await self._run_with_guard(task_name, date_str, handlers[task_name])
  184. # ---------- 下面是若干复合任务的局部实现 ----------
  185. async def _check_kimi_balance_handler(self) -> int:
  186. response = await check_kimi_balance()
  187. await self.log_client.log(
  188. contents={
  189. "trace_id": self.trace_id,
  190. "task": "check_kimi_balance",
  191. "data": response,
  192. }
  193. )
  194. return self.TASK_SUCCESS_STATUS
  195. async def _get_off_videos_task_handler(self) -> int:
  196. sub_task = GetOffVideos(self.db_client, self.log_client, self.trace_id)
  197. return await sub_task.deal()
  198. async def _check_video_audit_status_handler(self) -> int:
  199. sub_task = CheckVideoAuditStatus(self.db_client, self.log_client, self.trace_id)
  200. return await sub_task.deal()
  201. async def _task_processing_monitor_handler(self) -> int:
  202. sub_task = TaskProcessingMonitor(self.db_client)
  203. await sub_task.deal()
  204. return self.TASK_SUCCESS_STATUS
  205. async def _inner_gzh_articles_monitor_handler(self) -> int:
  206. sub_task = InnerGzhArticlesMonitor(self.db_client)
  207. return await sub_task.deal()
  208. async def _title_rewrite_handler(self):
  209. sub_task = TitleRewrite(self.db_client, self.log_client)
  210. return await sub_task.deal()
  211. async def _update_root_source_id_handler(self) -> int:
  212. sub_task = UpdateRootSourceIdAndUpdateTimeTask(self.db_client, self.log_client)
  213. await sub_task.deal()
  214. return self.TASK_SUCCESS_STATUS
  215. async def _outside_monitor_handler(self) -> int:
  216. collector = OutsideGzhArticlesCollector(self.db_client)
  217. await collector.deal()
  218. monitor = OutsideGzhArticlesMonitor(self.db_client)
  219. return await monitor.deal() # 应返回 SUCCESS / FAILED
  220. async def _recycle_article_data_handler(self) -> int:
  221. date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d")
  222. recycle = RecycleDailyPublishArticlesTask(
  223. self.db_client, self.log_client, date_str
  224. )
  225. await recycle.deal()
  226. check = CheckDailyPublishArticlesTask(self.db_client, self.log_client, date_str)
  227. await check.deal()
  228. return self.TASK_SUCCESS_STATUS
  229. async def _crawler_toutiao_handler(self) -> int:
  230. sub_task = CrawlerToutiao(self.db_client, self.log_client, self.trace_id)
  231. method = self.data.get("method", "account")
  232. media_type = self.data.get("media_type", "article")
  233. category_list = self.data.get("category_list", [])
  234. match method:
  235. case "account":
  236. await sub_task.crawler_task(media_type=media_type)
  237. case "recommend":
  238. await sub_task.crawl_toutiao_recommend_task(category_list)
  239. case "search":
  240. await sub_task.search_candidate_accounts()
  241. case _:
  242. raise ValueError(f"Unsupported method {method}")
  243. return self.TASK_SUCCESS_STATUS
  244. async def _article_pool_cold_start_handler(self) -> int:
  245. cold_start = ArticlePoolColdStart(
  246. self.db_client, self.log_client, self.trace_id
  247. )
  248. platform = self.data.get("platform", "weixin")
  249. crawler_methods = self.data.get("crawler_methods", [])
  250. await cold_start.deal(platform=platform, crawl_methods=crawler_methods)
  251. return self.TASK_SUCCESS_STATUS
  252. async def _candidate_account_quality_score_handler(self) -> int:
  253. task = CandidateAccountQualityScoreRecognizer(
  254. self.db_client, self.log_client, self.trace_id
  255. )
  256. await task.deal()
  257. return self.TASK_SUCCESS_STATUS