find_tree_node.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813
  1. """
  2. 查找树节点 Tool - 人设树节点查询
  3. 功能:
  4. 1. 获取人设树的常量节点(全局常量、局部常量)
  5. 2. 获取符合条件概率阈值的节点(按条件概率排序返回 topN)
  6. 平台库人设树(第二节输出)流水线(由 build_platform_tree_section_items 聚合):
  7. xiaohongshu/tree → 与账号相同的条件概率计算 → xiaohongshu/match_data 按匹配分过滤选题点
  8. → 剔除与账号段同名的节点。
  9. """
  10. import json
  11. import sys
  12. from pathlib import Path
  13. from typing import Any, Optional
  14. # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
  15. _root = Path(__file__).resolve().parent.parent
  16. if str(_root) not in sys.path:
  17. sys.path.insert(0, str(_root))
  18. from utils.conditional_ratio_calc import ( # noqa: E402
  19. build_node_post_index,
  20. build_node_post_index_from_tree_dir,
  21. calc_node_conditional_ratio,
  22. load_persona_trees_from_dir,
  23. )
  24. from tools.point_match import ( # noqa: E402
  25. DEFAULT_MATCH_THRESHOLD,
  26. match_derivation_to_post_points,
  27. )
  28. try:
  29. from agent.tools import tool, ToolResult, ToolContext
  30. except ImportError:
  31. def tool(*args, **kwargs):
  32. return lambda f: f
  33. ToolResult = None # 仅用 main() 测核心逻辑时可无 agent
  34. ToolContext = None
  35. # 相对本文件:tools -> overall_derivation,input / output 在 overall_derivation 下
  36. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  37. _BASE_OUTPUT = Path(__file__).resolve().parent.parent / "output"
  38. def _dimension_analysis_log_dir(account_name: str, post_id: str, log_id: str) -> Path:
  39. """推导日志目录:output/{account_name}/推导日志/{post_id}/{log_id}/"""
  40. return _BASE_OUTPUT / account_name / "推导日志" / post_id / log_id
  41. def _load_derived_dim_tree_node_names(
  42. account_name: str, post_id: str, log_id: str, round: int
  43. ) -> list[str]:
  44. """
  45. 读取当前轮次对应的维度分析 JSON(优先 {round}_维度分析.json,不存在则 {round-1}_维度分析.json),
  46. 返回 derived_dims 中每项的 tree_node_name(已推导出的维度节点,人设树中层次较高)。
  47. 无可用文件时返回空列表。
  48. """
  49. if not log_id or not str(log_id).strip():
  50. return []
  51. log_dir = _dimension_analysis_log_dir(account_name, post_id, str(log_id).strip())
  52. for r in (round, round - 1):
  53. if r < 1:
  54. continue
  55. path = log_dir / f"{r}_维度分析.json"
  56. if not path.is_file():
  57. continue
  58. try:
  59. with open(path, "r", encoding="utf-8") as f:
  60. data = json.load(f)
  61. except Exception:
  62. continue
  63. dims = data.get("derived_dims") or []
  64. names: list[str] = []
  65. for d in dims:
  66. if isinstance(d, dict):
  67. tn = d.get("tree_node_name")
  68. if tn is not None and str(tn).strip():
  69. names.append(str(tn).strip())
  70. return names
  71. return []
  72. def _descendant_names_under_tree_nodes(
  73. account_name: str, anchor_node_names: list[str]
  74. ) -> tuple[set[str], dict[str, str]]:
  75. """
  76. 在每个人设维度树根上 DFS,收集所有锚点(derived_dims.tree_node_name)之下的**全部后代**(不含锚点自身)。
  77. 同时记录「所属维度」:对路径上每个后代节点,取从维度根到该节点路径上**最深的**那个锚点
  78. (与原先沿父链向上找最近 derived_dim 一致;多个锚点呈祖孙时取更深者)。
  79. """
  80. if not anchor_node_names:
  81. return set(), {}
  82. S = set(anchor_node_names)
  83. allowed: set[str] = set()
  84. dim_map: dict[str, str] = {}
  85. for dim_root_name, root in _load_trees(account_name):
  86. def dfs(node_name: str, node_dict: dict, parent_deepest_s: Optional[str]) -> None:
  87. d_self = node_name if node_name in S else parent_deepest_s
  88. for cname, cnode in (node_dict.get("children") or {}).items():
  89. if not isinstance(cnode, dict):
  90. continue
  91. if cname not in S and d_self is not None:
  92. allowed.add(cname)
  93. dim_map[cname] = d_self
  94. dfs(cname, cnode, d_self)
  95. dfs(dim_root_name, root, None)
  96. return allowed, dim_map
  97. def _tree_dir(account_name: str) -> Path:
  98. """人设树目录:../input/{account_name}/处理后数据/tree/"""
  99. return _BASE_INPUT / account_name / "处理后数据" / "tree"
  100. def _load_trees(account_name: str) -> list[tuple[str, dict]]:
  101. """加载该账号下所有维度的人设树。返回 [(维度名, 根节点 dict), ...]。"""
  102. td = _tree_dir(account_name)
  103. if not td.is_dir():
  104. return []
  105. result = []
  106. for p in td.glob("*.json"):
  107. try:
  108. with open(p, "r", encoding="utf-8") as f:
  109. data = json.load(f)
  110. for dim_name, root in data.items():
  111. if isinstance(root, dict):
  112. result.append((dim_name, root))
  113. break
  114. except Exception:
  115. continue
  116. return result
  117. def _iter_all_nodes(account_name: str):
  118. """遍历该账号下所有人设树节点,产出 (节点名称, 父节点名称, 节点 dict)。"""
  119. for dim_name, root in _load_trees(account_name):
  120. def walk(parent_name: str, node_dict: dict):
  121. for name, child in (node_dict.get("children") or {}).items():
  122. if not isinstance(child, dict):
  123. continue
  124. yield (name, parent_name, child)
  125. yield from walk(name, child)
  126. yield from walk(dim_name, root)
  127. # ---------------------------------------------------------------------------
  128. # 1. 获取人设树常量节点
  129. # ---------------------------------------------------------------------------
  130. def get_constant_nodes(account_name: str) -> list[dict[str, Any]]:
  131. """
  132. 获取人设树的常量节点。
  133. - 全局常量:_is_constant=True
  134. - 局部常量:_is_local_constant=True 且 _is_constant=False
  135. 返回列表项:节点名称、概率(_ratio)、常量类型。
  136. """
  137. result = []
  138. for node_name, _parent, node in _iter_all_nodes(account_name):
  139. is_const = node.get("_is_constant") is True
  140. is_local = node.get("_is_local_constant") is True
  141. if is_const:
  142. const_type = "全局常量"
  143. elif is_local and not is_const:
  144. const_type = "局部常量"
  145. else:
  146. continue
  147. ratio = node.get("_ratio")
  148. result.append({
  149. "节点名称": node_name,
  150. "概率": ratio,
  151. "常量类型": const_type,
  152. })
  153. result.sort(key=lambda x: (x["概率"] is None, -(x["概率"] or 0)))
  154. return result
  155. # ---------------------------------------------------------------------------
  156. # 2. 获取符合条件概率阈值的节点
  157. # ---------------------------------------------------------------------------
  158. def get_nodes_by_conditional_ratio(
  159. account_name: str,
  160. derived_list: list[tuple[str, str]],
  161. threshold: float,
  162. top_n: int,
  163. allowed_node_names: Optional[set[str]] = None,
  164. node_belonging_dim: Optional[dict[str, str]] = None,
  165. ) -> list[dict[str, Any]]:
  166. """
  167. 获取人设树中条件概率 >= threshold 的节点,按条件概率降序,返回前 top_n 个。
  168. derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点);为空时使用节点自身的 _ratio 作为条件概率。
  169. allowed_node_names: 若给定,仅保留节点名称在该集合内的结果。
  170. node_belonging_dim: 与 allowed 同步生成(见 _descendant_names_under_tree_nodes),节点名 -> 所属已推导维度;不传则所属维度均为「—」。
  171. 返回列表项:节点名称、条件概率、父节点名称、所属维度。
  172. """
  173. base_dir = _BASE_INPUT
  174. node_to_parent: dict[str, str] = {}
  175. if derived_list:
  176. for n, p, _ in _iter_all_nodes(account_name):
  177. node_to_parent[n] = p
  178. def dim_for(node_name: str) -> str:
  179. if not node_belonging_dim:
  180. return "—"
  181. return node_belonging_dim.get(node_name) or "—"
  182. scored: list[tuple[str, float, str, str]] = []
  183. if not derived_list:
  184. for node_name, parent_name, node in _iter_all_nodes(account_name):
  185. if allowed_node_names is not None and node_name not in allowed_node_names:
  186. continue
  187. ratio = node.get("_ratio")
  188. if ratio is None:
  189. ratio = 0.0
  190. else:
  191. ratio = float(ratio)
  192. if ratio >= threshold:
  193. scored.append((node_name, ratio, parent_name, dim_for(node_name)))
  194. else:
  195. node_post_index = build_node_post_index(account_name, base_dir)
  196. for node_name, parent_name in node_to_parent.items():
  197. if allowed_node_names is not None and node_name not in allowed_node_names:
  198. continue
  199. ratio = calc_node_conditional_ratio(
  200. account_name,
  201. derived_list,
  202. node_name,
  203. base_dir=base_dir,
  204. node_post_index=node_post_index,
  205. target_ratio=threshold,
  206. )
  207. if ratio >= threshold:
  208. scored.append((node_name, ratio, parent_name, dim_for(node_name)))
  209. scored.sort(key=lambda x: x[1], reverse=True)
  210. top = scored[:top_n]
  211. return [
  212. {
  213. "节点名称": name,
  214. "条件概率": ratio,
  215. "父节点名称": parent,
  216. "所属维度": dim,
  217. }
  218. for name, ratio, parent, dim in top
  219. ]
  220. def _platform_tree_dir() -> Path:
  221. """平台库人设树目录:../input/xiaohongshu/tree/"""
  222. return _BASE_INPUT / "xiaohongshu" / "tree"
  223. def get_platform_nodes_by_conditional_ratio(
  224. derived_list: list[tuple[str, str]],
  225. threshold: float,
  226. top_n: int,
  227. ) -> list[dict[str, Any]]:
  228. """
  229. 平台库人设树节点条件概率筛选,计算方式与 get_nodes_by_conditional_ratio 一致
  230. (同一套 calc_node_conditional_ratio / _post_ids 规则,索引来自 xiaohongshu/tree)。
  231. derived_list 为空时用节点 _ratio。
  232. """
  233. tree_dir = _platform_tree_dir()
  234. if not tree_dir.is_dir():
  235. return []
  236. scored: list[tuple[str, float, str, str]] = []
  237. if not derived_list:
  238. for dim_name, root in load_persona_trees_from_dir(tree_dir):
  239. def walk(parent_name: str, node_dict: dict) -> None:
  240. for name, child in (node_dict.get("children") or {}).items():
  241. if not isinstance(child, dict):
  242. continue
  243. ratio = child.get("_ratio")
  244. if ratio is None:
  245. r = 0.0
  246. else:
  247. r = float(ratio)
  248. if r >= threshold:
  249. scored.append((name, r, parent_name, dim_name))
  250. walk(name, child)
  251. walk(dim_name, root)
  252. else:
  253. node_post_index = build_node_post_index_from_tree_dir(tree_dir)
  254. node_to_parent_dim: dict[str, tuple[str, str]] = {}
  255. for dim_name, root in load_persona_trees_from_dir(tree_dir):
  256. def walk2(parent_name: str, node_dict: dict) -> None:
  257. for name, child in (node_dict.get("children") or {}).items():
  258. if not isinstance(child, dict):
  259. continue
  260. node_to_parent_dim[name] = (parent_name, dim_name)
  261. walk2(name, child)
  262. walk2(dim_name, root)
  263. for node_name, (parent_name, dim_name) in node_to_parent_dim.items():
  264. ratio = calc_node_conditional_ratio(
  265. "",
  266. derived_list,
  267. node_name,
  268. base_dir=_BASE_INPUT,
  269. node_post_index=node_post_index,
  270. target_ratio=threshold,
  271. )
  272. if ratio >= threshold:
  273. scored.append((node_name, ratio, parent_name, dim_name))
  274. scored.sort(key=lambda x: x[1], reverse=True)
  275. top = scored[:top_n]
  276. return [
  277. {
  278. "节点名称": name,
  279. "条件概率": ratio,
  280. "父节点名称": parent,
  281. "所属维度": dim,
  282. }
  283. for name, ratio, parent, dim in top
  284. ]
  285. def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]:
  286. """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。"""
  287. out = []
  288. for item in derived_items:
  289. if isinstance(item, dict):
  290. topic = item.get("topic") or item.get("已推导的选题点")
  291. source = item.get("source_node") or item.get("推导来源人设树节点")
  292. if topic is not None and source is not None:
  293. out.append((str(topic).strip(), str(source).strip()))
  294. elif isinstance(item, (list, tuple)) and len(item) >= 2:
  295. out.append((str(item[0]).strip(), str(item[1]).strip()))
  296. return out
  297. # ---------------------------------------------------------------------------
  298. # 3. 平台库人设树辅助节点(基于帖子与平台库人设树匹配结果)
  299. # ---------------------------------------------------------------------------
  300. def _platform_match_topics_by_node(
  301. post_id: str,
  302. match_score_threshold: float,
  303. ) -> dict[tuple[str, str], dict[str, float]]:
  304. """
  305. 读取 xiaohongshu/match_data/{post_id}_匹配_all.json,
  306. 返回 (dimension, 人设树节点名) -> {帖子选题点: 最高分},仅收录 match_score >= match_score_threshold 的对。
  307. """
  308. out: dict[tuple[str, str], dict[str, float]] = {}
  309. if not post_id:
  310. return out
  311. path = _BASE_INPUT / "xiaohongshu" / "match_data" / f"{post_id}_匹配_all.json"
  312. if not path.is_file():
  313. return out
  314. try:
  315. with open(path, "r", encoding="utf-8") as f:
  316. data = json.load(f)
  317. except Exception:
  318. return out
  319. if not isinstance(data, list):
  320. return out
  321. thr = float(match_score_threshold)
  322. for item in data:
  323. if not isinstance(item, dict):
  324. continue
  325. topic = item.get("name")
  326. matches = item.get("match_personas")
  327. if topic is None or not isinstance(matches, list):
  328. continue
  329. topic_s = str(topic).strip()
  330. if not topic_s:
  331. continue
  332. for m in matches:
  333. if not isinstance(m, dict):
  334. continue
  335. name = m.get("name")
  336. dim = m.get("dimension")
  337. score = m.get("match_score")
  338. if name is None or dim is None or score is None:
  339. continue
  340. try:
  341. s = float(score)
  342. except Exception:
  343. continue
  344. if s < thr:
  345. continue
  346. key = (str(dim).strip(), str(name).strip())
  347. bucket = out.setdefault(key, {})
  348. prev = bucket.get(topic_s)
  349. if prev is None or s > prev:
  350. bucket[topic_s] = s
  351. return out
  352. def _platform_node_belonging_dim_from_anchor_nodes(
  353. anchor_node_names: list[str],
  354. ) -> dict[str, str]:
  355. """
  356. 计算平台库人设树中:节点名 -> 所属最深 derived_dim 锚点节点名。
  357. 逻辑与账号段 _descendant_names_under_tree_nodes 保持一致(但树结构来自 xiaohongshu/tree)。
  358. """
  359. if not anchor_node_names:
  360. return {}
  361. S = set(anchor_node_names)
  362. dim_map: dict[str, str] = {}
  363. tree_dir = _platform_tree_dir()
  364. if not tree_dir.is_dir():
  365. return {}
  366. for dim_root_name, root in load_persona_trees_from_dir(tree_dir):
  367. def dfs(node_name: str, node_dict: dict, parent_deepest_s: Optional[str]) -> None:
  368. d_self = node_name if node_name in S else parent_deepest_s
  369. for cname, cnode in (node_dict.get("children") or {}).items():
  370. if not isinstance(cnode, dict):
  371. continue
  372. if cname not in S and d_self is not None:
  373. dim_map[cname] = d_self
  374. dfs(cname, cnode, d_self)
  375. dfs(dim_root_name, root, None)
  376. return dim_map
  377. def _load_platform_match_nodes(
  378. post_id: str,
  379. derived_list: list[tuple[str, str]],
  380. conditional_ratio_threshold: float,
  381. match_score_threshold: float,
  382. top_n: int,
  383. node_belonging_dim_platform: Optional[dict[str, str]] = None,
  384. ) -> list[dict[str, Any]]:
  385. """
  386. 平台库人设树:先按与账号一致的条件概率筛选(get_platform_nodes_by_conditional_ratio),
  387. 再仅保留在 xiaohongshu 匹配文件中、且单条 match_score >= match_score_threshold 的帖子选题点;
  388. 无达标选题点匹配的节点丢弃。
  389. """
  390. candidates = get_platform_nodes_by_conditional_ratio(
  391. derived_list,
  392. float(conditional_ratio_threshold),
  393. int(top_n),
  394. )
  395. if not candidates or not post_id:
  396. return []
  397. topic_map = _platform_match_topics_by_node(post_id, float(match_score_threshold))
  398. out: list[dict[str, Any]] = []
  399. for row in candidates:
  400. lookup_dim = str(row.get("所属维度") or "").strip()
  401. name = str(row.get("节点名称") or "").strip()
  402. key = (lookup_dim, name)
  403. topics = topic_map.get(key) or {}
  404. if not topics:
  405. continue
  406. topic_items = sorted(topics.items(), key=lambda x: x[1], reverse=True)
  407. match_list = [{"帖子选题点": t, "匹配分数": sc} for t, sc in topic_items]
  408. out_dim = "—"
  409. if node_belonging_dim_platform is not None:
  410. out_dim = node_belonging_dim_platform.get(name) or "—"
  411. if out_dim == "—":
  412. continue
  413. row_out = dict(row)
  414. row_out["所属维度"] = out_dim
  415. row_out["帖子选题点匹配"] = match_list
  416. out.append(row_out)
  417. return out
  418. def build_platform_tree_section_items(
  419. post_id: str,
  420. derived_list: list[tuple[str, str]],
  421. conditional_ratio_threshold: float,
  422. match_score_threshold: float,
  423. top_n: int,
  424. exclude_node_names: set[str],
  425. node_belonging_dim_platform: Optional[dict[str, str]] = None,
  426. ) -> list[dict[str, Any]]:
  427. """
  428. 平台库人设树节点:条件概率 + xiaohongshu/match_data 匹配,并排除与账号段重复的节点名称。
  429. 供 find_tree_nodes_by_conditional_ratio 聚合输出使用。
  430. """
  431. if not post_id:
  432. return []
  433. plat = _load_platform_match_nodes(
  434. post_id=post_id,
  435. derived_list=derived_list,
  436. conditional_ratio_threshold=float(conditional_ratio_threshold),
  437. match_score_threshold=float(match_score_threshold),
  438. top_n=int(top_n),
  439. node_belonging_dim_platform=node_belonging_dim_platform,
  440. )
  441. ex = {str(n).strip() for n in exclude_node_names}
  442. return [
  443. p for p in plat
  444. if str(p.get("节点名称", "")).strip() not in ex
  445. ]
  446. # ---------------------------------------------------------------------------
  447. # Agent Tools(参考 glob_tool 封装)
  448. # ---------------------------------------------------------------------------
  449. @tool()
  450. async def find_tree_constant_nodes(
  451. account_name: str,
  452. post_id: str,
  453. ) -> ToolResult:
  454. """
  455. 获取人设树中的常量节点列表(全局常量与局部常量),并检查每个节点与帖子选题点的匹配情况。
  456. Args:
  457. account_name : 账号名,用于定位该账号的人设树数据。
  458. post_id : 帖子ID,用于加载帖子选题点并与各常量节点做匹配判断。
  459. Returns:
  460. ToolResult:
  461. - title: 结果标题。
  462. - output: 可读的节点列表文本(每行:节点名称、概率、常量类型、帖子匹配情况)。
  463. - 出错时 error 为错误信息。
  464. """
  465. tree_dir = _tree_dir(account_name)
  466. if not tree_dir.is_dir():
  467. return ToolResult(
  468. title="人设树目录不存在",
  469. output=f"目录不存在: {tree_dir}",
  470. error="Directory not found",
  471. )
  472. try:
  473. items = get_constant_nodes(account_name)
  474. # 批量匹配所有节点与帖子选题点
  475. if items and post_id:
  476. node_names = [x["节点名称"] for x in items]
  477. matched_results = await match_derivation_to_post_points(
  478. node_names, account_name, post_id, match_threshold=float(DEFAULT_MATCH_THRESHOLD)
  479. )
  480. node_match_map: dict[str, list] = {}
  481. for m in matched_results:
  482. node_match_map.setdefault(m["推导选题点"], []).append({
  483. "帖子选题点": m["帖子选题点"],
  484. "匹配分数": m["匹配分数"],
  485. })
  486. for item in items:
  487. matches = node_match_map.get(item["节点名称"], [])
  488. item["帖子选题点匹配"] = matches if matches else "无"
  489. if not items:
  490. output = "未找到常量节点"
  491. else:
  492. lines = []
  493. for x in items:
  494. match_info = x.get("帖子选题点匹配", "无")
  495. if isinstance(match_info, list):
  496. match_str = "、".join(f"{m['帖子选题点']}({m['匹配分数']})" for m in match_info)
  497. else:
  498. match_str = str(match_info)
  499. lines.append(f"- {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}\t帖子选题点匹配={match_str}")
  500. output = "\n".join(lines)
  501. return ToolResult(
  502. title=f"常量节点 ({account_name})",
  503. output=output,
  504. metadata={"account_name": account_name, "count": len(items)},
  505. )
  506. except Exception as e:
  507. return ToolResult(
  508. title="获取常量节点失败",
  509. output=str(e),
  510. error=str(e),
  511. )
  512. @tool()
  513. async def find_tree_nodes_by_conditional_ratio(
  514. account_name: str,
  515. post_id: str,
  516. derived_items: list[dict[str, str]],
  517. conditional_ratio_threshold: float,
  518. top_n: int = 100,
  519. round: int = 1,
  520. log_id: str = "",
  521. match_score_threshold: float = DEFAULT_MATCH_THRESHOLD,
  522. ) -> ToolResult:
  523. """
  524. 按条件概率阈值筛选节点:先账号人设树(优先使用),再平台库人设树;两段不合并。
  525. 条件概率计算对两棵树使用同一套规则(calc_node_conditional_ratio / 节点 _post_ids)。
  526. 「帖子选题点匹配」仅保留匹配分 >= match_score_threshold 的选题点;无达标匹配的节点不返回。
  527. Args:
  528. account_name : 账号名,用于定位该账号的人设树数据。
  529. post_id : 帖子ID,用于加载帖子选题点并与各节点做匹配判断。
  530. derived_items : 已推导选题点列表,可为空。非空时每项为字典,需含 topic(或「已推导的选题点」)与 source_node(或「推导来源人设树节点」)
  531. conditional_ratio_threshold : 条件概率阈值,仅返回条件概率 >= 该值的节点。
  532. top_n : 返回条数上限(账号段、平台段各自取前 top_n 条条件概率结果后再按匹配过滤)。
  533. round : 推导轮次。
  534. log_id : 推导日志ID
  535. match_score_threshold : 帖子选题点匹配分阈值,与 point_match 默认一致。
  536. Returns:
  537. ToolResult:
  538. - title: 结果标题。
  539. - output: 两段文本——先账号人设树,后平台库人设树;
  540. 账号侧匹配来自 input/{账号}/match_data;平台侧条件概率基于 input/xiaohongshu/tree,匹配来自 input/xiaohongshu/match_data。
  541. - 出错时 error 为错误信息。
  542. """
  543. tree_dir = _tree_dir(account_name)
  544. if not tree_dir.is_dir():
  545. return ToolResult(
  546. title="人设树目录不存在",
  547. output=f"目录不存在: {tree_dir}",
  548. error="Directory not found",
  549. )
  550. try:
  551. derived_list = _parse_derived_list(derived_items or [])
  552. allowed: Optional[set[str]] = None
  553. node_belonging_dim: dict[str, str] = {}
  554. node_belonging_dim_platform: Optional[dict[str, str]] = None
  555. dim_source = ""
  556. derived_dim_names: list[str] = []
  557. derived_items_len = len(derived_items or [])
  558. if log_id and str(log_id).strip():
  559. derived_dim_names = _load_derived_dim_tree_node_names(
  560. account_name, post_id, str(log_id).strip(), int(round)
  561. )
  562. if derived_dim_names:
  563. allowed, node_belonging_dim = _descendant_names_under_tree_nodes(
  564. account_name, derived_dim_names
  565. )
  566. node_belonging_dim_platform = _platform_node_belonging_dim_from_anchor_nodes(
  567. derived_dim_names
  568. )
  569. # 记录实际用到的维度分析文件(与读取逻辑一致)
  570. log_dir = _dimension_analysis_log_dir(account_name, post_id, str(log_id).strip())
  571. for r in (int(round), int(round) - 1):
  572. if r >= 1 and (log_dir / f"{r}_维度分析.json").is_file():
  573. dim_source = f"{r}_维度分析.json (derived_dims -> 全部后代)"
  574. break
  575. else:
  576. dim_source = "未读到 derived_dims(无对应维度分析文件或为空),未收窄"
  577. # 当 derived_items 太多时,用 derived_dim_names 作为条件概率计算锚点:
  578. # 将每个 derived_dim_names 的 name 都映射为 (topic=name, source_node=name)。
  579. if derived_items_len > 15 and derived_dim_names:
  580. derived_list = [(n, n) for n in derived_dim_names]
  581. # 1)账号人设树:按条件概率筛选;帖子选题点匹配仅走账号 match_data(match_derivation_to_post_points)
  582. items = get_nodes_by_conditional_ratio(
  583. account_name,
  584. derived_list,
  585. conditional_ratio_threshold,
  586. top_n,
  587. allowed_node_names=allowed,
  588. node_belonging_dim=node_belonging_dim if node_belonging_dim else None,
  589. )
  590. if items and post_id:
  591. node_names = [x["节点名称"] for x in items]
  592. matched_results = await match_derivation_to_post_points(
  593. node_names, account_name, post_id, match_threshold=float(match_score_threshold)
  594. )
  595. node_match_map: dict[str, list] = {}
  596. for m in matched_results:
  597. node_match_map.setdefault(m["推导选题点"], []).append({
  598. "帖子选题点": m["帖子选题点"],
  599. "匹配分数": m["匹配分数"],
  600. })
  601. for item in items:
  602. matches = node_match_map.get(item["节点名称"], [])
  603. item["帖子选题点匹配"] = matches if matches else "无"
  604. # [临时] 仅保留有帖子选题点匹配的记录(过滤掉「无」),方便后续删除
  605. items = [x for x in items if isinstance(x.get("帖子选题点匹配"), list)]
  606. # 2)平台库人设树(条件概率 + xiaohongshu 匹配文件;与账号节点同名则剔除)
  607. account_node_names = {str(x.get("节点名称", "")).strip() for x in items}
  608. platform_items: list[dict[str, Any]] = []
  609. if post_id:
  610. platform_items = build_platform_tree_section_items(
  611. post_id=post_id,
  612. derived_list=derived_list,
  613. conditional_ratio_threshold=float(conditional_ratio_threshold),
  614. match_score_threshold=float(match_score_threshold),
  615. top_n=top_n,
  616. exclude_node_names=account_node_names,
  617. node_belonging_dim_platform=node_belonging_dim_platform,
  618. )
  619. def _format_node_line(x: dict[str, Any]) -> str:
  620. match_info = x.get("帖子选题点匹配", "无")
  621. if isinstance(match_info, list):
  622. match_str = "、".join(f"{m['帖子选题点']}({m['匹配分数']})" for m in match_info)
  623. else:
  624. match_str = str(match_info)
  625. dim_label = x.get("所属维度", "—")
  626. return (
  627. f"- {x['节点名称']}\t条件概率={x['条件概率']}\t所属维度={dim_label}"
  628. f"\t帖子选题点匹配={match_str}"
  629. )
  630. lines: list[str] = []
  631. lines.append(
  632. "【优先使用】第一节为账号人设树中条件概率达标的节点;"
  633. "第二节为平台库人设树中条件概率达标的节点;"
  634. )
  635. lines.append("")
  636. lines.append("—— 账号人设树节点 ——")
  637. if not items:
  638. lines.append(f"(无:未找到条件概率 >= {conditional_ratio_threshold} 且与帖子选题点有匹配的节点)")
  639. else:
  640. lines.extend(_format_node_line(x) for x in items)
  641. lines.append("")
  642. lines.append("—— 平台库人设树节点 ——")
  643. if not platform_items:
  644. lines.append(
  645. "(无:未找到条件概率达标且存在达标帖子选题点匹配的节点)"
  646. )
  647. else:
  648. lines.extend(_format_node_line(x) for x in platform_items)
  649. output = "\n".join(lines)
  650. return ToolResult(
  651. title=f"条件概率节点 ({account_name}, 阈值={conditional_ratio_threshold})",
  652. output=output,
  653. metadata={
  654. "account_name": account_name,
  655. "threshold": conditional_ratio_threshold,
  656. "match_score_threshold": float(match_score_threshold),
  657. "top_n": top_n,
  658. "account_tree_count": len(items),
  659. "platform_tree_count": len(platform_items),
  660. "count": len(items) + len(platform_items),
  661. "round": int(round),
  662. "log_id": str(log_id).strip() if log_id else "",
  663. "dimension_filter": {
  664. "derived_dim_nodes": derived_dim_names,
  665. "allowed_descendant_count": len(allowed) if allowed is not None else None,
  666. "source": dim_source or ("未提供 log_id,未按维度收窄" if not (log_id and str(log_id).strip()) else ""),
  667. },
  668. },
  669. )
  670. except Exception as e:
  671. return ToolResult(
  672. title="按条件概率查询节点失败",
  673. output=str(e),
  674. error=str(e),
  675. )
  676. def main() -> None:
  677. """本地测试:用家有大志账号测常量节点与条件概率节点,有 agent 时再跑一遍 tool 接口。"""
  678. import asyncio
  679. account_name = "阿里多多酱"
  680. post_id = "6915dfc400000000070224d9"
  681. log_id = "20260320160445"
  682. round = 4
  683. # derived_items = [
  684. # {"topic": "分享", "source_node": "分享"},
  685. # {"topic": "叙事结构", "source_node": "叙事结构"},
  686. # ]
  687. derived_items = [{"topic":"推广","source_node":"推广"},{"topic":"视觉调性","source_node":"视觉调性"}]
  688. conditional_ratio_threshold = 0.2
  689. top_n = 2000
  690. # # 1)常量节点(核心函数,无匹配)
  691. # constant_nodes = get_constant_nodes(account_name)
  692. # print(f"账号: {account_name} — 常量节点共 {len(constant_nodes)} 个(前 50 个):")
  693. # for x in constant_nodes[:50]:
  694. # print(f" - {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}")
  695. # print()
  696. #
  697. # # 2)条件概率节点(核心函数)
  698. # derived_list = _parse_derived_list(derived_items)
  699. # ratio_nodes = get_nodes_by_conditional_ratio(
  700. # account_name, derived_list, conditional_ratio_threshold, top_n
  701. # )
  702. # print(f"条件概率节点 阈值={conditional_ratio_threshold}, top_n={top_n}, 共 {len(ratio_nodes)} 个:")
  703. # for x in ratio_nodes:
  704. # print(f" - {x['节点名称']}\t条件概率={x['条件概率']}\t父节点={x['父节点名称']}")
  705. # print()
  706. # 3)有 agent 时通过 tool 接口再跑一遍(含帖子选题点匹配)
  707. if ToolResult is not None:
  708. async def run_tools():
  709. # r1 = await find_tree_constant_nodes(account_name, post_id=post_id)
  710. # print("--- find_tree_constant_nodes ---")
  711. # print(r1.output[:2000] + "..." if len(r1.output) > 2000 else r1.output)
  712. r2 = await find_tree_nodes_by_conditional_ratio(
  713. account_name,
  714. post_id=post_id,
  715. derived_items=derived_items,
  716. conditional_ratio_threshold=conditional_ratio_threshold,
  717. top_n=top_n,
  718. round=round,
  719. log_id=log_id,
  720. )
  721. print("\n--- find_tree_nodes_by_conditional_ratio ---")
  722. print(r2.output)
  723. asyncio.run(run_tools())
  724. if __name__ == "__main__":
  725. main()