_utils.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import json
  2. from typing import Dict, List
  3. from app.infra.internal.aigc_decode_server import AigcDecodeServer
  4. from ._const import DecodeArticleConst
  5. class AigcDecodeUtils(DecodeArticleConst):
  6. decode_server = AigcDecodeServer()
  7. async def submit_decode_batch(
  8. self, posts: List[Dict], *, config_id: int = None, skip_completed: bool = False
  9. ) -> Dict[str, Dict]:
  10. """分批提交解构任务,返回 {content_id: {status, errorMessage}}"""
  11. cfg_id = config_id or self.CONFIG_ID
  12. result = {}
  13. for i in range(0, len(posts), self.SUBMIT_BATCH):
  14. batch = posts[i : i + self.SUBMIT_BATCH]
  15. response = await self.decode_server.submit_decode(
  16. config_id=cfg_id, posts=batch, skip_completed=skip_completed
  17. )
  18. if response.get("code") == 0:
  19. for item in response.get("data", []):
  20. result[item["channelContentId"]] = item
  21. else:
  22. # 整批失败,标记所有帖子为 FAILED
  23. for post in batch:
  24. cid = post["channelContentId"]
  25. result[cid] = {
  26. "channelContentId": cid,
  27. "status": "FAILED",
  28. "errorMessage": f"batch submit failed: {response}",
  29. }
  30. return result
  31. async def query_decode_results_batch(
  32. self, content_ids: List[str], *, config_id: int = None
  33. ) -> Dict[str, Dict]:
  34. """分批查询解构结果,返回 {content_id: {status, dataContent, html, errorMessage}}
  35. 当 API 调用失败时,对应条目 status 为 API_ERROR,调用方应保持 INIT 等待重试。
  36. """
  37. cfg_id = config_id or self.CONFIG_ID
  38. result = {}
  39. for i in range(0, len(content_ids), self.SUBMIT_BATCH):
  40. batch = content_ids[i : i + self.SUBMIT_BATCH]
  41. response = await self.decode_server.query_decode_results(
  42. config_id=cfg_id, channel_content_ids=batch
  43. )
  44. if response.get("code") == 0:
  45. for item in response.get("data", []):
  46. result[item["channelContentId"]] = item
  47. else:
  48. for cid in batch:
  49. result[cid] = {
  50. "channelContentId": cid,
  51. "status": "API_ERROR",
  52. "errorMessage": f"query API failed: {response}",
  53. }
  54. return result
  55. @staticmethod
  56. def extract_decode_result(result: Dict) -> Dict:
  57. """从解构结果中解析出灵感点、目的点、关键点、选题
  58. 兼容新旧两种数据格式:v1 有 final_normalization_rebuild 包裹层,v2 无
  59. """
  60. final_result = result.get("final_normalization_rebuild") or result
  61. inspiration_list = final_result.get("inspiration_final_result", {}).get(
  62. "最终灵感点列表", []
  63. )
  64. purpose_list = final_result.get("purpose_final_result", {}).get(
  65. "最终目的点列表", []
  66. )
  67. keypoint_list = final_result.get("keypoint_final", {}).get("最终关键点列表", [])
  68. topic_fusion = final_result.get("topic_fusion_result", {})
  69. topic_text = (
  70. topic_fusion.get("最终选题", {}).get("选题", "")
  71. if isinstance(topic_fusion.get("最终选题"), dict)
  72. else ""
  73. )
  74. def _join_points(items: list, key: str) -> str:
  75. parts = [str(p[key]) for p in items if isinstance(p, dict) and p.get(key)]
  76. return ",".join(parts)
  77. return {
  78. "inspiration": _join_points(inspiration_list, "灵感点"),
  79. "purpose": _join_points(purpose_list, "目的点"),
  80. "key_point": _join_points(keypoint_list, "关键点"),
  81. "topic": topic_text,
  82. }
  83. class AdPlatformArticlesDecodeUtils(AigcDecodeUtils):
  84. @staticmethod
  85. def format_images(images: str) -> List[str]:
  86. if not images or not images.strip():
  87. return []
  88. try:
  89. image_list = json.loads(images)
  90. except (json.JSONDecodeError, TypeError):
  91. return []
  92. if not isinstance(image_list, list):
  93. return []
  94. return [
  95. i.get("image_url")
  96. for i in image_list
  97. if isinstance(i, dict) and i.get("image_url")
  98. ]
  99. def prepare_posts(self, articles: List[Dict]) -> List[Dict]:
  100. posts = []
  101. for article in articles:
  102. images = self.format_images(article.get("article_images") or "")
  103. posts.append(
  104. {
  105. "channelContentId": article["wx_sn"],
  106. "title": article.get("article_title", ""),
  107. "bodyText": article.get("article_text", ""),
  108. "images": images,
  109. "video": None,
  110. "contentModal": self.ContentModal.LONG_ARTICLE,
  111. "channel": self.Channel.WECHAT,
  112. }
  113. )
  114. return posts
  115. class InnerArticlesDecodeUtils(AigcDecodeUtils):
  116. def prepare_posts(
  117. self, articles: List[Dict], produce_info_map: Dict[str, List[Dict]]
  118. ) -> List[Dict]:
  119. posts = []
  120. for article in articles:
  121. source_id = str(article["source_id"])
  122. produce_info = produce_info_map.get(source_id, [])
  123. # 收集图片:封面(coverimgurl) + produce COVER + produce IMAGE
  124. images = []
  125. if article.get("coverimgurl"):
  126. images.append(article["coverimgurl"])
  127. for pi in produce_info:
  128. if pi["produce_module_type"] == self.ProduceModuleType.COVER:
  129. images.append(pi["output"])
  130. for pi in produce_info:
  131. if pi["produce_module_type"] == self.ProduceModuleType.IMAGE:
  132. images.append(pi["output"])
  133. posts.append(
  134. {
  135. "title": article.get("title", ""),
  136. "bodyText": article.get("article_text", ""),
  137. "images": images,
  138. "video": None,
  139. "contentModal": self.ContentModal.LONG_ARTICLE,
  140. "channel": self.Channel.WECHAT,
  141. "channelContentId": source_id,
  142. }
  143. )
  144. return posts
  145. __all__ = [
  146. "AigcDecodeUtils",
  147. "AdPlatformArticlesDecodeUtils",
  148. "InnerArticlesDecodeUtils",
  149. ]