Jelajahi Sumber

异步任务取消

luojunhui 3 minggu lalu
induk
melakukan
1f0ecd5909

+ 10 - 1
app/api/service/task_manager_service.py

@@ -9,8 +9,17 @@ class TaskConst:
     INIT_STATUS = 0
     PROCESSING_STATUS = 1
     FINISHED_STATUS = 2
+    CANCELLING_STATUS = 3
+    CANCELLED_STATUS = 4
     FAILED_STATUS = 99
-    STATUS_TEXT = {0: "初始化", 1: "处理中", 2: "完成", 99: "失败"}
+    STATUS_TEXT = {
+        0: "初始化",
+        1: "处理中",
+        2: "完成",
+        3: "待取消",
+        4: "已取消",
+        99: "失败",
+    }
 
     DEFAULT_PAGE = 1
     DEFAULT_SIZE = 50

+ 72 - 19
app/api/service/task_scheduler.py

@@ -21,6 +21,7 @@ from app.jobs.task_utils import (
 from app.core.config import GlobalConfigSettings
 from app.core.database import DatabaseManager
 from app.core.observability import LogService
+from app.core.task_registry import TaskRegistry
 
 
 class TaskScheduler(TaskHandler):
@@ -39,9 +40,11 @@ class TaskScheduler(TaskHandler):
         db_client: DatabaseManager,
         trace_id: str,
         config: GlobalConfigSettings,
+        task_registry: TaskRegistry,
     ):
         super().__init__(data, log_service, db_client, trace_id, config)
         self.table = TaskUtils.validate_table_name(TaskConstants.TASK_TABLE)
+        self.task_registry = task_registry
 
     # ==================== 数据库操作 ====================
 
@@ -231,12 +234,30 @@ class TaskScheduler(TaskHandler):
 
         # 3. 后台执行任务
         async def _task_wrapper():
-            """任务执行包装器 - 处理错误和重试"""
+            """任务执行包装器 - 处理错误、重试和取消"""
             status = TaskStatus.FAILED
             retry_count = 0
             config = get_task_config(task_name)
             start_time = time.time()
 
+            async def _cancel_watchdog(_task: asyncio.Task):
+                """定期轮询 DB,检测任务是否被外部取消"""
+                while True:
+                    await asyncio.sleep(config.cancel_check_interval)
+                    row = await self.db_client.async_fetch_one(
+                        f"SELECT task_status FROM {self.table} WHERE trace_id = %s",
+                        (self.trace_id,),
+                    )
+                    if row and row["task_status"] == TaskStatus.CANCELLING:
+                        _task.cancel()
+                        return
+
+            main_task = asyncio.current_task()
+            watchdog = asyncio.create_task(
+                _cancel_watchdog(main_task),
+                name=f"watchdog_{task_name}_{self.trace_id}",
+            )
+
             try:
                 await self._log_task_event("task_started", task_name=task_name)
 
@@ -251,6 +272,14 @@ class TaskScheduler(TaskHandler):
                     duration=duration,
                 )
 
+            except asyncio.CancelledError:
+                duration = time.time() - start_time
+                await self._log_task_event(
+                    "task_cancelled", task_name=task_name, duration=duration
+                )
+                status = TaskStatus.CANCELLED
+                raise
+
             except TaskError as e:
                 # 已知的任务错误
                 duration = time.time() - start_time
@@ -304,10 +333,17 @@ class TaskScheduler(TaskHandler):
                 )
 
             finally:
+                watchdog.cancel()
+                try:
+                    await watchdog
+                except asyncio.CancelledError:
+                    pass
                 await self._release_task(status)
+                await self.task_registry.unregister(self.trace_id)
 
