search_library copy.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. """
  2. 关键点检索工具 - 根据输入的点在图数据库中查找所有关联的点
  3. 用于 Agent 执行时自主调取关联关键点数据。
  4. """
  5. import json
  6. import os
  7. from pathlib import Path
  8. from typing import Any, Dict, List, Optional
  9. from agent.tools import tool, ToolResult
  10. # 完整图数据库文件路径(包含 edges)
  11. GRAPH_FULL_DATA_PATH = os.getenv(
  12. "GRAPH_FULL_DATA_PATH",
  13. # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json")
  14. str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_max.json")
  15. )
  16. # 缓存图数据,避免重复加载
  17. _graph_full_cache: Optional[Dict[str, Any]] = None
  18. def _load_graph_full() -> Dict[str, Any]:
  19. """加载完整图数据(带缓存,包含 edges)"""
  20. global _graph_full_cache
  21. if _graph_full_cache is None:
  22. with open(GRAPH_FULL_DATA_PATH, 'r', encoding='utf-8') as f:
  23. _graph_full_cache = json.load(f)
  24. return _graph_full_cache
  25. def _remove_post_ids_recursively(data: Any) -> Any:
  26. """递归移除所有字典中的 post_ids 和 _post_ids 字段"""
  27. if isinstance(data, dict):
  28. cleaned = {}
  29. for key, value in data.items():
  30. # 跳过 post_ids 和 _post_ids 字段
  31. if key in ("post_ids", "_post_ids"):
  32. continue
  33. # 递归处理值
  34. cleaned[key] = _remove_post_ids_recursively(value)
  35. return cleaned
  36. elif isinstance(data, list):
  37. # 处理列表中的每个元素
  38. return [_remove_post_ids_recursively(item) for item in data]
  39. else:
  40. # 其他类型直接返回
  41. return data
  42. def _remove_post_ids_from_edges(edges: Dict[str, Any]) -> Dict[str, Any]:
  43. """移除 edges 中所有的 post_ids 和 _post_ids 字段(递归)"""
  44. return _remove_post_ids_recursively(edges)
  45. def _filter_top_edges_by_confidence(edges: Dict[str, Any], top_k: int = 999) -> Dict[str, Any]:
  46. """
  47. 筛选置信度最高的前 K 条边,并移除所有 post_ids 字段
  48. Args:
  49. edges: 边数据字典
  50. top_k: 返回前 K 条边,默认 10
  51. Returns:
  52. 筛选后的边数据字典(已移除所有 post_ids 和 _post_ids 字段)
  53. """
  54. if not edges:
  55. return {}
  56. # 先移除 _post_ids
  57. cleaned_edges = _remove_post_ids_from_edges(edges)
  58. # 提取所有边及其置信度
  59. edge_list = []
  60. for edge_name, edge_data in cleaned_edges.items():
  61. if isinstance(edge_data, dict):
  62. # 从 co_in_post 中提取置信度
  63. confidence = 0.0
  64. if "co_in_post" in edge_data and isinstance(edge_data["co_in_post"], dict):
  65. confidence = edge_data["co_in_post"].get("confidence", 0.0)
  66. edge_list.append({
  67. "name": edge_name,
  68. "data": edge_data,
  69. "confidence": confidence
  70. })
  71. # 按置信度降序排序
  72. edge_list.sort(key=lambda x: x["confidence"], reverse=True)
  73. # 取前 top_k 条
  74. top_edges = edge_list[:top_k]
  75. # 重新构建边字典
  76. result_edges = {}
  77. for edge_item in top_edges:
  78. result_edges[edge_item["name"]] = edge_item["data"]
  79. return result_edges
  80. def _search_points_by_element_from_full(
  81. element_value: str,
  82. element_type: str,
  83. top_k: int = 9999
  84. ) -> Dict[str, Any]:
  85. """
  86. 根据元素值和类型在完整图数据库的 elements 字段中查找匹配的点
  87. Args:
  88. element_value: 元素值,如 "标准化", "懒人妻子"
  89. element_type: 元素类型,"实质" / "形式" / "意图"
  90. top_k: 返回前 K 个点(按频率排序)
  91. Returns:
  92. 包含匹配点完整信息的字典(每个点包括置信度最高的10条边,已移除 _post_ids)
  93. """
  94. graph = _load_graph_full()
  95. matched_points = []
  96. # 遍历图中所有点
  97. for point_name, point_data in graph.items():
  98. meta = point_data.get("meta", {})
  99. elements = meta.get("elements", {})
  100. dimension = meta.get("dimension")
  101. # 检查:元素值在 elements 中 AND dimension 匹配 element_type
  102. if element_value in elements and dimension == element_type:
  103. # 筛选置信度最高的前10条边,并移除 _post_ids
  104. top_edges = _filter_top_edges_by_confidence(point_data.get("edges", {}), top_k)
  105. # 返回结构与 search_point_by_path_from_full_all_levels 保持一致
  106. point_info = {
  107. "point": point_name,
  108. "point_type": meta.get("point_type"),
  109. "dimension": dimension,
  110. "point_path": meta.get("path"),
  111. "frequency_in_posts": meta.get("frequency_in_posts", 0),
  112. "elements": elements,
  113. "edge_count": len(top_edges),
  114. "edges": top_edges
  115. }
  116. matched_points.append(point_info)
  117. if not matched_points:
  118. return {
  119. "found": False,
  120. "element_value": element_value,
  121. "element_type": element_type,
  122. "message": f"未找到匹配的点: element_value={element_value}, element_type={element_type}"
  123. }
  124. # 按频率降序排序,取前 top_k 个
  125. matched_points.sort(key=lambda x: x["frequency_in_posts"], reverse=True)
  126. matched_points = matched_points[:top_k]
  127. return {
  128. "found": True,
  129. "element_value": element_value,
  130. "element_type": element_type,
  131. "total_matched_count": len(matched_points),
  132. "returned_count": len(matched_points),
  133. "matched_points": matched_points
  134. }
  135. def _search_point_by_path_from_full(path: str) -> Dict[str, Any]:
  136. """
  137. 根据完整路径在完整图数据库中查找点
  138. Args:
  139. path: 点的完整路径,如 "关键点_形式_架构>逻辑>逻辑架构>组织逻辑>框架规划>结构设计"
  140. Returns:
  141. 包含该点完整信息的字典(包括置信度最高的10条边,已移除 _post_ids)
  142. """
  143. graph = _load_graph_full()
  144. if path not in graph:
  145. return {
  146. "found": False,
  147. "path": path,
  148. "message": f"未找到路径: {path}"
  149. }
  150. point_data = graph[path]
  151. meta = point_data.get("meta", {})
  152. # 筛选置信度最高的前10条边,并移除 _post_ids
  153. top_edges = _filter_top_edges_by_confidence(point_data.get("edges", {}))
  154. return {
  155. "found": True,
  156. "path": path,
  157. "point_type": meta.get("point_type"),
  158. "dimension": meta.get("dimension"),
  159. "point_path": meta.get("path"),
  160. "frequency_in_posts": meta.get("frequency_in_posts"),
  161. "elements": meta.get("elements", {}),
  162. "edge_count": len(top_edges),
  163. "edges": top_edges
  164. }
  165. @tool(
  166. description="根据点值和点类型在完整图数据库中查找匹配的点,返回包含置信度最高10条边的完整数据。",
  167. display={
  168. "zh": {
  169. "name": "元素类型完整检索",
  170. "params": [
  171. {
  172. "point_value": "点的值",
  173. "point_type": "点的类型(实质/形式/意图)"
  174. }
  175. ],
  176. },
  177. },
  178. )
  179. async def search_class_by_point(
  180. element_value: str,
  181. element_type: str
  182. ) -> ToolResult:
  183. """
  184. 根据元素值和类型在完整图数据库中检索点,返回包含置信度最高10条边的完整数据。
  185. Args:
  186. element_value: 元素名称,如 "标准化", "懒人妻子"
  187. element_type: 元素类型,"实质" / "形式" / "意图"
  188. top_k: 返回前 K 个点,默认 10
  189. Returns:
  190. ToolResult: 匹配点的完整数据(包括置信度最高的10条边)
  191. """
  192. if not element_value:
  193. return ToolResult(
  194. title="元素类型检索失败",
  195. output="",
  196. error="请提供元素值",
  197. )
  198. if element_type not in ["实质", "形式", "意图"]:
  199. return ToolResult(
  200. title="元素类型检索失败",
  201. output="",
  202. error=f"元素类型必须是 '实质'、'形式' 或 '意图',当前值: {element_type}",
  203. )
  204. try:
  205. result = _search_points_by_element_from_full(element_value, element_type, top_k)
  206. except FileNotFoundError:
  207. return ToolResult(
  208. title="元素类型检索失败",
  209. output="",
  210. error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
  211. )
  212. except Exception as e:
  213. return ToolResult(
  214. title="元素类型检索失败",
  215. output="",
  216. error=f"检索异常: {str(e)}",
  217. )
  218. if not result["found"]:
  219. return ToolResult(
  220. title="元素类型检索",
  221. output=json.dumps(
  222. {
  223. "message": result["message"],
  224. "element_value": element_value,
  225. "element_type": element_type
  226. },
  227. ensure_ascii=False,
  228. indent=2
  229. ),
  230. )
  231. # 格式化输出
  232. output_data = {
  233. "element_value": result["element_value"],
  234. "element_type": result["element_type"],
  235. "total_matched_count": result["total_matched_count"],
  236. "returned_count": result["returned_count"],
  237. "matched_points": result["matched_points"]
  238. }
  239. output = json.dumps(output_data, ensure_ascii=False, indent=2)
  240. return ToolResult(
  241. title=f"元素类型检索 - {element_value} ({element_type})",
  242. output=output,
  243. long_term_memory=f"检索到 {result['returned_count']} 个匹配点,元素值: {element_value}, 类型: {element_type}",
  244. )
  245. @tool(
  246. description="根据完整路径在完整图数据库中查找点,返回包含置信度最高10条边的完整数据。",
  247. display={
  248. "zh": {
  249. "name": "路径完整检索",
  250. "params": {
  251. "path": "点的完整路径",
  252. },
  253. },
  254. },
  255. )
  256. async def search_point_by_path_from_full_all_levels(path: str) -> ToolResult:
  257. """
  258. 根据完整路径在完整图数据库中检索点,返回包含置信度最高10条边的完整数据。
  259. Args:
  260. path: 点的完整路径,如 "关键点_形式_架构>逻辑>逻辑架构>组织逻辑>框架规划>结构设计"
  261. Returns:
  262. ToolResult: 点的完整数据(包括置信度最高的10条边)
  263. """
  264. if not path:
  265. return ToolResult(
  266. title="路径检索失败",
  267. output="",
  268. error="请提供路径",
  269. )
  270. try:
  271. result = _search_point_by_path_from_full(path)
  272. except FileNotFoundError:
  273. return ToolResult(
  274. title="路径检索失败",
  275. output="",
  276. error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
  277. )
  278. except Exception as e:
  279. return ToolResult(
  280. title="路径检索失败",
  281. output="",
  282. error=f"检索异常: {str(e)}",
  283. )
  284. if not result["found"]:
  285. return ToolResult(
  286. title="路径检索",
  287. output=json.dumps(
  288. {"message": result["message"], "path": path},
  289. ensure_ascii=False,
  290. indent=2
  291. ),
  292. )
  293. # 格式化输出
  294. output_data = {
  295. "path": result["path"],
  296. "point_type": result["point_type"],
  297. "dimension": result["dimension"],
  298. "point_path": result["point_path"],
  299. "frequency_in_posts": result["frequency_in_posts"],
  300. "elements": result["elements"],
  301. "edge_count": result["edge_count"],
  302. "edges": result["edges"]
  303. }
  304. output = json.dumps(output_data, ensure_ascii=False, indent=2)
  305. return ToolResult(
  306. title=f"路径检索 - {path}",
  307. output=output,
  308. long_term_memory=f"检索到路径 {path} 的完整信息,包含置信度最高的 {result['edge_count']} 条边",
  309. )