luojunhui 1 месяц назад
Родитель
Сommit
f5cd3b6a3d

+ 3 - 3
app/infra/mapper/aigc_mapper.py

@@ -24,9 +24,9 @@ class AigcDatabaseMapper:
         """
         COVER_TYPE = 2
         query = """
-            select image_url, oss_object_key
-            from crawler_content_image
-            where channel_content_id = %s and image_type = %s;
+            SELECT image_url, oss_object_key
+            FROM crawler_content_image
+            WHERE channel_content_id = %s AND image_type = %s;
         """
         return await pool.async_fetch(
             query=query, db_name="aigc", params=(channel_content_id, COVER_TYPE)

+ 1 - 1
app/infra/mapper/long_video_mapper.py

@@ -13,7 +13,7 @@ class LongVideoDatabaseMapper:
         use channel_content_id to find long video cover
         """
         query = """
-            select image_path from video_cover_snapshots where video_id = %s;
+            SELECT image_path FROM video_cover_snapshots WHERE video_id = %s;
         """
         return await pool.async_fetch(
             query=query, db_name="long_video", params=(video_id,)

+ 1 - 0
app/infra/shared/__init__.py

@@ -1,4 +1,5 @@
 from .async_tasks import run_tasks_with_asyncio_task_group
+from .async_tasks import run_tasks_with_async_worker_group
 from .http_client import AsyncHttpClient
 
 # server response

+ 79 - 0
app/infra/shared/async_tasks.py

@@ -4,6 +4,85 @@ from typing import Callable, Coroutine, List, Any, Dict
 from tqdm.asyncio import tqdm
 
 
+async def run_tasks_with_async_worker_group(
+    task_list: List[Any],
+    handler: Callable[[Any], Coroutine[Any, Any, None]],
+    *,
+    description: str = None,
+    unit: str,
+    max_concurrency: int = 20,
+    fail_fast: bool = False,
+) -> Dict[str, Any]:
+    """using async worker pool to process I/O-intensive tasks"""
+
+    if not task_list:
+        return {"total_task": 0, "processed_task": 0, "errors": []}
+
+    total_task = len(task_list)
+    processed_task = 0
+    errors: List[tuple[int, Any, Exception]] = []
+
+    queue: asyncio.Queue[tuple[int, Any] | None] = asyncio.Queue()
+    processing_bar = tqdm(total=total_task, unit=unit, desc=description)
+
+    cancel_event = asyncio.Event()
+    counter_lock = asyncio.Lock()
+
+    async def worker(worker_id: int):
+        nonlocal processed_task
+        while True:
+            item = await queue.get()
+            if item is None:
+                queue.task_done()
+                break
+
+            idx, task_obj = item
+            try:
+                if cancel_event.is_set():
+                    return
+
+                await handler(task_obj)
+                async with counter_lock:
+                    processed_task += 1
+
+            except Exception as e:
+                if fail_fast:
+                    cancel_event.set()
+                    raise
+                errors.append((idx, task_obj, e))
+
+            finally:
+                processing_bar.update()
+                queue.task_done()
+
+    workers = [asyncio.create_task(worker(i)) for i in range(max_concurrency)]
+
+    try:
+        for index, task in enumerate(task_list, start=1):
+            await queue.put((index, task))
+
+        await queue.join()
+
+    except Exception:
+        # fail_fast=True 时,worker 抛异常会走到这里
+        for w in workers:
+            w.cancel()
+        raise
+
+    finally:
+        for _ in range(max_concurrency):
+            await queue.put(None)
+
+        await asyncio.gather(*workers, return_exceptions=True)
+        processing_bar.close()
+
+    return {
+        "total_task": total_task,
+        "processed_task": processed_task,
+        "errors": errors,
+    }
+
+
 # 使用asyncio.TaskGroup 来高效处理I/O密集型任务
 async def run_tasks_with_asyncio_task_group(
     task_list: List[Any],