task_manager.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # task_manager.py
  2. import asyncio
  3. from datetime import datetime
  4. from typing import Optional
  5. class TaskManager:
  6. """跨请求共享的任务池:trace_id> Task + 元数据"""
  7. def __init__(self):
  8. self._tasks: dict[str, asyncio.Task] = {}
  9. self._meta: dict[str, dict] = {}
  10. def register(self, task_id: str, task: asyncio.Task, task_name: str):
  11. self._tasks[task_id] = task
  12. self._meta[task_id] = {
  13. "name": task_name,
  14. "started": datetime.now().isoformat(timespec="seconds"),
  15. }
  16. def list(self):
  17. """返回所有活跃任务及其状态"""
  18. return [
  19. {
  20. "task_id": tid,
  21. "name": self._meta[tid]["name"],
  22. "started": self._meta[tid]["started"],
  23. "done": t.done(),
  24. "cancelled": t.cancelled(),
  25. }
  26. for tid, t in self._tasks.items()
  27. ]
  28. async def cancel(self, task_id: str) -> bool:
  29. t: Optional[asyncio.Task] = self._tasks.get(task_id)
  30. if not t: # 已经结束或 id 写错
  31. return True
  32. if t.done():
  33. self._cleanup(task_id)
  34. return True
  35. t.cancel()
  36. try:
  37. await t # 等协程跑完 finally
  38. except asyncio.CancelledError:
  39. pass
  40. finally:
  41. self._cleanup(task_id)
  42. return True
  43. def _cleanup(self, task_id: str):
  44. self._tasks.pop(task_id, None)
  45. self._meta.pop(task_id, None)