_utils.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from typing import Dict, List
  2. from app.infra.internal.aigc_decode_server import AigcDecodeServer
  3. from ._const import DecodeMaterialConst
  4. class MaterialDecodeUtils(DecodeMaterialConst):
  5. decode_server = AigcDecodeServer()
  6. async def submit_decode_batch(
  7. self, posts: List[Dict], *, skip_completed: bool = False
  8. ) -> Dict[str, Dict]:
  9. """分批提交素材解构任务"""
  10. result = {}
  11. for i in range(0, len(posts), self.SUBMIT_BATCH):
  12. batch = posts[i : i + self.SUBMIT_BATCH]
  13. response = await self.decode_server.submit_decode(
  14. config_id=self.CONFIG_ID, posts=batch, skip_completed=skip_completed
  15. )
  16. if response.get("code") == 0:
  17. for item in response.get("data", []):
  18. result[item["channelContentId"]] = item
  19. else:
  20. for post in batch:
  21. cid = post["channelContentId"]
  22. result[cid] = {
  23. "channelContentId": cid,
  24. "status": "FAILED",
  25. "errorMessage": f"batch submit failed: {response}",
  26. }
  27. return result
  28. async def query_decode_results_batch(
  29. self, content_ids: List[str]
  30. ) -> Dict[str, Dict]:
  31. """分批查询素材解构结果"""
  32. result = {}
  33. for i in range(0, len(content_ids), self.SUBMIT_BATCH):
  34. batch = content_ids[i : i + self.SUBMIT_BATCH]
  35. response = await self.decode_server.query_decode_results(
  36. config_id=self.CONFIG_ID, channel_content_ids=batch
  37. )
  38. if response.get("code") == 0:
  39. for item in response.get("data", []):
  40. result[item["channelContentId"]] = item
  41. else:
  42. for cid in batch:
  43. result[cid] = {
  44. "channelContentId": cid,
  45. "status": "API_ERROR",
  46. "errorMessage": f"query API failed: {response}",
  47. }
  48. return result
  49. @staticmethod
  50. def prepare_posts(materials: List[Dict]) -> List[Dict]:
  51. """将素材数据转换为 AIGC 解构 API 所需的 post 格式"""
  52. posts = []
  53. for m in materials:
  54. images = []
  55. cover = m.get("material_cover")
  56. if cover:
  57. images.append(cover)
  58. posts.append(
  59. {
  60. "channelContentId": str(m["material_id"]),
  61. "title": m.get("material_title", ""),
  62. "bodyText": "",
  63. "images": images,
  64. "video": None,
  65. "contentModal": DecodeMaterialConst.ContentModal.PICTURE_TEXT,
  66. "channel": DecodeMaterialConst.Channel.GROWTH_MATERIAL,
  67. }
  68. )
  69. return posts
  70. __all__ = ["MaterialDecodeUtils"]