find_tree_node.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841
  1. """
  2. 查找树节点 Tool - 人设树节点查询
  3. 功能:
  4. 1. 获取人设树的常量节点(全局常量、局部常量)
  5. 2. 获取符合条件概率阈值的节点(按条件概率排序返回 topN)
  6. 平台库人设树(第二节输出)流水线(由 build_platform_tree_section_items 聚合):
  7. xiaohongshu/tree → 与账号相同的条件概率计算 → xiaohongshu/match_data 按匹配分过滤选题点
  8. → 剔除与账号段同名的节点。
  9. """
  10. import json
  11. import sys
  12. from pathlib import Path
  13. from typing import Any, Optional
  14. # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
  15. _root = Path(__file__).resolve().parent.parent
  16. if str(_root) not in sys.path:
  17. sys.path.insert(0, str(_root))
  18. from utils.conditional_ratio_calc import ( # noqa: E402
  19. build_node_post_index,
  20. build_node_post_index_from_tree_dir,
  21. calc_node_conditional_ratio,
  22. load_persona_trees_from_dir,
  23. )
  24. try:
  25. from agent.tools import tool, ToolResult, ToolContext
  26. except ImportError:
  27. def tool(*args, **kwargs):
  28. return lambda f: f
  29. ToolResult = None # 仅用 main() 测核心逻辑时可无 agent
  30. ToolContext = None
  31. # 相对本文件:tools -> overall_derivation,input / output 在 overall_derivation 下
  32. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  33. _BASE_OUTPUT = Path(__file__).resolve().parent.parent / "output"
  34. def _dimension_analysis_log_dir(account_name: str, post_id: str, log_id: str) -> Path:
  35. """推导日志目录:output/{account_name}/推导日志/{post_id}/{log_id}/"""
  36. return _BASE_OUTPUT / account_name / "推导日志" / post_id / log_id
  37. def _load_derived_dim_tree_node_names(
  38. account_name: str, post_id: str, log_id: str, round: int
  39. ) -> list[str]:
  40. """
  41. 读取当前轮次对应的维度分析 JSON(优先 {round}_维度分析.json,不存在则 {round-1}_维度分析.json),
  42. 返回 derived_dims 中每项的 tree_node_name(已推导出的维度节点,人设树中层次较高)。
  43. 无可用文件时返回空列表。
  44. """
  45. if not log_id or not str(log_id).strip():
  46. return []
  47. log_dir = _dimension_analysis_log_dir(account_name, post_id, str(log_id).strip())
  48. for r in (round, round - 1):
  49. if r < 1:
  50. continue
  51. path = log_dir / f"{r}_维度分析.json"
  52. if not path.is_file():
  53. continue
  54. try:
  55. with open(path, "r", encoding="utf-8") as f:
  56. data = json.load(f)
  57. except Exception:
  58. continue
  59. dims = data.get("derived_dims") or []
  60. names: list[str] = []
  61. for d in dims:
  62. if isinstance(d, dict):
  63. tn = d.get("tree_node_name")
  64. if tn is not None and str(tn).strip():
  65. names.append(str(tn).strip())
  66. return names
  67. return []
  68. def _descendant_names_under_tree_nodes(
  69. account_name: str, anchor_node_names: list[str]
  70. ) -> tuple[set[str], dict[str, str]]:
  71. """
  72. 在每个人设维度树根上 DFS,收集所有锚点(derived_dims.tree_node_name)之下的**全部后代**(不含锚点自身)。
  73. 同时记录「所属维度」:对路径上每个后代节点,取从维度根到该节点路径上**最深的**那个锚点
  74. (与原先沿父链向上找最近 derived_dim 一致;多个锚点呈祖孙时取更深者)。
  75. """
  76. if not anchor_node_names:
  77. return set(), {}
  78. S = set(anchor_node_names)
  79. allowed: set[str] = set()
  80. dim_map: dict[str, str] = {}
  81. for dim_root_name, root in _load_trees(account_name):
  82. def dfs(node_name: str, node_dict: dict, parent_deepest_s: Optional[str]) -> None:
  83. d_self = node_name if node_name in S else parent_deepest_s
  84. for cname, cnode in (node_dict.get("children") or {}).items():
  85. if not isinstance(cnode, dict):
  86. continue
  87. if cname not in S and d_self is not None:
  88. allowed.add(cname)
  89. dim_map[cname] = d_self
  90. dfs(cname, cnode, d_self)
  91. dfs(dim_root_name, root, None)
  92. return allowed, dim_map
  93. def _tree_dir(account_name: str) -> Path:
  94. """人设树目录:../input/{account_name}/处理后数据/tree/"""
  95. return _BASE_INPUT / account_name / "处理后数据" / "tree"
  96. def _load_trees(account_name: str) -> list[tuple[str, dict]]:
  97. """加载该账号下所有维度的人设树。返回 [(维度名, 根节点 dict), ...]。"""
  98. td = _tree_dir(account_name)
  99. if not td.is_dir():
  100. return []
  101. result = []
  102. for p in td.glob("*.json"):
  103. try:
  104. with open(p, "r", encoding="utf-8") as f:
  105. data = json.load(f)
  106. for dim_name, root in data.items():
  107. if isinstance(root, dict):
  108. result.append((dim_name, root))
  109. break
  110. except Exception:
  111. continue
  112. return result
  113. def _iter_all_nodes(account_name: str):
  114. """遍历该账号下所有人设树节点,产出 (节点名称, 父节点名称, 节点 dict)。"""
  115. for dim_name, root in _load_trees(account_name):
  116. def walk(parent_name: str, node_dict: dict):
  117. for name, child in (node_dict.get("children") or {}).items():
  118. if not isinstance(child, dict):
  119. continue
  120. yield (name, parent_name, child)
  121. yield from walk(name, child)
  122. yield from walk(dim_name, root)
  123. # ---------------------------------------------------------------------------
  124. # 1. 获取人设树常量节点
  125. # ---------------------------------------------------------------------------
  126. def get_constant_nodes(account_name: str) -> list[dict[str, Any]]:
  127. """
  128. 获取人设树的常量节点。
  129. - 全局常量:_is_constant=True
  130. - 局部常量:_is_local_constant=True 且 _is_constant=False
  131. 返回列表项:节点名称、概率(_ratio)、常量类型。
  132. """
  133. result = []
  134. for node_name, _parent, node in _iter_all_nodes(account_name):
  135. is_const = node.get("_is_constant") is True
  136. is_local = node.get("_is_local_constant") is True
  137. if is_const:
  138. const_type = "全局常量"
  139. elif is_local and not is_const:
  140. const_type = "局部常量"
  141. else:
  142. continue
  143. ratio = node.get("_ratio")
  144. result.append({
  145. "节点名称": node_name,
  146. "概率": ratio,
  147. "常量类型": const_type,
  148. })
  149. result.sort(key=lambda x: (x["概率"] is None, -(x["概率"] or 0)))
  150. return result
  151. # ---------------------------------------------------------------------------
  152. # 2. 获取符合条件概率阈值的节点
  153. # ---------------------------------------------------------------------------
  154. def get_nodes_by_conditional_ratio(
  155. account_name: str,
  156. derived_list: list[tuple[str, str]],
  157. threshold: float,
  158. top_n: int,
  159. allowed_node_names: Optional[set[str]] = None,
  160. node_belonging_dim: Optional[dict[str, str]] = None,
  161. ) -> list[dict[str, Any]]:
  162. """
  163. 获取人设树中条件概率 >= threshold 的节点,按条件概率降序,返回前 top_n 个。
  164. derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点);为空时使用节点自身的 _ratio 作为条件概率。
  165. allowed_node_names: 若给定,仅保留节点名称在该集合内的结果。
  166. node_belonging_dim: 与 allowed 同步生成(见 _descendant_names_under_tree_nodes),节点名 -> 所属已推导维度;不传则所属维度均为「—」。
  167. 返回列表项:节点名称、条件概率、父节点名称、所属维度。
  168. """
  169. base_dir = _BASE_INPUT
  170. node_to_parent: dict[str, str] = {}
  171. if derived_list:
  172. for n, p, _ in _iter_all_nodes(account_name):
  173. node_to_parent[n] = p
  174. def dim_for(node_name: str) -> str:
  175. if not node_belonging_dim:
  176. return "—"
  177. return node_belonging_dim.get(node_name) or "—"
  178. scored: list[tuple[str, float, str, str]] = []
  179. if not derived_list:
  180. for node_name, parent_name, node in _iter_all_nodes(account_name):
  181. if allowed_node_names is not None and node_name not in allowed_node_names:
  182. continue
  183. ratio = node.get("_ratio")
  184. if ratio is None:
  185. ratio = 0.0
  186. else:
  187. ratio = float(ratio)
  188. if ratio >= threshold:
  189. scored.append((node_name, ratio, parent_name, dim_for(node_name)))
  190. else:
  191. node_post_index = build_node_post_index(account_name, base_dir)
  192. for node_name, parent_name in node_to_parent.items():
  193. if allowed_node_names is not None and node_name not in allowed_node_names:
  194. continue
  195. ratio = calc_node_conditional_ratio(
  196. account_name,
  197. derived_list,
  198. node_name,
  199. base_dir=base_dir,
  200. node_post_index=node_post_index,
  201. target_ratio=threshold,
  202. )
  203. if ratio >= threshold:
  204. scored.append((node_name, ratio, parent_name, dim_for(node_name)))
  205. scored.sort(key=lambda x: x[1], reverse=True)
  206. top = scored[:top_n]
  207. return [
  208. {
  209. "节点名称": name,
  210. "条件概率": ratio,
  211. "父节点名称": parent,
  212. "所属维度": dim,
  213. }
  214. for name, ratio, parent, dim in top
  215. ]
  216. def _platform_tree_dir() -> Path:
  217. """平台库人设树目录:../input/xiaohongshu/tree/"""
  218. return _BASE_INPUT / "xiaohongshu" / "tree"
  219. def _collect_platform_scored_tuples(
  220. derived_list: list[tuple[str, str]],
  221. threshold: float,
  222. max_nodes: int = 12000,
  223. ) -> list[tuple[str, float, str, str]]:
  224. """
  225. 平台库人设树:条件概率 >= threshold 的节点全量收集,按条件概率降序。
  226. max_nodes 防止极端大树占满内存;截断发生在全局排序之后(保留高分段)。
  227. """
  228. tree_dir = _platform_tree_dir()
  229. if not tree_dir.is_dir():
  230. return []
  231. thr = float(threshold)
  232. scored: list[tuple[str, float, str, str]] = []
  233. if not derived_list:
  234. for dim_name, root in load_persona_trees_from_dir(tree_dir):
  235. def walk(parent_name: str, node_dict: dict) -> None:
  236. for name, child in (node_dict.get("children") or {}).items():
  237. if not isinstance(child, dict):
  238. continue
  239. ratio = child.get("_ratio")
  240. r = 0.0 if ratio is None else float(ratio)
  241. if r >= thr:
  242. scored.append((name, r, parent_name, dim_name))
  243. walk(name, child)
  244. walk(dim_name, root)
  245. else:
  246. node_post_index = build_node_post_index_from_tree_dir(tree_dir)
  247. node_to_parent_dim: dict[str, tuple[str, str]] = {}
  248. for dim_name, root in load_persona_trees_from_dir(tree_dir):
  249. def walk2(parent_name: str, node_dict: dict) -> None:
  250. for name, child in (node_dict.get("children") or {}).items():
  251. if not isinstance(child, dict):
  252. continue
  253. node_to_parent_dim[name] = (parent_name, dim_name)
  254. walk2(name, child)
  255. walk2(dim_name, root)
  256. for node_name, (parent_name, dim_name) in node_to_parent_dim.items():
  257. ratio = calc_node_conditional_ratio(
  258. "",
  259. derived_list,
  260. node_name,
  261. base_dir=_BASE_INPUT,
  262. node_post_index=node_post_index,
  263. target_ratio=thr,
  264. )
  265. if ratio >= thr:
  266. scored.append((node_name, ratio, parent_name, dim_name))
  267. scored.sort(key=lambda x: x[1], reverse=True)
  268. if max_nodes > 0 and len(scored) > max_nodes:
  269. scored = scored[:max_nodes]
  270. return scored
  271. def get_platform_nodes_by_conditional_ratio(
  272. derived_list: list[tuple[str, str]],
  273. threshold: float,
  274. top_n: int,
  275. ) -> list[dict[str, Any]]:
  276. """
  277. 平台库人设树节点条件概率筛选,计算方式与 get_nodes_by_conditional_ratio 一致
  278. (同一套 calc_node_conditional_ratio / _post_ids 规则,索引来自 xiaohongshu/tree)。
  279. derived_list 为空时用节点 _ratio。
  280. """
  281. n = max(0, int(top_n))
  282. scored = _collect_platform_scored_tuples(derived_list, threshold)
  283. top = scored[:n]
  284. return [
  285. {
  286. "节点名称": name,
  287. "条件概率": ratio,
  288. "父节点名称": parent,
  289. "所属维度": dim,
  290. }
  291. for name, ratio, parent, dim in top
  292. ]
  293. def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]:
  294. """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。"""
  295. out = []
  296. for item in derived_items:
  297. if isinstance(item, dict):
  298. topic = item.get("topic") or item.get("已推导的选题点")
  299. source = item.get("source_node") or item.get("推导来源人设树节点")
  300. if topic is not None and source is not None:
  301. out.append((str(topic).strip(), str(source).strip()))
  302. elif isinstance(item, (list, tuple)) and len(item) >= 2:
  303. out.append((str(item[0]).strip(), str(item[1]).strip()))
  304. return out
  305. # ---------------------------------------------------------------------------
  306. # 3. 平台库人设树辅助节点(基于帖子与平台库人设树匹配结果)
  307. # ---------------------------------------------------------------------------
  308. def _platform_match_topics_by_node(
  309. post_id: str,
  310. match_score_threshold: float,
  311. ) -> dict[tuple[str, str], dict[str, float]]:
  312. """
  313. 读取 xiaohongshu/match_data/{post_id}_匹配_all.json,
  314. 返回 (dimension, 人设树节点名) -> {帖子选题点: 最高分},仅收录 match_score >= match_score_threshold 的对。
  315. """
  316. out: dict[tuple[str, str], dict[str, float]] = {}
  317. if not post_id:
  318. return out
  319. path = _BASE_INPUT / "xiaohongshu" / "match_data" / f"{post_id}_匹配_all.json"
  320. if not path.is_file():
  321. return out
  322. try:
  323. with open(path, "r", encoding="utf-8") as f:
  324. data = json.load(f)
  325. except Exception:
  326. return out
  327. if not isinstance(data, list):
  328. return out
  329. thr = float(match_score_threshold)
  330. for item in data:
  331. if not isinstance(item, dict):
  332. continue
  333. topic = item.get("name")
  334. matches = item.get("match_personas")
  335. if topic is None or not isinstance(matches, list):
  336. continue
  337. topic_s = str(topic).strip()
  338. if not topic_s:
  339. continue
  340. for m in matches:
  341. if not isinstance(m, dict):
  342. continue
  343. name = m.get("name")
  344. dim = m.get("dimension")
  345. score = m.get("match_score")
  346. if name is None or dim is None or score is None:
  347. continue
  348. try:
  349. s = float(score)
  350. except Exception:
  351. continue
  352. if s < thr:
  353. continue
  354. key = (str(dim).strip(), str(name).strip())
  355. bucket = out.setdefault(key, {})
  356. prev = bucket.get(topic_s)
  357. if prev is None or s > prev:
  358. bucket[topic_s] = s
  359. return out
  360. def _platform_node_belonging_dim_from_anchor_nodes(
  361. anchor_node_names: list[str],
  362. ) -> dict[str, str]:
  363. """
  364. 计算平台库人设树中:节点名 -> 所属最深 derived_dim 锚点节点名。
  365. 逻辑与账号段 _descendant_names_under_tree_nodes 保持一致(但树结构来自 xiaohongshu/tree)。
  366. """
  367. if not anchor_node_names:
  368. return {}
  369. S = set(anchor_node_names)
  370. dim_map: dict[str, str] = {}
  371. tree_dir = _platform_tree_dir()
  372. if not tree_dir.is_dir():
  373. return {}
  374. for dim_root_name, root in load_persona_trees_from_dir(tree_dir):
  375. def dfs(node_name: str, node_dict: dict, parent_deepest_s: Optional[str]) -> None:
  376. d_self = node_name if node_name in S else parent_deepest_s
  377. for cname, cnode in (node_dict.get("children") or {}).items():
  378. if not isinstance(cnode, dict):
  379. continue
  380. if cname not in S and d_self is not None:
  381. dim_map[cname] = d_self
  382. dfs(cname, cnode, d_self)
  383. dfs(dim_root_name, root, None)
  384. return dim_map
  385. def _load_platform_nodes_split(
  386. post_id: str,
  387. derived_list: list[tuple[str, str]],
  388. conditional_ratio_threshold: float,
  389. match_score_threshold: float,
  390. top_n: int,
  391. node_belonging_dim_platform: Optional[dict[str, str]] = None,
  392. ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
  393. """
  394. 平台库人设树:用 _collect_platform_scored_tuples 得到条件概率达标的节点,
  395. 再按 xiaohongshu/match_data 分为「有帖子选题点匹配 / 无匹配」两类,**两类各自按条件概率取 Top 池**(同一全局 TopN 不会挤掉另一类),
  396. 最后分别组装返回:
  397. - matched:有 match_score >= match_score_threshold 的帖子选题点匹配的节点
  398. - unmatched:无达标帖子选题点匹配的节点
  399. 两组均要求节点在 node_belonging_dim_platform 中有有效的所属维度(不为「—」)。
  400. """
  401. matched: list[dict[str, Any]] = []
  402. unmatched: list[dict[str, Any]] = []
  403. topic_map: dict[tuple[str, str], dict[str, float]] = {}
  404. if post_id:
  405. topic_map = _platform_match_topics_by_node(post_id, float(match_score_threshold))
  406. # 维度标签可能与树侧不完全一致:保留一个按节点名聚合的兜底索引,避免误判为“无匹配”。
  407. topic_map_by_name: dict[str, dict[str, float]] = {}
  408. for (_dim, n), topics in topic_map.items():
  409. bucket = topic_map_by_name.setdefault(str(n).strip(), {})
  410. for t, sc in (topics or {}).items():
  411. prev = bucket.get(t)
  412. if prev is None or sc > prev:
  413. bucket[t] = sc
  414. # 有 match_data 命中与无命中两类分开按条件概率取 Top,避免混在一个全局 TopN 里挤掉某一类。
  415. all_scored = _collect_platform_scored_tuples(
  416. derived_list,
  417. float(conditional_ratio_threshold),
  418. )
  419. if not all_scored:
  420. return matched, unmatched
  421. matched_tuples: list[tuple[str, float, str, str]] = []
  422. unmatched_tuples: list[tuple[str, float, str, str]] = []
  423. for name, ratio, parent, dim in all_scored:
  424. lookup_dim = str(dim).strip()
  425. key = (lookup_dim, str(name).strip())
  426. topics = topic_map.get(key) or topic_map_by_name.get(str(name).strip()) or {}
  427. if topics:
  428. matched_tuples.append((name, ratio, parent, dim))
  429. else:
  430. unmatched_tuples.append((name, ratio, parent, dim))
  431. _pool = max(int(top_n), min(2000, max(500, int(top_n) * 5)))
  432. matched_tuples = matched_tuples[:_pool]
  433. unmatched_tuples = unmatched_tuples[:_pool]
  434. def _emit_tuple_rows(
  435. tuples: list[tuple[str, float, str, str]],
  436. *,
  437. has_topics: bool,
  438. ) -> None:
  439. for name, ratio, parent, dim in tuples:
  440. row = {
  441. "节点名称": name,
  442. "条件概率": ratio,
  443. "父节点名称": parent,
  444. "所属维度": dim,
  445. }
  446. name_s = str(row.get("节点名称") or "").strip()
  447. out_dim = "—"
  448. if node_belonging_dim_platform is not None:
  449. out_dim = node_belonging_dim_platform.get(name_s) or "—"
  450. if node_belonging_dim_platform is not None and out_dim == "—":
  451. continue
  452. row_out = dict(row)
  453. row_out["所属维度"] = out_dim
  454. lookup_dim = str(row.get("所属维度") or "").strip()
  455. key2 = (lookup_dim, name_s)
  456. topics = topic_map.get(key2) or topic_map_by_name.get(name_s) or {}
  457. if has_topics:
  458. if not topics:
  459. continue
  460. topic_items = sorted(topics.items(), key=lambda x: x[1], reverse=True)
  461. row_out["帖子选题点匹配"] = [{"帖子选题点": t, "匹配分数": sc} for t, sc in topic_items]
  462. matched.append(row_out)
  463. else:
  464. if topics:
  465. continue
  466. row_out["帖子选题点匹配"] = "无"
  467. unmatched.append(row_out)
  468. _emit_tuple_rows(matched_tuples, has_topics=True)
  469. _emit_tuple_rows(unmatched_tuples, has_topics=False)
  470. return matched, unmatched
  471. def build_platform_tree_section_items_split(
  472. post_id: str,
  473. derived_list: list[tuple[str, str]],
  474. conditional_ratio_threshold: float,
  475. match_score_threshold: float,
  476. top_n: int,
  477. exclude_node_names: set[str],
  478. node_belonging_dim_platform: Optional[dict[str, str]] = None,
  479. ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
  480. """
  481. 平台库人设树节点:条件概率 + xiaohongshu/match_data 匹配,排除与账号段重复的节点名称,
  482. 返回 (有帖子选题点匹配的节点列表, 无帖子选题点匹配的节点列表)。
  483. 供 find_tree_nodes_by_conditional_ratio 聚合输出使用。
  484. """
  485. if not post_id:
  486. return [], []
  487. ex = {str(n).strip() for n in exclude_node_names}
  488. matched, unmatched = _load_platform_nodes_split(
  489. post_id=post_id,
  490. derived_list=derived_list,
  491. conditional_ratio_threshold=float(conditional_ratio_threshold),
  492. match_score_threshold=float(match_score_threshold),
  493. top_n=int(top_n),
  494. node_belonging_dim_platform=node_belonging_dim_platform,
  495. )
  496. matched_filtered = [p for p in matched if str(p.get("节点名称", "")).strip() not in ex]
  497. unmatched_filtered = [p for p in unmatched if str(p.get("节点名称", "")).strip() not in ex]
  498. return matched_filtered, unmatched_filtered
  499. # ---------------------------------------------------------------------------
  500. # Agent Tools(参考 glob_tool 封装)
  501. # ---------------------------------------------------------------------------
  502. @tool()
  503. async def find_tree_constant_nodes(
  504. account_name: str,
  505. post_id: str,
  506. ) -> ToolResult:
  507. """
  508. 获取人设树中的常量节点列表(全局常量与局部常量)。
  509. Args:
  510. account_name : 账号名,用于定位该账号的人设树数据。
  511. post_id : 帖子ID(保留参数,当前版本暂不使用)。
  512. Returns:
  513. ToolResult:
  514. - title: 结果标题。
  515. - output: 可读的节点列表文本(每行:节点名称、概率、常量类型)。
  516. - 出错时 error 为错误信息。
  517. """
  518. tree_dir = _tree_dir(account_name)
  519. if not tree_dir.is_dir():
  520. return ToolResult(
  521. title="人设树目录不存在",
  522. output=f"目录不存在: {tree_dir}",
  523. error="Directory not found",
  524. )
  525. try:
  526. items = get_constant_nodes(account_name)
  527. if not items:
  528. output = "未找到常量节点"
  529. else:
  530. lines = [f"- {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}" for x in items]
  531. output = "\n".join(lines)
  532. return ToolResult(
  533. title=f"常量节点 ({account_name})",
  534. output=output,
  535. metadata={"account_name": account_name, "count": len(items)},
  536. )
  537. except Exception as e:
  538. return ToolResult(
  539. title="获取常量节点失败",
  540. output=str(e),
  541. error=str(e),
  542. )
  543. @tool()
  544. async def find_tree_nodes_by_conditional_ratio(
  545. account_name: str,
  546. post_id: str,
  547. derived_items: list[dict[str, str]],
  548. conditional_ratio_threshold: float,
  549. top_n: int = 100,
  550. round: int = 1,
  551. log_id: str = "",
  552. match_score_threshold: float = 0.7,
  553. ) -> ToolResult:
  554. """
  555. 按条件概率阈值筛选节点,第一节为账号人设树节点(优先使用),第二节为平台库人设树节点。
  556. Args:
  557. account_name : 账号名,用于定位该账号的人设树数据。
  558. post_id : 帖子ID,用于定位推导日志目录(维度分析文件)。
  559. derived_items : 已推导选题点列表,可为空。非空时每项为字典,需含 topic(或「已推导的选题点」)与 source_node(或「推导来源人设树节点」)
  560. conditional_ratio_threshold : 条件概率阈值,仅返回条件概率 >= 该值的节点。
  561. top_n : 最终返回总条数上限,按 账号60%/平台40% 分配。
  562. round : 推导轮次。
  563. log_id : 推导日志ID
  564. match_score_threshold : 帖子选题点匹配分阈值(保留参数,当前版本暂不使用)。
  565. Returns:
  566. ToolResult:
  567. - title: 结果标题。
  568. - output: 两段文本——先账号人设树,后平台库人设树;
  569. 平台侧条件概率基于 input/xiaohongshu/tree。
  570. - 出错时 error 为错误信息。
  571. """
  572. tree_dir = _tree_dir(account_name)
  573. if not tree_dir.is_dir():
  574. return ToolResult(
  575. title="人设树目录不存在",
  576. output=f"目录不存在: {tree_dir}",
  577. error="Directory not found",
  578. )
  579. try:
  580. derived_list = _parse_derived_list(derived_items or [])
  581. allowed: Optional[set[str]] = None
  582. node_belonging_dim: dict[str, str] = {}
  583. node_belonging_dim_platform: Optional[dict[str, str]] = None
  584. dim_source = ""
  585. derived_dim_names: list[str] = []
  586. derived_items_len = len(derived_items or [])
  587. if log_id and str(log_id).strip():
  588. derived_dim_names = _load_derived_dim_tree_node_names(
  589. account_name, post_id, str(log_id).strip(), int(round)
  590. )
  591. if derived_dim_names:
  592. allowed, node_belonging_dim = _descendant_names_under_tree_nodes(
  593. account_name, derived_dim_names
  594. )
  595. node_belonging_dim_platform = _platform_node_belonging_dim_from_anchor_nodes(
  596. derived_dim_names
  597. )
  598. # 记录实际用到的维度分析文件(与读取逻辑一致)
  599. log_dir = _dimension_analysis_log_dir(account_name, post_id, str(log_id).strip())
  600. for r in (int(round), int(round) - 1):
  601. if r >= 1 and (log_dir / f"{r}_维度分析.json").is_file():
  602. dim_source = f"{r}_维度分析.json (derived_dims -> 全部后代)"
  603. break
  604. else:
  605. dim_source = "未读到 derived_dims(无对应维度分析文件或为空),未收窄"
  606. # 当 derived_items 太多时,用 derived_dim_names 作为条件概率计算锚点:
  607. # 将每个 derived_dim_names 的 name 都映射为 (topic=name, source_node=name)。
  608. if derived_items_len > 15 and derived_dim_names:
  609. derived_list = [(n, n) for n in derived_dim_names]
  610. # 1)账号人设树:按条件概率筛选
  611. items = get_nodes_by_conditional_ratio(
  612. account_name,
  613. derived_list,
  614. conditional_ratio_threshold,
  615. top_n,
  616. allowed_node_names=allowed,
  617. node_belonging_dim=node_belonging_dim if node_belonging_dim else None,
  618. )
  619. # 账号配额:占 top_n 的 60%
  620. account_quota = int(top_n * 0.6 + 0.5)
  621. items = items[:account_quota]
  622. # 2)平台库人设树:按条件概率筛选,排除账号段已有节点名
  623. platform_quota = top_n - account_quota
  624. account_all_names = {str(x.get("节点名称", "")).strip() for x in items}
  625. platform_items: list[dict[str, Any]] = []
  626. all_platform_scored = _collect_platform_scored_tuples(
  627. derived_list, float(conditional_ratio_threshold)
  628. )
  629. for name, ratio, parent, dim in all_platform_scored:
  630. if str(name).strip() in account_all_names:
  631. continue
  632. out_dim = "—"
  633. if node_belonging_dim_platform is not None:
  634. out_dim = node_belonging_dim_platform.get(str(name).strip()) or "—"
  635. if node_belonging_dim_platform is not None and out_dim == "—":
  636. continue
  637. platform_items.append({
  638. "节点名称": name,
  639. "条件概率": ratio,
  640. "父节点名称": parent,
  641. "所属维度": out_dim,
  642. })
  643. if len(platform_items) >= platform_quota:
  644. break
  645. def _format_node_line(x: dict[str, Any]) -> str:
  646. dim_label = x.get("所属维度", "—")
  647. return f"- {x['节点名称']}\t条件概率={x['条件概率']}\t所属维度={dim_label}"
  648. lines: list[str] = []
  649. lines.append(
  650. "【优先使用】第一节为账号人设树中条件概率达标的节点;"
  651. "第二节为平台库人设树中条件概率达标的节点;"
  652. )
  653. lines.append("")
  654. lines.append("—— 账号人设树节点 ——")
  655. if not items:
  656. lines.append(f"(无:未找到条件概率 >= {conditional_ratio_threshold} 的节点)")
  657. else:
  658. lines.extend(_format_node_line(x) for x in items)
  659. lines.append("")
  660. lines.append("—— 平台库人设树节点 ——")
  661. if not platform_items:
  662. lines.append("(无:未找到条件概率达标的节点)")
  663. else:
  664. lines.extend(_format_node_line(x) for x in platform_items)
  665. output = "\n".join(lines)
  666. return ToolResult(
  667. title=f"条件概率节点 ({account_name}, 阈值={conditional_ratio_threshold})",
  668. output=output,
  669. metadata={
  670. "account_name": account_name,
  671. "threshold": conditional_ratio_threshold,
  672. "top_n": top_n,
  673. "quota": {
  674. "account_quota": account_quota,
  675. "platform_quota": platform_quota,
  676. },
  677. "account_tree_count": len(items),
  678. "platform_tree_count": len(platform_items),
  679. "count": len(items) + len(platform_items),
  680. "round": int(round),
  681. "log_id": str(log_id).strip() if log_id else "",
  682. "dimension_filter": {
  683. "derived_dim_nodes": derived_dim_names,
  684. "allowed_descendant_count": len(allowed) if allowed is not None else None,
  685. "source": dim_source or ("未提供 log_id,未按维度收窄" if not (log_id and str(log_id).strip()) else ""),
  686. },
  687. },
  688. )
  689. except Exception as e:
  690. return ToolResult(
  691. title="按条件概率查询节点失败",
  692. output=str(e),
  693. error=str(e),
  694. )
  695. def main() -> None:
  696. """本地测试:用家有大志账号测常量节点与条件概率节点,有 agent 时再跑一遍 tool 接口。"""
  697. import asyncio
  698. account_name = "家有大志"
  699. post_id = "68fb6a5c000000000302e5de"
  700. log_id = "20260324141307"
  701. round = 4
  702. # derived_items = [
  703. # {"topic": "分享", "source_node": "分享"},
  704. # {"topic": "叙事结构", "source_node": "叙事结构"},
  705. # ]
  706. derived_items = [{"topic":"推广","source_node":"推广"},{"topic":"视觉调性","source_node":"视觉调性"}]
  707. conditional_ratio_threshold = 0.2
  708. top_n = 100
  709. # # 1)常量节点(核心函数,无匹配)
  710. # constant_nodes = get_constant_nodes(account_name)
  711. # print(f"账号: {account_name} — 常量节点共 {len(constant_nodes)} 个(前 50 个):")
  712. # for x in constant_nodes[:50]:
  713. # print(f" - {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}")
  714. # print()
  715. #
  716. # # 2)条件概率节点(核心函数)
  717. # derived_list = _parse_derived_list(derived_items)
  718. # ratio_nodes = get_nodes_by_conditional_ratio(
  719. # account_name, derived_list, conditional_ratio_threshold, top_n
  720. # )
  721. # print(f"条件概率节点 阈值={conditional_ratio_threshold}, top_n={top_n}, 共 {len(ratio_nodes)} 个:")
  722. # for x in ratio_nodes:
  723. # print(f" - {x['节点名称']}\t条件概率={x['条件概率']}\t父节点={x['父节点名称']}")
  724. # print()
  725. # 3)有 agent 时通过 tool 接口再跑一遍
  726. if ToolResult is not None:
  727. async def run_tools():
  728. r1 = await find_tree_constant_nodes(account_name, post_id=post_id)
  729. print("--- find_tree_constant_nodes ---")
  730. print(r1.output[:2000] + "..." if len(r1.output) > 2000 else r1.output)
  731. r2 = await find_tree_nodes_by_conditional_ratio(
  732. account_name,
  733. post_id=post_id,
  734. derived_items=derived_items,
  735. conditional_ratio_threshold=conditional_ratio_threshold,
  736. top_n=top_n,
  737. round=round,
  738. log_id=log_id,
  739. )
  740. print("\n--- find_tree_nodes_by_conditional_ratio ---")
  741. print(r2.output)
  742. asyncio.run(run_tools())
  743. if __name__ == "__main__":
  744. main()