search_library.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. """
  2. 关键点检索工具 - 根据输入的点在图数据库中查找所有关联的点
  3. 用于 Agent 执行时自主调取关联关键点数据。
  4. """
  5. import json
  6. import os
  7. from pathlib import Path
  8. from typing import Any, Dict, List, Optional
  9. from agent.tools import tool, ToolResult
  10. # 完整图数据库文件路径(包含 edges)
  11. GRAPH_FULL_DATA_PATH = os.getenv(
  12. "GRAPH_FULL_DATA_PATH",
  13. # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json")
  14. str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_max.json")
  15. )
  16. GRAPH_FULL_MAP_DATA_PATH = os.getenv(
  17. "GRAPH_FULL_MAP_DATA_PATH",
  18. # str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_all_levels.json")
  19. str(Path(__file__).parent.parent / "data/library/item_graph/item_graph_full_max_map.json")
  20. )
  21. # 缓存图数据,避免重复加载
  22. _graph_full_cache: Optional[Dict[str, Any]] = None
  23. _graph_full_map_cache: Optional[Dict[str, Any]] = None
  24. def _load_graph_full() -> Dict[str, Any]:
  25. """加载完整图数据(带缓存,包含 edges)"""
  26. global _graph_full_cache
  27. if _graph_full_cache is None:
  28. with open(GRAPH_FULL_DATA_PATH, 'r', encoding='utf-8') as f:
  29. _graph_full_cache = json.load(f)
  30. return _graph_full_cache
  31. def _load_graph_full_map() -> Dict[str, Any]:
  32. """加载完整图数据(带缓存,包含 edges)"""
  33. global _graph_full_map_cache
  34. if _graph_full_map_cache is None:
  35. with open(GRAPH_FULL_MAP_DATA_PATH, 'r', encoding='utf-8') as f:
  36. _graph_full_map_cache = json.load(f)
  37. return _graph_full_map_cache
  38. def _search_relation_class_by_class(class_paths: List[str], top_k: int = 5) -> Dict[str, Any]:
  39. """
  40. 根据类别查找与该类别相关的其他类别
  41. Args:
  42. class_paths: 类别名称列表,如 ["关键点_实质_理念>现象>社会>时空背景", "关键点_形式_架构>策略>行为体验"]
  43. top_k: 每个类别返回前 K 个关联类别,默认 5
  44. Returns:
  45. 包含每个类别及其关联类别的列表
  46. """
  47. graph = _load_graph_full()
  48. results = []
  49. for class_path in class_paths:
  50. value = graph.get(class_path, {})
  51. edges = value.get("edges", {})
  52. # 提取所有相关类别及其置信度
  53. related_classes = []
  54. for target_class_path, edge_data in edges.items():
  55. co_in_post = edge_data.get("co_in_post", {})
  56. confidence = co_in_post.get("confidence", 0.0)
  57. related_classes.append({
  58. "class_path": target_class_path,
  59. "confidence": confidence
  60. })
  61. # 按置信度降序排序,取前 top_k 个
  62. related_classes.sort(key=lambda x: x["confidence"], reverse=True)
  63. related_classes = related_classes[:top_k]
  64. results.append({
  65. "input_class_path": class_path,
  66. "related_class_paths": [item["class_path"] for item in related_classes]
  67. })
  68. return results
  69. @tool(
  70. description="根据类别查找与该类别相关的其他类别,返回关系统计数据。支持单个或多个类别路径。每个类别独立返回其关联类别。",
  71. display={
  72. "zh": {
  73. "name": "类别关系检索",
  74. "params": {
  75. "class_paths": "类别路径数组",
  76. "top_k": "每个类别返回数量(默认5)",
  77. },
  78. },
  79. },
  80. )
  81. async def search_relation_class_by_class(
  82. class_paths: List[str],
  83. top_k: int = 5
  84. ) -> ToolResult:
  85. """
  86. 根据类别查找与该类别相关的其他类别。
  87. Args:
  88. class_paths: 类别路径数组,如 ["关键点_形式_架构>策略>行为体验"] 或 ["关键点_形式_架构>策略>行为体验", "目的点_意图_分享"]
  89. top_k: 每个类别返回前 K 个相关类别,默认 5
  90. Returns:
  91. ToolResult: 每个类别及其关联类别
  92. """
  93. if not class_paths or len(class_paths) == 0:
  94. return ToolResult(
  95. title="类别关系检索失败",
  96. output="",
  97. error="请提供类别路径",
  98. )
  99. try:
  100. result = _search_relation_class_by_class(class_paths, top_k)
  101. except FileNotFoundError:
  102. return ToolResult(
  103. title="类别关系检索失败",
  104. output="",
  105. error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
  106. )
  107. except Exception as e:
  108. return ToolResult(
  109. title="类别关系检索失败",
  110. output="",
  111. error=f"检索异常: {str(e)}",
  112. )
  113. # 统计总共找到的关联类别数
  114. total_related = sum(len(item["related_class_paths"]) for item in result)
  115. output = json.dumps(result, ensure_ascii=False, indent=2)
  116. return ToolResult(
  117. title=f"类别关系检索 - {len(class_paths)} 个类别",
  118. output=output,
  119. long_term_memory=f"为 {len(class_paths)} 个类别检索到 {total_related} 个关联类别",
  120. )
  121. def _search_relation_point_by_point(points: List[Dict], top_k: int = 999) -> Dict[str, Any]:
  122. """
  123. 根据点的信息查找与该点相关的其他点
  124. Args:
  125. points: 点信息列表
  126. top_k: 返回前 K 个相关点
  127. Returns:
  128. 包含相关点及其关系数据的字典
  129. """
  130. graph = _load_graph_full()
  131. graph_map = _load_graph_full_map()
  132. all_results = []
  133. for point in points:
  134. point_value = point.get("point_value", "")
  135. point_type = point.get("point_type", "")
  136. dimension = point.get("dimension", "")
  137. accounts = point.get("accounts", [])
  138. # 先通过 map 找到完整的 point_path
  139. class_paths = []
  140. type_dict = graph_map.get(point_type, {})
  141. dim_dict = type_dict.get(dimension, {})
  142. for account in accounts:
  143. account_dict = dim_dict.get(account, {})
  144. class_path = account_dict.get(point_value, "")
  145. if class_path:
  146. class_paths.append(class_path)
  147. # 用找到的 point_paths 查找关联点
  148. all_related_points = {}
  149. for class_path in class_paths:
  150. if class_path not in graph:
  151. continue
  152. point_data = graph[class_path]
  153. edges = point_data.get("edges", {})
  154. # 提取所有相关点及其置信度
  155. for target_point, edge_data in edges.items():
  156. co_in_post = edge_data.get("co_in_post", {})
  157. confidence = co_in_post.get("confidence", 0.0)
  158. # 聚合相同目标点的数据
  159. if target_point not in all_related_points:
  160. all_related_points[target_point] = {
  161. "confidence": 0.0,
  162. "source_count": 0
  163. }
  164. all_related_points[target_point]["confidence"] += confidence
  165. all_related_points[target_point]["source_count"] += 1
  166. # 计算平均置信度并排序
  167. related_points = []
  168. for target_point, data in all_related_points.items():
  169. avg_confidence = data["confidence"] / data["source_count"] if data["source_count"] > 0 else 0.0
  170. related_points.append({
  171. "point_path": target_point,
  172. "confidence": avg_confidence
  173. })
  174. # 按置信度降序排序
  175. related_points.sort(key=lambda x: x["confidence"], reverse=True)
  176. related_points = related_points[:top_k]
  177. all_results.append({
  178. "input_point": point,
  179. "related_point_paths": [item["point_path"] for item in related_points]
  180. })
  181. return all_results
  182. @tool(
  183. description="根据点的信息查找与该点相关的其他点,返回关系数据。",
  184. display={
  185. "zh": {
  186. "name": "点关系检索",
  187. "params": {
  188. "points": [
  189. {
  190. "point_value": "点",
  191. "point_type": "点类型",
  192. "dimension": "维度",
  193. "accounts": ["账号名称"]
  194. }
  195. ],
  196. "top_k": "返回数量(默认999)",
  197. },
  198. },
  199. },
  200. )
  201. async def search_relation_point_by_point(
  202. points: List[Dict],
  203. top_k: int = 999
  204. ) -> ToolResult:
  205. """
  206. 根据点的信息查找与该点相关的其他点。
  207. Args:
  208. points: 点信息列表
  209. top_k: 返回前 K 个相关点,默认 999
  210. Returns:
  211. ToolResult: 相关点及其关系数据
  212. """
  213. if not points or len(points) == 0:
  214. return ToolResult(
  215. title="点关系检索失败",
  216. output="",
  217. error="请提供点信息",
  218. )
  219. try:
  220. result = _search_relation_point_by_point(points, top_k)
  221. except FileNotFoundError:
  222. return ToolResult(
  223. title="点关系检索失败",
  224. output="",
  225. error=f"图数据文件不存在: {GRAPH_FULL_DATA_PATH}",
  226. )
  227. except Exception as e:
  228. return ToolResult(
  229. title="点关系检索失败",
  230. output="",
  231. error=f"检索异常: {str(e)}",
  232. )
  233. output = json.dumps(result, ensure_ascii=False, indent=2)
  234. return ToolResult(
  235. title=f"点关系检索 - {len(points)} 个点",
  236. output=output,
  237. long_term_memory=f"检索到与输入点相关的点",
  238. )
  239. def _search_class_by_point(
  240. points: List[Dict]
  241. ) -> Dict[str, Any]:
  242. graph = _load_graph_full_map()
  243. data = []
  244. for point in points:
  245. point_value = point.get("point_value", "")
  246. point_type = point.get("point_type", "")
  247. dimension = point.get("dimension", "")
  248. type_dict = graph.get(point_type, {})
  249. dim_dict = type_dict.get(dimension, {})
  250. class_path = dim_dict.get(point_value, "")
  251. if class_path:
  252. data.append({
  253. "point": point,
  254. "class_path": class_path
  255. })
  256. return data
  257. @tool(
  258. description="根据点的属性查找该点所属的类别。",
  259. display={
  260. "zh": {
  261. "name": "点类别查询",
  262. "params": {
  263. "points": [
  264. {
  265. "point_value": "点",
  266. "point_type": "点类型",
  267. "dimension": "维度",
  268. "accounts": [
  269. "账号名称"
  270. ]
  271. }
  272. ]
  273. },
  274. },
  275. },
  276. )
  277. async def search_class_by_point(
  278. points: List[Dict]
  279. ) -> ToolResult:
  280. if not points:
  281. return ToolResult(
  282. title="点类别查询失败",
  283. output="",
  284. error="请提供 points, point_type 和 dimension",
  285. )
  286. try:
  287. result = _search_class_by_point(points)
  288. except FileNotFoundError:
  289. return ToolResult(
  290. title="点类别查询失败",
  291. output="",
  292. error=f"图数据文件不存在: {GRAPH_FULL_MAP_DATA_PATH}",
  293. )
  294. except Exception as e:
  295. return ToolResult(
  296. title="点类别查询失败",
  297. output="",
  298. error=f"检索异常: {str(e)}",
  299. )
  300. output = json.dumps(result, ensure_ascii=False, indent=2)
  301. return ToolResult(
  302. title=f"点类别查询 - {len(points)} 个点",
  303. output=output,
  304. long_term_memory=f"查询到 {len(result)} 个点的类别信息",
  305. )
  306. def _search_point_by_class(class_paths: List[str]) -> Dict[str, Any]:
  307. """
  308. 根据类别查找属于该类别的所有点
  309. Args:
  310. class_paths: 类别路径列表,如 ["关键点_形式"]
  311. top_k: 返回前 K 个点(按频率排序)
  312. Returns:
  313. 包含该类别所有点的字典
  314. """
  315. graph = _load_graph_full()
  316. data = []
  317. for class_path in class_paths:
  318. points = graph.get(class_path, {})
  319. data.append({
  320. "class_path": class_path,
  321. "points": list(points.get("meta", {}).get("elements", {}).keys())
  322. })
  323. return data
  324. @tool(
  325. description="根据类别查找属于该类别的所有点。支持单个或多个类别路径。",
  326. display={
  327. "zh": {
  328. "name": "类别点检索",
  329. "params": {
  330. "class_paths": "类别路径数组"
  331. },
  332. },
  333. },
  334. )
  335. async def search_point_by_class(
  336. class_paths: List[str]
  337. ) -> ToolResult:
  338. """
  339. 根据类别查找属于该类别的所有点。
  340. Args:
  341. class_paths: 类别路径数组,如 ["关键点_形式"]
  342. top_k: 返回前 K 个点,默认 999
  343. Returns:
  344. ToolResult: 该类别的所有点
  345. """
  346. if not class_paths or len(class_paths) == 0:
  347. return ToolResult(
  348. title="类别点检索失败",
  349. output="",
  350. error="请提供类别路径",
  351. )
  352. try:
  353. result = _search_point_by_class(class_paths)
  354. except FileNotFoundError:
  355. return ToolResult(
  356. title="类别点检索失败",
  357. output=""
  358. )
  359. except Exception as e:
  360. return ToolResult(
  361. title="类别点检索失败",
  362. output=""
  363. )
  364. output = json.dumps(result, ensure_ascii=False, indent=2)
  365. return ToolResult(
  366. title=f"类别点检索 - {len(class_paths)} 个类别",
  367. output=output,
  368. long_term_memory=f"检索到类别的点数据"
  369. )
  370. if __name__ == "__main__":
  371. print(_search_point_by_class(["关键点_形式_架构>策略>行为体验"]))