search_library.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. """
  2. 关键点检索工具 - 根据输入的点在图数据库中查找所有关联的点
  3. 用于 Agent 执行时自主调取关联关键点数据。
  4. """
  5. import json
  6. import os
  7. from pathlib import Path
  8. from typing import Any, Dict, List, Optional
  9. from agent.tools import tool, ToolResult
  10. # 图数据库文件路径
  11. GRAPH_DATA_PATH = os.getenv(
  12. "GRAPH_DATA_PATH",
  13. str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_point_type_only_all_levels.json")
  14. )
  15. # 完整图数据库文件路径(包含 edges)
  16. GRAPH_FULL_DATA_PATH = os.getenv(
  17. "GRAPH_FULL_DATA_PATH",
  18. # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json")
  19. str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_max.json")
  20. )
  21. # 缓存图数据,避免重复加载
  22. _graph_cache: Optional[Dict[str, Any]] = None
  23. _graph_full_cache: Optional[Dict[str, Any]] = None
  24. def _load_graph() -> Dict[str, Any]:
  25. """加载图数据(带缓存)"""
  26. global _graph_cache
  27. if _graph_cache is None:
  28. with open(GRAPH_DATA_PATH, 'r', encoding='utf-8') as f:
  29. _graph_cache = json.load(f)
  30. return _graph_cache
  31. def _load_graph_full() -> Dict[str, Any]:
  32. """加载完整图数据(带缓存,包含 edges)"""
  33. global _graph_full_cache
  34. if _graph_full_cache is None:
  35. with open(GRAPH_FULL_DATA_PATH, 'r', encoding='utf-8') as f:
  36. _graph_full_cache = json.load(f)
  37. return _graph_full_cache
  38. def _remove_post_ids_recursively(data: Any) -> Any:
  39. """递归移除所有字典中的 post_ids 和 _post_ids 字段"""
  40. if isinstance(data, dict):
  41. cleaned = {}
  42. for key, value in data.items():
  43. # 跳过 post_ids 和 _post_ids 字段
  44. if key in ("post_ids", "_post_ids"):
  45. continue
  46. # 递归处理值
  47. cleaned[key] = _remove_post_ids_recursively(value)
  48. return cleaned
  49. elif isinstance(data, list):
  50. # 处理列表中的每个元素
  51. return [_remove_post_ids_recursively(item) for item in data]
  52. else:
  53. # 其他类型直接返回
  54. return data
  55. def _remove_post_ids_from_edges(edges: Dict[str, Any]) -> Dict[str, Any]:
  56. """移除 edges 中所有的 post_ids 和 _post_ids 字段(递归)"""
  57. return _remove_post_ids_recursively(edges)
  58. def _filter_top_edges_by_confidence(edges: Dict[str, Any], top_k: int = 5) -> Dict[str, Any]:
  59. """
  60. 筛选置信度最高的前 K 条边,并移除所有 post_ids 字段
  61. Args:
  62. edges: 边数据字典
  63. top_k: 返回前 K 条边,默认 10
  64. Returns:
  65. 筛选后的边数据字典(已移除所有 post_ids 和 _post_ids 字段)
  66. """
  67. if not edges:
  68. return {}
  69. # 先移除 _post_ids
  70. cleaned_edges = _remove_post_ids_from_edges(edges)
  71. # 提取所有边及其置信度
  72. edge_list = []
  73. for edge_name, edge_data in cleaned_edges.items():
  74. if isinstance(edge_data, dict):
  75. # 从 co_in_post 中提取置信度
  76. confidence = 0.0
  77. if "co_in_post" in edge_data and isinstance(edge_data["co_in_post"], dict):
  78. confidence = edge_data["co_in_post"].get("confidence", 0.0)
  79. edge_list.append({
  80. "name": edge_name,
  81. "data": edge_data,
  82. "confidence": confidence
  83. })
  84. # 按置信度降序排序
  85. edge_list.sort(key=lambda x: x["confidence"], reverse=True)
  86. # 取前 top_k 条
  87. top_edges = edge_list[:top_k]
  88. # 重新构建边字典
  89. result_edges = {}
  90. for edge_item in top_edges:
  91. result_edges[edge_item["name"]] = edge_item["data"]
  92. return result_edges
  93. def _search_points_by_element_from_full(
  94. element_value: str,
  95. element_type: str,
  96. top_k: int = 9999
  97. ) -> Dict[str, Any]:
  98. """
  99. 根据元素值和类型在完整图数据库的 elements 字段中查找匹配的点
  100. Args:
  101. element_value: 元素值,如 "标准化", "懒人妻子"
  102. element_type: 元素类型,"实质" / "形式" / "意图"
  103. top_k: 返回前 K 个点(按频率排序)
  104. Returns:
  105. 包含匹配点完整信息的字典(每个点包括置信度最高的10条边,已移除 _post_ids)
  106. """
  107. graph = _load_graph_full()
  108. matched_points = []
  109. # 遍历图中所有点
  110. for point_name, point_data in graph.items():
  111. meta = point_data.get("meta", {})
  112. elements = meta.get("elements", {})
  113. dimension = meta.get("dimension")
  114. # 检查:元素值在 elements 中 AND dimension 匹配 element_type
  115. if element_value in elements and dimension == element_type:
  116. # 筛选置信度最高的前10条边,并移除 _post_ids
  117. top_edges = _filter_top_edges_by_confidence(point_data.get("edges", {}), top_k)
  118. # 返回结构与 search_point_by_path_from_full_all_levels 保持一致
  119. point_info = {
  120. "point": point_name,
  121. "point_type": meta.get("point_type"),
  122. "dimension": dimension,
  123. "point_path": meta.get("path"),
  124. "frequency_in_posts": meta.get("frequency_in_posts", 0),
  125. "elements": elements,
  126. "edge_count": len(top_edges),
  127. "edges": top_edges
  128. }
  129. matched_points.append(point_info)
  130. if not matched_points:
  131. return {
  132. "found": False,
  133. "element_value": element_value,
  134. "element_type": element_type,
  135. "message": f"未找到匹配的点: element_value={element_value}, element_type={element_type}"
  136. }
  137. # 按频率降序排序,取前 top_k 个
  138. matched_points.sort(key=lambda x: x["frequency_in_posts"], reverse=True)
  139. matched_points = matched_points[:top_k]
  140. return {
  141. "found": True,
  142. "element_value": element_value,
  143. "element_type": element_type,
  144. "total_matched_count": len(matched_points),
  145. "returned_count": len(matched_points),
  146. "matched_points": matched_points
  147. }
  148. def _search_point_by_path_from_full(path: str) -> Dict[str, Any]:
  149. """
  150. 根据完整路径在完整图数据库中查找点
  151. Args:
  152. path: 点的完整路径,如 "关键点_形式_架构>逻辑>逻辑架构>组织逻辑>框架规划>结构设计"
  153. Returns:
  154. 包含该点完整信息的字典(包括置信度最高的10条边,已移除 _post_ids)
  155. """
  156. graph = _load_graph_full()
  157. if path not in graph:
  158. return {
  159. "found": False,
  160. "path": path,
  161. "message": f"未找到路径: {path}"
  162. }
  163. point_data = graph[path]
  164. meta = point_data.get("meta", {})
  165. # 筛选置信度最高的前10条边,并移除 _post_ids
  166. top_edges = _filter_top_edges_by_confidence(point_data.get("edges", {}))
  167. return {
  168. "found": True,
  169. "path": path,
  170. "point_type": meta.get("point_type"),
  171. "dimension": meta.get("dimension"),
  172. "point_path": meta.get("path"),
  173. "frequency_in_posts": meta.get("frequency_in_posts"),
  174. "elements": meta.get("elements", {}),
  175. "edge_count": len(top_edges),
  176. "edges": top_edges
  177. }
  178. @tool(
  179. description="根据元素值和类型在完整图数据库中查找匹配的点,返回包含置信度最高10条边的完整数据。",
  180. display={
  181. "zh": {
  182. "name": "元素类型完整检索",
  183. "params": {
  184. "element_value": "元素值",
  185. "element_type": "元素类型(实质/形式/意图)",
  186. "top_k": "返回数量(默认10)",
  187. },
  188. },
  189. },
  190. )
  191. async def search_point_by_element_from_full_all_levels(
  192. element_value: str,
  193. element_type: str,
  194. top_k: int = 10
  195. ) -> ToolResult:
  196. """
  197. 根据元素值和类型在完整图数据库中检索点,返回包含置信度最高10条边的完整数据。
  198. Args:
  199. element_value: 元素名称,如 "标准化", "懒人妻子"
  200. element_type: 元素类型,"实质" / "形式" / "意图"
  201. top_k: 返回前 K 个点,默认 10
  202. Returns:
  203. ToolResult: 匹配点的完整数据(包括置信度最高的10条边)
  204. """
  205. if not element_value:
  206. return ToolResult(
  207. title="元素类型检索失败",
  208. output="",
  209. error="请提供元素值",
  210. )
  211. if element_type not in ["实质", "形式", "意图"]:
  212. return ToolResult(
  213. title="元素类型检索失败",
  214. output="",
  215. error=f"元素类型必须是 '实质'、'形式' 或 '意图',当前值: {element_type}",
  216. )
  217. try:
  218. result = _search_points_by_element_from_full(element_value, element_type, top_k)
  219. except FileNotFoundError:
  220. return ToolResult(
  221. title="元素类型检索失败",
  222. output="",
  223. error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
  224. )
  225. except Exception as e:
  226. return ToolResult(
  227. title="元素类型检索失败",
  228. output="",
  229. error=f"检索异常: {str(e)}",
  230. )
  231. if not result["found"]:
  232. return ToolResult(
  233. title="元素类型检索",
  234. output=json.dumps(
  235. {
  236. "message": result["message"],
  237. "element_value": element_value,
  238. "element_type": element_type
  239. },
  240. ensure_ascii=False,
  241. indent=2
  242. ),
  243. )
  244. # 格式化输出
  245. output_data = {
  246. "element_value": result["element_value"],
  247. "element_type": result["element_type"],
  248. "total_matched_count": result["total_matched_count"],
  249. "returned_count": result["returned_count"],
  250. "matched_points": result["matched_points"]
  251. }
  252. output = json.dumps(output_data, ensure_ascii=False, indent=2)
  253. return ToolResult(
  254. title=f"元素类型检索 - {element_value} ({element_type})",
  255. output=output,
  256. long_term_memory=f"检索到 {result['returned_count']} 个匹配点,元素值: {element_value}, 类型: {element_type}",
  257. )
  258. @tool(
  259. description="根据完整路径在完整图数据库中查找点,返回包含置信度最高10条边的完整数据。",
  260. display={
  261. "zh": {
  262. "name": "路径完整检索",
  263. "params": {
  264. "path": "点的完整路径",
  265. },
  266. },
  267. },
  268. )
  269. async def search_point_by_path_from_full_all_levels(path: str) -> ToolResult:
  270. """
  271. 根据完整路径在完整图数据库中检索点,返回包含置信度最高10条边的完整数据。
  272. Args:
  273. path: 点的完整路径,如 "关键点_形式_架构>逻辑>逻辑架构>组织逻辑>框架规划>结构设计"
  274. Returns:
  275. ToolResult: 点的完整数据(包括置信度最高的10条边)
  276. """
  277. if not path:
  278. return ToolResult(
  279. title="路径检索失败",
  280. output="",
  281. error="请提供路径",
  282. )
  283. try:
  284. result = _search_point_by_path_from_full(path)
  285. except FileNotFoundError:
  286. return ToolResult(
  287. title="路径检索失败",
  288. output="",
  289. error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
  290. )
  291. except Exception as e:
  292. return ToolResult(
  293. title="路径检索失败",
  294. output="",
  295. error=f"检索异常: {str(e)}",
  296. )
  297. if not result["found"]:
  298. return ToolResult(
  299. title="路径检索",
  300. output=json.dumps(
  301. {"message": result["message"], "path": path},
  302. ensure_ascii=False,
  303. indent=2
  304. ),
  305. )
  306. # 格式化输出
  307. output_data = {
  308. "path": result["path"],
  309. "point_type": result["point_type"],
  310. "dimension": result["dimension"],
  311. "point_path": result["point_path"],
  312. "frequency_in_posts": result["frequency_in_posts"],
  313. "elements": result["elements"],
  314. "edge_count": result["edge_count"],
  315. "edges": result["edges"]
  316. }
  317. output = json.dumps(output_data, ensure_ascii=False, indent=2)
  318. return ToolResult(
  319. title=f"路径检索 - {path}",
  320. output=output,
  321. long_term_memory=f"检索到路径 {path} 的完整信息,包含置信度最高的 {result['edge_count']} 条边",
  322. )