123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- import asyncio
- from typing import Callable, Coroutine, List, Any, Dict
- from tqdm.asyncio import tqdm
- # 使用asyncio.TaskGroup 来高效处理I/O密集型任务
- async def run_tasks_with_asyncio_task_group(
- task_list: List[Any],
- handler: Callable[[Any], Coroutine[Any, Any, None]],
- *,
- description: str = None, # 任务介绍
- unit: str,
- max_concurrency: int = 20, # 最大并发数
- fail_fast: bool = False, # 是否遇到错误就退出整个tasks
- ) -> Dict[str, Any]:
- """using asyncio.TaskGroup to process I/O-intensive tasks"""
- if not task_list:
- return {"total_task": 0, "processed_task": 0, "errors": []}
- processed_task = 0
- total_task = len(task_list)
- errors: List[tuple[int, Any, Exception]] = []
- semaphore = asyncio.Semaphore(max_concurrency)
- processing_bar = tqdm(total=total_task, unit=unit, desc=description)
- async def _run_single_task(task_obj: Any, idx: int):
- nonlocal processed_task
- async with semaphore:
- try:
- await handler(task_obj)
- processed_task += 1
- except Exception as e:
- if fail_fast:
- raise e
- errors.append((idx, task_obj, e))
- finally:
- processing_bar.update()
- async with asyncio.TaskGroup() as task_group:
- for index, task in enumerate(task_list, start=1):
- task_group.create_task(
- _run_single_task(task, index), name=f"processing {description}-{index}"
- )
- processing_bar.close()
- return {
- "total_task": total_task,
- "processed_task": processed_task,
- "errors": errors,
- }
|