|
@@ -21,6 +21,7 @@ from app.jobs.task_utils import (
|
|
|
from app.core.config import GlobalConfigSettings
|
|
from app.core.config import GlobalConfigSettings
|
|
|
from app.core.database import DatabaseManager
|
|
from app.core.database import DatabaseManager
|
|
|
from app.core.observability import LogService
|
|
from app.core.observability import LogService
|
|
|
|
|
+from app.core.task_registry import TaskRegistry
|
|
|
|
|
|
|
|
|
|
|
|
|
class TaskScheduler(TaskHandler):
|
|
class TaskScheduler(TaskHandler):
|
|
@@ -39,9 +40,11 @@ class TaskScheduler(TaskHandler):
|
|
|
db_client: DatabaseManager,
|
|
db_client: DatabaseManager,
|
|
|
trace_id: str,
|
|
trace_id: str,
|
|
|
config: GlobalConfigSettings,
|
|
config: GlobalConfigSettings,
|
|
|
|
|
+ task_registry: TaskRegistry,
|
|
|
):
|
|
):
|
|
|
super().__init__(data, log_service, db_client, trace_id, config)
|
|
super().__init__(data, log_service, db_client, trace_id, config)
|
|
|
self.table = TaskUtils.validate_table_name(TaskConstants.TASK_TABLE)
|
|
self.table = TaskUtils.validate_table_name(TaskConstants.TASK_TABLE)
|
|
|
|
|
+ self.task_registry = task_registry
|
|
|
|
|
|
|
|
# ==================== 数据库操作 ====================
|
|
# ==================== 数据库操作 ====================
|
|
|
|
|
|
|
@@ -231,12 +234,30 @@ class TaskScheduler(TaskHandler):
|
|
|
|
|
|
|
|
# 3. 后台执行任务
|
|
# 3. 后台执行任务
|
|
|
async def _task_wrapper():
|
|
async def _task_wrapper():
|
|
|
- """任务执行包装器 - 处理错误和重试"""
|
|
|
|
|
|
|
+ """任务执行包装器 - 处理错误、重试和取消"""
|
|
|
status = TaskStatus.FAILED
|
|
status = TaskStatus.FAILED
|
|
|
retry_count = 0
|
|
retry_count = 0
|
|
|
config = get_task_config(task_name)
|
|
config = get_task_config(task_name)
|
|
|
start_time = time.time()
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
+ async def _cancel_watchdog(_task: asyncio.Task):
|
|
|
|
|
+ """定期轮询 DB,检测任务是否被外部取消"""
|
|
|
|
|
+ while True:
|
|
|
|
|
+ await asyncio.sleep(config.cancel_check_interval)
|
|
|
|
|
+ row = await self.db_client.async_fetch_one(
|
|
|
|
|
+ f"SELECT task_status FROM {self.table} WHERE trace_id = %s",
|
|
|
|
|
+ (self.trace_id,),
|
|
|
|
|
+ )
|
|
|
|
|
+ if row and row["task_status"] == TaskStatus.CANCELLING:
|
|
|
|
|
+ _task.cancel()
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ main_task = asyncio.current_task()
|
|
|
|
|
+ watchdog = asyncio.create_task(
|
|
|
|
|
+ _cancel_watchdog(main_task),
|
|
|
|
|
+ name=f"watchdog_{task_name}_{self.trace_id}",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
try:
|
|
try:
|
|
|
await self._log_task_event("task_started", task_name=task_name)
|
|
await self._log_task_event("task_started", task_name=task_name)
|
|
|
|
|
|
|
@@ -251,6 +272,14 @@ class TaskScheduler(TaskHandler):
|
|
|
duration=duration,
|
|
duration=duration,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ except asyncio.CancelledError:
|
|
|
|
|
+ duration = time.time() - start_time
|
|
|
|
|
+ await self._log_task_event(
|
|
|
|
|
+ "task_cancelled", task_name=task_name, duration=duration
|
|
|
|
|
+ )
|
|
|
|
|
+ status = TaskStatus.CANCELLED
|
|
|
|
|
+ raise
|
|
|
|
|
+
|
|
|
except TaskError as e:
|
|
except TaskError as e:
|
|
|
# 已知的任务错误
|
|
# 已知的任务错误
|
|
|
duration = time.time() - start_time
|
|
duration = time.time() - start_time
|
|
@@ -304,10 +333,17 @@ class TaskScheduler(TaskHandler):
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
finally:
|
|
finally:
|
|
|
|
|
+ watchdog.cancel()
|
|
|
|
|
+ try:
|
|
|
|
|
+ await watchdog
|
|
|
|
|
+ except asyncio.CancelledError:
|
|
|
|
|
+ pass
|
|
|
await self._release_task(status)
|
|
await self._release_task(status)
|
|
|
|
|
+ await self.task_registry.unregister(self.trace_id)
|
|
|
|
|
|
|
|
- # 创建后台任务
|
|
|
|
|
- asyncio.create_task(_task_wrapper(), name=f"{task_name}_{self.trace_id}")
|
|
|
|
|
|
|
+ # 创建后台任务并注册
|
|
|
|
|
+ task = asyncio.create_task(_task_wrapper(), name=f"{task_name}_{self.trace_id}")
|
|
|
|
|
+ await self.task_registry.register(self.trace_id, task)
|
|
|
|
|
|
|
|
return await task_schedule_response.success_response(
|
|
return await task_schedule_response.success_response(
|
|
|
task_name=task_name,
|
|
task_name=task_name,
|
|
@@ -339,35 +375,52 @@ class TaskScheduler(TaskHandler):
|
|
|
|
|
|
|
|
async def cancel_task(self, trace_id: Optional[str] = None) -> bool:
|
|
async def cancel_task(self, trace_id: Optional[str] = None) -> bool:
|
|
|
"""
|
|
"""
|
|
|
- 取消任务(将状态设置为失败)
|
|
|
|
|
|
|
+ 请求取消任务
|
|
|
|
|
+
|
|
|
|
|
+ 流程: PROCESSING -> CANCELLING (由 watchdog 检测后取消协程) -> CANCELLED
|
|
|
|
|
+ 对于 INIT 状态的任务直接标记为 CANCELLED
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
|
|
trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
- 是否成功取消
|
|
|
|
|
|
|
+ 是否成功发起取消
|
|
|
"""
|
|
"""
|
|
|
trace_id = trace_id or self.trace_id
|
|
trace_id = trace_id or self.trace_id
|
|
|
- query = f"""
|
|
|
|
|
|
|
+
|
|
|
|
|
+ # INIT 状态的任务还没开始执行,直接标记为 CANCELLED
|
|
|
|
|
+ init_query = f"""
|
|
|
UPDATE {self.table}
|
|
UPDATE {self.table}
|
|
|
SET task_status = %s, finish_timestamp = %s
|
|
SET task_status = %s, finish_timestamp = %s
|
|
|
- WHERE trace_id = %s AND task_status IN (%s, %s)
|
|
|
|
|
|
|
+ WHERE trace_id = %s AND task_status = %s
|
|
|
"""
|
|
"""
|
|
|
- result = await self.db_client.async_save(
|
|
|
|
|
- query,
|
|
|
|
|
- (
|
|
|
|
|
- TaskStatus.FAILED,
|
|
|
|
|
- int(time.time()),
|
|
|
|
|
- trace_id,
|
|
|
|
|
- TaskStatus.INIT,
|
|
|
|
|
- TaskStatus.PROCESSING,
|
|
|
|
|
- ),
|
|
|
|
|
|
|
+ init_result = await self.db_client.async_save(
|
|
|
|
|
+ init_query,
|
|
|
|
|
+ (TaskStatus.CANCELLED, int(time.time()), trace_id, TaskStatus.INIT),
|
|
|
)
|
|
)
|
|
|
-
|
|
|
|
|
- if result:
|
|
|
|
|
|
|
+ if init_result:
|
|
|
await self._log_task_event("task_cancelled", trace_id=trace_id)
|
|
await self._log_task_event("task_cancelled", trace_id=trace_id)
|
|
|
|
|
+ return True
|
|
|
|
|
|
|
|
- return bool(result)
|
|
|
|
|
|
|
+ # PROCESSING 状态的任务标记为 CANCELLING,等待 watchdog 检测并取消协程
|
|
|
|
|
+ processing_query = f"""
|
|
|
|
|
+ UPDATE {self.table}
|
|
|
|
|
+ SET task_status = %s
|
|
|
|
|
+ WHERE trace_id = %s AND task_status = %s
|
|
|
|
|
+ """
|
|
|
|
|
+ processing_result = await self.db_client.async_save(
|
|
|
|
|
+ processing_query,
|
|
|
|
|
+ (TaskStatus.CANCELLING, trace_id, TaskStatus.PROCESSING),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if processing_result:
|
|
|
|
|
+ # 同进程优化:直接取消协程,不用等 watchdog 轮询
|
|
|
|
|
+ await self.task_registry.cancel_task(trace_id)
|
|
|
|
|
+ await self._log_task_event(
|
|
|
|
|
+ "task_cancelling", trace_id=trace_id
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return bool(processing_result)
|
|
|
|
|
|
|
|
async def retry_task(self, trace_id: Optional[str] = None) -> bool:
|
|
async def retry_task(self, trace_id: Optional[str] = None) -> bool:
|
|
|
"""
|
|
"""
|