find_tree_node.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. """
  2. 查找树节点 Tool - 人设树节点查询
  3. 功能:
  4. 1. 获取人设树的常量节点(全局常量、局部常量)
  5. 2. 获取符合条件概率阈值的节点(按条件概率排序返回 topN)
  6. """
  7. import json
  8. import sys
  9. from pathlib import Path
  10. from typing import Any, Optional
  11. # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
  12. _root = Path(__file__).resolve().parent.parent
  13. if str(_root) not in sys.path:
  14. sys.path.insert(0, str(_root))
  15. from utils.conditional_ratio_calc import calc_node_conditional_ratio # noqa: E402
  16. from tools.point_match import match_derivation_to_post_points # noqa: E402
  17. try:
  18. from agent.tools import tool, ToolResult, ToolContext
  19. except ImportError:
  20. def tool(*args, **kwargs):
  21. return lambda f: f
  22. ToolResult = None # 仅用 main() 测核心逻辑时可无 agent
  23. ToolContext = None
  24. # 相对本文件:tools -> overall_derivation,input / output 在 overall_derivation 下
  25. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  26. _BASE_OUTPUT = Path(__file__).resolve().parent.parent / "output"
  27. def _dimension_analysis_log_dir(account_name: str, post_id: str, log_id: str) -> Path:
  28. """推导日志目录:output/{account_name}/推导日志/{post_id}/{log_id}/"""
  29. return _BASE_OUTPUT / account_name / "推导日志" / post_id / log_id
  30. def _load_derived_dim_tree_node_names(
  31. account_name: str, post_id: str, log_id: str, round: int
  32. ) -> list[str]:
  33. """
  34. 读取当前轮次对应的维度分析 JSON(优先 {round}_维度分析.json,不存在则 {round-1}_维度分析.json),
  35. 返回 derived_dims 中每项的 tree_node_name(已推导出的维度节点,人设树中层次较高)。
  36. 无可用文件时返回空列表。
  37. """
  38. if not log_id or not str(log_id).strip():
  39. return []
  40. log_dir = _dimension_analysis_log_dir(account_name, post_id, str(log_id).strip())
  41. for r in (round, round - 1):
  42. if r < 1:
  43. continue
  44. path = log_dir / f"{r}_维度分析.json"
  45. if not path.is_file():
  46. continue
  47. try:
  48. with open(path, "r", encoding="utf-8") as f:
  49. data = json.load(f)
  50. except Exception:
  51. continue
  52. dims = data.get("derived_dims") or []
  53. names: list[str] = []
  54. for d in dims:
  55. if isinstance(d, dict):
  56. tn = d.get("tree_node_name")
  57. if tn is not None and str(tn).strip():
  58. names.append(str(tn).strip())
  59. return names
  60. return []
  61. def _descendant_names_under_tree_nodes(
  62. account_name: str, anchor_node_names: list[str]
  63. ) -> tuple[set[str], dict[str, str]]:
  64. """
  65. 在每个人设维度树根上 DFS,收集所有锚点(derived_dims.tree_node_name)之下的**全部后代**(不含锚点自身)。
  66. 同时记录「所属维度」:对路径上每个后代节点,取从维度根到该节点路径上**最深的**那个锚点
  67. (与原先沿父链向上找最近 derived_dim 一致;多个锚点呈祖孙时取更深者)。
  68. Returns:
  69. (allowed 节点名集合, 节点名 -> 所属已推导维度树节点名)
  70. """
  71. if not anchor_node_names:
  72. return set(), {}
  73. S = set(anchor_node_names)
  74. allowed: set[str] = set()
  75. dim_map: dict[str, str] = {}
  76. for dim_root_name, root in _load_trees(account_name):
  77. def dfs(node_name: str, node_dict: dict, parent_deepest_s: Optional[str]) -> None:
  78. d_self = node_name if node_name in S else parent_deepest_s
  79. for cname, cnode in (node_dict.get("children") or {}).items():
  80. if not isinstance(cnode, dict):
  81. continue
  82. if cname not in S and d_self is not None:
  83. allowed.add(cname)
  84. dim_map[cname] = d_self
  85. dfs(cname, cnode, d_self)
  86. dfs(dim_root_name, root, None)
  87. return allowed, dim_map
  88. def _tree_dir(account_name: str) -> Path:
  89. """人设树目录:../input/{account_name}/原始数据/tree/"""
  90. return _BASE_INPUT / account_name / "原始数据" / "tree"
  91. def _load_trees(account_name: str) -> list[tuple[str, dict]]:
  92. """加载该账号下所有维度的人设树。返回 [(维度名, 根节点 dict), ...]。"""
  93. td = _tree_dir(account_name)
  94. if not td.is_dir():
  95. return []
  96. result = []
  97. for p in td.glob("*.json"):
  98. try:
  99. with open(p, "r", encoding="utf-8") as f:
  100. data = json.load(f)
  101. for dim_name, root in data.items():
  102. if isinstance(root, dict):
  103. result.append((dim_name, root))
  104. break
  105. except Exception:
  106. continue
  107. return result
  108. def _iter_all_nodes(account_name: str):
  109. """遍历该账号下所有人设树节点,产出 (节点名称, 父节点名称, 节点 dict)。"""
  110. for dim_name, root in _load_trees(account_name):
  111. def walk(parent_name: str, node_dict: dict):
  112. for name, child in (node_dict.get("children") or {}).items():
  113. if not isinstance(child, dict):
  114. continue
  115. yield (name, parent_name, child)
  116. yield from walk(name, child)
  117. yield from walk(dim_name, root)
  118. # ---------------------------------------------------------------------------
  119. # 1. 获取人设树常量节点
  120. # ---------------------------------------------------------------------------
  121. def get_constant_nodes(account_name: str) -> list[dict[str, Any]]:
  122. """
  123. 获取人设树的常量节点。
  124. - 全局常量:_is_constant=True
  125. - 局部常量:_is_local_constant=True 且 _is_constant=False
  126. 返回列表项:节点名称、概率(_ratio)、常量类型。
  127. """
  128. result = []
  129. for node_name, _parent, node in _iter_all_nodes(account_name):
  130. is_const = node.get("_is_constant") is True
  131. is_local = node.get("_is_local_constant") is True
  132. if is_const:
  133. const_type = "全局常量"
  134. elif is_local and not is_const:
  135. const_type = "局部常量"
  136. else:
  137. continue
  138. ratio = node.get("_ratio")
  139. result.append({
  140. "节点名称": node_name,
  141. "概率": ratio,
  142. "常量类型": const_type,
  143. })
  144. result.sort(key=lambda x: (x["概率"] is None, -(x["概率"] or 0)))
  145. return result
  146. # ---------------------------------------------------------------------------
  147. # 2. 获取符合条件概率阈值的节点
  148. # ---------------------------------------------------------------------------
  149. def get_nodes_by_conditional_ratio(
  150. account_name: str,
  151. derived_list: list[tuple[str, str]],
  152. threshold: float,
  153. top_n: int,
  154. allowed_node_names: Optional[set[str]] = None,
  155. node_belonging_dim: Optional[dict[str, str]] = None,
  156. ) -> list[dict[str, Any]]:
  157. """
  158. 获取人设树中条件概率 >= threshold 的节点,按条件概率降序,返回前 top_n 个。
  159. derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点);为空时使用节点自身的 _ratio 作为条件概率。
  160. allowed_node_names: 若给定,仅保留节点名称在该集合内的结果。
  161. node_belonging_dim: 与 allowed 同步生成(见 _descendant_names_under_tree_nodes),节点名 -> 所属已推导维度;不传则所属维度均为「—」。
  162. 返回列表项:节点名称、条件概率、父节点名称、所属维度。
  163. """
  164. base_dir = _BASE_INPUT
  165. node_to_parent: dict[str, str] = {}
  166. if derived_list:
  167. for n, p, _ in _iter_all_nodes(account_name):
  168. node_to_parent[n] = p
  169. def dim_for(node_name: str) -> str:
  170. if not node_belonging_dim:
  171. return "—"
  172. return node_belonging_dim.get(node_name) or "—"
  173. scored: list[tuple[str, float, str, str]] = []
  174. if not derived_list:
  175. for node_name, parent_name, node in _iter_all_nodes(account_name):
  176. if allowed_node_names is not None and node_name not in allowed_node_names:
  177. continue
  178. ratio = node.get("_ratio")
  179. if ratio is None:
  180. ratio = 0.0
  181. else:
  182. ratio = float(ratio)
  183. if ratio >= threshold:
  184. scored.append((node_name, ratio, parent_name, dim_for(node_name)))
  185. else:
  186. for node_name, parent_name in node_to_parent.items():
  187. if allowed_node_names is not None and node_name not in allowed_node_names:
  188. continue
  189. ratio = calc_node_conditional_ratio(
  190. account_name, derived_list, node_name, base_dir=base_dir
  191. )
  192. if ratio >= threshold:
  193. scored.append((node_name, ratio, parent_name, dim_for(node_name)))
  194. scored.sort(key=lambda x: x[1], reverse=True)
  195. top = scored[:top_n]
  196. return [
  197. {
  198. "节点名称": name,
  199. "条件概率": ratio,
  200. "父节点名称": parent,
  201. "所属维度": dim,
  202. }
  203. for name, ratio, parent, dim in top
  204. ]
  205. def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]:
  206. """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。"""
  207. out = []
  208. for item in derived_items:
  209. if isinstance(item, dict):
  210. topic = item.get("topic") or item.get("已推导的选题点")
  211. source = item.get("source_node") or item.get("推导来源人设树节点")
  212. if topic is not None and source is not None:
  213. out.append((str(topic).strip(), str(source).strip()))
  214. elif isinstance(item, (list, tuple)) and len(item) >= 2:
  215. out.append((str(item[0]).strip(), str(item[1]).strip()))
  216. return out
  217. # ---------------------------------------------------------------------------
  218. # Agent Tools(参考 glob_tool 封装)
  219. # ---------------------------------------------------------------------------
  220. @tool()
  221. async def find_tree_constant_nodes(
  222. account_name: str,
  223. post_id: str,
  224. ) -> ToolResult:
  225. """
  226. 获取人设树中的常量节点列表(全局常量与局部常量),并检查每个节点与帖子选题点的匹配情况。
  227. Args:
  228. account_name : 账号名,用于定位该账号的人设树数据。
  229. post_id : 帖子ID,用于加载帖子选题点并与各常量节点做匹配判断。
  230. Returns:
  231. ToolResult:
  232. - title: 结果标题。
  233. - output: 可读的节点列表文本(每行:节点名称、概率、常量类型、帖子匹配情况)。
  234. - 出错时 error 为错误信息。
  235. """
  236. tree_dir = _tree_dir(account_name)
  237. if not tree_dir.is_dir():
  238. return ToolResult(
  239. title="人设树目录不存在",
  240. output=f"目录不存在: {tree_dir}",
  241. error="Directory not found",
  242. )
  243. try:
  244. items = get_constant_nodes(account_name)
  245. # 批量匹配所有节点与帖子选题点
  246. if items and post_id:
  247. node_names = [x["节点名称"] for x in items]
  248. matched_results = await match_derivation_to_post_points(node_names, account_name, post_id)
  249. node_match_map: dict[str, list] = {}
  250. for m in matched_results:
  251. node_match_map.setdefault(m["推导选题点"], []).append({
  252. "帖子选题点": m["帖子选题点"],
  253. "匹配分数": m["匹配分数"],
  254. })
  255. for item in items:
  256. matches = node_match_map.get(item["节点名称"], [])
  257. item["帖子选题点匹配"] = matches if matches else "无"
  258. if not items:
  259. output = "未找到常量节点"
  260. else:
  261. lines = []
  262. for x in items:
  263. match_info = x.get("帖子选题点匹配", "无")
  264. if isinstance(match_info, list):
  265. match_str = "、".join(f"{m['帖子选题点']}({m['匹配分数']})" for m in match_info)
  266. else:
  267. match_str = str(match_info)
  268. lines.append(f"- {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}\t帖子选题点匹配={match_str}")
  269. output = "\n".join(lines)
  270. return ToolResult(
  271. title=f"常量节点 ({account_name})",
  272. output=output,
  273. metadata={"account_name": account_name, "count": len(items)},
  274. )
  275. except Exception as e:
  276. return ToolResult(
  277. title="获取常量节点失败",
  278. output=str(e),
  279. error=str(e),
  280. )
  281. @tool()
  282. async def find_tree_nodes_by_conditional_ratio(
  283. account_name: str,
  284. post_id: str,
  285. derived_items: list[dict[str, str]],
  286. conditional_ratio_threshold: float,
  287. top_n: int = 100,
  288. round: int = 1,
  289. log_id: str = "",
  290. ) -> ToolResult:
  291. """
  292. 按条件概率阈值从人设树筛选节点,返回最多 top_n 条(按条件概率降序),并检查每个节点与帖子选题点的匹配情况。
  293. Args:
  294. account_name : 账号名,用于定位该账号的人设树数据。
  295. post_id : 帖子ID,用于加载帖子选题点并与各节点做匹配判断。
  296. derived_items : 已推导选题点列表,可为空。非空时每项为字典,需含 topic(或「已推导的选题点」)与 source_node(或「推导来源人设树节点」)
  297. conditional_ratio_threshold : 条件概率阈值,仅返回条件概率 >= 该值的节点。
  298. top_n : 返回条数上限。
  299. round : 推导轮次。
  300. log_id : 推导日志ID
  301. Returns:
  302. ToolResult:
  303. - title: 结果标题。
  304. - output: 可读的节点列表文本(每行:节点名称、条件概率、父节点、所属维度、帖子匹配情况)。
  305. - 出错时 error 为错误信息。
  306. """
  307. tree_dir = _tree_dir(account_name)
  308. if not tree_dir.is_dir():
  309. return ToolResult(
  310. title="人设树目录不存在",
  311. output=f"目录不存在: {tree_dir}",
  312. error="Directory not found",
  313. )
  314. try:
  315. derived_list = _parse_derived_list(derived_items or [])
  316. allowed: Optional[set[str]] = None
  317. node_belonging_dim: dict[str, str] = {}
  318. dim_source = ""
  319. derived_dim_names: list[str] = []
  320. if log_id and str(log_id).strip():
  321. derived_dim_names = _load_derived_dim_tree_node_names(
  322. account_name, post_id, str(log_id).strip(), int(round)
  323. )
  324. if derived_dim_names:
  325. allowed, node_belonging_dim = _descendant_names_under_tree_nodes(
  326. account_name, derived_dim_names
  327. )
  328. # 记录实际用到的维度分析文件(与读取逻辑一致)
  329. log_dir = _dimension_analysis_log_dir(account_name, post_id, str(log_id).strip())
  330. for r in (int(round), int(round) - 1):
  331. if r >= 1 and (log_dir / f"{r}_维度分析.json").is_file():
  332. dim_source = f"{r}_维度分析.json (derived_dims -> 全部后代)"
  333. break
  334. else:
  335. dim_source = "未读到 derived_dims(无对应维度分析文件或为空),未收窄"
  336. items = get_nodes_by_conditional_ratio(
  337. account_name,
  338. derived_list,
  339. conditional_ratio_threshold,
  340. top_n,
  341. allowed_node_names=allowed,
  342. node_belonging_dim=node_belonging_dim if node_belonging_dim else None,
  343. )
  344. # 批量匹配所有节点与帖子选题点
  345. if items and post_id:
  346. node_names = [x["节点名称"] for x in items]
  347. matched_results = await match_derivation_to_post_points(node_names, account_name, post_id)
  348. node_match_map: dict[str, list] = {}
  349. for m in matched_results:
  350. node_match_map.setdefault(m["推导选题点"], []).append({
  351. "帖子选题点": m["帖子选题点"],
  352. "匹配分数": m["匹配分数"],
  353. })
  354. for item in items:
  355. matches = node_match_map.get(item["节点名称"], [])
  356. item["帖子选题点匹配"] = matches if matches else "无"
  357. # [临时] 仅保留有帖子选题点匹配的记录(过滤掉「无」),方便后续删除
  358. items = [x for x in items if isinstance(x.get("帖子选题点匹配"), list)]
  359. if not items:
  360. output = f"未找到条件概率 >= {conditional_ratio_threshold} 的节点"
  361. else:
  362. lines = []
  363. for x in items:
  364. match_info = x.get("帖子选题点匹配", "无")
  365. if isinstance(match_info, list):
  366. match_str = "、".join(f"{m['帖子选题点']}({m['匹配分数']})" for m in match_info)
  367. else:
  368. match_str = str(match_info)
  369. dim_label = x.get("所属维度", "—")
  370. lines.append(
  371. f"- {x['节点名称']}\t条件概率={x['条件概率']}\t父节点={x['父节点名称']}\t所属维度={dim_label}\t帖子选题点匹配={match_str}"
  372. )
  373. output = "\n".join(lines)
  374. return ToolResult(
  375. title=f"条件概率节点 ({account_name}, 阈值={conditional_ratio_threshold})",
  376. output=output,
  377. metadata={
  378. "account_name": account_name,
  379. "threshold": conditional_ratio_threshold,
  380. "top_n": top_n,
  381. "count": len(items),
  382. "round": int(round),
  383. "log_id": str(log_id).strip() if log_id else "",
  384. "dimension_filter": {
  385. "derived_dim_nodes": derived_dim_names,
  386. "allowed_descendant_count": len(allowed) if allowed is not None else None,
  387. "source": dim_source or ("未提供 log_id,未按维度收窄" if not (log_id and str(log_id).strip()) else ""),
  388. },
  389. },
  390. )
  391. except Exception as e:
  392. return ToolResult(
  393. title="按条件概率查询节点失败",
  394. output=str(e),
  395. error=str(e),
  396. )
  397. def main() -> None:
  398. """本地测试:用家有大志账号测常量节点与条件概率节点,有 agent 时再跑一遍 tool 接口。"""
  399. import asyncio
  400. account_name = "家有大志"
  401. post_id = "68fb6a5c000000000302e5de"
  402. # derived_items = [
  403. # {"topic": "分享", "source_node": "分享"},
  404. # {"topic": "叙事结构", "source_node": "叙事结构"},
  405. # ]
  406. derived_items = [{"topic":"分享","source_node":"分享"},{"topic":"叙事结构","source_node":"叙事编排"},{"topic":"幽默化标题","source_node":"幽默化标题"},{"source_node":"叙事结构","topic":"叙事结构"},{"topic":"夸张堆叠","source_node":"夸张转化"},{"topic":"居家生活场景","source_node":"生活场景"},{"topic":"图片文字","source_node":"图片文字"},{"source_node":"补充说明式","topic":"补充说明式"},{"topic":"标题","source_node":"标题"},{"topic":"递进式","source_node":"递进式"},{"source_node":"荒诞夸张","topic":"夸张堆叠"},{"topic":"图片文字","source_node":"图文组合"},{"topic":"补充说明式","source_node":"解读说明"},{"source_node":"分步骤","topic":"递进式"},{"source_node":"视觉证据","topic":"拖鞋物证"},{"topic":"鞋架","source_node":"家居器具"},{"topic":"柴犬形象","source_node":"形象演绎"},{"topic":"叙事结构","source_node":"结构编排"},{"topic":"图片文字","source_node":"图文编排"},{"source_node":"形象","topic":"柴犬形象"},{"source_node":"版面结构","topic":"叙事结构"},{"source_node":"夸张变形","topic":"夸张堆叠"},{"source_node":"夸张造型","topic":"夸张堆叠"},{"topic":"夸张堆叠","source_node":"夸张穿戴法"}]
  407. conditional_ratio_threshold = 0.2
  408. top_n = 2000
  409. # # 1)常量节点(核心函数,无匹配)
  410. # constant_nodes = get_constant_nodes(account_name)
  411. # print(f"账号: {account_name} — 常量节点共 {len(constant_nodes)} 个(前 50 个):")
  412. # for x in constant_nodes[:50]:
  413. # print(f" - {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}")
  414. # print()
  415. #
  416. # # 2)条件概率节点(核心函数)
  417. # derived_list = _parse_derived_list(derived_items)
  418. # ratio_nodes = get_nodes_by_conditional_ratio(
  419. # account_name, derived_list, conditional_ratio_threshold, top_n
  420. # )
  421. # print(f"条件概率节点 阈值={conditional_ratio_threshold}, top_n={top_n}, 共 {len(ratio_nodes)} 个:")
  422. # for x in ratio_nodes:
  423. # print(f" - {x['节点名称']}\t条件概率={x['条件概率']}\t父节点={x['父节点名称']}")
  424. # print()
  425. # 3)有 agent 时通过 tool 接口再跑一遍(含帖子选题点匹配)
  426. if ToolResult is not None:
  427. async def run_tools():
  428. # r1 = await find_tree_constant_nodes(account_name, post_id=post_id)
  429. # print("--- find_tree_constant_nodes ---")
  430. # print(r1.output[:2000] + "..." if len(r1.output) > 2000 else r1.output)
  431. r2 = await find_tree_nodes_by_conditional_ratio(
  432. account_name,
  433. post_id=post_id,
  434. derived_items=derived_items,
  435. conditional_ratio_threshold=conditional_ratio_threshold,
  436. top_n=top_n,
  437. round=6,
  438. log_id="20260318172724",
  439. )
  440. print("\n--- find_tree_nodes_by_conditional_ratio ---")
  441. print(r2.output)
  442. asyncio.run(run_tools())
  443. if __name__ == "__main__":
  444. main()