search_pattern.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import json
  2. import os
  3. from pathlib import Path
  4. from typing import Any, Dict, List, Optional
  5. from agent.tools import tool, ToolResult
  6. # 完整图数据库文件路径(包含 edges)
  7. GRAPH_FULL_DATA_PATH = os.getenv(
  8. "GRAPH_FULL_DATA_PATH",
  9. # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json")
  10. str(Path(__file__).parent.parent / "data/library/apriori_analysis_post_level/frequent_itemsets_multi_depth_index.json")
  11. )
  12. # 缓存图数据,避免重复加载
  13. _graph_full_cache: Optional[Dict[str, Any]] = None
  14. def _load_graph_full() -> Dict[str, Any]:
  15. """加载完整图数据(带缓存,包含 edges)"""
  16. global _graph_full_cache
  17. if _graph_full_cache is None:
  18. with open(GRAPH_FULL_DATA_PATH, 'r', encoding='utf-8') as f:
  19. _graph_full_cache = json.load(f)
  20. return _graph_full_cache
  21. def _search_pattern(class_paths: List[str], top_k: int = 10) -> Dict[str, Any]:
  22. graph = _load_graph_full()
  23. index = graph.get("index", {})
  24. items = graph.get("items", {})
  25. results = {}
  26. for class_path in class_paths:
  27. groups = index.get(class_path, [])
  28. print(groups)
  29. results[class_path] = []
  30. for group in groups:
  31. print(group)
  32. results[class_path].append(items.get(group, []))
  33. return results
  34. @tool(
  35. description="根据类别路径查找包含这些类别的所有模式组合。返回每个类别所在的所有频繁项集。",
  36. display={
  37. "zh": {
  38. "name": "模式组合检索",
  39. "params": {
  40. "class_paths": "类别路径数组",
  41. "top_k": "每个类别返回的组合数量(默认10)",
  42. },
  43. },
  44. },
  45. )
  46. async def search_pattern(
  47. class_paths: List[str],
  48. top_k: int = 10
  49. ) -> ToolResult:
  50. """
  51. 根据类别路径查找包含这些类别的所有模式组合。
  52. Args:
  53. class_paths: 类别路径数组,如 ["关键点_形式_架构>策略>行为体验"]
  54. top_k: 每个类别返回前 K 个组合,默认 10
  55. Returns:
  56. ToolResult: 每个类别及其所在的所有组合
  57. """
  58. if not class_paths or len(class_paths) == 0:
  59. return ToolResult(
  60. title="模式组合检索失败",
  61. output="",
  62. error="请提供类别路径",
  63. )
  64. try:
  65. result = _search_pattern(class_paths, top_k)
  66. except FileNotFoundError:
  67. return ToolResult(
  68. title="模式组合检索失败",
  69. output="",
  70. error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
  71. )
  72. except Exception as e:
  73. return ToolResult(
  74. title="模式组合检索失败",
  75. output="",
  76. error=f"检索异常: {str(e)}",
  77. )
  78. # 统计总共找到的组合数
  79. total_patterns = sum(len(patterns) for patterns in result.values())
  80. # 限制返回数量
  81. limited_result = {}
  82. for class_path, patterns in result.items():
  83. limited_result[class_path] = patterns
  84. output = json.dumps(limited_result, ensure_ascii=False, indent=2)
  85. return ToolResult(
  86. title=f"模式组合检索 - {len(class_paths)} 个类别",
  87. output=output,
  88. long_term_memory=f"为 {len(class_paths)} 个类别检索到 {total_patterns} 个模式组合)",
  89. )
  90. if __name__ == "__main__":
  91. print(_search_pattern(["关键点_形式_架构>叙事>叙事形式"]))