|
|
@@ -4,6 +4,85 @@ from typing import Callable, Coroutine, List, Any, Dict
|
|
|
from tqdm.asyncio import tqdm
|
|
|
|
|
|
|
|
|
+async def run_tasks_with_async_worker_group(
|
|
|
+ task_list: List[Any],
|
|
|
+ handler: Callable[[Any], Coroutine[Any, Any, None]],
|
|
|
+ *,
|
|
|
+ description: str = None,
|
|
|
+ unit: str,
|
|
|
+ max_concurrency: int = 20,
|
|
|
+ fail_fast: bool = False,
|
|
|
+) -> Dict[str, Any]:
|
|
|
+ """using async worker pool to process I/O-intensive tasks"""
|
|
|
+
|
|
|
+ if not task_list:
|
|
|
+ return {"total_task": 0, "processed_task": 0, "errors": []}
|
|
|
+
|
|
|
+ total_task = len(task_list)
|
|
|
+ processed_task = 0
|
|
|
+ errors: List[tuple[int, Any, Exception]] = []
|
|
|
+
|
|
|
+ queue: asyncio.Queue[tuple[int, Any] | None] = asyncio.Queue()
|
|
|
+ processing_bar = tqdm(total=total_task, unit=unit, desc=description)
|
|
|
+
|
|
|
+ cancel_event = asyncio.Event()
|
|
|
+ counter_lock = asyncio.Lock()
|
|
|
+
|
|
|
+ async def worker(worker_id: int):
|
|
|
+ nonlocal processed_task
|
|
|
+ while True:
|
|
|
+ item = await queue.get()
|
|
|
+ if item is None:
|
|
|
+ queue.task_done()
|
|
|
+ break
|
|
|
+
|
|
|
+ idx, task_obj = item
|
|
|
+ try:
|
|
|
+ if cancel_event.is_set():
|
|
|
+ return
|
|
|
+
|
|
|
+ await handler(task_obj)
|
|
|
+ async with counter_lock:
|
|
|
+ processed_task += 1
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ if fail_fast:
|
|
|
+ cancel_event.set()
|
|
|
+ raise
|
|
|
+ errors.append((idx, task_obj, e))
|
|
|
+
|
|
|
+ finally:
|
|
|
+ processing_bar.update()
|
|
|
+ queue.task_done()
|
|
|
+
|
|
|
+ workers = [asyncio.create_task(worker(i)) for i in range(max_concurrency)]
|
|
|
+
|
|
|
+ try:
|
|
|
+ for index, task in enumerate(task_list, start=1):
|
|
|
+ await queue.put((index, task))
|
|
|
+
|
|
|
+ await queue.join()
|
|
|
+
|
|
|
+ except Exception:
|
|
|
+ # fail_fast=True 时,worker 抛异常会走到这里
|
|
|
+ for w in workers:
|
|
|
+ w.cancel()
|
|
|
+ raise
|
|
|
+
|
|
|
+ finally:
|
|
|
+ for _ in range(max_concurrency):
|
|
|
+ await queue.put(None)
|
|
|
+
|
|
|
+ await asyncio.gather(*workers, return_exceptions=True)
|
|
|
+ processing_bar.close()
|
|
|
+
|
|
|
+ return {
|
|
|
+ "total_task": total_task,
|
|
|
+ "processed_task": processed_task,
|
|
|
+ "errors": errors,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
# 使用asyncio.TaskGroup 来高效处理I/O密集型任务
|
|
|
async def run_tasks_with_asyncio_task_group(
|
|
|
task_list: List[Any],
|