import asyncio import logging from typing import Dict, List, Any, Optional logger = logging.getLogger(__name__) class TaskRegistry: """全局任务注册表 - 管理所有运行中的 asyncio.Task""" def __init__(self): self._tasks: Dict[str, asyncio.Task] = {} self._lock = asyncio.Lock() async def register(self, trace_id: str, task: asyncio.Task) -> None: async with self._lock: self._tasks[trace_id] = task async def unregister(self, trace_id: str) -> None: async with self._lock: self._tasks.pop(trace_id, None) async def cancel_task(self, trace_id: str) -> bool: async with self._lock: task = self._tasks.get(trace_id) if task and not task.done(): task.cancel() return True return False async def cancel_all(self) -> int: async with self._lock: count = 0 for task in self._tasks.values(): if not task.done(): task.cancel() count += 1 return count async def get_running_tasks(self) -> List[Dict[str, Any]]: async with self._lock: return [ { "trace_id": trace_id, "task_name": task.get_name(), "done": task.done(), "cancelled": task.cancelled(), } for trace_id, task in self._tasks.items() ] __all__ = ["TaskRegistry"]