import json import os from pathlib import Path from typing import Any, Dict, List, Optional from agent.tools import tool, ToolResult # 完整图数据库文件路径(包含 edges) GRAPH_FULL_DATA_PATH = os.getenv( "GRAPH_FULL_DATA_PATH", # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json") str(Path(__file__).parent.parent / "data/library/apriori_analysis_post_level/frequent_itemsets_multi_depth_index.json") ) # 缓存图数据,避免重复加载 _graph_full_cache: Optional[Dict[str, Any]] = None def _load_graph_full() -> Dict[str, Any]: """加载完整图数据(带缓存,包含 edges)""" global _graph_full_cache if _graph_full_cache is None: with open(GRAPH_FULL_DATA_PATH, 'r', encoding='utf-8') as f: _graph_full_cache = json.load(f) return _graph_full_cache def _search_pattern(class_paths: List[str], top_k: int = 10) -> Dict[str, Any]: graph = _load_graph_full() index = graph.get("index", {}) items = graph.get("items", {}) results = {} for class_path in class_paths: groups = index.get(class_path, []) print(groups) results[class_path] = [] for group in groups: print(group) results[class_path].append(items.get(group, [])) return results @tool( description="根据类别路径查找包含这些类别的所有模式组合。返回每个类别所在的所有频繁项集。", display={ "zh": { "name": "模式组合检索", "params": { "class_paths": "类别路径数组", "top_k": "每个类别返回的组合数量(默认10)", }, }, }, ) async def search_pattern( class_paths: List[str], top_k: int = 10 ) -> ToolResult: """ 根据类别路径查找包含这些类别的所有模式组合。 Args: class_paths: 类别路径数组,如 ["关键点_形式_架构>策略>行为体验"] top_k: 每个类别返回前 K 个组合,默认 10 Returns: ToolResult: 每个类别及其所在的所有组合 """ if not class_paths or len(class_paths) == 0: return ToolResult( title="模式组合检索失败", output="", error="请提供类别路径", ) try: result = _search_pattern(class_paths, top_k) except FileNotFoundError: return ToolResult( title="模式组合检索失败", output="", error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}", ) except Exception as e: return ToolResult( title="模式组合检索失败", output="", error=f"检索异常: {str(e)}", ) # 统计总共找到的组合数 total_patterns = sum(len(patterns) for patterns in result.values()) # 限制返回数量 limited_result = {} for class_path, patterns in result.items(): limited_result[class_path] = patterns output = json.dumps(limited_result, ensure_ascii=False, indent=2) return ToolResult( title=f"模式组合检索 - {len(class_paths)} 个类别", output=output, long_term_memory=f"为 {len(class_paths)} 个类别检索到 {total_patterns} 个模式组合)", ) if __name__ == "__main__": print(_search_pattern(["关键点_形式_架构>叙事>叙事形式"]))