registry.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import asyncio
  2. import logging
  3. from typing import Dict, List, Any, Optional
  4. logger = logging.getLogger(__name__)
  5. class TaskRegistry:
  6. """全局任务注册表 - 管理所有运行中的 asyncio.Task"""
  7. def __init__(self):
  8. self._tasks: Dict[str, asyncio.Task] = {}
  9. self._lock = asyncio.Lock()
  10. async def register(self, trace_id: str, task: asyncio.Task) -> None:
  11. async with self._lock:
  12. self._tasks[trace_id] = task
  13. async def unregister(self, trace_id: str) -> None:
  14. async with self._lock:
  15. self._tasks.pop(trace_id, None)
  16. async def cancel_task(self, trace_id: str) -> bool:
  17. async with self._lock:
  18. task = self._tasks.get(trace_id)
  19. if task and not task.done():
  20. task.cancel()
  21. return True
  22. return False
  23. async def cancel_all(self) -> int:
  24. async with self._lock:
  25. count = 0
  26. for task in self._tasks.values():
  27. if not task.done():
  28. task.cancel()
  29. count += 1
  30. return count
  31. async def get_running_tasks(self) -> List[Dict[str, Any]]:
  32. async with self._lock:
  33. return [
  34. {
  35. "trace_id": trace_id,
  36. "task_name": task.get_name(),
  37. "done": task.done(),
  38. "cancelled": task.cancelled(),
  39. }
  40. for trace_id, task in self._tasks.items()
  41. ]
  42. __all__ = ["TaskRegistry"]