_utils.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import json
  2. from typing import Dict, List
  3. from app.infra.internal.aigc_decode_server import AigcDecodeServer
  4. from ._const import DecodeArticleConst
  5. class AigcDecodeUtils(DecodeArticleConst):
  6. decode_server = AigcDecodeServer()
  7. async def submit_decode_batch(
  8. self, posts: List[Dict], *, config_id: int = None, skip_completed: bool = False
  9. ) -> Dict[str, Dict]:
  10. """分批提交解构任务,返回 {channel_content_id: {status, errorMessage}}"""
  11. cfg_id = config_id or self.CONFIG_ID
  12. result = {}
  13. for i in range(0, len(posts), self.SUBMIT_BATCH):
  14. batch = posts[i : i + self.SUBMIT_BATCH]
  15. response = await self.decode_server.submit_decode(
  16. config_id=cfg_id, posts=batch, skip_completed=skip_completed
  17. )
  18. if response.get("code") == 0:
  19. for item in response.get("data", []):
  20. result[item["channelContentId"]] = item
  21. else:
  22. # 整批失败,标记所有帖子为 FAILED
  23. for post in batch:
  24. cid = post["channelContentId"]
  25. result[cid] = {
  26. "channelContentId": cid,
  27. "status": "FAILED",
  28. "errorMessage": f"batch submit failed: {response}",
  29. }
  30. return result
  31. async def query_decode_results_batch(
  32. self, channel_content_ids: List[str], *, config_id: int = None
  33. ) -> Dict[str, Dict]:
  34. """分批查询解构结果,返回 {channel_content_id: {status, dataContent, html, errorMessage}}
  35. 当 API 调用失败时,对应条目 status 为 API_ERROR,调用方应保持 INIT 等待重试。
  36. """
  37. cfg_id = config_id or self.CONFIG_ID
  38. result = {}
  39. for i in range(0, len(channel_content_ids), self.SUBMIT_BATCH):
  40. batch = channel_content_ids[i : i + self.SUBMIT_BATCH]
  41. response = await self.decode_server.query_decode_results(
  42. config_id=cfg_id, channel_content_ids=batch
  43. )
  44. if response.get("code") == 0:
  45. for item in response.get("data", []):
  46. result[item["channelContentId"]] = item
  47. else:
  48. for cid in batch:
  49. result[cid] = {
  50. "channelContentId": cid,
  51. "status": "API_ERROR",
  52. "errorMessage": f"query API failed: {response}",
  53. }
  54. return result
  55. @staticmethod
  56. def extract_decode_result(result: Dict) -> Dict:
  57. """从解构结果中解析出灵感点、目的点、关键点、选题
  58. 兼容新旧两种数据格式:v1 有 final_normalization_rebuild 包裹层,v2 无
  59. """
  60. final_result = result.get("final_normalization_rebuild") or result
  61. inspiration_list = final_result.get("inspiration_final_result", {}).get(
  62. "最终灵感点列表", []
  63. )
  64. purpose_list = final_result.get("purpose_final_result", {}).get(
  65. "最终目的点列表", []
  66. )
  67. keypoint_list = final_result.get("keypoint_final", {}).get("最终关键点列表", [])
  68. topic_fusion = final_result.get("topic_fusion_result", {})
  69. topic_text = (
  70. topic_fusion.get("最终选题", {}).get("选题", "")
  71. if isinstance(topic_fusion.get("最终选题"), dict)
  72. else ""
  73. )
  74. def _join_points(items: list, key: str) -> str:
  75. parts = [
  76. str(p[key]) for p in items if isinstance(p, dict) and p.get(key)
  77. ]
  78. return ",".join(parts)
  79. return {
  80. "inspiration": _join_points(inspiration_list, "灵感点"),
  81. "purpose": _join_points(purpose_list, "目的点"),
  82. "key_point": _join_points(keypoint_list, "关键点"),
  83. "topic": topic_text,
  84. }
  85. class AdPlatformArticlesDecodeUtils(AigcDecodeUtils):
  86. @staticmethod
  87. def format_images(images: str) -> List[str]:
  88. if not images or not images.strip():
  89. return []
  90. try:
  91. image_list = json.loads(images)
  92. except (json.JSONDecodeError, TypeError):
  93. return []
  94. if not isinstance(image_list, list):
  95. return []
  96. return [
  97. i.get("image_url")
  98. for i in image_list
  99. if isinstance(i, dict) and i.get("image_url")
  100. ]
  101. def prepare_posts(self, articles: List[Dict]) -> List[Dict]:
  102. posts = []
  103. for article in articles:
  104. images = self.format_images(article.get("article_images") or "")
  105. posts.append(
  106. {
  107. "channelContentId": article["wx_sn"],
  108. "title": article.get("article_title", ""),
  109. "bodyText": article.get("article_text", ""),
  110. "images": images,
  111. "video": None,
  112. "contentModal": self.ContentModal.LONG_ARTICLE,
  113. "channel": self.Channel.WECHAT,
  114. }
  115. )
  116. return posts
  117. class InnerArticlesDecodeUtils(AigcDecodeUtils):
  118. def prepare_posts(
  119. self, articles: List[Dict], produce_info_map: Dict[str, List[Dict]]
  120. ) -> List[Dict]:
  121. posts = []
  122. for article in articles:
  123. wx_sn = article["wx_sn"]
  124. produce_info = produce_info_map.get(wx_sn, [])
  125. # 收集图片:封面(coverimgurl) + produce COVER + produce IMAGE
  126. images = []
  127. if article.get("coverimgurl"):
  128. images.append(article["coverimgurl"])
  129. for pi in produce_info:
  130. if pi["produce_module_type"] == self.ProduceModuleType.COVER:
  131. images.append(pi["output"])
  132. for pi in produce_info:
  133. if pi["produce_module_type"] == self.ProduceModuleType.IMAGE:
  134. images.append(pi["output"])
  135. posts.append(
  136. {
  137. "title": article.get("title", ""),
  138. "bodyText": article.get("article_text", ""),
  139. "images": images,
  140. "video": None,
  141. "contentModal": self.ContentModal.LONG_ARTICLE,
  142. "channel": self.Channel.WECHAT,
  143. "channelContentId": wx_sn,
  144. }
  145. )
  146. return posts
  147. __all__ = [
  148. "AigcDecodeUtils",
  149. "AdPlatformArticlesDecodeUtils",
  150. "InnerArticlesDecodeUtils",
  151. ]