search_library.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. """
  2. 关键点检索工具 - 根据输入的点在图数据库中查找所有关联的点
  3. 用于 Agent 执行时自主调取关联关键点数据。
  4. """
  5. import json
  6. import os
  7. from pathlib import Path
  8. from typing import Any, Dict, List, Optional
  9. from agent.tools import tool, ToolResult
  10. # 图数据库文件路径
  11. GRAPH_DATA_PATH = os.getenv(
  12. "GRAPH_DATA_PATH",
  13. str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_point_type_only_all_levels.json")
  14. )
  15. # 完整图数据库文件路径(包含 edges)
  16. GRAPH_FULL_DATA_PATH = os.getenv(
  17. "GRAPH_FULL_DATA_PATH",
  18. str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json")
  19. )
  20. # 缓存图数据,避免重复加载
  21. _graph_cache: Optional[Dict[str, Any]] = None
  22. _graph_full_cache: Optional[Dict[str, Any]] = None
  23. def _load_graph() -> Dict[str, Any]:
  24. """加载图数据(带缓存)"""
  25. global _graph_cache
  26. if _graph_cache is None:
  27. with open(GRAPH_DATA_PATH, 'r', encoding='utf-8') as f:
  28. _graph_cache = json.load(f)
  29. return _graph_cache
  30. def _load_graph_full() -> Dict[str, Any]:
  31. """加载完整图数据(带缓存,包含 edges)"""
  32. global _graph_full_cache
  33. if _graph_full_cache is None:
  34. with open(GRAPH_FULL_DATA_PATH, 'r', encoding='utf-8') as f:
  35. _graph_full_cache = json.load(f)
  36. return _graph_full_cache
  37. def _remove_post_ids_from_edges(edges: Dict[str, Any]) -> Dict[str, Any]:
  38. """移除 edges 中的 _post_ids 字段"""
  39. if not edges:
  40. return edges
  41. cleaned_edges = {}
  42. for edge_name, edge_data in edges.items():
  43. if isinstance(edge_data, dict):
  44. # 移除 _post_ids 字段
  45. cleaned_data = {k: v for k, v in edge_data.items() if k != "_post_ids"}
  46. cleaned_edges[edge_name] = cleaned_data
  47. else:
  48. cleaned_edges[edge_name] = edge_data
  49. return cleaned_edges
  50. def _search_points_by_element_from_full(
  51. element_value: str,
  52. element_type: str,
  53. top_k: int = 10
  54. ) -> Dict[str, Any]:
  55. """
  56. 根据元素值和类型在完整图数据库的 elements 字段中查找匹配的点
  57. Args:
  58. element_value: 元素值,如 "标准化", "懒人妻子"
  59. element_type: 元素类型,"实质" / "形式" / "意图"
  60. top_k: 返回前 K 个点(按频率排序)
  61. Returns:
  62. 包含匹配点完整信息的字典(包括 edges,已移除 _post_ids)
  63. """
  64. graph = _load_graph_full()
  65. matched_points = []
  66. # 遍历图中所有点
  67. for point_name, point_data in graph.items():
  68. meta = point_data.get("meta", {})
  69. elements = meta.get("elements", {})
  70. dimension = meta.get("dimension")
  71. # 检查:元素值在 elements 中 AND dimension 匹配 element_type
  72. if element_value in elements and dimension == element_type:
  73. # 移除 edges 中的 _post_ids
  74. cleaned_edges = _remove_post_ids_from_edges(point_data.get("edges", {}))
  75. # 返回结构与 search_point_by_path_from_full_all_levels 保持一致
  76. point_info = {
  77. "point": point_name,
  78. "point_type": meta.get("point_type"),
  79. "dimension": dimension,
  80. "point_path": meta.get("path"),
  81. "frequency_in_posts": meta.get("frequency_in_posts", 0),
  82. "elements": elements,
  83. "edge_count": len(cleaned_edges),
  84. "edges": cleaned_edges
  85. }
  86. matched_points.append(point_info)
  87. if not matched_points:
  88. return {
  89. "found": False,
  90. "element_value": element_value,
  91. "element_type": element_type,
  92. "message": f"未找到匹配的点: element_value={element_value}, element_type={element_type}"
  93. }
  94. # 按频率降序排序,取前 top_k 个
  95. matched_points.sort(key=lambda x: x["frequency_in_posts"], reverse=True)
  96. matched_points = matched_points[:top_k]
  97. return {
  98. "found": True,
  99. "element_value": element_value,
  100. "element_type": element_type,
  101. "total_matched_count": len(matched_points),
  102. "returned_count": len(matched_points),
  103. "matched_points": matched_points
  104. }
  105. def _search_point_by_path_from_full(path: str) -> Dict[str, Any]:
  106. """
  107. 根据完整路径在完整图数据库中查找点
  108. Args:
  109. path: 点的完整路径,如 "关键点_形式_架构>逻辑>逻辑架构>组织逻辑>框架规划>结构设计"
  110. Returns:
  111. 包含该点完整信息的字典(包括 edges,已移除 _post_ids)
  112. """
  113. graph = _load_graph_full()
  114. if path not in graph:
  115. return {
  116. "found": False,
  117. "path": path,
  118. "message": f"未找到路径: {path}"
  119. }
  120. point_data = graph[path]
  121. meta = point_data.get("meta", {})
  122. # 移除 edges 中的 _post_ids
  123. cleaned_edges = _remove_post_ids_from_edges(point_data.get("edges", {}))
  124. return {
  125. "found": True,
  126. "path": path,
  127. "point_type": meta.get("point_type"),
  128. "dimension": meta.get("dimension"),
  129. "point_path": meta.get("path"),
  130. "frequency_in_posts": meta.get("frequency_in_posts"),
  131. "elements": meta.get("elements", {}),
  132. "edge_count": len(cleaned_edges),
  133. "edges": cleaned_edges
  134. }
  135. @tool(
  136. description="根据元素值和类型在完整图数据库中查找匹配的点,返回包含边信息的完整数据。",
  137. display={
  138. "zh": {
  139. "name": "元素类型完整检索",
  140. "params": {
  141. "element_value": "元素值",
  142. "element_type": "元素类型(实质/形式/意图)",
  143. "top_k": "返回数量(默认10)",
  144. },
  145. },
  146. },
  147. )
  148. async def search_point_by_element_from_full_all_levels(
  149. element_value: str,
  150. element_type: str,
  151. top_k: int = 10
  152. ) -> ToolResult:
  153. """
  154. 根据元素值和类型在完整图数据库中检索点,返回包含边信息的完整数据。
  155. Args:
  156. element_value: 元素名称,如 "标准化", "懒人妻子"
  157. element_type: 元素类型,"实质" / "形式" / "意图"
  158. top_k: 返回前 K 个点,默认 10
  159. Returns:
  160. ToolResult: 匹配点的完整数据(包括 edges)
  161. """
  162. if not element_value:
  163. return ToolResult(
  164. title="元素类型检索失败",
  165. output="",
  166. error="请提供元素值",
  167. )
  168. if element_type not in ["实质", "形式", "意图"]:
  169. return ToolResult(
  170. title="元素类型检索失败",
  171. output="",
  172. error=f"元素类型必须是 '实质'、'形式' 或 '意图',当前值: {element_type}",
  173. )
  174. try:
  175. result = _search_points_by_element_from_full(element_value, element_type, top_k)
  176. except FileNotFoundError:
  177. return ToolResult(
  178. title="元素类型检索失败",
  179. output="",
  180. error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
  181. )
  182. except Exception as e:
  183. return ToolResult(
  184. title="元素类型检索失败",
  185. output="",
  186. error=f"检索异常: {str(e)}",
  187. )
  188. if not result["found"]:
  189. return ToolResult(
  190. title="元素类型检索",
  191. output=json.dumps(
  192. {
  193. "message": result["message"],
  194. "element_value": element_value,
  195. "element_type": element_type
  196. },
  197. ensure_ascii=False,
  198. indent=2
  199. ),
  200. )
  201. # 格式化输出
  202. output_data = {
  203. "element_value": result["element_value"],
  204. "element_type": result["element_type"],
  205. "total_matched_count": result["total_matched_count"],
  206. "returned_count": result["returned_count"],
  207. "matched_points": result["matched_points"]
  208. }
  209. output = json.dumps(output_data, ensure_ascii=False, indent=2)
  210. return ToolResult(
  211. title=f"元素类型检索 - {element_value} ({element_type})",
  212. output=output,
  213. long_term_memory=f"检索到 {result['returned_count']} 个匹配点,元素值: {element_value}, 类型: {element_type}",
  214. )
  215. @tool(
  216. description="根据完整路径在完整图数据库中查找点,返回包含边信息的完整数据。",
  217. display={
  218. "zh": {
  219. "name": "路径完整检索",
  220. "params": {
  221. "path": "点的完整路径",
  222. },
  223. },
  224. },
  225. )
  226. async def search_point_by_path_from_full_all_levels(path: str) -> ToolResult:
  227. """
  228. 根据完整路径在完整图数据库中检索点,返回包含边信息的完整数据。
  229. Args:
  230. path: 点的完整路径,如 "关键点_形式_架构>逻辑>逻辑架构>组织逻辑>框架规划>结构设计"
  231. Returns:
  232. ToolResult: 点的完整数据(包括 edges)
  233. """
  234. if not path:
  235. return ToolResult(
  236. title="路径检索失败",
  237. output="",
  238. error="请提供路径",
  239. )
  240. try:
  241. result = _search_point_by_path_from_full(path)
  242. except FileNotFoundError:
  243. return ToolResult(
  244. title="路径检索失败",
  245. output="",
  246. error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
  247. )
  248. except Exception as e:
  249. return ToolResult(
  250. title="路径检索失败",
  251. output="",
  252. error=f"检索异常: {str(e)}",
  253. )
  254. if not result["found"]:
  255. return ToolResult(
  256. title="路径检索",
  257. output=json.dumps(
  258. {"message": result["message"], "path": path},
  259. ensure_ascii=False,
  260. indent=2
  261. ),
  262. )
  263. # 格式化输出
  264. output_data = {
  265. "path": result["path"],
  266. "point_type": result["point_type"],
  267. "dimension": result["dimension"],
  268. "point_path": result["point_path"],
  269. "frequency_in_posts": result["frequency_in_posts"],
  270. "elements": result["elements"],
  271. "edge_count": result["edge_count"],
  272. "edges": result["edges"]
  273. }
  274. output = json.dumps(output_data, ensure_ascii=False, indent=2)
  275. return ToolResult(
  276. title=f"路径检索 - {path}",
  277. output=output,
  278. long_term_memory=f"检索到路径 {path} 的完整信息,包含 {result['edge_count']} 条边",
  279. )