import json from collections import defaultdict from typing import List, Dict from app.core.database import DatabaseManager from app.core.observability import LogService from app.infra.shared import run_tasks_with_asyncio_task_group from ._const import DecodeCardConst from ._mapper import CardDecodeTaskMapper from ._utils import CardDecodeUtils class FetchCardDecodeResults(DecodeCardConst): def __init__(self, pool: DatabaseManager, log_service: LogService): self.pool = pool self.log_service = log_service self.mapper = CardDecodeTaskMapper(self.pool) self.tool = CardDecodeUtils() @staticmethod def _group_tasks_by_config(tasks: List[Dict]) -> Dict[int, List[Dict]]: grouped = defaultdict(list) for task in tasks: grouped[task["config_id"]].append(task) return dict(grouped) async def _process_batch(self, tasks: List[Dict], config_id: int): source_ids = [t["source_id"] for t in tasks] results = await self.tool.query_decode_results_batch( source_ids, config_id=config_id ) for task in tasks: source_id = task["source_id"] result = results.get(source_id) if not result: await self.mapper.update_task_status_by_source_id( source_id=source_id, config_id=config_id, new_status=self.TaskStatus.FAILED, remark="卡片解构任务在结果查询中未返回", ) await self.log_service.log( contents={ "task": "fetch_card_decode_results", "source_id": source_id, "config_id": config_id, "status": "fail", "message": "source_id not in query response", } ) continue status = result.get("status") if status == "API_ERROR": continue elif status == self.QueryStatus.SUCCESS: data_content = result.get("dataContent") or "{}" html = result.get("html") await self.mapper.set_decode_result( source_id=source_id, config_id=config_id, result=json.dumps( {"dataContent": data_content, "html": html}, ensure_ascii=False, ), remark="卡片解构结果获取成功", ) elif status in (self.QueryStatus.PENDING, self.QueryStatus.RUNNING): pass elif status == self.QueryStatus.FAILED: await self.mapper.update_task_status_by_source_id( source_id=source_id, config_id=config_id, new_status=self.TaskStatus.FAILED, remark=f"卡片解构任务失败: {result.get('errorMessage', '')}", ) else: await self.log_service.log( contents={ "task": "fetch_card_decode_results", "source_id": source_id, "config_id": config_id, "status": "unknown", "message": f"unexpected query status: {status}", "data": result, } ) async def deal(self): pending_tasks = await self.mapper.fetch_pending_tasks() if not pending_tasks: await self.log_service.log( contents={ "task": "fetch_card_decode_results", "message": "No more card tasks to fetch", } ) return grouped = self._group_tasks_by_config(pending_tasks) for config_id, tasks in grouped.items(): batches = [ tasks[i : i + self.SUBMIT_BATCH] for i in range(0, len(tasks), self.SUBMIT_BATCH) ] await run_tasks_with_asyncio_task_group( task_list=[ {"batch": b, "config_id": config_id} for b in batches ], handler=lambda item, cid=config_id: self._process_batch( item["batch"], cid ), description="批量查询卡片解构结果", unit="batch", ) await self.log_service.log( contents={ "task": "fetch_card_decode_results", "message": f"Processed {len(pending_tasks)} pending card tasks across {len(grouped)} configs", } ) __all__ = ["FetchCardDecodeResults"]