find_tree_node.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. """
  2. 查找树节点 Tool - 人设树节点查询
  3. 功能:
  4. 1. 获取人设树的常量节点(全局常量、局部常量)
  5. 2. 获取符合条件概率阈值的节点(按条件概率排序返回 topN)
  6. """
  7. import importlib.util
  8. import json
  9. from pathlib import Path
  10. from typing import Any, Optional
  11. try:
  12. from agent.tools import tool, ToolResult, ToolContext
  13. except ImportError:
  14. def tool(*args, **kwargs):
  15. return lambda f: f
  16. ToolResult = None # 仅用 main() 测核心逻辑时可无 agent
  17. ToolContext = None
  18. # 加载同目录层级的 conditional_ratio_calc(不依赖包结构)
  19. _utils_dir = Path(__file__).resolve().parent.parent / "utils"
  20. _cond_spec = importlib.util.spec_from_file_location(
  21. "conditional_ratio_calc",
  22. _utils_dir / "conditional_ratio_calc.py",
  23. )
  24. _cond_mod = importlib.util.module_from_spec(_cond_spec)
  25. _cond_spec.loader.exec_module(_cond_mod)
  26. calc_node_conditional_ratio = _cond_mod.calc_node_conditional_ratio
  27. # 相对本文件:tools -> overall_derivation,input 在 overall_derivation 下
  28. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  29. # 加载 point_match(用于检查节点是否匹配帖子选题点)
  30. _point_match_spec = importlib.util.spec_from_file_location(
  31. "point_match",
  32. Path(__file__).resolve().parent / "point_match.py",
  33. )
  34. _point_match_mod = importlib.util.module_from_spec(_point_match_spec)
  35. _point_match_spec.loader.exec_module(_point_match_mod)
  36. _match_derivation_to_post_points = _point_match_mod.match_derivation_to_post_points
  37. def _tree_dir(account_name: str) -> Path:
  38. """人设树目录:../input/{account_name}/原始数据/tree/"""
  39. return _BASE_INPUT / account_name / "原始数据" / "tree"
  40. def _load_trees(account_name: str) -> list[tuple[str, dict]]:
  41. """加载该账号下所有维度的人设树。返回 [(维度名, 根节点 dict), ...]。"""
  42. td = _tree_dir(account_name)
  43. if not td.is_dir():
  44. return []
  45. result = []
  46. for p in td.glob("*.json"):
  47. try:
  48. with open(p, "r", encoding="utf-8") as f:
  49. data = json.load(f)
  50. for dim_name, root in data.items():
  51. if isinstance(root, dict):
  52. result.append((dim_name, root))
  53. break
  54. except Exception:
  55. continue
  56. return result
  57. def _iter_all_nodes(account_name: str):
  58. """遍历该账号下所有人设树节点,产出 (节点名称, 父节点名称, 节点 dict)。"""
  59. for dim_name, root in _load_trees(account_name):
  60. def walk(parent_name: str, node_dict: dict):
  61. for name, child in (node_dict.get("children") or {}).items():
  62. if not isinstance(child, dict):
  63. continue
  64. yield (name, parent_name, child)
  65. yield from walk(name, child)
  66. yield from walk(dim_name, root)
  67. # ---------------------------------------------------------------------------
  68. # 1. 获取人设树常量节点
  69. # ---------------------------------------------------------------------------
  70. def get_constant_nodes(account_name: str) -> list[dict[str, Any]]:
  71. """
  72. 获取人设树的常量节点。
  73. - 全局常量:_is_constant=True
  74. - 局部常量:_is_local_constant=True 且 _is_constant=False
  75. 返回列表项:节点名称、概率(_ratio)、常量类型。
  76. """
  77. result = []
  78. for node_name, _parent, node in _iter_all_nodes(account_name):
  79. is_const = node.get("_is_constant") is True
  80. is_local = node.get("_is_local_constant") is True
  81. if is_const:
  82. const_type = "全局常量"
  83. elif is_local and not is_const:
  84. const_type = "局部常量"
  85. else:
  86. continue
  87. ratio = node.get("_ratio")
  88. result.append({
  89. "节点名称": node_name,
  90. "概率": ratio,
  91. "常量类型": const_type,
  92. })
  93. result.sort(key=lambda x: (x["概率"] is None, -(x["概率"] or 0)))
  94. return result
  95. # ---------------------------------------------------------------------------
  96. # 2. 获取符合条件概率阈值的节点
  97. # ---------------------------------------------------------------------------
  98. def get_nodes_by_conditional_ratio(
  99. account_name: str,
  100. derived_list: list[tuple[str, str]],
  101. threshold: float,
  102. top_n: int,
  103. ) -> list[dict[str, Any]]:
  104. """
  105. 获取人设树中条件概率 >= threshold 的节点,按条件概率降序,返回前 top_n 个。
  106. derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点);为空时使用节点自身的 _ratio 作为条件概率。
  107. 返回列表项:节点名称、条件概率、父节点名称。
  108. """
  109. base_dir = _BASE_INPUT
  110. scored: list[tuple[str, float, str]] = []
  111. if not derived_list:
  112. # derived_items 为空:条件概率取节点本身的 _ratio
  113. for node_name, parent_name, node in _iter_all_nodes(account_name):
  114. ratio = node.get("_ratio")
  115. if ratio is None:
  116. ratio = 0.0
  117. else:
  118. ratio = float(ratio)
  119. if ratio >= threshold:
  120. scored.append((node_name, ratio, parent_name))
  121. else:
  122. node_to_parent: dict[str, str] = {}
  123. for node_name, parent_name, _ in _iter_all_nodes(account_name):
  124. node_to_parent[node_name] = parent_name
  125. for node_name, parent_name in node_to_parent.items():
  126. ratio = calc_node_conditional_ratio(
  127. account_name, derived_list, node_name, base_dir=base_dir
  128. )
  129. if ratio >= threshold:
  130. scored.append((node_name, ratio, parent_name))
  131. scored.sort(key=lambda x: x[1], reverse=True)
  132. top = scored[:top_n]
  133. return [
  134. {"节点名称": name, "条件概率": ratio, "父节点名称": parent}
  135. for name, ratio, parent in top
  136. ]
  137. def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]:
  138. """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。"""
  139. out = []
  140. for item in derived_items:
  141. if isinstance(item, dict):
  142. topic = item.get("topic") or item.get("已推导的选题点")
  143. source = item.get("source_node") or item.get("推导来源人设树节点")
  144. if topic is not None and source is not None:
  145. out.append((str(topic).strip(), str(source).strip()))
  146. elif isinstance(item, (list, tuple)) and len(item) >= 2:
  147. out.append((str(item[0]).strip(), str(item[1]).strip()))
  148. return out
  149. # ---------------------------------------------------------------------------
  150. # Agent Tools(参考 glob_tool 封装)
  151. # ---------------------------------------------------------------------------
  152. @tool(
  153. description="获取指定账号人设树中的常量节点(全局常量、局部常量),并检查每个节点与帖子选题点的匹配情况。"
  154. "功能:根据账号名查询该账号人设树中所有常量节点,同时对每个节点判断是否匹配帖子选题点,匹配结果直接包含在返回数据中。"
  155. "参数:account_name 为账号名;post_id 为帖子ID,用于加载帖子选题点并做匹配判断。"
  156. "返回:ToolResult,output 为可读的节点列表文本,metadata.items 为列表,每项含「节点名称」「概率」「常量类型」「帖子选题点匹配」(超过阈值的匹配列表,每项含帖子选题点与匹配分数;若无匹配则为字符串'无匹配帖子选题点')。"
  157. )
  158. async def find_tree_constant_nodes(
  159. account_name: str,
  160. post_id: str,
  161. context: Optional[ToolContext] = None,
  162. ) -> ToolResult:
  163. """
  164. 获取人设树中的常量节点列表(全局常量与局部常量),并检查每个节点与帖子选题点的匹配情况。
  165. 参数
  166. -------
  167. account_name : 账号名,用于定位该账号的人设树数据。
  168. post_id : 帖子ID,用于加载帖子选题点并与各常量节点做匹配判断。
  169. context : 可选,Agent 工具上下文。
  170. 返回
  171. -------
  172. ToolResult:
  173. - title: 结果标题。
  174. - output: 可读的节点列表文本(每行:节点名称、概率、常量类型、帖子匹配情况)。
  175. - metadata: 含 account_name、count、items;items 为列表,每项为
  176. {"节点名称": str, "概率": 数值或 None, "常量类型": "全局常量"|"局部常量",
  177. "帖子选题点匹配": list[{"帖子选题点": str, "匹配分数": float}] 或 "无匹配帖子选题点"}。
  178. - 出错时 error 为错误信息。
  179. """
  180. tree_dir = _tree_dir(account_name)
  181. if not tree_dir.is_dir():
  182. return ToolResult(
  183. title="人设树目录不存在",
  184. output=f"目录不存在: {tree_dir}",
  185. error="Directory not found",
  186. )
  187. try:
  188. items = get_constant_nodes(account_name)
  189. # 批量匹配所有节点与帖子选题点
  190. if items and post_id:
  191. node_names = [x["节点名称"] for x in items]
  192. matched_results = await _match_derivation_to_post_points(node_names, account_name, post_id)
  193. node_match_map: dict[str, list] = {}
  194. for m in matched_results:
  195. node_match_map.setdefault(m["推导选题点"], []).append({
  196. "帖子选题点": m["帖子选题点"],
  197. "匹配分数": m["匹配分数"],
  198. })
  199. for item in items:
  200. matches = node_match_map.get(item["节点名称"], [])
  201. item["帖子选题点匹配"] = matches if matches else "无匹配帖子选题点"
  202. if not items:
  203. output = "未找到常量节点"
  204. else:
  205. lines = []
  206. for x in items:
  207. match_info = x.get("帖子选题点匹配", "未查询")
  208. if isinstance(match_info, list):
  209. match_str = "、".join(f"{m['帖子选题点']}({m['匹配分数']})" for m in match_info)
  210. else:
  211. match_str = str(match_info)
  212. lines.append(f"- {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}\t帖子匹配={match_str}")
  213. output = "\n".join(lines)
  214. return ToolResult(
  215. title=f"常量节点 ({account_name})",
  216. output=output,
  217. metadata={"account_name": account_name, "count": len(items), "items": items},
  218. )
  219. except Exception as e:
  220. return ToolResult(
  221. title="获取常量节点失败",
  222. output=str(e),
  223. error=str(e),
  224. )
  225. @tool(
  226. description="按条件概率从人设树中筛选节点,返回达到阈值且按条件概率排序的前 topN 条,并检查每个节点与帖子选题点的匹配情况。"
  227. "功能:根据账号与已推导选题点(可选),筛选人设树中条件概率不低于阈值的节点,同时对每个节点判断是否匹配帖子选题点,匹配结果直接包含在返回数据中。"
  228. "参数:account_name 为账号名;post_id 为帖子ID,用于加载帖子选题点并做匹配判断;derived_items 为已推导选题点列表,每项含 topic(或已推导的选题点)与 source_node(或推导来源人设树节点),可为空,为空时条件概率使用节点自身的 _ratio;conditional_ratio_threshold 为条件概率阈值;top_n 为返回条数上限,默认 100。"
  229. "返回:ToolResult,output 为可读的节点列表文本,metadata.items 为列表,每项含「节点名称」「条件概率」「父节点名称」「帖子选题点匹配」(超过阈值的匹配列表,每项含帖子选题点与匹配分数;若无匹配则为字符串'无匹配帖子选题点')。"
  230. )
  231. async def find_tree_nodes_by_conditional_ratio(
  232. account_name: str,
  233. post_id: str,
  234. derived_items: list[dict[str, str]],
  235. conditional_ratio_threshold: float,
  236. top_n: int = 100,
  237. context: Optional[ToolContext] = None,
  238. ) -> ToolResult:
  239. """
  240. 按条件概率阈值从人设树筛选节点,返回最多 top_n 条(按条件概率降序),并检查每个节点与帖子选题点的匹配情况。
  241. 参数
  242. -------
  243. account_name : 账号名,用于定位该账号的人设树数据。
  244. post_id : 帖子ID,用于加载帖子选题点并与各节点做匹配判断。
  245. derived_items : 已推导选题点列表,可为空。非空时每项为字典,需含 topic(或「已推导的选题点」)与 source_node(或「推导来源人设树节点」);为空时各节点的条件概率取其自身 _ratio。
  246. conditional_ratio_threshold : 条件概率阈值,仅返回条件概率 >= 该值的节点。
  247. top_n : 返回条数上限,默认 100。
  248. context : 可选,Agent 工具上下文。
  249. 返回
  250. -------
  251. ToolResult:
  252. - title: 结果标题。
  253. - output: 可读的节点列表文本(每行:节点名称、条件概率、父节点名称、帖子匹配情况)。
  254. - metadata: 含 account_name、threshold、top_n、count、items;
  255. items 为列表,每项为 {"节点名称": str, "条件概率": float, "父节点名称": str,
  256. "帖子选题点匹配": list[{"帖子选题点": str, "匹配分数": float}] 或 "无匹配帖子选题点"}。
  257. - 出错时 error 为错误信息。
  258. """
  259. tree_dir = _tree_dir(account_name)
  260. if not tree_dir.is_dir():
  261. return ToolResult(
  262. title="人设树目录不存在",
  263. output=f"目录不存在: {tree_dir}",
  264. error="Directory not found",
  265. )
  266. try:
  267. derived_list = _parse_derived_list(derived_items or [])
  268. items = get_nodes_by_conditional_ratio(
  269. account_name, derived_list, conditional_ratio_threshold, top_n
  270. )
  271. # 批量匹配所有节点与帖子选题点
  272. if items and post_id:
  273. node_names = [x["节点名称"] for x in items]
  274. matched_results = await _match_derivation_to_post_points(node_names, account_name, post_id)
  275. node_match_map: dict[str, list] = {}
  276. for m in matched_results:
  277. node_match_map.setdefault(m["推导选题点"], []).append({
  278. "帖子选题点": m["帖子选题点"],
  279. "匹配分数": m["匹配分数"],
  280. })
  281. for item in items:
  282. matches = node_match_map.get(item["节点名称"], [])
  283. item["帖子选题点匹配"] = matches if matches else "无匹配帖子选题点"
  284. if not items:
  285. output = f"未找到条件概率 >= {conditional_ratio_threshold} 的节点"
  286. else:
  287. lines = []
  288. for x in items:
  289. match_info = x.get("帖子选题点匹配", "未查询")
  290. if isinstance(match_info, list):
  291. match_str = "、".join(f"{m['帖子选题点']}({m['匹配分数']})" for m in match_info)
  292. else:
  293. match_str = str(match_info)
  294. lines.append(
  295. f"- {x['节点名称']}\t条件概率={x['条件概率']}\t父节点={x['父节点名称']}\t帖子匹配={match_str}"
  296. )
  297. output = "\n".join(lines)
  298. return ToolResult(
  299. title=f"条件概率节点 ({account_name}, 阈值={conditional_ratio_threshold})",
  300. output=output,
  301. metadata={
  302. "account_name": account_name,
  303. "threshold": conditional_ratio_threshold,
  304. "top_n": top_n,
  305. "count": len(items),
  306. "items": items,
  307. },
  308. )
  309. except Exception as e:
  310. return ToolResult(
  311. title="按条件概率查询节点失败",
  312. output=str(e),
  313. error=str(e),
  314. )
  315. def main() -> None:
  316. """本地测试:用家有大志账号测常量节点与条件概率节点,有 agent 时再跑一遍 tool 接口。"""
  317. import asyncio
  318. account_name = "家有大志"
  319. post_id = "68fb6a5c000000000302e5de"
  320. derived_items = [
  321. {"topic": "分享", "source_node": "分享"},
  322. ]
  323. conditional_ratio_threshold = 0.1
  324. top_n = 10
  325. # 1)常量节点(核心函数,无匹配)
  326. constant_nodes = get_constant_nodes(account_name)
  327. print(f"账号: {account_name} — 常量节点共 {len(constant_nodes)} 个(前 50 个):")
  328. for x in constant_nodes[:50]:
  329. print(f" - {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}")
  330. print()
  331. # 2)条件概率节点(核心函数)
  332. derived_list = _parse_derived_list(derived_items)
  333. ratio_nodes = get_nodes_by_conditional_ratio(
  334. account_name, derived_list, conditional_ratio_threshold, top_n
  335. )
  336. print(f"条件概率节点 阈值={conditional_ratio_threshold}, top_n={top_n}, 共 {len(ratio_nodes)} 个:")
  337. for x in ratio_nodes:
  338. print(f" - {x['节点名称']}\t条件概率={x['条件概率']}\t父节点={x['父节点名称']}")
  339. print()
  340. # 3)有 agent 时通过 tool 接口再跑一遍(含帖子选题点匹配)
  341. if ToolResult is not None:
  342. async def run_tools():
  343. r1 = await find_tree_constant_nodes(account_name, post_id=post_id)
  344. print("--- find_tree_constant_nodes ---")
  345. print(r1.output[:200] + "..." if len(r1.output) > 200 else r1.output)
  346. r2 = await find_tree_nodes_by_conditional_ratio(
  347. account_name,
  348. post_id=post_id,
  349. derived_items=derived_items,
  350. conditional_ratio_threshold=conditional_ratio_threshold,
  351. top_n=top_n,
  352. )
  353. print("\n--- find_tree_nodes_by_conditional_ratio ---")
  354. print(r2.output)
  355. asyncio.run(run_tools())
  356. if __name__ == "__main__":
  357. main()