123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- import asyncio
- import time
- from datetime import datetime
- from typing import Awaitable, Callable, Dict
- from applications.api import feishu_robot
- from applications.utils import task_schedule_response, generate_task_trace_id
- from applications.tasks.cold_start_tasks import ArticlePoolColdStart
- from applications.tasks.crawler_tasks import CrawlerToutiao
- from applications.tasks.data_recycle_tasks import CheckDailyPublishArticlesTask
- from applications.tasks.data_recycle_tasks import RecycleDailyPublishArticlesTask
- from applications.tasks.data_recycle_tasks import UpdateRootSourceIdAndUpdateTimeTask
- from applications.tasks.llm_tasks import TitleRewrite
- from applications.tasks.monitor_tasks import check_kimi_balance
- from applications.tasks.monitor_tasks import GetOffVideos
- from applications.tasks.monitor_tasks import CheckVideoAuditStatus
- from applications.tasks.monitor_tasks import InnerGzhArticlesMonitor
- from applications.tasks.monitor_tasks import OutsideGzhArticlesMonitor
- from applications.tasks.monitor_tasks import OutsideGzhArticlesCollector
- from applications.tasks.monitor_tasks import TaskProcessingMonitor
- from applications.tasks.task_mapper import TaskMapper
- class TaskScheduler(TaskMapper):
- """统一调度入口:外部只需调用 `await TaskScheduler(data, log_cli, db_cli).deal()`"""
- # ---------- 初始化 ----------
- def __init__(self, data, log_service, db_client):
- self.data = data
- self.log_client = log_service
- self.db_client = db_client
- self.table = "long_articles_task_manager"
- self.trace_id = generate_task_trace_id()
- # ---------- 公共数据库工具 ----------
- async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None:
- """新建记录(若同键已存在则忽略)"""
- query = (
- f"insert ignore into {self.table} "
- "(date_string, task_name, start_timestamp, task_status, trace_id) "
- "values (%s, %s, %s, %s, %s);"
- )
- await self.db_client.async_save(
- query=query,
- params=(
- date_str,
- task_name,
- int(time.time()),
- self.TASK_INIT_STATUS,
- self.trace_id,
- ),
- )
- async def _try_lock_task(self, task_name: str, date_str: str) -> bool:
- """一次 UPDATE 抢锁;返回 True 表示成功上锁"""
- query = (
- f"update {self.table} "
- "set task_status = %s "
- "where task_name = %s and date_string = %s and task_status = %s;"
- )
- res = await self.db_client.async_save(
- query=query,
- params=(
- self.TASK_PROCESSING_STATUS,
- task_name,
- date_str,
- self.TASK_INIT_STATUS,
- ),
- )
- return True if res else False
- async def _release_task(self, task_name: str, date_str: str, status: int) -> None:
- query = (
- f"update {self.table} set task_status=%s, finish_timestamp=%s "
- "where task_name=%s and date_string=%s and task_status=%s;"
- )
- await self.db_client.async_save(
- query=query,
- params=(
- status,
- int(time.time()),
- task_name,
- date_str,
- self.TASK_PROCESSING_STATUS,
- ),
- )
- async def _is_processing_overtime(self, task_name: str) -> bool:
- """检测是否已有同名任务在执行且超时。若超时会发飞书告警"""
- query = f"select start_timestamp from {self.table} where task_name=%s and task_status=%s"
- rows = await self.db_client.async_fetch(
- query=query, params=(task_name, self.TASK_PROCESSING_STATUS)
- )
- if not rows:
- return False
- start_ts = rows[0]["start_timestamp"]
- if int(time.time()) - start_ts >= self.get_task_config(task_name).get(
- "expire_duration", self.DEFAULT_TIMEOUT
- ):
- await feishu_robot.bot(
- title=f"{task_name} is overtime",
- detail={"start_ts": start_ts},
- )
- return True
- async def _run_with_guard(
- self, task_name: str, date_str: str, task_coro: Callable[[], Awaitable[int]]
- ):
- """公共:检查、建记录、抢锁、后台运行"""
- # 1. 超时检测(若有正在执行的同名任务则拒绝)
- if await self._is_processing_overtime(task_name):
- return await task_schedule_response.fail_response(
- "5001", "task is processing"
- )
- # 2. 记录并尝试抢锁
- await self._insert_or_ignore_task(task_name, date_str)
- if not await self._try_lock_task(task_name, date_str):
- return await task_schedule_response.fail_response(
- "5001", "task is processing"
- )
- # 3. 真正执行任务 —— 使用后台协程保证不阻塞调度入口
- async def _wrapper():
- status = self.TASK_FAILED_STATUS
- try:
- status = (
- await task_coro()
- ) # 你的任务函数需返回 TASK_SUCCESS_STATUS / FAILED_STATUS
- except Exception as e:
- await self.log_client.log(
- contents={
- "trace_id": self.trace_id,
- "task": task_name,
- "err": str(e),
- }
- )
- await feishu_robot.bot(
- title=f"{task_name} is failed",
- detail={"task": task_name, "err": str(e)},
- )
- finally:
- await self._release_task(task_name, date_str, status)
- asyncio.create_task(_wrapper(), name=task_name)
- return await task_schedule_response.success_response(
- task_name=task_name, data={"code": 0, "message": "task started"}
- )
- # ---------- 主入口 ----------
- async def deal(self):
- task_name: str | None = self.data.get("task_name")
- if not task_name:
- return await task_schedule_response.fail_response(
- "4002", "task_name must be input"
- )
- date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d")
- # === 所有任务在此注册:映射到一个返回 int 状态码的异步函数 ===
- handlers: Dict[str, Callable[[], Awaitable[int]]] = {
- "check_kimi_balance": lambda: check_kimi_balance(),
- "get_off_videos": self._get_off_videos_task,
- "check_publish_video_audit_status": self._check_video_audit_status,
- "task_processing_monitor": self._task_processing_monitor,
- "outside_article_monitor": self._outside_monitor_handler,
- "inner_article_monitor": self._inner_gzh_articles_monitor,
- "title_rewrite": self._title_rewrite,
- "daily_publish_articles_recycle": self._recycle_handler,
- "update_root_source_id": self._update_root_source_id,
- "crawler_toutiao_articles": self._crawler_toutiao_handler,
- "article_pool_pool_cold_start": self._article_pool_cold_start_handler,
- }
- if task_name not in handlers:
- return await task_schedule_response.fail_response(
- "4001", "wrong task name input"
- )
- return await self._run_with_guard(task_name, date_str, handlers[task_name])
- # ---------- 下面是若干复合任务的局部实现 ----------
- # 写成独立方法保持清爽
- async def _get_off_videos_task(self):
- sub_task = GetOffVideos(self.db_client, self.log_client, self.trace_id)
- return await sub_task.deal()
- async def _check_video_audit_status(self):
- sub_task = CheckVideoAuditStatus(self.db_client, self.log_client, self.trace_id)
- return await sub_task.deal()
- async def _task_processing_monitor(self):
- sub_task = TaskProcessingMonitor(self.db_client)
- return await sub_task.deal()
- async def _inner_gzh_articles_monitor(self):
- sub_task = InnerGzhArticlesMonitor(self.db_client)
- return await sub_task.deal()
- async def _title_rewrite(self):
- sub_task = TitleRewrite(self.db_client, self.log_client)
- return await sub_task.deal()
- async def _update_root_source_id(self) -> int:
- sub_task = UpdateRootSourceIdAndUpdateTimeTask(self.db_client, self.log_client)
- return await sub_task.deal()
- async def _outside_monitor_handler(self) -> int:
- collector = OutsideGzhArticlesCollector(self.db_client)
- await collector.deal()
- monitor = OutsideGzhArticlesMonitor(self.db_client)
- return await monitor.deal() # 应返回 SUCCESS / FAILED
- async def _recycle_handler(self) -> int:
- date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d")
- recycle = RecycleDailyPublishArticlesTask(
- self.db_client, self.log_client, date_str
- )
- await recycle.deal()
- check = CheckDailyPublishArticlesTask(self.db_client, self.log_client, date_str)
- return await check.deal()
- async def _crawler_toutiao_handler(self) -> int:
- sub_task = CrawlerToutiao(self.db_client, self.log_client, self.trace_id)
- media_type = self.data.get("media_type", "article")
- method = self.data.get("method", "account")
- category_list = self.data.get("category_list", [])
- if method == "account":
- await sub_task.crawler_task(media_type=media_type)
- elif method == "recommend":
- await sub_task.crawl_toutiao_recommend_task(category_list)
- else:
- raise ValueError(f"Unsupported method {method}")
- return self.TASK_SUCCESS_STATUS
- async def _article_pool_cold_start_handler(self) -> int:
- cold_start = ArticlePoolColdStart(
- self.db_client, self.log_client, self.trace_id
- )
- platform = self.data.get("platform", "weixin")
- crawler_methods = self.data.get("crawler_methods", [])
- await cold_start.deal(platform=platform, crawl_methods=crawler_methods)
- return self.TASK_SUCCESS_STATUS
|