fetch_decode_results.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import json
  2. from typing import List, Dict
  3. from app.core.database import DatabaseManager
  4. from app.core.observability import LogService
  5. from app.infra.shared import run_tasks_with_asyncio_task_group
  6. from ._const import DecodeArticleConst
  7. from ._mapper import ArticlesDecodeTaskMapper
  8. from ._utils import AigcDecodeUtils
  9. class FetchDecodeResults(DecodeArticleConst):
  10. def __init__(self, pool: DatabaseManager, log_service: LogService):
  11. self.pool = pool
  12. self.log_service = log_service
  13. self.mapper = ArticlesDecodeTaskMapper(self.pool)
  14. self.tool = AigcDecodeUtils()
  15. async def _process_batch(self, tasks: List[Dict]):
  16. source_ids = [t["source_id"] for t in tasks]
  17. results = await self.tool.query_decode_results_batch(source_ids)
  18. for task in tasks:
  19. source_id = task["source_id"]
  20. result = results.get(source_id)
  21. if not result:
  22. await self.mapper.update_task_status_by_source_id(
  23. source_id=source_id,
  24. ori_status=self.TaskStatus.INIT,
  25. new_status=self.TaskStatus.FAILED,
  26. remark="解构任务在结果查询中未返回,可能不存在",
  27. )
  28. await self.log_service.log(
  29. contents={
  30. "task": "fetch_decode_results_v2",
  31. "source_id": source_id,
  32. "status": "fail",
  33. "message": "source_id not in query response",
  34. }
  35. )
  36. continue
  37. status = result.get("status")
  38. if status == "API_ERROR":
  39. # 查询 API 调用失败,保持 INIT 等待重试
  40. continue
  41. elif status == self.QueryStatus.SUCCESS:
  42. data_content = result.get("dataContent") or "{}"
  43. html = result.get("html")
  44. await self.mapper.set_decode_result(
  45. source_id=source_id,
  46. result=json.dumps(
  47. {"dataContent": data_content, "html": html},
  48. ensure_ascii=False,
  49. ),
  50. remark="解构结果获取成功",
  51. )
  52. elif status in (self.QueryStatus.PENDING, self.QueryStatus.RUNNING):
  53. pass
  54. elif status == self.QueryStatus.FAILED:
  55. await self.mapper.update_task_status_by_source_id(
  56. source_id=source_id,
  57. ori_status=self.TaskStatus.INIT,
  58. new_status=self.TaskStatus.FAILED,
  59. remark=f"解构任务失败: {result.get('errorMessage', '')}",
  60. )
  61. else:
  62. await self.log_service.log(
  63. contents={
  64. "task": "fetch_decode_results_v2",
  65. "source_id": source_id,
  66. "status": "unknown",
  67. "message": f"unexpected query status: {status}",
  68. "data": result,
  69. }
  70. )
  71. async def deal(self):
  72. pending_tasks = await self.mapper.fetch_pending_tasks()
  73. if not pending_tasks:
  74. await self.log_service.log(
  75. contents={
  76. "task": "fetch_decode_results_v2",
  77. "message": "No more tasks to fetch",
  78. }
  79. )
  80. return
  81. # 拆成多个批次,并发查询
  82. batches = [
  83. pending_tasks[i : i + self.SUBMIT_BATCH]
  84. for i in range(0, len(pending_tasks), self.SUBMIT_BATCH)
  85. ]
  86. await run_tasks_with_asyncio_task_group(
  87. task_list=batches,
  88. handler=self._process_batch,
  89. description="批量查询解构结果",
  90. unit="batch",
  91. )
  92. await self.log_service.log(
  93. contents={
  94. "task": "fetch_decode_results_v2",
  95. "message": f"Processed {len(pending_tasks)} pending tasks in {len(batches)} batches",
  96. }
  97. )
  98. __all__ = ["FetchDecodeResults"]