""" 关键点检索工具 - 根据输入的点在图数据库中查找所有关联的点 用于 Agent 执行时自主调取关联关键点数据。 """ import json import os from pathlib import Path from typing import Any, Dict, List, Optional from agent.tools import tool, ToolResult # 完整图数据库文件路径(包含 edges) GRAPH_FULL_DATA_PATH = os.getenv( "GRAPH_FULL_DATA_PATH", # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json") str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_max.json") ) GRAPH_FULL_MAP_DATA_PATH = os.getenv( "GRAPH_FULL_MAP_DATA_PATH", # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json") str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_max_map.json") ) # 缓存图数据,避免重复加载 _graph_full_cache: Optional[Dict[str, Any]] = None _graph_full_map_cache: Optional[Dict[str, Any]] = None def _load_graph_full() -> Dict[str, Any]: """加载完整图数据(带缓存,包含 edges)""" global _graph_full_cache if _graph_full_cache is None: with open(GRAPH_FULL_DATA_PATH, 'r', encoding='utf-8') as f: _graph_full_cache = json.load(f) return _graph_full_cache def _load_graph_full_map() -> Dict[str, Any]: """加载完整图数据(带缓存,包含 edges)""" global _graph_full_map_cache if _graph_full_map_cache is None: with open(GRAPH_FULL_MAP_DATA_PATH, 'r', encoding='utf-8') as f: _graph_full_map_cache = json.load(f) return _graph_full_map_cache def _search_relation_class_by_class(class_paths: List[str], top_k: int = 5) -> Dict[str, Any]: """ 根据类别查找与该类别相关的其他类别 Args: class_paths: 类别名称列表,如 ["关键点_实质_理念>现象>社会>时空背景", "关键点_形式_架构>策略>行为体验"] top_k: 每个类别返回前 K 个关联类别,默认 5 Returns: 包含每个类别及其关联类别的列表 """ graph = _load_graph_full() results = [] for class_path in class_paths: value = graph.get(class_path, {}) edges = value.get("edges", {}) # 提取所有相关类别及其置信度 related_classes = [] for target_class_path, edge_data in edges.items(): co_in_post = edge_data.get("co_in_post", {}) confidence = co_in_post.get("confidence", 0.0) related_classes.append({ "class_path": target_class_path, "confidence": confidence }) # 按置信度降序排序,取前 top_k 个 related_classes.sort(key=lambda x: x["confidence"], reverse=True) related_classes = related_classes[:top_k] results.append({ "input_class_path": class_path, "related_class_paths": [item["class_path"] for item in related_classes] }) return results @tool( description="根据类别查找与该类别相关的其他类别,返回关系统计数据。支持单个或多个类别路径。每个类别独立返回其关联类别。", display={ "zh": { "name": "类别关系检索", "params": { "class_paths": "类别路径数组", "top_k": "每个类别返回数量(默认5)", }, }, }, ) async def search_relation_class_by_class( class_paths: List[str], top_k: int = 5 ) -> ToolResult: """ 根据类别查找与该类别相关的其他类别。 Args: class_paths: 类别路径数组,如 ["关键点_形式_架构>策略>行为体验"] 或 ["关键点_形式_架构>策略>行为体验", "目的点_意图_分享"] top_k: 每个类别返回前 K 个相关类别,默认 5 Returns: ToolResult: 每个类别及其关联类别 """ if not class_paths or len(class_paths) == 0: return ToolResult( title="类别关系检索失败", output="", error="请提供类别路径", ) try: result = _search_relation_class_by_class(class_paths, top_k) except FileNotFoundError: return ToolResult( title="类别关系检索失败", output="", error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}", ) except Exception as e: return ToolResult( title="类别关系检索失败", output="", error=f"检索异常: {str(e)}", ) # 统计总共找到的关联类别数 total_related = sum(len(item["related_class_paths"]) for item in result) output = json.dumps(result, ensure_ascii=False, indent=2) return ToolResult( title=f"类别关系检索 - {len(class_paths)} 个类别", output=output, long_term_memory=f"为 {len(class_paths)} 个类别检索到 {total_related} 个关联类别", ) def _search_relation_point_by_point(points: List[Dict], top_k: int = 999) -> Dict[str, Any]: """ 根据点的信息查找与该点相关的其他点 Args: points: 点信息列表 top_k: 返回前 K 个相关点 Returns: 包含相关点及其关系数据的字典 """ graph = _load_graph_full() graph_map = _load_graph_full_map() all_results = [] for point in points: point_value = point.get("point_value", "") point_type = point.get("point_type", "") dimension = point.get("dimension", "") accounts = point.get("accounts", []) # 先通过 map 找到完整的 point_path class_paths = [] type_dict = graph_map.get(point_type, {}) dim_dict = type_dict.get(dimension, {}) for account in accounts: account_dict = dim_dict.get(account, {}) class_path = account_dict.get(point_value, "") if class_path: class_paths.append(class_path) # 用找到的 point_paths 查找关联点 all_related_points = {} for class_path in class_paths: if class_path not in graph: continue point_data = graph[class_path] edges = point_data.get("edges", {}) # 提取所有相关点及其置信度 for target_point, edge_data in edges.items(): co_in_post = edge_data.get("co_in_post", {}) confidence = co_in_post.get("confidence", 0.0) # 聚合相同目标点的数据 if target_point not in all_related_points: all_related_points[target_point] = { "confidence": 0.0, "source_count": 0 } all_related_points[target_point]["confidence"] += confidence all_related_points[target_point]["source_count"] += 1 # 计算平均置信度并排序 related_points = [] for target_point, data in all_related_points.items(): avg_confidence = data["confidence"] / data["source_count"] if data["source_count"] > 0 else 0.0 related_points.append({ "point_path": target_point, "confidence": avg_confidence }) # 按置信度降序排序 related_points.sort(key=lambda x: x["confidence"], reverse=True) related_points = related_points[:top_k] all_results.append({ "input_point": point, "related_point_paths": [item["point_path"] for item in related_points] }) return all_results @tool( description="根据点的信息查找与该点相关的其他点,返回关系数据。", display={ "zh": { "name": "点关系检索", "params": { "points": [ { "point_value": "点", "point_type": "点类型", "dimension": "维度", "accounts": ["账号名称"] } ], "top_k": "返回数量(默认999)", }, }, }, ) async def search_relation_point_by_point( points: List[Dict], top_k: int = 999 ) -> ToolResult: """ 根据点的信息查找与该点相关的其他点。 Args: points: 点信息列表 top_k: 返回前 K 个相关点,默认 999 Returns: ToolResult: 相关点及其关系数据 """ if not points or len(points) == 0: return ToolResult( title="点关系检索失败", output="", error="请提供点信息", ) try: result = _search_relation_point_by_point(points, top_k) except FileNotFoundError: return ToolResult( title="点关系检索失败", output="", error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}", ) except Exception as e: return ToolResult( title="点关系检索失败", output="", error=f"检索异常: {str(e)}", ) output = json.dumps(result, ensure_ascii=False, indent=2) return ToolResult( title=f"点关系检索 - {len(points)} 个点", output=output, long_term_memory=f"检索到与输入点相关的点", ) def _search_class_by_point( points: List[Dict] ) -> Dict[str, Any]: graph = _load_graph_full_map() data = [] for point in points: point_value = point.get("point_value", "") point_type = point.get("point_type", "") dimension = point.get("dimension", "") type_dict = graph.get(point_type, {}) dim_dict = type_dict.get(dimension, {}) class_path = dim_dict.get(point_value, "") if class_path: data.append({ "point": point, "class_path": class_path }) return data @tool( description="根据点的属性查找该点所属的类别。", display={ "zh": { "name": "点类别查询", "params": { "points": [ { "point_value": "点", "point_type": "点类型", "dimension": "维度", "accounts": [ "账号名称" ] } ] }, }, }, ) async def search_class_by_point( points: List[Dict] ) -> ToolResult: if not points: return ToolResult( title="点类别查询失败", output="", error="请提供 points, point_type 和 dimension", ) try: result = _search_class_by_point(points) except FileNotFoundError: return ToolResult( title="点类别查询失败", output="", error=f"图数据文件不存在: {GRAPH_FULL_MAP_DATA_PATH}", ) except Exception as e: return ToolResult( title="点类别查询失败", output="", error=f"检索异常: {str(e)}", ) output = json.dumps(result, ensure_ascii=False, indent=2) return ToolResult( title=f"点类别查询 - {len(points)} 个点", output=output, long_term_memory=f"查询到 {len(result)} 个点的类别信息", ) def _search_point_by_class(class_paths: List[str]) -> Dict[str, Any]: """ 根据类别查找属于该类别的所有点 Args: class_paths: 类别路径列表,如 ["关键点_形式"] top_k: 返回前 K 个点(按频率排序) Returns: 包含该类别所有点的字典 """ graph = _load_graph_full() data = [] for class_path in class_paths: points = graph.get(class_path, {}) data.append({ "class_path": class_path, "points": list(points.get("meta", {}).get("elements", {}).keys()) }) return data @tool( description="根据类别查找属于该类别的所有点。支持单个或多个类别路径。", display={ "zh": { "name": "类别点检索", "params": { "class_paths": "类别路径数组" }, }, }, ) async def search_point_by_class( class_paths: List[str] ) -> ToolResult: """ 根据类别查找属于该类别的所有点。 Args: class_paths: 类别路径数组,如 ["关键点_形式"] top_k: 返回前 K 个点,默认 999 Returns: ToolResult: 该类别的所有点 """ if not class_paths or len(class_paths) == 0: return ToolResult( title="类别点检索失败", output="", error="请提供类别路径", ) try: result = _search_point_by_class(class_paths) except FileNotFoundError: return ToolResult( title="类别点检索失败", output="" ) except Exception as e: return ToolResult( title="类别点检索失败", output="" ) output = json.dumps(result, ensure_ascii=False, indent=2) return ToolResult( title=f"类别点检索 - {len(class_paths)} 个类别", output=output, long_term_memory=f"检索到类别的点数据" ) if __name__ == "__main__": print(_search_point_by_class(["关键点_形式_架构>策略>行为体验"]))