| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368 |
- """
- 关键点检索工具 - 根据输入的点在图数据库中查找所有关联的点
- 用于 Agent 执行时自主调取关联关键点数据。
- """
- import json
- import os
- from pathlib import Path
- from typing import Any, Dict, List, Optional
- from agent.tools import tool, ToolResult
- # 完整图数据库文件路径(包含 edges)
- GRAPH_FULL_DATA_PATH = os.getenv(
- "GRAPH_FULL_DATA_PATH",
- # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json")
- str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_max.json")
- )
- # 缓存图数据,避免重复加载
- _graph_full_cache: Optional[Dict[str, Any]] = None
- def _load_graph_full() -> Dict[str, Any]:
- """加载完整图数据(带缓存,包含 edges)"""
- global _graph_full_cache
- if _graph_full_cache is None:
- with open(GRAPH_FULL_DATA_PATH, 'r', encoding='utf-8') as f:
- _graph_full_cache = json.load(f)
- return _graph_full_cache
- def _remove_post_ids_recursively(data: Any) -> Any:
- """递归移除所有字典中的 post_ids 和 _post_ids 字段"""
- if isinstance(data, dict):
- cleaned = {}
- for key, value in data.items():
- # 跳过 post_ids 和 _post_ids 字段
- if key in ("post_ids", "_post_ids"):
- continue
- # 递归处理值
- cleaned[key] = _remove_post_ids_recursively(value)
- return cleaned
- elif isinstance(data, list):
- # 处理列表中的每个元素
- return [_remove_post_ids_recursively(item) for item in data]
- else:
- # 其他类型直接返回
- return data
- def _remove_post_ids_from_edges(edges: Dict[str, Any]) -> Dict[str, Any]:
- """移除 edges 中所有的 post_ids 和 _post_ids 字段(递归)"""
- return _remove_post_ids_recursively(edges)
- def _filter_top_edges_by_confidence(edges: Dict[str, Any], top_k: int = 999) -> Dict[str, Any]:
- """
- 筛选置信度最高的前 K 条边,并移除所有 post_ids 字段
- Args:
- edges: 边数据字典
- top_k: 返回前 K 条边,默认 10
- Returns:
- 筛选后的边数据字典(已移除所有 post_ids 和 _post_ids 字段)
- """
- if not edges:
- return {}
- # 先移除 _post_ids
- cleaned_edges = _remove_post_ids_from_edges(edges)
- # 提取所有边及其置信度
- edge_list = []
- for edge_name, edge_data in cleaned_edges.items():
- if isinstance(edge_data, dict):
- # 从 co_in_post 中提取置信度
- confidence = 0.0
- if "co_in_post" in edge_data and isinstance(edge_data["co_in_post"], dict):
- confidence = edge_data["co_in_post"].get("confidence", 0.0)
- edge_list.append({
- "name": edge_name,
- "data": edge_data,
- "confidence": confidence
- })
- # 按置信度降序排序
- edge_list.sort(key=lambda x: x["confidence"], reverse=True)
- # 取前 top_k 条
- top_edges = edge_list[:top_k]
- # 重新构建边字典
- result_edges = {}
- for edge_item in top_edges:
- result_edges[edge_item["name"]] = edge_item["data"]
- return result_edges
- def _search_points_by_element_from_full(
- element_value: str,
- element_type: str,
- top_k: int = 9999
- ) -> Dict[str, Any]:
- """
- 根据元素值和类型在完整图数据库的 elements 字段中查找匹配的点
- Args:
- element_value: 元素值,如 "标准化", "懒人妻子"
- element_type: 元素类型,"实质" / "形式" / "意图"
- top_k: 返回前 K 个点(按频率排序)
- Returns:
- 包含匹配点完整信息的字典(每个点包括置信度最高的10条边,已移除 _post_ids)
- """
- graph = _load_graph_full()
- matched_points = []
- # 遍历图中所有点
- for point_name, point_data in graph.items():
- meta = point_data.get("meta", {})
- elements = meta.get("elements", {})
- dimension = meta.get("dimension")
- # 检查:元素值在 elements 中 AND dimension 匹配 element_type
- if element_value in elements and dimension == element_type:
- # 筛选置信度最高的前10条边,并移除 _post_ids
- top_edges = _filter_top_edges_by_confidence(point_data.get("edges", {}), top_k)
- # 返回结构与 search_point_by_path_from_full_all_levels 保持一致
- point_info = {
- "point": point_name,
- "point_type": meta.get("point_type"),
- "dimension": dimension,
- "point_path": meta.get("path"),
- "frequency_in_posts": meta.get("frequency_in_posts", 0),
- "elements": elements,
- "edge_count": len(top_edges),
- "edges": top_edges
- }
- matched_points.append(point_info)
- if not matched_points:
- return {
- "found": False,
- "element_value": element_value,
- "element_type": element_type,
- "message": f"未找到匹配的点: element_value={element_value}, element_type={element_type}"
- }
- # 按频率降序排序,取前 top_k 个
- matched_points.sort(key=lambda x: x["frequency_in_posts"], reverse=True)
- matched_points = matched_points[:top_k]
- return {
- "found": True,
- "element_value": element_value,
- "element_type": element_type,
- "total_matched_count": len(matched_points),
- "returned_count": len(matched_points),
- "matched_points": matched_points
- }
- def _search_point_by_path_from_full(path: str) -> Dict[str, Any]:
- """
- 根据完整路径在完整图数据库中查找点
- Args:
- path: 点的完整路径,如 "关键点_形式_架构>逻辑>逻辑架构>组织逻辑>框架规划>结构设计"
- Returns:
- 包含该点完整信息的字典(包括置信度最高的10条边,已移除 _post_ids)
- """
- graph = _load_graph_full()
- if path not in graph:
- return {
- "found": False,
- "path": path,
- "message": f"未找到路径: {path}"
- }
- point_data = graph[path]
- meta = point_data.get("meta", {})
- # 筛选置信度最高的前10条边,并移除 _post_ids
- top_edges = _filter_top_edges_by_confidence(point_data.get("edges", {}))
- return {
- "found": True,
- "path": path,
- "point_type": meta.get("point_type"),
- "dimension": meta.get("dimension"),
- "point_path": meta.get("path"),
- "frequency_in_posts": meta.get("frequency_in_posts"),
- "elements": meta.get("elements", {}),
- "edge_count": len(top_edges),
- "edges": top_edges
- }
- @tool(
- description="根据点值和点类型在完整图数据库中查找匹配的点,返回包含置信度最高10条边的完整数据。",
- display={
- "zh": {
- "name": "元素类型完整检索",
- "params": [
- {
- "point_value": "点的值",
- "point_type": "点的类型(实质/形式/意图)"
- }
- ],
- },
- },
- )
- async def search_class_by_point(
- element_value: str,
- element_type: str
- ) -> ToolResult:
- """
- 根据元素值和类型在完整图数据库中检索点,返回包含置信度最高10条边的完整数据。
- Args:
- element_value: 元素名称,如 "标准化", "懒人妻子"
- element_type: 元素类型,"实质" / "形式" / "意图"
- top_k: 返回前 K 个点,默认 10
- Returns:
- ToolResult: 匹配点的完整数据(包括置信度最高的10条边)
- """
- if not element_value:
- return ToolResult(
- title="元素类型检索失败",
- output="",
- error="请提供元素值",
- )
- if element_type not in ["实质", "形式", "意图"]:
- return ToolResult(
- title="元素类型检索失败",
- output="",
- error=f"元素类型必须是 '实质'、'形式' 或 '意图',当前值: {element_type}",
- )
- try:
- result = _search_points_by_element_from_full(element_value, element_type, top_k)
- except FileNotFoundError:
- return ToolResult(
- title="元素类型检索失败",
- output="",
- error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
- )
- except Exception as e:
- return ToolResult(
- title="元素类型检索失败",
- output="",
- error=f"检索异常: {str(e)}",
- )
- if not result["found"]:
- return ToolResult(
- title="元素类型检索",
- output=json.dumps(
- {
- "message": result["message"],
- "element_value": element_value,
- "element_type": element_type
- },
- ensure_ascii=False,
- indent=2
- ),
- )
- # 格式化输出
- output_data = {
- "element_value": result["element_value"],
- "element_type": result["element_type"],
- "total_matched_count": result["total_matched_count"],
- "returned_count": result["returned_count"],
- "matched_points": result["matched_points"]
- }
- output = json.dumps(output_data, ensure_ascii=False, indent=2)
- return ToolResult(
- title=f"元素类型检索 - {element_value} ({element_type})",
- output=output,
- long_term_memory=f"检索到 {result['returned_count']} 个匹配点,元素值: {element_value}, 类型: {element_type}",
- )
- @tool(
- description="根据完整路径在完整图数据库中查找点,返回包含置信度最高10条边的完整数据。",
- display={
- "zh": {
- "name": "路径完整检索",
- "params": {
- "path": "点的完整路径",
- },
- },
- },
- )
- async def search_point_by_path_from_full_all_levels(path: str) -> ToolResult:
- """
- 根据完整路径在完整图数据库中检索点,返回包含置信度最高10条边的完整数据。
- Args:
- path: 点的完整路径,如 "关键点_形式_架构>逻辑>逻辑架构>组织逻辑>框架规划>结构设计"
- Returns:
- ToolResult: 点的完整数据(包括置信度最高的10条边)
- """
- if not path:
- return ToolResult(
- title="路径检索失败",
- output="",
- error="请提供路径",
- )
- try:
- result = _search_point_by_path_from_full(path)
- except FileNotFoundError:
- return ToolResult(
- title="路径检索失败",
- output="",
- error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
- )
- except Exception as e:
- return ToolResult(
- title="路径检索失败",
- output="",
- error=f"检索异常: {str(e)}",
- )
- if not result["found"]:
- return ToolResult(
- title="路径检索",
- output=json.dumps(
- {"message": result["message"], "path": path},
- ensure_ascii=False,
- indent=2
- ),
- )
- # 格式化输出
- output_data = {
- "path": result["path"],
- "point_type": result["point_type"],
- "dimension": result["dimension"],
- "point_path": result["point_path"],
- "frequency_in_posts": result["frequency_in_posts"],
- "elements": result["elements"],
- "edge_count": result["edge_count"],
- "edges": result["edges"]
- }
- output = json.dumps(output_data, ensure_ascii=False, indent=2)
- return ToolResult(
- title=f"路径检索 - {path}",
- output=output,
- long_term_memory=f"检索到路径 {path} 的完整信息,包含置信度最高的 {result['edge_count']} 条边",
- )
|