| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- import json
- import os
- from pathlib import Path
- from typing import Any, Dict, List, Optional
- from agent.tools import tool, ToolResult
- # 完整图数据库文件路径(包含 edges)
- GRAPH_FULL_DATA_PATH = os.getenv(
- "GRAPH_FULL_DATA_PATH",
- # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json")
- str(Path(__file__).parent.parent / "data/library/apriori_analysis_post_level/frequent_itemsets_multi_depth_index.json")
- )
- # 缓存图数据,避免重复加载
- _graph_full_cache: Optional[Dict[str, Any]] = None
- def _load_graph_full() -> Dict[str, Any]:
- """加载完整图数据(带缓存,包含 edges)"""
- global _graph_full_cache
- if _graph_full_cache is None:
- with open(GRAPH_FULL_DATA_PATH, 'r', encoding='utf-8') as f:
- _graph_full_cache = json.load(f)
- return _graph_full_cache
- def _search_pattern(class_paths: List[str], top_k: int = 10) -> Dict[str, Any]:
- graph = _load_graph_full()
- index = graph.get("index", {})
- items = graph.get("items", {})
- results = {}
- for class_path in class_paths:
- groups = index.get(class_path, [])
- print(groups)
- results[class_path] = []
- for group in groups:
- print(group)
- results[class_path].append(items.get(group, []))
- return results
- @tool(
- description="根据类别路径查找包含这些类别的所有模式组合。返回每个类别所在的所有频繁项集。",
- display={
- "zh": {
- "name": "模式组合检索",
- "params": {
- "class_paths": "类别路径数组",
- "top_k": "每个类别返回的组合数量(默认10)",
- },
- },
- },
- )
- async def search_pattern(
- class_paths: List[str],
- top_k: int = 10
- ) -> ToolResult:
- """
- 根据类别路径查找包含这些类别的所有模式组合。
- Args:
- class_paths: 类别路径数组,如 ["关键点_形式_架构>策略>行为体验"]
- top_k: 每个类别返回前 K 个组合,默认 10
- Returns:
- ToolResult: 每个类别及其所在的所有组合
- """
- if not class_paths or len(class_paths) == 0:
- return ToolResult(
- title="模式组合检索失败",
- output="",
- error="请提供类别路径",
- )
- try:
- result = _search_pattern(class_paths, top_k)
- except FileNotFoundError:
- return ToolResult(
- title="模式组合检索失败",
- output="",
- error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
- )
- except Exception as e:
- return ToolResult(
- title="模式组合检索失败",
- output="",
- error=f"检索异常: {str(e)}",
- )
- # 统计总共找到的组合数
- total_patterns = sum(len(patterns) for patterns in result.values())
- # 限制返回数量
- limited_result = {}
- for class_path, patterns in result.items():
- limited_result[class_path] = patterns
- output = json.dumps(limited_result, ensure_ascii=False, indent=2)
- return ToolResult(
- title=f"模式组合检索 - {len(class_paths)} 个类别",
- output=output,
- long_term_memory=f"为 {len(class_paths)} 个类别检索到 {total_patterns} 个模式组合)",
- )
- if __name__ == "__main__":
- print(_search_pattern(["关键点_形式_架构>叙事>叙事形式"]))
|