| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- import asyncio
- import logging
- from typing import Dict, List, Any, Optional
- logger = logging.getLogger(__name__)
- class TaskRegistry:
- """全局任务注册表 - 管理所有运行中的 asyncio.Task"""
- def __init__(self):
- self._tasks: Dict[str, asyncio.Task] = {}
- self._lock = asyncio.Lock()
- async def register(self, trace_id: str, task: asyncio.Task) -> None:
- async with self._lock:
- self._tasks[trace_id] = task
- async def unregister(self, trace_id: str) -> None:
- async with self._lock:
- self._tasks.pop(trace_id, None)
- async def cancel_task(self, trace_id: str) -> bool:
- async with self._lock:
- task = self._tasks.get(trace_id)
- if task and not task.done():
- task.cancel()
- return True
- return False
- async def cancel_all(self) -> int:
- async with self._lock:
- count = 0
- for task in self._tasks.values():
- if not task.done():
- task.cancel()
- count += 1
- return count
- async def get_running_tasks(self) -> List[Dict[str, Any]]:
- async with self._lock:
- return [
- {
- "trace_id": trace_id,
- "task_name": task.get_name(),
- "done": task.done(),
- "cancelled": task.cancelled(),
- }
- for trace_id, task in self._tasks.items()
- ]
- __all__ = ["TaskRegistry"]
|