_utils.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import json
  2. from typing import Dict, List
  3. from app.infra.internal.aigc_decode_server import AigcDecodeServer
  4. from ._const import DecodeArticleConst
  5. class AigcDecodeUtils(DecodeArticleConst):
  6. decode_server = AigcDecodeServer()
  7. async def submit_decode_batch(
  8. self, posts: List[Dict]
  9. ) -> Dict[str, Dict]:
  10. """分批提交解构任务,返回 {channel_content_id: {status, errorMessage}}"""
  11. result = {}
  12. for i in range(0, len(posts), self.SUBMIT_BATCH):
  13. batch = posts[i : i + self.SUBMIT_BATCH]
  14. response = await self.decode_server.submit_decode(
  15. config_id=self.CONFIG_ID, posts=batch
  16. )
  17. if response.get("code") == 0:
  18. for item in response.get("data", []):
  19. result[item["channelContentId"]] = item
  20. else:
  21. # 整批失败,标记所有帖子为 FAILED
  22. for post in batch:
  23. cid = post["channelContentId"]
  24. result[cid] = {
  25. "channelContentId": cid,
  26. "status": "FAILED",
  27. "errorMessage": f"batch submit failed: {response}",
  28. }
  29. return result
  30. async def query_decode_results_batch(
  31. self, channel_content_ids: List[str]
  32. ) -> Dict[str, Dict]:
  33. """分批查询解构结果,返回 {channel_content_id: {status, dataContent, html, errorMessage}}
  34. 当 API 调用失败时,对应条目 status 为 API_ERROR,调用方应保持 INIT 等待重试。
  35. """
  36. result = {}
  37. for i in range(0, len(channel_content_ids), self.SUBMIT_BATCH):
  38. batch = channel_content_ids[i : i + self.SUBMIT_BATCH]
  39. response = await self.decode_server.query_decode_results(
  40. config_id=self.CONFIG_ID, channel_content_ids=batch
  41. )
  42. if response.get("code") == 0:
  43. for item in response.get("data", []):
  44. result[item["channelContentId"]] = item
  45. else:
  46. for cid in batch:
  47. result[cid] = {
  48. "channelContentId": cid,
  49. "status": "API_ERROR",
  50. "errorMessage": f"query API failed: {response}",
  51. }
  52. return result
  53. @staticmethod
  54. def extract_decode_result(result: Dict) -> Dict:
  55. """从解构结果中解析出灵感点、目的点、关键点、选题
  56. 兼容新旧两种数据格式:v1 有 final_normalization_rebuild 包裹层,v2 无
  57. """
  58. final_result = result.get("final_normalization_rebuild") or result
  59. inspiration_list = final_result.get("inspiration_final_result", {}).get(
  60. "最终灵感点列表", []
  61. )
  62. purpose_list = final_result.get("purpose_final_result", {}).get(
  63. "最终目的点列表", []
  64. )
  65. keypoint_list = final_result.get("keypoint_final", {}).get("最终关键点列表", [])
  66. topic_fusion = final_result.get("topic_fusion_result", {})
  67. topic_text = (
  68. topic_fusion.get("最终选题", {}).get("选题", "")
  69. if isinstance(topic_fusion.get("最终选题"), dict)
  70. else ""
  71. )
  72. def _join_points(items: list, key: str) -> str:
  73. parts = [
  74. str(p[key]) for p in items if isinstance(p, dict) and p.get(key)
  75. ]
  76. return ",".join(parts)
  77. return {
  78. "inspiration": _join_points(inspiration_list, "灵感点"),
  79. "purpose": _join_points(purpose_list, "目的点"),
  80. "key_point": _join_points(keypoint_list, "关键点"),
  81. "topic": topic_text,
  82. }
  83. class AdPlatformArticlesDecodeUtils(AigcDecodeUtils):
  84. @staticmethod
  85. def format_images(images: str) -> List[str]:
  86. if not images or not images.strip():
  87. return []
  88. try:
  89. image_list = json.loads(images)
  90. except (json.JSONDecodeError, TypeError):
  91. return []
  92. if not isinstance(image_list, list):
  93. return []
  94. return [
  95. i.get("image_url")
  96. for i in image_list
  97. if isinstance(i, dict) and i.get("image_url")
  98. ]
  99. def prepare_posts(self, articles: List[Dict]) -> List[Dict]:
  100. posts = []
  101. for article in articles:
  102. images = self.format_images(article.get("article_images") or "")
  103. posts.append(
  104. {
  105. "channelContentId": article["wx_sn"],
  106. "title": article.get("article_title", ""),
  107. "bodyText": article.get("article_text", ""),
  108. "images": images,
  109. "video": None,
  110. "contentModal": self.ContentModal.LONG_ARTICLE,
  111. "channel": self.Channel.WECHAT,
  112. }
  113. )
  114. return posts
  115. class InnerArticlesDecodeUtils(AigcDecodeUtils):
  116. def prepare_posts(
  117. self, articles: List[Dict], produce_info_map: Dict[str, List[Dict]]
  118. ) -> List[Dict]:
  119. posts = []
  120. for article in articles:
  121. wx_sn = article["wx_sn"]
  122. produce_info = produce_info_map.get(wx_sn, [])
  123. images = [
  124. i["output"]
  125. for i in produce_info
  126. if i["produce_module_type"]
  127. in (self.ProduceModuleType.COVER, self.ProduceModuleType.IMAGE)
  128. ]
  129. text_parts = [
  130. i["output"]
  131. for i in produce_info
  132. if i["produce_module_type"] == self.ProduceModuleType.CONTENT
  133. ]
  134. posts.append(
  135. {
  136. "channelContentId": wx_sn,
  137. "title": article.get("title", ""),
  138. "bodyText": "\n".join(text_parts),
  139. "images": images,
  140. "video": None,
  141. "contentModal": self.ContentModal.LONG_ARTICLE,
  142. "channel": self.Channel.WECHAT,
  143. }
  144. )
  145. return posts
  146. __all__ = [
  147. "AigcDecodeUtils",
  148. "AdPlatformArticlesDecodeUtils",
  149. "InnerArticlesDecodeUtils",
  150. ]