frequent_itemsets.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. """
  2. 频繁项集 API 工具
  3. 封装 pattern.aiddit.com 的频繁项集接口,用于查询与指定分类节点
  4. 在优质内容中共同出现的关联要素。
  5. """
  6. import logging
  7. import httpx
  8. from agent.tools import tool
  9. from agent.tools.models import ToolResult
  10. logger = logging.getLogger(__name__)
  11. ITEMSETS_URL = "https://pattern.aiddit.com/api/pattern/tools/get_frequent_itemsets/execute"
  12. HEADERS = {
  13. "Accept": "*/*",
  14. "Accept-Language": "zh-CN,zh;q=0.9",
  15. "Connection": "keep-alive",
  16. "Content-Type": "application/json",
  17. "Origin": "https://pattern.aiddit.com",
  18. "Referer": "https://pattern.aiddit.com/execution/33",
  19. "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/145.0.0.0 Safari/537.36",
  20. }
  21. @tool(description="查询与指定分类节点在优质内容中共同出现的频繁项集(关联要素),用于扩展制作需求的关联维度")
  22. async def get_frequent_itemsets(
  23. entity_ids: list,
  24. top_n: int = 20,
  25. execution_id: int = 33,
  26. sort_by: str = "absolute_support",
  27. ) -> ToolResult:
  28. """
  29. 获取与指定分类节点关联的频繁项集。
  30. Args:
  31. entity_ids: 分类节点的 entity_id 列表(即搜索接口返回的 entity_id 字段,非 stable_id)
  32. top_n: 返回前 N 个项集,默认 20
  33. execution_id: 执行 ID,默认 33
  34. sort_by: 排序字段,默认 "absolute_support"
  35. """
  36. payload = {
  37. "execution_id": execution_id,
  38. "args": {
  39. "top_n": top_n,
  40. "category_ids": entity_ids,
  41. "sort_by": sort_by,
  42. },
  43. }
  44. try:
  45. import json as _json
  46. async with httpx.AsyncClient(timeout=30.0) as client:
  47. resp = await client.post(ITEMSETS_URL, json=payload, headers=HEADERS)
  48. resp.raise_for_status()
  49. outer = resp.json()
  50. # result 字段是 JSON 字符串,需要二次解析
  51. data = _json.loads(outer["result"])
  52. total = data.get("total", 0)
  53. groups = data.get("groups", {})
  54. # 收集所有 group 下的 itemsets
  55. all_itemsets = []
  56. for group_key, group in groups.items():
  57. for itemset in group.get("itemsets", []):
  58. itemset["_group"] = group_key
  59. all_itemsets.append(itemset)
  60. lines = [f"频繁项集查询 entity_ids={entity_ids},共 {total} 条,返回 {len(all_itemsets)} 条:\n"]
  61. for i, itemset in enumerate(all_itemsets, 1):
  62. itemset_id = itemset.get("id", "")
  63. item_count = itemset.get("item_count", "")
  64. support = itemset.get("support", 0)
  65. abs_support = itemset.get("absolute_support", "")
  66. lines.append(f"{i}. 项集ID={itemset_id} | 项数={item_count} | support={support:.4f} | abs={abs_support}")
  67. for elem in itemset.get("items", []):
  68. dim = elem.get("dimension", "")
  69. path = elem.get("category_path", "")
  70. ename = elem.get("element_name") or ""
  71. label = f"{path}({ename})" if ename else path
  72. lines.append(f" [{dim}] {label}")
  73. lines.append("")
  74. return ToolResult(
  75. title=f"频繁项集: entity_ids={entity_ids} → {total} 条",
  76. output="\n".join(lines),
  77. )
  78. return ToolResult(
  79. title="频繁项集查询失败",
  80. output=f"HTTP {e.response.status_code}: {e.response.text[:200]}",
  81. )
  82. except Exception as e:
  83. logger.exception("get_frequent_itemsets error")
  84. return ToolResult(title="频繁项集查询失败", output=f"错误: {e}")