|
@@ -1,4 +1,5 @@
|
|
|
import asyncio
|
|
|
+import json
|
|
|
import time
|
|
|
import traceback
|
|
|
from datetime import datetime
|
|
@@ -39,8 +40,8 @@ class TaskScheduler(TaskMapper):
|
|
|
"""新建记录(若同键已存在则忽略)"""
|
|
|
query = (
|
|
|
f"insert ignore into {self.table} "
|
|
|
- "(date_string, task_name, start_timestamp, task_status, trace_id) "
|
|
|
- "values (%s, %s, %s, %s, %s);"
|
|
|
+ "(date_string, task_name, start_timestamp, task_status, trace_id, data) "
|
|
|
+ "values (%s, %s, %s, %s, %s, %s);"
|
|
|
)
|
|
|
await self.db_client.async_save(
|
|
|
query=query,
|
|
@@ -50,74 +51,76 @@ class TaskScheduler(TaskMapper):
|
|
|
int(time.time()),
|
|
|
self.TASK_INIT_STATUS,
|
|
|
self.trace_id,
|
|
|
+ json.dumps(self.data, ensure_ascii=False),
|
|
|
),
|
|
|
)
|
|
|
|
|
|
- async def _try_lock_task(self, task_name: str, date_str: str) -> bool:
|
|
|
+ async def _try_lock_task(self) -> bool:
|
|
|
"""一次 UPDATE 抢锁;返回 True 表示成功上锁"""
|
|
|
query = (
|
|
|
f"update {self.table} "
|
|
|
"set task_status = %s "
|
|
|
- "where task_name = %s and date_string = %s and task_status = %s;"
|
|
|
+ "where trace_id = %s and task_status = %s;"
|
|
|
)
|
|
|
res = await self.db_client.async_save(
|
|
|
query=query,
|
|
|
params=(
|
|
|
self.TASK_PROCESSING_STATUS,
|
|
|
- task_name,
|
|
|
- date_str,
|
|
|
+ self.trace_id,
|
|
|
self.TASK_INIT_STATUS,
|
|
|
),
|
|
|
)
|
|
|
return True if res else False
|
|
|
|
|
|
- async def _release_task(self, task_name: str, date_str: str, status: int) -> None:
|
|
|
+ async def _release_task(self, status: int) -> None:
|
|
|
query = (
|
|
|
f"update {self.table} set task_status=%s, finish_timestamp=%s "
|
|
|
- "where task_name=%s and date_string=%s and task_status=%s;"
|
|
|
+ "where trace_d=%s and task_status=%s;"
|
|
|
)
|
|
|
await self.db_client.async_save(
|
|
|
query=query,
|
|
|
params=(
|
|
|
status,
|
|
|
int(time.time()),
|
|
|
- task_name,
|
|
|
- date_str,
|
|
|
+ self.trace_id,
|
|
|
self.TASK_PROCESSING_STATUS,
|
|
|
),
|
|
|
)
|
|
|
|
|
|
- async def _is_processing_overtime(self, task_name: str) -> bool:
|
|
|
- """检测是否已有同名任务在执行且超时。若超时会发飞书告警"""
|
|
|
- query = f"select start_timestamp from {self.table} where task_name=%s and task_status=%s"
|
|
|
+ async def _is_processing_overtime(self, task_name) -> bool:
|
|
|
+ """检测在处理任务是否超时,或者超过最大并行数,若超时会发飞书告警"""
|
|
|
+ query = f"select trace_id from {self.table} where task_status = %s and task_name = %s;"
|
|
|
rows = await self.db_client.async_fetch(
|
|
|
- query=query, params=(task_name, self.TASK_PROCESSING_STATUS)
|
|
|
+ query=query, params=(self.TASK_PROCESSING_STATUS, task_name)
|
|
|
)
|
|
|
if not rows:
|
|
|
return False
|
|
|
- start_ts = rows[0]["start_timestamp"]
|
|
|
- if int(time.time()) - start_ts >= self.get_task_config(task_name).get(
|
|
|
- "expire_duration", self.DEFAULT_TIMEOUT
|
|
|
+
|
|
|
+ processing_task_num = len(rows)
|
|
|
+ if processing_task_num >= self.get_task_config(task_name).get(
|
|
|
+ "task_max_num", self.TASK_MAX_NUM
|
|
|
):
|
|
|
await feishu_robot.bot(
|
|
|
- title=f"{task_name} is overtime",
|
|
|
- detail={"start_ts": start_ts},
|
|
|
+ title=f"multi {task_name} is processing ",
|
|
|
+ detail={"detail": rows},
|
|
|
)
|
|
|
- return True
|
|
|
+ return True
|
|
|
+
|
|
|
+ return False
|
|
|
|
|
|
async def _run_with_guard(
|
|
|
self, task_name: str, date_str: str, task_coro: Callable[[], Awaitable[int]]
|
|
|
):
|
|
|
"""公共:检查、建记录、抢锁、后台运行"""
|
|
|
- # 1. 超时检测(若有正在执行的同名任务则拒绝)
|
|
|
+ # 1. 超时检测
|
|
|
if await self._is_processing_overtime(task_name):
|
|
|
return await task_schedule_response.fail_response(
|
|
|
- "5001", "task is processing"
|
|
|
+ "5005", "muti tasks with same task_name is processing"
|
|
|
)
|
|
|
|
|
|
# 2. 记录并尝试抢锁
|
|
|
await self._insert_or_ignore_task(task_name, date_str)
|
|
|
- if not await self._try_lock_task(task_name, date_str):
|
|
|
+ if not await self._try_lock_task():
|
|
|
return await task_schedule_response.fail_response(
|
|
|
"5001", "task is processing"
|
|
|
)
|
|
@@ -145,7 +148,7 @@ class TaskScheduler(TaskMapper):
|
|
|
},
|
|
|
)
|
|
|
finally:
|
|
|
- await self._release_task(task_name, date_str, status)
|
|
|
+ await self._release_task(status)
|
|
|
|
|
|
asyncio.create_task(_wrapper(), name=task_name)
|
|
|
return await task_schedule_response.success_response(
|
|
@@ -158,7 +161,7 @@ class TaskScheduler(TaskMapper):
|
|
|
task_name: str | None = self.data.get("task_name")
|
|
|
if not task_name:
|
|
|
return await task_schedule_response.fail_response(
|
|
|
- "4002", "task_name must be input"
|
|
|
+ "4003", "task_name must be input"
|
|
|
)
|
|
|
|
|
|
date_str = self.data.get("date_string") or datetime.now().strftime("%Y-%m-%d")
|
|
@@ -251,8 +254,8 @@ class TaskScheduler(TaskMapper):
|
|
|
|
|
|
async def _crawler_toutiao_handler(self) -> int:
|
|
|
sub_task = CrawlerToutiao(self.db_client, self.log_client, self.trace_id)
|
|
|
- media_type = self.data.get("media_type", "article")
|
|
|
method = self.data.get("method", "account")
|
|
|
+ media_type = self.data.get("media_type", "article")
|
|
|
category_list = self.data.get("category_list", [])
|
|
|
|
|
|
match method:
|