find_tree_node.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969
  1. """
  2. 查找树节点 Tool - 人设树节点查询
  3. 功能:
  4. 1. 获取人设树的常量节点(全局常量、局部常量)
  5. 2. 获取符合条件概率阈值的节点(按条件概率排序返回 topN)
  6. 平台库人设树(第二节输出)流水线(由 build_platform_tree_section_items 聚合):
  7. xiaohongshu/tree → 与账号相同的条件概率计算 → xiaohongshu/match_data 按匹配分过滤选题点
  8. → 剔除与账号段同名的节点。
  9. """
  10. import json
  11. import sys
  12. from pathlib import Path
  13. from typing import Any, Optional
  14. # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
  15. _root = Path(__file__).resolve().parent.parent
  16. if str(_root) not in sys.path:
  17. sys.path.insert(0, str(_root))
  18. from utils.conditional_ratio_calc import ( # noqa: E402
  19. build_node_post_index,
  20. build_node_post_index_from_tree_dir,
  21. calc_node_conditional_ratio,
  22. load_persona_trees_from_dir,
  23. )
  24. try:
  25. from agent.tools import tool, ToolResult, ToolContext
  26. except ImportError:
  27. def tool(*args, **kwargs):
  28. return lambda f: f
  29. ToolResult = None # 仅用 main() 测核心逻辑时可无 agent
  30. ToolContext = None
  31. # 相对本文件:tools -> overall_derivation,input / output 在 overall_derivation 下
  32. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  33. _BASE_OUTPUT = Path(__file__).resolve().parent.parent / "output"
  34. def _dimension_analysis_log_dir(account_name: str, post_id: str, log_id: str) -> Path:
  35. """推导日志目录:output/{account_name}/推导日志/{post_id}/{log_id}/"""
  36. return _BASE_OUTPUT / account_name / "推导日志" / post_id / log_id
  37. def _load_derived_dim_tree_node_names(
  38. account_name: str, post_id: str, log_id: str, round: int
  39. ) -> list[str]:
  40. """
  41. 读取当前轮次对应的维度分析 JSON(优先 {round}_维度分析.json,不存在则 {round-1}_维度分析.json),
  42. 返回 derived_dims 中每项的 tree_node_name(已推导出的维度节点,人设树中层次较高)。
  43. 无可用文件时返回空列表。
  44. """
  45. if not log_id or not str(log_id).strip():
  46. return []
  47. log_dir = _dimension_analysis_log_dir(account_name, post_id, str(log_id).strip())
  48. for r in (round, round - 1):
  49. if r < 1:
  50. continue
  51. path = log_dir / f"{r}_维度分析.json"
  52. if not path.is_file():
  53. continue
  54. try:
  55. with open(path, "r", encoding="utf-8") as f:
  56. data = json.load(f)
  57. except Exception:
  58. continue
  59. dims = data.get("derived_dims") or []
  60. names: list[str] = []
  61. for d in dims:
  62. if isinstance(d, dict):
  63. tn = d.get("tree_node_name")
  64. if tn is not None and str(tn).strip():
  65. names.append(str(tn).strip())
  66. return names
  67. return []
  68. def _descendant_names_under_tree_nodes(
  69. account_name: str, anchor_node_names: list[str]
  70. ) -> tuple[set[str], dict[str, str]]:
  71. """
  72. 在每个人设维度树根上 DFS,收集所有锚点(derived_dims.tree_node_name)之下的**全部后代**(不含锚点自身)。
  73. 同时记录「所属维度」:对路径上每个后代节点,取从维度根到该节点路径上**最深的**那个锚点
  74. (与原先沿父链向上找最近 derived_dim 一致;多个锚点呈祖孙时取更深者)。
  75. """
  76. if not anchor_node_names:
  77. return set(), {}
  78. S = set(anchor_node_names)
  79. allowed: set[str] = set()
  80. dim_map: dict[str, str] = {}
  81. for dim_root_name, root in _load_trees(account_name):
  82. def dfs(node_name: str, node_dict: dict, parent_deepest_s: Optional[str]) -> None:
  83. d_self = node_name if node_name in S else parent_deepest_s
  84. for cname, cnode in (node_dict.get("children") or {}).items():
  85. if not isinstance(cnode, dict):
  86. continue
  87. if cname not in S and d_self is not None:
  88. allowed.add(cname)
  89. dim_map[cname] = d_self
  90. dfs(cname, cnode, d_self)
  91. dfs(dim_root_name, root, None)
  92. return allowed, dim_map
  93. def _direct_child_names_under_tree_nodes(
  94. account_name: str, anchor_node_names: list[str]
  95. ) -> tuple[set[str], dict[str, str]]:
  96. """
  97. 仅收集锚点(derived_dim_names)的一层子节点(不含锚点自身,不含孙子层)。
  98. 返回:
  99. - allowed: 可返回节点名集合(仅一层孩子)
  100. - dim_map: 节点名 -> 所属锚点(其直接父锚点)
  101. """
  102. if not anchor_node_names:
  103. return set(), {}
  104. S = {str(n).strip() for n in anchor_node_names if str(n).strip()}
  105. if not S:
  106. return set(), {}
  107. allowed: set[str] = set()
  108. dim_map: dict[str, str] = {}
  109. for dim_root_name, root in _load_trees(account_name):
  110. def dfs(node_name: str, node_dict: dict) -> None:
  111. children = node_dict.get("children") or {}
  112. if node_name in S:
  113. for cname, cnode in children.items():
  114. if not isinstance(cnode, dict):
  115. continue
  116. if cname in S:
  117. continue
  118. allowed.add(cname)
  119. dim_map[cname] = node_name
  120. for cname, cnode in children.items():
  121. if isinstance(cnode, dict):
  122. dfs(cname, cnode)
  123. dfs(dim_root_name, root)
  124. return allowed, dim_map
  125. def _tree_dir(account_name: str) -> Path:
  126. """人设树目录:../input/{account_name}/处理后数据/tree/"""
  127. return _BASE_INPUT / account_name / "处理后数据" / "tree"
  128. def _load_trees(account_name: str) -> list[tuple[str, dict]]:
  129. """加载该账号下所有维度的人设树。返回 [(维度名, 根节点 dict), ...]。"""
  130. td = _tree_dir(account_name)
  131. if not td.is_dir():
  132. return []
  133. result = []
  134. for p in td.glob("*.json"):
  135. try:
  136. with open(p, "r", encoding="utf-8") as f:
  137. data = json.load(f)
  138. for dim_name, root in data.items():
  139. if isinstance(root, dict):
  140. result.append((dim_name, root))
  141. break
  142. except Exception:
  143. continue
  144. return result
  145. def _build_node_parent_map(account_name: str) -> dict[str, str]:
  146. """构建账号人设树节点名 -> 父节点名映射。"""
  147. mp: dict[str, str] = {}
  148. for n, p, _ in _iter_all_nodes(account_name):
  149. mp[n] = p
  150. return mp
  151. def _iter_all_nodes(account_name: str):
  152. """遍历该账号下所有人设树节点,产出 (节点名称, 父节点名称, 节点 dict)。"""
  153. for dim_name, root in _load_trees(account_name):
  154. def walk(parent_name: str, node_dict: dict):
  155. for name, child in (node_dict.get("children") or {}).items():
  156. if not isinstance(child, dict):
  157. continue
  158. yield (name, parent_name, child)
  159. yield from walk(name, child)
  160. yield from walk(dim_name, root)
  161. # ---------------------------------------------------------------------------
  162. # 1. 获取人设树常量节点
  163. # ---------------------------------------------------------------------------
  164. def get_constant_nodes(account_name: str) -> list[dict[str, Any]]:
  165. """
  166. 获取人设树的常量节点。
  167. - 全局常量:_is_constant=True
  168. - 局部常量:_is_local_constant=True 且 _is_constant=False
  169. 返回列表项:节点名称、概率(_ratio)、常量类型。
  170. """
  171. result = []
  172. for node_name, _parent, node in _iter_all_nodes(account_name):
  173. is_const = node.get("_is_constant") is True
  174. is_local = node.get("_is_local_constant") is True
  175. if is_const:
  176. const_type = "全局常量"
  177. elif is_local and not is_const:
  178. const_type = "局部常量"
  179. else:
  180. continue
  181. ratio = node.get("_ratio")
  182. result.append({
  183. "节点名称": node_name,
  184. "概率": ratio,
  185. "常量类型": const_type,
  186. })
  187. result.sort(key=lambda x: (x["概率"] is None, -(x["概率"] or 0)))
  188. return result
  189. # ---------------------------------------------------------------------------
  190. # 2. 获取符合条件概率阈值的节点
  191. # ---------------------------------------------------------------------------
  192. def get_nodes_by_conditional_ratio(
  193. account_name: str,
  194. derived_list: list[tuple[str, str]],
  195. threshold: float,
  196. top_n: int,
  197. allowed_node_names: Optional[set[str]] = None,
  198. node_belonging_dim: Optional[dict[str, str]] = None,
  199. ) -> list[dict[str, Any]]:
  200. """
  201. 获取人设树中条件概率 >= threshold 的节点,按条件概率降序,返回前 top_n 个。
  202. derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点);为空时使用节点自身的 _ratio 作为条件概率。
  203. allowed_node_names: 若给定,仅保留节点名称在该集合内的结果。
  204. node_belonging_dim: 与 allowed 同步生成(见 _descendant_names_under_tree_nodes),节点名 -> 所属已推导维度;不传则所属维度均为「—」。
  205. 返回列表项:节点名称、条件概率、父节点名称、所属维度。
  206. """
  207. base_dir = _BASE_INPUT
  208. node_to_parent: dict[str, str] = {}
  209. if derived_list:
  210. for n, p, _ in _iter_all_nodes(account_name):
  211. node_to_parent[n] = p
  212. def dim_for(node_name: str) -> str:
  213. if not node_belonging_dim:
  214. return "—"
  215. return node_belonging_dim.get(node_name) or "—"
  216. scored: list[tuple[str, float, str, str]] = []
  217. if not derived_list:
  218. for node_name, parent_name, node in _iter_all_nodes(account_name):
  219. if allowed_node_names is not None and node_name not in allowed_node_names:
  220. continue
  221. ratio = node.get("_ratio")
  222. if ratio is None:
  223. ratio = 0.0
  224. else:
  225. ratio = float(ratio)
  226. if ratio >= threshold:
  227. scored.append((node_name, ratio, parent_name, dim_for(node_name)))
  228. else:
  229. node_post_index = build_node_post_index(account_name, base_dir)
  230. for node_name, parent_name in node_to_parent.items():
  231. if allowed_node_names is not None and node_name not in allowed_node_names:
  232. continue
  233. ratio = calc_node_conditional_ratio(
  234. account_name,
  235. derived_list,
  236. node_name,
  237. base_dir=base_dir,
  238. node_post_index=node_post_index,
  239. target_ratio=threshold,
  240. )
  241. if ratio >= threshold:
  242. scored.append((node_name, ratio, parent_name, dim_for(node_name)))
  243. scored.sort(key=lambda x: x[1], reverse=True)
  244. top = scored[:top_n]
  245. return [
  246. {
  247. "节点名称": name,
  248. "条件概率": ratio,
  249. "父节点名称": parent,
  250. "所属维度": dim,
  251. }
  252. for name, ratio, parent, dim in top
  253. ]
  254. def _platform_tree_dir() -> Path:
  255. """平台库人设树目录:../input/xiaohongshu/tree/"""
  256. return _BASE_INPUT / "xiaohongshu" / "tree"
  257. def _collect_platform_scored_tuples(
  258. derived_list: list[tuple[str, str]],
  259. threshold: float,
  260. max_nodes: int = 12000,
  261. ) -> list[tuple[str, float, str, str]]:
  262. """
  263. 平台库人设树:条件概率 >= threshold 的节点全量收集,按条件概率降序。
  264. max_nodes 防止极端大树占满内存;截断发生在全局排序之后(保留高分段)。
  265. """
  266. tree_dir = _platform_tree_dir()
  267. if not tree_dir.is_dir():
  268. return []
  269. thr = float(threshold)
  270. scored: list[tuple[str, float, str, str]] = []
  271. if not derived_list:
  272. for dim_name, root in load_persona_trees_from_dir(tree_dir):
  273. def walk(parent_name: str, node_dict: dict) -> None:
  274. for name, child in (node_dict.get("children") or {}).items():
  275. if not isinstance(child, dict):
  276. continue
  277. ratio = child.get("_ratio")
  278. r = 0.0 if ratio is None else float(ratio)
  279. if r >= thr:
  280. scored.append((name, r, parent_name, dim_name))
  281. walk(name, child)
  282. walk(dim_name, root)
  283. else:
  284. node_post_index = build_node_post_index_from_tree_dir(tree_dir)
  285. node_to_parent_dim: dict[str, tuple[str, str]] = {}
  286. for dim_name, root in load_persona_trees_from_dir(tree_dir):
  287. def walk2(parent_name: str, node_dict: dict) -> None:
  288. for name, child in (node_dict.get("children") or {}).items():
  289. if not isinstance(child, dict):
  290. continue
  291. node_to_parent_dim[name] = (parent_name, dim_name)
  292. walk2(name, child)
  293. walk2(dim_name, root)
  294. for node_name, (parent_name, dim_name) in node_to_parent_dim.items():
  295. ratio = calc_node_conditional_ratio(
  296. "",
  297. derived_list,
  298. node_name,
  299. base_dir=_BASE_INPUT,
  300. node_post_index=node_post_index,
  301. target_ratio=thr,
  302. )
  303. if ratio >= thr:
  304. scored.append((node_name, ratio, parent_name, dim_name))
  305. scored.sort(key=lambda x: x[1], reverse=True)
  306. if max_nodes > 0 and len(scored) > max_nodes:
  307. scored = scored[:max_nodes]
  308. return scored
  309. def get_platform_nodes_by_conditional_ratio(
  310. derived_list: list[tuple[str, str]],
  311. threshold: float,
  312. top_n: int,
  313. ) -> list[dict[str, Any]]:
  314. """
  315. 平台库人设树节点条件概率筛选,计算方式与 get_nodes_by_conditional_ratio 一致
  316. (同一套 calc_node_conditional_ratio / _post_ids 规则,索引来自 xiaohongshu/tree)。
  317. derived_list 为空时用节点 _ratio。
  318. """
  319. n = max(0, int(top_n))
  320. scored = _collect_platform_scored_tuples(derived_list, threshold)
  321. top = scored[:n]
  322. return [
  323. {
  324. "节点名称": name,
  325. "条件概率": ratio,
  326. "父节点名称": parent,
  327. "所属维度": dim,
  328. }
  329. for name, ratio, parent, dim in top
  330. ]
  331. def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]:
  332. """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。"""
  333. out = []
  334. for item in derived_items:
  335. if isinstance(item, dict):
  336. topic = item.get("topic") or item.get("已推导的选题点")
  337. source = item.get("source_node") or item.get("推导来源人设树节点")
  338. if topic is not None and source is not None:
  339. out.append((str(topic).strip(), str(source).strip()))
  340. elif isinstance(item, (list, tuple)) and len(item) >= 2:
  341. out.append((str(item[0]).strip(), str(item[1]).strip()))
  342. return out
  343. # ---------------------------------------------------------------------------
  344. # 3. 平台库人设树辅助节点(基于帖子与平台库人设树匹配结果)
  345. # ---------------------------------------------------------------------------
  346. def _platform_match_topics_by_node(
  347. post_id: str,
  348. match_score_threshold: float,
  349. ) -> dict[tuple[str, str], dict[str, float]]:
  350. """
  351. 读取 xiaohongshu/match_data/{post_id}_匹配_all.json,
  352. 返回 (dimension, 人设树节点名) -> {帖子选题点: 最高分},仅收录 match_score >= match_score_threshold 的对。
  353. """
  354. out: dict[tuple[str, str], dict[str, float]] = {}
  355. if not post_id:
  356. return out
  357. path = _BASE_INPUT / "xiaohongshu" / "match_data" / f"{post_id}_匹配_all.json"
  358. if not path.is_file():
  359. return out
  360. try:
  361. with open(path, "r", encoding="utf-8") as f:
  362. data = json.load(f)
  363. except Exception:
  364. return out
  365. if not isinstance(data, list):
  366. return out
  367. thr = float(match_score_threshold)
  368. for item in data:
  369. if not isinstance(item, dict):
  370. continue
  371. topic = item.get("name")
  372. matches = item.get("match_personas")
  373. if topic is None or not isinstance(matches, list):
  374. continue
  375. topic_s = str(topic).strip()
  376. if not topic_s:
  377. continue
  378. for m in matches:
  379. if not isinstance(m, dict):
  380. continue
  381. name = m.get("name")
  382. dim = m.get("dimension")
  383. score = m.get("match_score")
  384. if name is None or dim is None or score is None:
  385. continue
  386. try:
  387. s = float(score)
  388. except Exception:
  389. continue
  390. if s < thr:
  391. continue
  392. key = (str(dim).strip(), str(name).strip())
  393. bucket = out.setdefault(key, {})
  394. prev = bucket.get(topic_s)
  395. if prev is None or s > prev:
  396. bucket[topic_s] = s
  397. return out
  398. def _platform_node_belonging_dim_from_anchor_nodes(
  399. anchor_node_names: list[str],
  400. ) -> dict[str, str]:
  401. """
  402. 计算平台库人设树中:节点名 -> 所属最深 derived_dim 锚点节点名。
  403. 逻辑与账号段 _descendant_names_under_tree_nodes 保持一致(但树结构来自 xiaohongshu/tree)。
  404. """
  405. if not anchor_node_names:
  406. return {}
  407. S = set(anchor_node_names)
  408. dim_map: dict[str, str] = {}
  409. tree_dir = _platform_tree_dir()
  410. if not tree_dir.is_dir():
  411. return {}
  412. for dim_root_name, root in load_persona_trees_from_dir(tree_dir):
  413. def dfs(node_name: str, node_dict: dict, parent_deepest_s: Optional[str]) -> None:
  414. d_self = node_name if node_name in S else parent_deepest_s
  415. for cname, cnode in (node_dict.get("children") or {}).items():
  416. if not isinstance(cnode, dict):
  417. continue
  418. if cname not in S and d_self is not None:
  419. dim_map[cname] = d_self
  420. dfs(cname, cnode, d_self)
  421. dfs(dim_root_name, root, None)
  422. return dim_map
  423. def _platform_direct_child_dim_map_from_anchor_nodes(
  424. anchor_node_names: list[str],
  425. ) -> dict[str, str]:
  426. """
  427. 计算平台库人设树中:仅锚点的一层子节点 -> 其直接父锚点。
  428. """
  429. if not anchor_node_names:
  430. return {}
  431. S = {str(n).strip() for n in anchor_node_names if str(n).strip()}
  432. if not S:
  433. return {}
  434. dim_map: dict[str, str] = {}
  435. tree_dir = _platform_tree_dir()
  436. if not tree_dir.is_dir():
  437. return {}
  438. for dim_root_name, root in load_persona_trees_from_dir(tree_dir):
  439. def dfs(node_name: str, node_dict: dict) -> None:
  440. children = node_dict.get("children") or {}
  441. if node_name in S:
  442. for cname, cnode in children.items():
  443. if not isinstance(cnode, dict):
  444. continue
  445. if cname in S:
  446. continue
  447. dim_map[cname] = node_name
  448. for cname, cnode in children.items():
  449. if isinstance(cnode, dict):
  450. dfs(cname, cnode)
  451. dfs(dim_root_name, root)
  452. return dim_map
  453. def _extend_derived_dim_names_with_source_children(
  454. account_name: str,
  455. derived_dim_names: list[str],
  456. derived_items: list[dict[str, str]],
  457. ) -> list[str]:
  458. """
  459. 用 derived_items 的 source_node 扩展已推导维度:
  460. - 若 source_node 是当前任一已推导维度节点的直接孩子,则将 source_node 加入已推导维度集合。
  461. - 迭代执行直到不再新增,保证链式一层扩展可被纳入。
  462. """
  463. if not derived_dim_names:
  464. return []
  465. anchors = {str(n).strip() for n in derived_dim_names if str(n).strip()}
  466. if not anchors:
  467. return []
  468. parent_map = _build_node_parent_map(account_name)
  469. source_nodes: set[str] = set()
  470. for item in derived_items or []:
  471. if not isinstance(item, dict):
  472. continue
  473. s = item.get("source_node") or item.get("推导来源人设树节点")
  474. if s is None:
  475. continue
  476. ss = str(s).strip()
  477. if ss:
  478. source_nodes.add(ss)
  479. changed = True
  480. while changed:
  481. changed = False
  482. for s in source_nodes:
  483. p = parent_map.get(s)
  484. if p in anchors and s not in anchors:
  485. anchors.add(s)
  486. changed = True
  487. return sorted(anchors)
  488. def _load_platform_nodes_split(
  489. post_id: str,
  490. derived_list: list[tuple[str, str]],
  491. conditional_ratio_threshold: float,
  492. match_score_threshold: float,
  493. top_n: int,
  494. node_belonging_dim_platform: Optional[dict[str, str]] = None,
  495. ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
  496. """
  497. 平台库人设树:用 _collect_platform_scored_tuples 得到条件概率达标的节点,
  498. 再按 xiaohongshu/match_data 分为「有帖子选题点匹配 / 无匹配」两类,**两类各自按条件概率取 Top 池**(同一全局 TopN 不会挤掉另一类),
  499. 最后分别组装返回:
  500. - matched:有 match_score >= match_score_threshold 的帖子选题点匹配的节点
  501. - unmatched:无达标帖子选题点匹配的节点
  502. 两组均要求节点在 node_belonging_dim_platform 中有有效的所属维度(不为「—」)。
  503. """
  504. matched: list[dict[str, Any]] = []
  505. unmatched: list[dict[str, Any]] = []
  506. topic_map: dict[tuple[str, str], dict[str, float]] = {}
  507. if post_id:
  508. topic_map = _platform_match_topics_by_node(post_id, float(match_score_threshold))
  509. # 维度标签可能与树侧不完全一致:保留一个按节点名聚合的兜底索引,避免误判为“无匹配”。
  510. topic_map_by_name: dict[str, dict[str, float]] = {}
  511. for (_dim, n), topics in topic_map.items():
  512. bucket = topic_map_by_name.setdefault(str(n).strip(), {})
  513. for t, sc in (topics or {}).items():
  514. prev = bucket.get(t)
  515. if prev is None or sc > prev:
  516. bucket[t] = sc
  517. # 有 match_data 命中与无命中两类分开按条件概率取 Top,避免混在一个全局 TopN 里挤掉某一类。
  518. all_scored = _collect_platform_scored_tuples(
  519. derived_list,
  520. float(conditional_ratio_threshold),
  521. )
  522. if not all_scored:
  523. return matched, unmatched
  524. matched_tuples: list[tuple[str, float, str, str]] = []
  525. unmatched_tuples: list[tuple[str, float, str, str]] = []
  526. for name, ratio, parent, dim in all_scored:
  527. lookup_dim = str(dim).strip()
  528. key = (lookup_dim, str(name).strip())
  529. topics = topic_map.get(key) or topic_map_by_name.get(str(name).strip()) or {}
  530. if topics:
  531. matched_tuples.append((name, ratio, parent, dim))
  532. else:
  533. unmatched_tuples.append((name, ratio, parent, dim))
  534. _pool = max(int(top_n), min(2000, max(500, int(top_n) * 5)))
  535. matched_tuples = matched_tuples[:_pool]
  536. unmatched_tuples = unmatched_tuples[:_pool]
  537. def _emit_tuple_rows(
  538. tuples: list[tuple[str, float, str, str]],
  539. *,
  540. has_topics: bool,
  541. ) -> None:
  542. for name, ratio, parent, dim in tuples:
  543. row = {
  544. "节点名称": name,
  545. "条件概率": ratio,
  546. "父节点名称": parent,
  547. "所属维度": dim,
  548. }
  549. name_s = str(row.get("节点名称") or "").strip()
  550. out_dim = "—"
  551. if node_belonging_dim_platform is not None:
  552. out_dim = node_belonging_dim_platform.get(name_s) or "—"
  553. if node_belonging_dim_platform is not None and out_dim == "—":
  554. continue
  555. row_out = dict(row)
  556. row_out["所属维度"] = out_dim
  557. lookup_dim = str(row.get("所属维度") or "").strip()
  558. key2 = (lookup_dim, name_s)
  559. topics = topic_map.get(key2) or topic_map_by_name.get(name_s) or {}
  560. if has_topics:
  561. if not topics:
  562. continue
  563. topic_items = sorted(topics.items(), key=lambda x: x[1], reverse=True)
  564. row_out["帖子选题点匹配"] = [{"帖子选题点": t, "匹配分数": sc} for t, sc in topic_items]
  565. matched.append(row_out)
  566. else:
  567. if topics:
  568. continue
  569. row_out["帖子选题点匹配"] = "无"
  570. unmatched.append(row_out)
  571. _emit_tuple_rows(matched_tuples, has_topics=True)
  572. _emit_tuple_rows(unmatched_tuples, has_topics=False)
  573. return matched, unmatched
  574. def build_platform_tree_section_items_split(
  575. post_id: str,
  576. derived_list: list[tuple[str, str]],
  577. conditional_ratio_threshold: float,
  578. match_score_threshold: float,
  579. top_n: int,
  580. exclude_node_names: set[str],
  581. node_belonging_dim_platform: Optional[dict[str, str]] = None,
  582. ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
  583. """
  584. 平台库人设树节点:条件概率 + xiaohongshu/match_data 匹配,排除与账号段重复的节点名称,
  585. 返回 (有帖子选题点匹配的节点列表, 无帖子选题点匹配的节点列表)。
  586. 供 find_tree_nodes_by_conditional_ratio 聚合输出使用。
  587. """
  588. if not post_id:
  589. return [], []
  590. ex = {str(n).strip() for n in exclude_node_names}
  591. matched, unmatched = _load_platform_nodes_split(
  592. post_id=post_id,
  593. derived_list=derived_list,
  594. conditional_ratio_threshold=float(conditional_ratio_threshold),
  595. match_score_threshold=float(match_score_threshold),
  596. top_n=int(top_n),
  597. node_belonging_dim_platform=node_belonging_dim_platform,
  598. )
  599. matched_filtered = [p for p in matched if str(p.get("节点名称", "")).strip() not in ex]
  600. unmatched_filtered = [p for p in unmatched if str(p.get("节点名称", "")).strip() not in ex]
  601. return matched_filtered, unmatched_filtered
  602. # ---------------------------------------------------------------------------
  603. # Agent Tools(参考 glob_tool 封装)
  604. # ---------------------------------------------------------------------------
  605. @tool()
  606. async def find_tree_constant_nodes(
  607. account_name: str,
  608. post_id: str,
  609. ) -> ToolResult:
  610. """
  611. 获取人设树中的常量节点列表(全局常量与局部常量)。
  612. Args:
  613. account_name : 账号名,用于定位该账号的人设树数据。
  614. post_id : 帖子ID(保留参数,当前版本暂不使用)。
  615. Returns:
  616. ToolResult:
  617. - title: 结果标题。
  618. - output: 可读的节点列表文本(每行:节点名称、概率、常量类型)。
  619. - 出错时 error 为错误信息。
  620. """
  621. tree_dir = _tree_dir(account_name)
  622. if not tree_dir.is_dir():
  623. return ToolResult(
  624. title="人设树目录不存在",
  625. output=f"目录不存在: {tree_dir}",
  626. error="Directory not found",
  627. )
  628. try:
  629. items = get_constant_nodes(account_name)
  630. if not items:
  631. output = "未找到常量节点"
  632. else:
  633. lines = [f"- {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}" for x in items]
  634. output = "\n".join(lines)
  635. return ToolResult(
  636. title=f"常量节点 ({account_name})",
  637. output=output,
  638. metadata={"account_name": account_name, "count": len(items)},
  639. )
  640. except Exception as e:
  641. return ToolResult(
  642. title="获取常量节点失败",
  643. output=str(e),
  644. error=str(e),
  645. )
  646. @tool()
  647. async def find_tree_nodes_by_conditional_ratio(
  648. account_name: str,
  649. post_id: str,
  650. derived_items: list[dict[str, str]],
  651. conditional_ratio_threshold: float,
  652. top_n: int = 100,
  653. round: int = 1,
  654. log_id: str = "",
  655. match_score_threshold: float = 0.7,
  656. ) -> ToolResult:
  657. """
  658. 按条件概率阈值筛选节点,第一节为账号人设树节点(优先使用),第二节为平台库人设树节点。
  659. Args:
  660. account_name : 账号名,用于定位该账号的人设树数据。
  661. post_id : 帖子ID,用于定位推导日志目录(维度分析文件)。
  662. derived_items : 已推导选题点列表,可为空。非空时每项为字典,需含 topic(或「已推导的选题点」)与 source_node(或「推导来源人设树节点」)
  663. conditional_ratio_threshold : 条件概率阈值,仅返回条件概率 >= 该值的节点。
  664. top_n : 最终返回总条数上限,按 账号60%/平台40% 分配。
  665. round : 推导轮次。
  666. log_id : 推导日志ID
  667. match_score_threshold : 帖子选题点匹配分阈值(保留参数,当前版本暂不使用)。
  668. Returns:
  669. ToolResult:
  670. - title: 结果标题。
  671. - output: 两段文本——先账号人设树,后平台库人设树;
  672. 平台侧条件概率基于 input/xiaohongshu/tree。
  673. - 出错时 error 为错误信息。
  674. """
  675. tree_dir = _tree_dir(account_name)
  676. if not tree_dir.is_dir():
  677. return ToolResult(
  678. title="人设树目录不存在",
  679. output=f"目录不存在: {tree_dir}",
  680. error="Directory not found",
  681. )
  682. try:
  683. derived_list = _parse_derived_list(derived_items or [])
  684. allowed: Optional[set[str]] = None
  685. node_belonging_dim: dict[str, str] = {}
  686. node_belonging_dim_platform: Optional[dict[str, str]] = None
  687. dim_source = ""
  688. derived_dim_names: list[str] = []
  689. derived_items_len = len(derived_items or [])
  690. if log_id and str(log_id).strip():
  691. derived_dim_names = _load_derived_dim_tree_node_names(
  692. account_name, post_id, str(log_id).strip(), int(round)
  693. )
  694. if derived_dim_names:
  695. derived_dim_names = _extend_derived_dim_names_with_source_children(
  696. account_name=account_name,
  697. derived_dim_names=derived_dim_names,
  698. derived_items=derived_items or [],
  699. )
  700. allowed, node_belonging_dim = _direct_child_names_under_tree_nodes(
  701. account_name, derived_dim_names
  702. )
  703. node_belonging_dim_platform = _platform_direct_child_dim_map_from_anchor_nodes(
  704. derived_dim_names
  705. )
  706. # 记录实际用到的维度分析文件(与读取逻辑一致)
  707. log_dir = _dimension_analysis_log_dir(account_name, post_id, str(log_id).strip())
  708. for r in (int(round), int(round) - 1):
  709. if r >= 1 and (log_dir / f"{r}_维度分析.json").is_file():
  710. dim_source = f"{r}_维度分析.json (derived_dims + source_child -> 仅一层子节点)"
  711. break
  712. else:
  713. dim_source = "未读到 derived_dims(无对应维度分析文件或为空),未收窄"
  714. # 当 derived_items 太多时,用 derived_dim_names 作为条件概率计算锚点:
  715. # 将每个 derived_dim_names 的 name 都映射为 (topic=name, source_node=name)。
  716. if derived_items_len > 15 and derived_dim_names:
  717. derived_list = [(n, n) for n in derived_dim_names]
  718. # 1)账号人设树:按条件概率筛选
  719. items = get_nodes_by_conditional_ratio(
  720. account_name,
  721. derived_list,
  722. conditional_ratio_threshold,
  723. top_n,
  724. allowed_node_names=allowed,
  725. node_belonging_dim=node_belonging_dim if node_belonging_dim else None,
  726. )
  727. # 账号配额:占 top_n 的 60%
  728. account_quota = int(top_n * 0.6 + 0.5)
  729. items = items[:account_quota]
  730. # 2)平台库人设树:按条件概率筛选,排除账号段已有节点名
  731. platform_quota = top_n - account_quota
  732. account_all_names = {str(x.get("节点名称", "")).strip() for x in items}
  733. platform_items: list[dict[str, Any]] = []
  734. all_platform_scored = _collect_platform_scored_tuples(
  735. derived_list, float(conditional_ratio_threshold)
  736. )
  737. for name, ratio, parent, dim in all_platform_scored:
  738. if str(name).strip() in account_all_names:
  739. continue
  740. out_dim = "—"
  741. if node_belonging_dim_platform is not None:
  742. out_dim = node_belonging_dim_platform.get(str(name).strip()) or "—"
  743. if node_belonging_dim_platform is not None and out_dim == "—":
  744. continue
  745. platform_items.append({
  746. "节点名称": name,
  747. "条件概率": ratio,
  748. "父节点名称": parent,
  749. "所属维度": out_dim,
  750. })
  751. if len(platform_items) >= platform_quota:
  752. break
  753. def _format_node_line(x: dict[str, Any]) -> str:
  754. dim_label = x.get("所属维度", "—")
  755. return f"- {x['节点名称']}\t条件概率={x['条件概率']}\t所属维度={dim_label}"
  756. lines: list[str] = []
  757. lines.append(
  758. "【优先使用】第一节为账号人设树中条件概率达标的节点;"
  759. "第二节为平台库人设树中条件概率达标的节点;"
  760. )
  761. lines.append("")
  762. lines.append("—— 账号人设树节点 ——")
  763. if not items:
  764. lines.append(f"(无:未找到条件概率 >= {conditional_ratio_threshold} 的节点)")
  765. else:
  766. lines.extend(_format_node_line(x) for x in items)
  767. lines.append("")
  768. lines.append("—— 平台库人设树节点 ——")
  769. if not platform_items:
  770. lines.append("(无:未找到条件概率达标的节点)")
  771. else:
  772. lines.extend(_format_node_line(x) for x in platform_items)
  773. output = "\n".join(lines)
  774. return ToolResult(
  775. title=f"条件概率节点 ({account_name}, 阈值={conditional_ratio_threshold})",
  776. output=output,
  777. metadata={
  778. "account_name": account_name,
  779. "threshold": conditional_ratio_threshold,
  780. "top_n": top_n,
  781. "quota": {
  782. "account_quota": account_quota,
  783. "platform_quota": platform_quota,
  784. },
  785. "account_tree_count": len(items),
  786. "platform_tree_count": len(platform_items),
  787. "count": len(items) + len(platform_items),
  788. "round": int(round),
  789. "log_id": str(log_id).strip() if log_id else "",
  790. "dimension_filter": {
  791. "derived_dim_nodes": derived_dim_names,
  792. "allowed_descendant_count": len(allowed) if allowed is not None else None,
  793. "source": dim_source or ("未提供 log_id,未按维度收窄" if not (log_id and str(log_id).strip()) else ""),
  794. },
  795. },
  796. )
  797. except Exception as e:
  798. return ToolResult(
  799. title="按条件概率查询节点失败",
  800. output=str(e),
  801. error=str(e),
  802. )
  803. def main() -> None:
  804. """本地测试:用家有大志账号测常量节点与条件概率节点,有 agent 时再跑一遍 tool 接口。"""
  805. import asyncio
  806. account_name = "家有大志"
  807. post_id = "68fb6a5c000000000302e5de"
  808. log_id = "20260324141307"
  809. round = 4
  810. # derived_items = [
  811. # {"topic": "分享", "source_node": "分享"},
  812. # {"topic": "叙事结构", "source_node": "叙事结构"},
  813. # ]
  814. derived_items = [{"topic":"推广","source_node":"推广"},{"topic":"视觉调性","source_node":"视觉调性"}]
  815. conditional_ratio_threshold = 0.001
  816. top_n = 100
  817. # # 1)常量节点(核心函数,无匹配)
  818. # constant_nodes = get_constant_nodes(account_name)
  819. # print(f"账号: {account_name} — 常量节点共 {len(constant_nodes)} 个(前 50 个):")
  820. # for x in constant_nodes[:50]:
  821. # print(f" - {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}")
  822. # print()
  823. #
  824. # # 2)条件概率节点(核心函数)
  825. # derived_list = _parse_derived_list(derived_items)
  826. # ratio_nodes = get_nodes_by_conditional_ratio(
  827. # account_name, derived_list, conditional_ratio_threshold, top_n
  828. # )
  829. # print(f"条件概率节点 阈值={conditional_ratio_threshold}, top_n={top_n}, 共 {len(ratio_nodes)} 个:")
  830. # for x in ratio_nodes:
  831. # print(f" - {x['节点名称']}\t条件概率={x['条件概率']}\t父节点={x['父节点名称']}")
  832. # print()
  833. # 3)有 agent 时通过 tool 接口再跑一遍
  834. if ToolResult is not None:
  835. async def run_tools():
  836. r1 = await find_tree_constant_nodes(account_name, post_id=post_id)
  837. print("--- find_tree_constant_nodes ---")
  838. print(r1.output[:2000] + "..." if len(r1.output) > 2000 else r1.output)
  839. r2 = await find_tree_nodes_by_conditional_ratio(
  840. account_name,
  841. post_id=post_id,
  842. derived_items=derived_items,
  843. conditional_ratio_threshold=conditional_ratio_threshold,
  844. top_n=top_n,
  845. round=round,
  846. log_id=log_id,
  847. )
  848. print("\n--- find_tree_nodes_by_conditional_ratio ---")
  849. print(r2.output)
  850. asyncio.run(run_tools())
  851. if __name__ == "__main__":
  852. main()