_utils.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import asyncio
  2. from typing import Dict, List
  3. from app.infra.internal.aigc_decode_server import AigcDecodeServer
  4. from ._const import DecodeMaterialConst
  5. class MaterialDecodeUtils(DecodeMaterialConst):
  6. decode_server = AigcDecodeServer()
  7. async def submit_decode_batch(
  8. self, posts: List[Dict], *, skip_completed: bool = False
  9. ) -> Dict[str, Dict]:
  10. """分批提交素材解构任务"""
  11. result = {}
  12. for i in range(0, len(posts), self.SUBMIT_BATCH):
  13. batch = posts[i : i + self.SUBMIT_BATCH]
  14. response = await self.decode_server.submit_decode(
  15. config_id=self.CONFIG_ID, posts=batch, skip_completed=skip_completed
  16. )
  17. if response.get("code") == 0:
  18. for item in response.get("data", []):
  19. result[item["channelContentId"]] = item
  20. else:
  21. for post in batch:
  22. cid = post["channelContentId"]
  23. result[cid] = {
  24. "channelContentId": cid,
  25. "status": "FAILED",
  26. "errorMessage": f"batch submit failed: {response}",
  27. }
  28. if i + self.SUBMIT_BATCH < len(posts):
  29. await asyncio.sleep(self.API_INTERVAL)
  30. return result
  31. async def query_decode_results_batch(
  32. self, content_ids: List[str]
  33. ) -> Dict[str, Dict]:
  34. """分批查询素材解构结果"""
  35. result = {}
  36. for i in range(0, len(content_ids), self.SUBMIT_BATCH):
  37. batch = content_ids[i : i + self.SUBMIT_BATCH]
  38. response = await self.decode_server.query_decode_results(
  39. config_id=self.CONFIG_ID, channel_content_ids=batch
  40. )
  41. if response.get("code") == 0:
  42. for item in response.get("data", []):
  43. result[item["channelContentId"]] = item
  44. else:
  45. for cid in batch:
  46. result[cid] = {
  47. "channelContentId": cid,
  48. "status": "API_ERROR",
  49. "errorMessage": f"query API failed: {response}",
  50. }
  51. if i + self.SUBMIT_BATCH < len(content_ids):
  52. await asyncio.sleep(self.API_INTERVAL)
  53. return result
  54. @staticmethod
  55. def prepare_posts(materials: List[Dict]) -> List[Dict]:
  56. """将素材数据转换为 AIGC 解构 API 所需的 post 格式"""
  57. posts = []
  58. for m in materials:
  59. images = []
  60. cover = m.get("material_cover")
  61. if cover:
  62. images.append(cover)
  63. posts.append(
  64. {
  65. "channelContentId": str(m["material_id"]),
  66. "title": m.get("material_title", ""),
  67. "bodyText": "",
  68. "images": images,
  69. "video": None,
  70. "contentModal": DecodeMaterialConst.ContentModal.PICTURE_TEXT,
  71. "channel": DecodeMaterialConst.Channel.GROWTH_MATERIAL,
  72. }
  73. )
  74. return posts
  75. __all__ = ["MaterialDecodeUtils"]