|
- import asyncio
- import json
- import time
- import traceback
- from datetime import datetime, timedelta
- from typing import Awaitable, Callable, Dict
- from applications.api import feishu_robot
- from applications.utils import task_schedule_response
- from applications.tasks.task_handler import TaskHandler
- class TaskScheduler(TaskHandler):
- """统一调度入口:外部只需调用 `await TaskScheduler(data, log_cli, db_cli).deal()`"""
- # ---------- 初始化 ----------
- def __init__(self, data, log_service, db_client, trace_id):
- super().__init__(data, log_service, db_client, trace_id)
- self.data = data
- self.log_client = log_service
- self.db_client = db_client
- self.table = "long_articles_task_manager"
- self.trace_id = trace_id
- # ---------- 公共数据库工具 ----------
- async def _insert_or_ignore_task(self, task_name: str, date_str: str) -> None:
- """新建记录(若同键已存在则忽略)"""
- query = (
- f"insert ignore into {self.table} "
- "(date_string, task_name, start_timestamp, task_status, trace_id, data) "
- "values (%s, %s, %s, %s, %s, %s);"
- )
- await self.db_client.async_save(
- query=query,
- params=(
- date_str,
- task_name,
- int(time.time()),
- self.TASK_INIT_STATUS,
- self.trace_id,
- json.dumps(self.data, ensure_ascii=False),
- ),
- )
- async def _try_lock_task(self) -> bool:
- """一次 UPDATE 抢锁;返回 True 表示成功上锁"""
- query = (
- f"update {self.table} "
- "set task_status = %s "
- "where trace_id = %s and task_status = %s;"
- )
- res = await self.db_client.async_save(
- query=query,
- params=(
- self.TASK_PROCESSING_STATUS,
- self.trace_id,
- self.TASK_INIT_STATUS,
- ),
- )
- return True if res else False
- async def _release_task(self, status: int) -> None:
- query = (
- f"update {self.table} set task_status=%s, finish_timestamp=%s "
- "where trace_id=%s and task_status=%s;"
- )
- await self.db_client.async_save(
- query=query,
- params=(
- status,
- int(time.time()),
- self.trace_id,
- self.TASK_PROCESSING_STATUS,
- ),
- )
- async def _is_processing_overtime(self, task_name) -> bool:
- """检测在处理任务是否超时,或者超过最大并行数,若超时会发飞书告警"""
- query = f"select trace_id from {self.table} where task_status = %s and task_name = %s;"
- rows = await self.db_client.async_fetch(
- query=query, params=(self.TASK_PROCESSING_STATUS, task_name)
- )
- if not rows:
- return False
- processing_task_num = len(rows)
- if processing_task_num >= self.get_task_config(task_name).get(
- "task_max_num", self.TASK_MAX_NUM
- ):
- await feishu_robot.bot(
- title=f"multi {task_name} is processing ",
- detail={"detail": rows},
- )
- return True
- return False
- async def _run_with_guard(
- self, task_name: str, date_str: str, task_coro: Callable[[], Awaitable[int]]
- ):
- """公共:检查、建记录、抢锁、后台运行"""
- # 1. 超时检测
- if await self._is_processing_overtime(task_name):
- return await task_schedule_response.fail_response(
- "5005", "muti tasks with same task_name is processing"
- )
- # 2. 记录并尝试抢锁
- await self._insert_or_ignore_task(task_name, date_str)
- if not await self._try_lock_task():
- return await task_schedule_response.fail_response(
- "5001", "task is processing"
- )
- # 3. 真正执行任务 —— 使用后台协程保证不阻塞调度入口
- async def _wrapper():
- status = self.TASK_FAILED_STATUS
- try:
- status = await task_coro()
- except Exception as e:
- await self.log_client.log(
- contents={
- "trace_id": self.trace_id,
- "function": "cor_wrapper",
- "task": task_name,
- "error": str(e),
- }
- )
- await feishu_robot.bot(
- title=f"{task_name} is failed",
- detail={
- "task": task_name,
- "err": str(e),
- "traceback": traceback.format_exc(),
- },
- )
- finally:
- await self._release_task(status)
- asyncio.create_task(_wrapper(), name=task_name)
- return await task_schedule_response.success_response(
- task_name=task_name,
- data={"code": 0, "message": "task started", "trace_id": self.trace_id},
- )
- # ---------- 主入口 ----------
- async def deal(self):
- task_name: str | None = self.data.get("task_name")
- if not task_name:
- return await task_schedule_response.fail_response(
- "4003", "task_name must be input"
- )
- date_str = self.data.get("date_string") or (
- datetime.utcnow() + timedelta(hours=8)
- ).strftime("%Y-%m-%d")
- # === 所有任务在此注册:映射到一个返回 int 状态码的异步函数 ===
- handlers: Dict[str, Callable[[], Awaitable[int]]] = {
- # 校验kimi余额
- "check_kimi_balance": self._check_kimi_balance_handler,
- # 长文视频发布之后,三天后下架
- "get_off_videos": self._get_off_videos_task_handler,
- # 长文视频发布之后,三天内保持视频可见状态
- "check_publish_video_audit_status": self._check_video_audit_status_handler,
- # 外部服务号发文监测
- "outside_article_monitor": self._outside_monitor_handler,
- # 站内发文监测
- "inner_article_monitor": self._inner_gzh_articles_monitor_handler,
- # 标题重写(代测试)
- "title_rewrite": self._title_rewrite_handler,
- # 每日发文数据回收
- "daily_publish_articles_recycle": self._recycle_article_data_handler,
- # 每日发文更新root_source_id
- "update_root_source_id": self._update_root_source_id_handler,
- # 头条文章,视频抓取
- "crawler_toutiao": self._crawler_toutiao_handler,
- # 文章池冷启动发布
- "article_pool_cold_start": self._article_pool_cold_start_handler,
- # 任务超时监控
- "task_processing_monitor": self._task_processing_monitor_handler,
- # 候选账号质量分析
- "candidate_account_quality_analysis": self._candidate_account_quality_score_handler,
- # 文章内容池--标题品类处理
- "article_pool_category_generation": self._article_pool_category_generation_handler,
- # 抓取账号管理
- "crawler_account_manager": self._crawler_account_manager_handler,
- # 微信公众号文章抓取
- "crawler_gzh_articles": self._crawler_gzh_article_handler,
- # 服务号发文回收
- "fwh_daily_recycle": self._recycle_fwh_article_handler,
- # 发文账号品类分析
- "account_category_analysis": self._account_category_analysis_handler,
- # 抓取 文章/视频 数量分析
- "crawler_detail_analysis": self._crawler_article_analysis_handler,
- }
- if task_name not in handlers:
- return await task_schedule_response.fail_response(
- "4001", "wrong task name input"
- )
- return await self._run_with_guard(task_name, date_str, handlers[task_name])
- __all__ = ["TaskScheduler"]
|