-        # 创建后台任务
-        asyncio.create_task(_task_wrapper(), name=f"{task_name}_{self.trace_id}")
+        # 创建后台任务并注册
+        task = asyncio.create_task(_task_wrapper(), name=f"{task_name}_{self.trace_id}")
+        await self.task_registry.register(self.trace_id, task)
 
         return await task_schedule_response.success_response(
             task_name=task_name,
@@ -339,35 +375,52 @@ class TaskScheduler(TaskHandler):
 
     async def cancel_task(self, trace_id: Optional[str] = None) -> bool:
         """
-        取消任务(将状态设置为失败)
+        请求取消任务
+
+        流程: PROCESSING -> CANCELLING (由 watchdog 检测后取消协程) -> CANCELLED
+        对于 INIT 状态的任务直接标记为 CANCELLED
 
         Args:
             trace_id: 任务追踪 ID,默认使用当前实例的 trace_id
 
         Returns:
-            是否成功取消
+            是否成功发起取消
         """
         trace_id = trace_id or self.trace_id
-        query = f"""
+
+        # INIT 状态的任务还没开始执行,直接标记为 CANCELLED
+        init_query = f"""
             UPDATE {self.table}
             SET task_status = %s, finish_timestamp = %s
-            WHERE trace_id = %s AND task_status IN (%s, %s)
+            WHERE trace_id = %s AND task_status = %s
         """
-        result = await self.db_client.async_save(
-            query,
-            (
-                TaskStatus.FAILED,
-                int(time.time()),
-                trace_id,
-                TaskStatus.INIT,
-                TaskStatus.PROCESSING,
-            ),
+        init_result = await self.db_client.async_save(
+            init_query,
+            (TaskStatus.CANCELLED, int(time.time()), trace_id, TaskStatus.INIT),
         )
-
-        if result:
+        if init_result:
             await self._log_task_event("task_cancelled", trace_id=trace_id)
+            return True
 
-        return bool(result)
+        # PROCESSING 状态的任务标记为 CANCELLING,等待 watchdog 检测并取消协程
+        processing_query = f"""
+            UPDATE {self.table}
+            SET task_status = %s
+            WHERE trace_id = %s AND task_status = %s
+        """
+        processing_result = await self.db_client.async_save(
+            processing_query,
+            (TaskStatus.CANCELLING, trace_id, TaskStatus.PROCESSING),
+        )
+
+        if processing_result:
+            # 同进程优化:直接取消协程,不用等 watchdog 轮询
+            await self.task_registry.cancel_task(trace_id)
+            await self._log_task_event(
+                "task_cancelling", trace_id=trace_id
+            )
+
+        return bool(processing_result)
 
     async def retry_task(self, trace_id: Optional[str] = None) -> bool:
         """

+ 28 - 2
app/api/v1/endpoints/tasks.py

@@ -5,7 +5,7 @@ from quart import Blueprint, jsonify
 
 from app.api.service import TaskManager, TaskScheduler
 from app.api.v1.utils import ApiDependencies
-from app.api.v1.utils import RunTaskRequest, TaskListRequest
+from app.api.v1.utils import RunTaskRequest, TaskListRequest, CancelTaskRequest
 from app.api.v1.utils import parse_json, validation_error_response
 from app.infra.shared.tools import generate_task_trace_id
 
@@ -23,7 +23,9 @@ def create_tasks_bp(deps: ApiDependencies) -> Blueprint:
             payload, status = validation_error_response(e)
             return jsonify(payload), status
 
-        scheduler = TaskScheduler(body, deps.log, deps.db, trace_id, deps.config)
+        scheduler = TaskScheduler(
+            body, deps.log, deps.db, trace_id, deps.config, deps.task_registry
+        )
         result = await scheduler.deal()
         return jsonify(result)
 
@@ -39,4 +41,28 @@ def create_tasks_bp(deps: ApiDependencies) -> Blueprint:
         result = await manager.list_tasks()
         return jsonify(result)
 
+    @bp.route("/cancel_task", methods=["POST"])
+    async def cancel_task():
+        try:
+            _, body = await parse_json(CancelTaskRequest)
+        except ValidationError as e:
+            payload, status = validation_error_response(e)
+            return jsonify(payload), status
+
+        trace_id = body["trace_id"]
+        scheduler = TaskScheduler(
+            body, deps.log, deps.db, trace_id, deps.config, deps.task_registry
+        )
+        cancelled = await scheduler.cancel_task(trace_id)
+        return jsonify({
+            "code": 0 if cancelled else 4004,
+            "message": "Task cancelled" if cancelled else "Task not found or already finished",
+            "trace_id": trace_id,
+        })
+
+    @bp.route("/running_tasks", methods=["GET"])
+    async def running_tasks():
+        tasks = await deps.task_registry.get_running_tasks()
+        return jsonify({"code": 0, "data": tasks, "total": len(tasks)})
+
     return bp

+ 6 - 2
app/api/v1/routes/routes.py

@@ -15,6 +15,7 @@ from app.api.v1.endpoints import (
 from app.core.config import GlobalConfigSettings
 from app.core.database import DatabaseManager
 from app.core.observability import LogService
+from app.core.task_registry import TaskRegistry
 
 
 def register_v1_blueprints(deps: ApiDependencies) -> Blueprint:
@@ -41,10 +42,13 @@ def register_v1_blueprints(deps: ApiDependencies) -> Blueprint:
 
 
 def server_routes(
-    pools: DatabaseManager, log_service: LogService, config: GlobalConfigSettings
+    pools: DatabaseManager,
+    log_service: LogService,
+    config: GlobalConfigSettings,
+    task_registry: TaskRegistry,
 ) -> Blueprint:
     """
     兼容旧入口:保留 server_routes 签名,内部转为新的 deps + 统一注册。
     """
-    deps = ApiDependencies(db=pools, log=log_service, config=config)
+    deps = ApiDependencies(db=pools, log=log_service, config=config, task_registry=task_registry)
     return register_v1_blueprints(deps)

+ 2 - 0
app/api/v1/utils/__init__.py

@@ -3,6 +3,7 @@ from .deps import ApiDependencies
 from .schemas import (
     RunTaskRequest,
     TaskListRequest,
+    CancelTaskRequest,
     SaveTokenRequest,
     GetCoverRequest,
     LongArticlesMcpRequest,
@@ -15,6 +16,7 @@ __all__ = [
     "validation_error_response",
     "RunTaskRequest",
     "TaskListRequest",
+    "CancelTaskRequest",
     "SaveTokenRequest",
     "GetCoverRequest",
     "LongArticlesMcpRequest",

+ 2 - 0
app/api/v1/utils/deps.py

@@ -5,6 +5,7 @@ from dataclasses import dataclass
 from app.core.config import GlobalConfigSettings
 from app.core.database import DatabaseManager
 from app.core.observability import LogService
+from app.core.task_registry import TaskRegistry
 
 
 @dataclass(frozen=True)
@@ -14,3 +15,4 @@ class ApiDependencies:
     db: DatabaseManager
     log: LogService
     config: GlobalConfigSettings
+    task_registry: TaskRegistry

+ 4 - 0
app/api/v1/utils/schemas.py

@@ -28,6 +28,10 @@ class TaskListRequest(BaseRequest):
     task_status: Optional[int] = None
 
 
+class CancelTaskRequest(BaseRequest):
+    trace_id: str = Field(..., min_length=1)
+
+
 class GetCoverRequest(BaseRequest):
     """GetCoverService 的请求体字段不固定,先保持兼容。"""
 

+ 5 - 0
app/core/bootstrap/resource_manager.py

@@ -20,6 +20,11 @@ class AppContext:
         logger.info("aliyun log service init successfully")
 
     async def shutdown(self):
+        logger.info("取消所有运行中的任务")
+        registry = self.container.task_registry()
+        cancelled_count = await registry.cancel_all()
+        logger.info(f"已取消 {cancelled_count} 个运行中的任务")
+
         logger.info("关闭数据库连接池")
         mysql = self.container.mysql_manager()
         await mysql.close_pools()

+ 4 - 0
app/core/dependency/dependencies.py

@@ -3,6 +3,7 @@ from dependency_injector import containers, providers
 from app.core.config import GlobalConfigSettings
 from app.core.database import DatabaseManager
 from app.core.observability import LogService
+from app.core.task_registry import TaskRegistry
 
 
 class ServerContainer(containers.DeclarativeContainer):
@@ -15,6 +16,9 @@ class ServerContainer(containers.DeclarativeContainer):
     # MySQL
     mysql_manager = providers.Singleton(DatabaseManager, config=config)
 
+    # 任务注册表
+    task_registry = providers.Singleton(TaskRegistry)
+
 
 __all__ = [
     "ServerContainer",

+ 3 - 0
app/core/task_registry/__init__.py

@@ -0,0 +1,3 @@
+from .registry import TaskRegistry
+
+__all__ = ["TaskRegistry"]

+ 53 - 0
app/core/task_registry/registry.py

@@ -0,0 +1,53 @@
+import asyncio
+import logging
+from typing import Dict, List, Any, Optional
+
+logger = logging.getLogger(__name__)
+
+
+class TaskRegistry:
+    """全局任务注册表 - 管理所有运行中的 asyncio.Task"""
+
+    def __init__(self):
+        self._tasks: Dict[str, asyncio.Task] = {}
+        self._lock = asyncio.Lock()
+
+    async def register(self, trace_id: str, task: asyncio.Task) -> None:
+        async with self._lock:
+            self._tasks[trace_id] = task
+
+    async def unregister(self, trace_id: str) -> None:
+        async with self._lock:
+            self._tasks.pop(trace_id, None)
+
+    async def cancel_task(self, trace_id: str) -> bool:
+        async with self._lock:
+            task = self._tasks.get(trace_id)
+            if task and not task.done():
+                task.cancel()
+                return True
+            return False
+
+    async def cancel_all(self) -> int:
+        async with self._lock:
+            count = 0
+            for task in self._tasks.values():
+                if not task.done():
+                    task.cancel()
+                    count += 1
+            return count
+
+    async def get_running_tasks(self) -> List[Dict[str, Any]]:
+        async with self._lock:
+            return [
+                {
+                    "trace_id": trace_id,
+                    "task_name": task.get_name(),
+                    "done": task.done(),
+                    "cancelled": task.cancelled(),
+                }
+                for trace_id, task in self._tasks.items()
+            ]
+
+
+__all__ = ["TaskRegistry"]

+ 3 - 0
app/jobs/task_config.py

@@ -10,6 +10,7 @@ class TaskConfig:
     retry_times: int = 0  # 重试次数
     retryable: bool = True  # 是否可重试
     alert_on_failure: bool = True  # 失败时是否告警
+    cancel_check_interval: int = 5  # 取消检测轮询间隔(秒)
 
 
 class TaskStatus:
@@ -18,6 +19,8 @@ class TaskStatus:
     INIT = 0
     PROCESSING = 1
     SUCCESS = 2
+    CANCELLING = 3  # 待取消(用户请求取消)
+    CANCELLED = 4  # 已取消(协程已停止)
     FAILED = 99
 
 

+ 2 - 1
task_app.py

@@ -17,9 +17,10 @@ ctx = AppContext(server_container)
 config = server_container.config()
 log_service = server_container.log_service()
 mysql_manager = server_container.mysql_manager()
+task_registry = server_container.task_registry()
 
 
-routes = server_routes(mysql_manager, log_service, config)
+routes = server_routes(mysql_manager, log_service, config, task_registry)
 app.register_blueprint(routes)