async_tasks.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import asyncio
  2. from typing import Callable, Coroutine, List, Any, Dict
  3. from tqdm.asyncio import tqdm
  4. # 使用asyncio.TaskGroup 来高效处理I/O密集型任务
  5. async def run_tasks_with_asyncio_task_group(
  6. task_list: List[Any],
  7. handler: Callable[[Any], Coroutine[Any, Any, None]],
  8. *,
  9. description: str = None, # 任务介绍
  10. unit: str,
  11. max_concurrency: int = 20, # 最大并发数
  12. fail_fast: bool = False, # 是否遇到错误就退出整个tasks
  13. ) -> Dict[str, Any]:
  14. """using asyncio.TaskGroup to process I/O-intensive tasks"""
  15. if not task_list:
  16. return {"total_task": 0, "processed_task": 0, "errors": []}
  17. processed_task = 0
  18. total_task = len(task_list)
  19. errors: List[tuple[int, Any, Exception]] = []
  20. semaphore = asyncio.Semaphore(max_concurrency)
  21. processing_bar = tqdm(total=total_task, unit=unit, desc=description)
  22. async def _run_single_task(task_obj: Any, idx: int):
  23. nonlocal processed_task
  24. async with semaphore:
  25. try:
  26. await handler(task_obj)
  27. processed_task += 1
  28. except Exception as e:
  29. if fail_fast:
  30. raise e
  31. errors.append((idx, task_obj, e))
  32. finally:
  33. processing_bar.update()
  34. async with asyncio.TaskGroup() as task_group:
  35. for index, task in enumerate(task_list, start=1):
  36. task_group.create_task(
  37. _run_single_task(task, index), name=f"processing {description}-{index}"
  38. )
  39. processing_bar.close()
  40. return {
  41. "total_task": total_task,
  42. "processed_task": processed_task,
  43. "errors": errors,
  44. }