async_tasks.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import asyncio
  2. from typing import Callable, Coroutine, List, Any, Dict
  3. from tqdm.asyncio import tqdm
  4. async def run_tasks_with_async_worker_group(
  5. task_list: List[Any],
  6. handler: Callable[[Any], Coroutine[Any, Any, None]],
  7. *,
  8. description: str = None,
  9. unit: str,
  10. max_concurrency: int = 20,
  11. fail_fast: bool = False,
  12. ) -> Dict[str, Any]:
  13. """using async worker pool to process I/O-intensive tasks"""
  14. if not task_list:
  15. return {"total_task": 0, "processed_task": 0, "errors": []}
  16. total_task = len(task_list)
  17. processed_task = 0
  18. errors: List[tuple[int, Any, Exception]] = []
  19. queue: asyncio.Queue[tuple[int, Any] | None] = asyncio.Queue()
  20. processing_bar = tqdm(total=total_task, unit=unit, desc=description)
  21. cancel_event = asyncio.Event()
  22. counter_lock = asyncio.Lock()
  23. async def worker(worker_id: int):
  24. nonlocal processed_task
  25. while True:
  26. item = await queue.get()
  27. if item is None:
  28. queue.task_done()
  29. break
  30. idx, task_obj = item
  31. try:
  32. if cancel_event.is_set():
  33. return
  34. await handler(task_obj)
  35. async with counter_lock:
  36. processed_task += 1
  37. except Exception as e:
  38. if fail_fast:
  39. cancel_event.set()
  40. raise
  41. errors.append((idx, task_obj, e))
  42. finally:
  43. processing_bar.update()
  44. queue.task_done()
  45. workers = [asyncio.create_task(worker(i)) for i in range(max_concurrency)]
  46. try:
  47. for index, task in enumerate(task_list, start=1):
  48. await queue.put((index, task))
  49. await queue.join()
  50. except Exception:
  51. # fail_fast=True 时,worker 抛异常会走到这里
  52. for w in workers:
  53. w.cancel()
  54. raise
  55. finally:
  56. for _ in range(max_concurrency):
  57. await queue.put(None)
  58. await asyncio.gather(*workers, return_exceptions=True)
  59. processing_bar.close()
  60. return {
  61. "total_task": total_task,
  62. "processed_task": processed_task,
  63. "errors": errors,
  64. }
  65. # 使用asyncio.TaskGroup 来高效处理I/O密集型任务
  66. async def run_tasks_with_asyncio_task_group(
  67. task_list: List[Any],
  68. handler: Callable[[Any], Coroutine[Any, Any, None]],
  69. *,
  70. description: str = None, # 任务介绍
  71. unit: str,
  72. max_concurrency: int = 20, # 最大并发数
  73. fail_fast: bool = False, # 是否遇到错误就退出整个tasks
  74. ) -> Dict[str, Any]:
  75. """using asyncio.TaskGroup to process I/O-intensive tasks"""
  76. if not task_list:
  77. return {"total_task": 0, "processed_task": 0, "errors": []}
  78. processed_task = 0
  79. total_task = len(task_list)
  80. errors: List[tuple[int, Any, Exception]] = []
  81. semaphore = asyncio.Semaphore(max_concurrency)
  82. processing_bar = tqdm(total=total_task, unit=unit, desc=description)
  83. async def _run_single_task(task_obj: Any, idx: int):
  84. nonlocal processed_task
  85. async with semaphore:
  86. try:
  87. await handler(task_obj)
  88. processed_task += 1
  89. except Exception as e:
  90. if fail_fast:
  91. raise e
  92. errors.append((idx, task_obj, e))
  93. finally:
  94. processing_bar.update()
  95. async with asyncio.TaskGroup() as task_group:
  96. for index, task in enumerate(task_list, start=1):
  97. task_group.create_task(
  98. _run_single_task(task, index), name=f"processing {description}-{index}"
  99. )
  100. processing_bar.close()
  101. return {
  102. "total_task": total_task,
  103. "processed_task": processed_task,
  104. "errors": errors,
  105. }