123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- # task_manager.py
- import asyncio
- from datetime import datetime
- from typing import Optional
- class TaskManager:
- """跨请求共享的任务池:trace_id> Task + 元数据"""
- def __init__(self):
- self._tasks: dict[str, asyncio.Task] = {}
- self._meta: dict[str, dict] = {}
- def register(self, task_id: str, task: asyncio.Task, task_name: str):
- self._tasks[task_id] = task
- self._meta[task_id] = {
- "name": task_name,
- "started": datetime.now().isoformat(timespec="seconds"),
- }
- def list(self):
- """返回所有活跃任务及其状态"""
- return [
- {
- "task_id": tid,
- "name": self._meta[tid]["name"],
- "started": self._meta[tid]["started"],
- "done": t.done(),
- "cancelled": t.cancelled(),
- }
- for tid, t in self._tasks.items()
- ]
- async def cancel(self, task_id: str) -> bool:
- t: Optional[asyncio.Task] = self._tasks.get(task_id)
- if not t: # 已经结束或 id 写错
- return True
- if t.done():
- self._cleanup(task_id)
- return True
- t.cancel()
- try:
- await t # 等协程跑完 finally
- except asyncio.CancelledError:
- pass
- finally:
- self._cleanup(task_id)
- return True
- def _cleanup(self, task_id: str):
- self._tasks.pop(task_id, None)
- self._meta.pop(task_id, None)
|