浏览代码

新增头条搜索账号模式

luojunhui 1 月之前
父节点
当前提交
756ceacf8b
共有 2 个文件被更改,包括 31 次插入27 次删除
  1. 2 1
      applications/tasks/task_mapper.py
  2. 29 26
      applications/tasks/task_scheduler_v2.py

+ 2 - 1
applications/tasks/task_mapper.py

@@ -19,6 +19,7 @@ class Const:
     UPDATE_ROOT_SOURCE_ID_TIMEOUT = 3600
     CRAWLER_TOUTIAO_ARTICLES_TIMEOUT = 5 * 3600
     ARTICLE_POOL_COLD_START_TIMEOUT = 4 * 3600
+    TASK_MAX_NUM = 5
 
 
 class TaskMapper(Const):
@@ -58,4 +59,4 @@ class TaskMapper(Const):
             case _:
                 expire_duration = self.DEFAULT_TIMEOUT
 
-        return {"expire_duration": expire_duration}
+        return {"expire_duration": expire_duration, "task_max_num": self.TASK_MAX_NUM}

+ 29 - 26
applications/tasks/task_scheduler_v2.py

@@ -1,4 +1,5 @@
 import asyncio
+import json
 import time
 import traceback
 from datetime import datetime
@@ -39,8 +40,8 @@ class TaskScheduler(TaskMapper):
         """新建记录(若同键已存在则忽略)"""
         query = (
             f"insert ignore into {self.table} "
-            "(date_string, task_name, start_timestamp, task_status, trace_id) "
-            "values (%s, %s, %s, %s, %s);"
+            "(date_string, task_name, start_timestamp, task_status, trace_id, data) "
+            "values (%s, %s, %s, %s, %s, %s);"
         )
         await self.db_client.async_save(
             query=query,
@@ -50,74 +51,76 @@ class TaskScheduler(TaskMapper):
                 int(time.time()),
                 self.TASK_INIT_STATUS,
                 self.trace_id,
+                json.dumps(self.data, ensure_ascii=False),
             ),
         )
 
-    async def _try_lock_task(self, task_name: str, date_str: str) -> bool:
+    async def _try_lock_task(self) -> bool:
         """一次 UPDATE 抢锁;返回 True 表示成功上锁"""
         query = (
             f"update {self.table} "
             "set task_status = %s "
-            "where task_name = %s and date_string = %s and task_status = %s;"
+            "where trace_id = %s  and task_status = %s;"
         )
         res = await self.db_client.async_save(
             query=query,
             params=(
                 self.TASK_PROCESSING_STATUS,
-                task_name,
-                date_str,
+                self.trace_id,
                 self.TASK_INIT_STATUS,
             ),
         )
         return True if res else False
 
-    async def _release_task(self, task_name: str, date_str: str, status: int) -> None:
+    async def _release_task(self, status: int) -> None:
         query = (
             f"update {self.table} set task_status=%s, finish_timestamp=%s "
-            "where task_name=%s and date_string=%s and task_status=%s;"
+            "where trace_d=%s and task_status=%s;"
         )
         await self.db_client.async_save(
             query=query,
             params=(
                 status,
                 int(time.time()),
-                task_name,
-                date_str,
+                self.trace_id,
                 self.TASK_PROCESSING_STATUS,
             ),
         )
 
-    async def _is_processing_overtime(self, task_name: str) -> bool:
-        """检测是否已有同名任务在执行且超时。若超时会发飞书告警"""
-        query = f"select start_timestamp from {self.table} where task_name=%s and task_status=%s"
+    async def _is_processing_overtime(self, task_name) -> bool:
+        """检测在处理任务是否超时,或者超过最大并行数,若超时会发飞书告警"""
+        query = f"select trace_id from {self.table} where task_status = %s and task_name = %s;"
         rows = await self.db_client.async_fetch(
-            query=query, params=(task_name, self.TASK_PROCESSING_STATUS)
+            query=query, params=(self.TASK_PROCESSING_STATUS, task_name)
         )
         if not rows:
             return False
-        start_ts = rows[0]["start_timestamp"]
-        if int(time.time()) - start_ts >= self.get_task_config(task_name).get(
-            "expire_duration", self.DEFAULT_TIMEOUT
+
+        processing_task_num = len(rows)
+        if processing_task_num >= self.get_task_config(task_name).get(
+            "task_max_num", self.TASK_MAX_NUM
         ):
             await feishu_robot.bot(
-                title=f"{task_name} is overtime",
-                detail={"start_ts": start_ts},
+                title=f"multi {task_name} is processing ",
+                detail={"detail": rows},
             )
-        return True
+            return True
+
+        return False
 
     async def _run_with_guard(
         self, task_name: str, date_str: str, task_coro: Callable[[], Awaitable[int]]
     ):
         """公共:检查、建记录、抢锁、后台运行"""
-        # 1. 超时检测(若有正在执行的同名任务则拒绝)
+        # 1. 超时检测
         if await self._is_processing_overtime(task_name):
             return await task_schedule_response.fail_response(
-                "5001", "task is processing"
+                "5005", "muti tasks with same task_name is processing"
             )
 
         # 2. 记录并尝试抢锁
         await self._insert_or_ignore_task(task_name, date_str)
-        if not await self._try_lock_task(task_name, date_str):
+        if not await self._try_lock_task():
             return await task_schedule_response.fail_response(
                 "5001", "task is processing"
             )
@@ -145,7 +148,7 @@ class TaskScheduler(TaskMapper):
                     },
                 )
             finally:
-                await self._release_task(task_name, date_str, status)
+                await self._release_task(status)
 
         asyncio.create_task(_wrapper(), name=task_name)
         return await task_schedule_response.success_response(
@@ -158,7 +161,7 @@ class TaskScheduler(TaskMapper):
         task_name: str | None = self.data.get("task_name")
         if not task_name:
             return await task_schedule_response.fail_response(
-                "4002", "task_name must be input"
+                "4003", "task_name must be input"
             )
 
         date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d")
@@ -251,8 +254,8 @@ class TaskScheduler(TaskMapper):
 
     async def _crawler_toutiao_handler(self) -> int:
         sub_task = CrawlerToutiao(self.db_client, self.log_client, self.trace_id)
-        media_type = self.data.get("media_type", "article")
         method = self.data.get("method", "account")
+        media_type = self.data.get("media_type", "article")
         category_list = self.data.get("category_list", [])
 
         match method: