import asyncio import time from datetime import datetime from typing import Awaitable, Callable, Dict from applications.api import feishu_robot from applications.utils import task_schedule_response, generate_task_trace_id from applications.tasks.cold_start_tasks import ArticlePoolColdStart from applications.tasks.crawler_tasks import CrawlerToutiao from applications.tasks.data_recycle_tasks import CheckDailyPublishArticlesTask from applications.tasks.data_recycle_tasks import RecycleDailyPublishArticlesTask from applications.tasks.data_recycle_tasks import UpdateRootSourceIdAndUpdateTimeTask from applications.tasks.llm_tasks import TitleRewrite from applications.tasks.monitor_tasks import check_kimi_balance from applications.tasks.monitor_tasks import GetOffVideos from applications.tasks.monitor_tasks import CheckVideoAuditStatus from applications.tasks.monitor_tasks import InnerGzhArticlesMonitor from applications.tasks.monitor_tasks import OutsideGzhArticlesMonitor from applications.tasks.monitor_tasks import OutsideGzhArticlesCollector from applications.tasks.monitor_tasks import TaskProcessingMonitor from applications.tasks.task_mapper import TaskMapper class TaskScheduler(TaskMapper): """统一调度入口:外部只需调用 `await TaskScheduler(data, log_cli, db_cli).deal()`""" # ---------- 初始化 ---------- def __init__(self, data, log_service, db_client): self.data = data self.log_client = log_service self.db_client = db_client self.table = "long_articles_task_manager" self.trace_id = generate_task_trace_id() # ---------- 公共数据库工具 ---------- async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None: """新建记录(若同键已存在则忽略)""" query = ( f"insert ignore into {self.table} " "(date_string, task_name, start_timestamp, task_status, trace_id) " "values (%s, %s, %s, %s, %s);" ) await self.db_client.async_save( query=query, params=( date_str, task_name, int(time.time()), self.TASK_INIT_STATUS, self.trace_id, ), ) async def _try_lock_task(self, task_name: str, date_str: str) -> bool: """一次 UPDATE 抢锁;返回 True 表示成功上锁""" query = ( f"update {self.table} " "set task_status = %s " "where task_name = %s and date_string = %s and task_status = %s;" ) res = await self.db_client.async_save( query=query, params=( self.TASK_PROCESSING_STATUS, task_name, date_str, self.TASK_INIT_STATUS, ), ) return True if res else False async def _release_task(self, task_name: str, date_str: str, status: int) -> None: query = ( f"update {self.table} set task_status=%s, finish_timestamp=%s " "where task_name=%s and date_string=%s and task_status=%s;" ) await self.db_client.async_save( query=query, params=( status, int(time.time()), task_name, date_str, self.TASK_PROCESSING_STATUS, ), ) async def _is_processing_overtime(self, task_name: str) -> bool: """检测是否已有同名任务在执行且超时。若超时会发飞书告警""" query = f"select start_timestamp from {self.table} where task_name=%s and task_status=%s" rows = await self.db_client.async_fetch( query=query, params=(task_name, self.TASK_PROCESSING_STATUS) ) if not rows: return False start_ts = rows[0]["start_timestamp"] if int(time.time()) - start_ts >= self.get_task_config(task_name).get( "expire_duration", self.DEFAULT_TIMEOUT ): await feishu_robot.bot( title=f"{task_name} is overtime", detail={"start_ts": start_ts}, ) return True async def _run_with_guard( self, task_name: str, date_str: str, task_coro: Callable[[], Awaitable[int]] ): """公共:检查、建记录、抢锁、后台运行""" # 1. 超时检测(若有正在执行的同名任务则拒绝) if await self._is_processing_overtime(task_name): return await task_schedule_response.fail_response( "5001", "task is processing" ) # 2. 记录并尝试抢锁 await self._insert_or_ignore_task(task_name, date_str) if not await self._try_lock_task(task_name, date_str): return await task_schedule_response.fail_response( "5001", "task is processing" ) # 3. 真正执行任务 —— 使用后台协程保证不阻塞调度入口 async def _wrapper(): status = self.TASK_FAILED_STATUS try: status = ( await task_coro() ) # 你的任务函数需返回 TASK_SUCCESS_STATUS / FAILED_STATUS except Exception as e: await self.log_client.log( contents={ "trace_id": self.trace_id, "task": task_name, "err": str(e), } ) await feishu_robot.bot( title=f"{task_name} is failed", detail={"task": task_name, "err": str(e)}, ) finally: await self._release_task(task_name, date_str, status) asyncio.create_task(_wrapper(), name=task_name) return await task_schedule_response.success_response( task_name=task_name, data={"code": 0, "message": "task started"} ) # ---------- 主入口 ---------- async def deal(self): task_name: str | None = self.data.get("task_name") if not task_name: return await task_schedule_response.fail_response( "4002", "task_name must be input" ) date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d") # === 所有任务在此注册:映射到一个返回 int 状态码的异步函数 === handlers: Dict[str, Callable[[], Awaitable[int]]] = { "check_kimi_balance": lambda: check_kimi_balance(), "get_off_videos": self._get_off_videos_task, "check_publish_video_audit_status": self._check_video_audit_status, "task_processing_monitor": self._task_processing_monitor, "outside_article_monitor": self._outside_monitor_handler, "inner_article_monitor": self._inner_gzh_articles_monitor, "title_rewrite": self._title_rewrite, "daily_publish_articles_recycle": self._recycle_handler, "update_root_source_id": self._update_root_source_id, "crawler_toutiao_articles": self._crawler_toutiao_handler, "article_pool_pool_cold_start": self._article_pool_cold_start_handler, } if task_name not in handlers: return await task_schedule_response.fail_response( "4001", "wrong task name input" ) return await self._run_with_guard(task_name, date_str, handlers[task_name]) # ---------- 下面是若干复合任务的局部实现 ---------- # 写成独立方法保持清爽 async def _get_off_videos_task(self): sub_task = GetOffVideos(self.db_client, self.log_client, self.trace_id) return await sub_task.deal() async def _check_video_audit_status(self): sub_task = CheckVideoAuditStatus(self.db_client, self.log_client, self.trace_id) return await sub_task.deal() async def _task_processing_monitor(self): sub_task = TaskProcessingMonitor(self.db_client) return await sub_task.deal() async def _inner_gzh_articles_monitor(self): sub_task = InnerGzhArticlesMonitor(self.db_client) return await sub_task.deal() async def _title_rewrite(self): sub_task = TitleRewrite(self.db_client, self.log_client) return await sub_task.deal() async def _update_root_source_id(self) -> int: sub_task = UpdateRootSourceIdAndUpdateTimeTask(self.db_client, self.log_client) return await sub_task.deal() async def _outside_monitor_handler(self) -> int: collector = OutsideGzhArticlesCollector(self.db_client) await collector.deal() monitor = OutsideGzhArticlesMonitor(self.db_client) return await monitor.deal() # 应返回 SUCCESS / FAILED async def _recycle_handler(self) -> int: date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d") recycle = RecycleDailyPublishArticlesTask( self.db_client, self.log_client, date_str ) await recycle.deal() check = CheckDailyPublishArticlesTask(self.db_client, self.log_client, date_str) return await check.deal() async def _crawler_toutiao_handler(self) -> int: sub_task = CrawlerToutiao(self.db_client, self.log_client, self.trace_id) media_type = self.data.get("media_type", "article") method = self.data.get("method", "account") category_list = self.data.get("category_list", []) if method == "account": await sub_task.crawler_task(media_type=media_type) elif method == "recommend": await sub_task.crawl_toutiao_recommend_task(category_list) else: raise ValueError(f"Unsupported method {method}") return self.TASK_SUCCESS_STATUS async def _article_pool_cold_start_handler(self) -> int: cold_start = ArticlePoolColdStart( self.db_client, self.log_client, self.trace_id ) platform = self.data.get("platform", "weixin") crawler_methods = self.data.get("crawler_methods", []) await cold_start.deal(platform=platform, crawl_methods=crawler_methods) return self.TASK_SUCCESS_STATUS