| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448 |
- """
- 关键点检索工具 - 根据输入的点在图数据库中查找所有关联的点
- 用于 Agent 执行时自主调取关联关键点数据。
- """
- import json
- import os
- from pathlib import Path
- from typing import Any, Dict, List, Optional
- from agent.tools import tool, ToolResult
- # 完整图数据库文件路径(包含 edges)
- GRAPH_FULL_DATA_PATH = os.getenv(
- "GRAPH_FULL_DATA_PATH",
- # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json")
- str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_max.json")
- )
- GRAPH_FULL_MAP_DATA_PATH = os.getenv(
- "GRAPH_FULL_MAP_DATA_PATH",
- # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json")
- str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_max_map.json")
- )
- # 缓存图数据,避免重复加载
- _graph_full_cache: Optional[Dict[str, Any]] = None
- _graph_full_map_cache: Optional[Dict[str, Any]] = None
- def _load_graph_full() -> Dict[str, Any]:
- """加载完整图数据(带缓存,包含 edges)"""
- global _graph_full_cache
- if _graph_full_cache is None:
- with open(GRAPH_FULL_DATA_PATH, 'r', encoding='utf-8') as f:
- _graph_full_cache = json.load(f)
- return _graph_full_cache
- def _load_graph_full_map() -> Dict[str, Any]:
- """加载完整图数据(带缓存,包含 edges)"""
- global _graph_full_map_cache
- if _graph_full_map_cache is None:
- with open(GRAPH_FULL_MAP_DATA_PATH, 'r', encoding='utf-8') as f:
- _graph_full_map_cache = json.load(f)
- return _graph_full_map_cache
- def _search_relation_class_by_class(class_paths: List[str], top_k: int = 5) -> Dict[str, Any]:
- """
- 根据类别查找与该类别相关的其他类别
- Args:
- class_paths: 类别名称列表,如 ["关键点_实质_理念>现象>社会>时空背景", "关键点_形式_架构>策略>行为体验"]
- top_k: 每个类别返回前 K 个关联类别,默认 5
- Returns:
- 包含每个类别及其关联类别的列表
- """
- graph = _load_graph_full()
- results = []
- for class_path in class_paths:
- value = graph.get(class_path, {})
- edges = value.get("edges", {})
- # 提取所有相关类别及其置信度
- related_classes = []
- for target_class_path, edge_data in edges.items():
- co_in_post = edge_data.get("co_in_post", {})
- confidence = co_in_post.get("confidence", 0.0)
- related_classes.append({
- "class_path": target_class_path,
- "confidence": confidence
- })
- # 按置信度降序排序,取前 top_k 个
- related_classes.sort(key=lambda x: x["confidence"], reverse=True)
- related_classes = related_classes[:top_k]
- results.append({
- "input_class_path": class_path,
- "related_class_paths": [item["class_path"] for item in related_classes]
- })
- return results
- @tool(
- description="根据类别查找与该类别相关的其他类别,返回关系统计数据。支持单个或多个类别路径。每个类别独立返回其关联类别。",
- display={
- "zh": {
- "name": "类别关系检索",
- "params": {
- "class_paths": "类别路径数组",
- "top_k": "每个类别返回数量(默认5)",
- },
- },
- },
- )
- async def search_relation_class_by_class(
- class_paths: List[str],
- top_k: int = 5
- ) -> ToolResult:
- """
- 根据类别查找与该类别相关的其他类别。
- Args:
- class_paths: 类别路径数组,如 ["关键点_形式_架构>策略>行为体验"] 或 ["关键点_形式_架构>策略>行为体验", "目的点_意图_分享"]
- top_k: 每个类别返回前 K 个相关类别,默认 5
- Returns:
- ToolResult: 每个类别及其关联类别
- """
- if not class_paths or len(class_paths) == 0:
- return ToolResult(
- title="类别关系检索失败",
- output="",
- error="请提供类别路径",
- )
- try:
- result = _search_relation_class_by_class(class_paths, top_k)
- except FileNotFoundError:
- return ToolResult(
- title="类别关系检索失败",
- output="",
- error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
- )
- except Exception as e:
- return ToolResult(
- title="类别关系检索失败",
- output="",
- error=f"检索异常: {str(e)}",
- )
- # 统计总共找到的关联类别数
- total_related = sum(len(item["related_class_paths"]) for item in result)
- output = json.dumps(result, ensure_ascii=False, indent=2)
- return ToolResult(
- title=f"类别关系检索 - {len(class_paths)} 个类别",
- output=output,
- long_term_memory=f"为 {len(class_paths)} 个类别检索到 {total_related} 个关联类别",
- )
- def _search_relation_point_by_point(points: List[Dict], top_k: int = 999) -> Dict[str, Any]:
- """
- 根据点的信息查找与该点相关的其他点
- Args:
- points: 点信息列表
- top_k: 返回前 K 个相关点
- Returns:
- 包含相关点及其关系数据的字典
- """
- graph = _load_graph_full()
- graph_map = _load_graph_full_map()
- all_results = []
- for point in points:
- point_value = point.get("point_value", "")
- point_type = point.get("point_type", "")
- dimension = point.get("dimension", "")
- accounts = point.get("accounts", [])
- # 先通过 map 找到完整的 point_path
- class_paths = []
- type_dict = graph_map.get(point_type, {})
- dim_dict = type_dict.get(dimension, {})
- for account in accounts:
- account_dict = dim_dict.get(account, {})
- class_path = account_dict.get(point_value, "")
- if class_path:
- class_paths.append(class_path)
- # 用找到的 point_paths 查找关联点
- all_related_points = {}
- for class_path in class_paths:
- if class_path not in graph:
- continue
- point_data = graph[class_path]
- edges = point_data.get("edges", {})
- # 提取所有相关点及其置信度
- for target_point, edge_data in edges.items():
- co_in_post = edge_data.get("co_in_post", {})
- confidence = co_in_post.get("confidence", 0.0)
- # 聚合相同目标点的数据
- if target_point not in all_related_points:
- all_related_points[target_point] = {
- "confidence": 0.0,
- "source_count": 0
- }
- all_related_points[target_point]["confidence"] += confidence
- all_related_points[target_point]["source_count"] += 1
- # 计算平均置信度并排序
- related_points = []
- for target_point, data in all_related_points.items():
- avg_confidence = data["confidence"] / data["source_count"] if data["source_count"] > 0 else 0.0
- related_points.append({
- "point_path": target_point,
- "confidence": avg_confidence
- })
- # 按置信度降序排序
- related_points.sort(key=lambda x: x["confidence"], reverse=True)
- related_points = related_points[:top_k]
- all_results.append({
- "input_point": point,
- "related_point_paths": [item["point_path"] for item in related_points]
- })
- return all_results
- @tool(
- description="根据点的信息查找与该点相关的其他点,返回关系数据。",
- display={
- "zh": {
- "name": "点关系检索",
- "params": {
- "points": [
- {
- "point_value": "点",
- "point_type": "点类型",
- "dimension": "维度",
- "accounts": ["账号名称"]
- }
- ],
- "top_k": "返回数量(默认999)",
- },
- },
- },
- )
- async def search_relation_point_by_point(
- points: List[Dict],
- top_k: int = 999
- ) -> ToolResult:
- """
- 根据点的信息查找与该点相关的其他点。
- Args:
- points: 点信息列表
- top_k: 返回前 K 个相关点,默认 999
- Returns:
- ToolResult: 相关点及其关系数据
- """
- if not points or len(points) == 0:
- return ToolResult(
- title="点关系检索失败",
- output="",
- error="请提供点信息",
- )
- try:
- result = _search_relation_point_by_point(points, top_k)
- except FileNotFoundError:
- return ToolResult(
- title="点关系检索失败",
- output="",
- error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
- )
- except Exception as e:
- return ToolResult(
- title="点关系检索失败",
- output="",
- error=f"检索异常: {str(e)}",
- )
- output = json.dumps(result, ensure_ascii=False, indent=2)
- return ToolResult(
- title=f"点关系检索 - {len(points)} 个点",
- output=output,
- long_term_memory=f"检索到与输入点相关的点",
- )
- def _search_class_by_point(
- points: List[Dict]
- ) -> Dict[str, Any]:
- graph = _load_graph_full_map()
- data = []
- for point in points:
- point_value = point.get("point_value", "")
- point_type = point.get("point_type", "")
- dimension = point.get("dimension", "")
- type_dict = graph.get(point_type, {})
- dim_dict = type_dict.get(dimension, {})
- class_path = dim_dict.get(point_value, "")
- if class_path:
- data.append({
- "point": point,
- "class_path": class_path
- })
- return data
-
- @tool(
- description="根据点的属性查找该点所属的类别。",
- display={
- "zh": {
- "name": "点类别查询",
- "params": {
- "points": [
- {
- "point_value": "点",
- "point_type": "点类型",
- "dimension": "维度",
- "accounts": [
- "账号名称"
- ]
- }
- ]
-
- },
- },
- },
- )
- async def search_class_by_point(
- points: List[Dict]
- ) -> ToolResult:
- if not points:
- return ToolResult(
- title="点类别查询失败",
- output="",
- error="请提供 points, point_type 和 dimension",
- )
- try:
- result = _search_class_by_point(points)
- except FileNotFoundError:
- return ToolResult(
- title="点类别查询失败",
- output="",
- error=f"图数据文件不存在: {GRAPH_FULL_MAP_DATA_PATH}",
- )
- except Exception as e:
- return ToolResult(
- title="点类别查询失败",
- output="",
- error=f"检索异常: {str(e)}",
- )
- output = json.dumps(result, ensure_ascii=False, indent=2)
- return ToolResult(
- title=f"点类别查询 - {len(points)} 个点",
- output=output,
- long_term_memory=f"查询到 {len(result)} 个点的类别信息",
- )
- def _search_point_by_class(class_paths: List[str]) -> Dict[str, Any]:
- """
- 根据类别查找属于该类别的所有点
- Args:
- class_paths: 类别路径列表,如 ["关键点_形式"]
- top_k: 返回前 K 个点(按频率排序)
- Returns:
- 包含该类别所有点的字典
- """
- graph = _load_graph_full()
- data = []
- for class_path in class_paths:
- points = graph.get(class_path, {})
- data.append({
- "class_path": class_path,
- "points": list(points.get("meta", {}).get("elements", {}).keys())
- })
- return data
- @tool(
- description="根据类别查找属于该类别的所有点。支持单个或多个类别路径。",
- display={
- "zh": {
- "name": "类别点检索",
- "params": {
- "class_paths": "类别路径数组"
- },
- },
- },
- )
- async def search_point_by_class(
- class_paths: List[str]
- ) -> ToolResult:
- """
- 根据类别查找属于该类别的所有点。
- Args:
- class_paths: 类别路径数组,如 ["关键点_形式"]
- top_k: 返回前 K 个点,默认 999
- Returns:
- ToolResult: 该类别的所有点
- """
- if not class_paths or len(class_paths) == 0:
- return ToolResult(
- title="类别点检索失败",
- output="",
- error="请提供类别路径",
- )
- try:
- result = _search_point_by_class(class_paths)
- except FileNotFoundError:
- return ToolResult(
- title="类别点检索失败",
- output=""
- )
- except Exception as e:
- return ToolResult(
- title="类别点检索失败",
- output=""
- )
- output = json.dumps(result, ensure_ascii=False, indent=2)
- return ToolResult(
- title=f"类别点检索 - {len(class_paths)} 个类别",
- output=output,
- long_term_memory=f"检索到类别的点数据"
- )
- if __name__ == "__main__":
- print(_search_point_by_class(["关键点_形式_架构>策略>行为体验"]))
|