import asyncio import json import time import traceback from datetime import datetime, timedelta from typing import Awaitable, Callable, Dict from applications.api import feishu_robot from applications.utils import task_schedule_response from applications.tasks.task_handler import TaskHandler class TaskScheduler(TaskHandler): """统一调度入口:外部只需调用 `await TaskScheduler(data, log_cli, db_cli).deal()`""" # ---------- 初始化 ---------- def __init__(self, data, log_service, db_client, trace_id): super().__init__(data, log_service, db_client, trace_id) self.data = data self.log_client = log_service self.db_client = db_client self.table = "long_articles_task_manager" self.trace_id = trace_id # ---------- 公共数据库工具 ---------- async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None: """新建记录(若同键已存在则忽略)""" query = ( f"insert ignore into {self.table} " "(date_string, task_name, start_timestamp, task_status, trace_id, data) " "values (%s, %s, %s, %s, %s, %s);" ) await self.db_client.async_save( query=query, params=( date_str, task_name, int(time.time()), self.TASK_INIT_STATUS, self.trace_id, json.dumps(self.data, ensure_ascii=False), ), ) async def _try_lock_task(self) -> bool: """一次 UPDATE 抢锁;返回 True 表示成功上锁""" query = ( f"update {self.table} " "set task_status = %s " "where trace_id = %s and task_status = %s;" ) res = await self.db_client.async_save( query=query, params=( self.TASK_PROCESSING_STATUS, self.trace_id, self.TASK_INIT_STATUS, ), ) return True if res else False async def _release_task(self, status: int) -> None: query = ( f"update {self.table} set task_status=%s, finish_timestamp=%s " "where trace_id=%s and task_status=%s;" ) await self.db_client.async_save( query=query, params=( status, int(time.time()), self.trace_id, self.TASK_PROCESSING_STATUS, ), ) async def _is_processing_overtime(self, task_name) -> bool: """检测在处理任务是否超时,或者超过最大并行数,若超时会发飞书告警""" query = f"select trace_id from {self.table} where task_status = %s and task_name = %s;" rows = await self.db_client.async_fetch( query=query, params=(self.TASK_PROCESSING_STATUS, task_name) ) if not rows: return False processing_task_num = len(rows) if processing_task_num >= self.get_task_config(task_name).get( "task_max_num", self.TASK_MAX_NUM ): await feishu_robot.bot( title=f"multi {task_name} is processing ", detail={"detail": rows}, ) return True return False async def _run_with_guard( self, task_name: str, date_str: str, task_coro: Callable[[], Awaitable[int]] ): """公共:检查、建记录、抢锁、后台运行""" # 1. 超时检测 if await self._is_processing_overtime(task_name): return await task_schedule_response.fail_response( "5005", "muti tasks with same task_name is processing" ) # 2. 记录并尝试抢锁 await self._insert_or_ignore_task(task_name, date_str) if not await self._try_lock_task(): return await task_schedule_response.fail_response( "5001", "task is processing" ) # 3. 真正执行任务 —— 使用后台协程保证不阻塞调度入口 async def _wrapper(): status = self.TASK_FAILED_STATUS try: status = await task_coro() except Exception as e: await self.log_client.log( contents={ "trace_id": self.trace_id, "function": "cor_wrapper", "task": task_name, "error": str(e), } ) await feishu_robot.bot( title=f"{task_name} is failed", detail={ "task": task_name, "err": str(e), "traceback": traceback.format_exc(), }, ) finally: await self._release_task(status) asyncio.create_task(_wrapper(), name=task_name) return await task_schedule_response.success_response( task_name=task_name, data={"code": 0, "message": "task started", "trace_id": self.trace_id}, ) # ---------- 主入口 ---------- async def deal(self): task_name: str | None = self.data.get("task_name") if not task_name: return await task_schedule_response.fail_response( "4003", "task_name must be input" ) date_str = self.data.get("date_string") or ( datetime.utcnow() + timedelta(hours=8) ).strftime("%Y-%m-%d") # === 所有任务在此注册:映射到一个返回 int 状态码的异步函数 === handlers: Dict[str, Callable[[], Awaitable[int]]] = { # 校验kimi余额 "check_kimi_balance": self._check_kimi_balance_handler, # 长文视频发布之后,三天后下架 "get_off_videos": self._get_off_videos_task_handler, # 长文视频发布之后,三天内保持视频可见状态 "check_publish_video_audit_status": self._check_video_audit_status_handler, # 外部服务号发文监测 "outside_article_monitor": self._outside_monitor_handler, # 站内发文监测 "inner_article_monitor": self._inner_gzh_articles_monitor_handler, # 标题重写(代测试) "title_rewrite": self._title_rewrite_handler, # 每日发文数据回收 "daily_publish_articles_recycle": self._recycle_article_data_handler, # 每日发文更新root_source_id "update_root_source_id": self._update_root_source_id_handler, # 头条文章,视频抓取 "crawler_toutiao": self._crawler_toutiao_handler, # 文章池冷启动发布 "article_pool_cold_start": self._article_pool_cold_start_handler, # 任务超时监控 "task_processing_monitor": self._task_processing_monitor_handler, # 候选账号质量分析 "candidate_account_quality_analysis": self._candidate_account_quality_score_handler, # 文章内容池--标题品类处理 "article_pool_category_generation": self._article_pool_category_generation_handler, # 抓取账号管理 "crawler_account_manager": self._crawler_account_manager_handler, # 微信公众号文章抓取 "crawler_gzh_articles": self._crawler_gzh_article_handler, # 服务号发文回收 "fwh_daily_recycle": self._recycle_fwh_article_handler, # 发文账号品类分析 "account_category_analysis": self._account_category_analysis_handler, # 抓取 文章/视频 数量分析 "crawler_detail_analysis": self._crawler_article_analysis_handler, # 小程序裂变信息处理 "mini_program_detail_process": self._mini_program_detail_handler, } if task_name not in handlers: return await task_schedule_response.fail_response( "4001", "wrong task name input" ) return await self._run_with_guard(task_name, date_str, handlers[task_name]) __all__ = ["TaskScheduler"]