find_tree_node.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. """
  2. 查找树节点 Tool - 人设树节点查询
  3. 功能:
  4. 1. 获取人设树的常量节点(全局常量、局部常量)
  5. 2. 获取符合条件概率阈值的节点(按条件概率排序返回 topN)
  6. """
  7. import importlib.util
  8. import json
  9. from pathlib import Path
  10. from typing import Any, Optional
  11. try:
  12. from agent.tools import tool, ToolResult, ToolContext
  13. except ImportError:
  14. def tool(*args, **kwargs):
  15. return lambda f: f
  16. ToolResult = None # 仅用 main() 测核心逻辑时可无 agent
  17. ToolContext = None
  18. # 加载同目录层级的 conditional_ratio_calc(不依赖包结构)
  19. _utils_dir = Path(__file__).resolve().parent.parent / "utils"
  20. _cond_spec = importlib.util.spec_from_file_location(
  21. "conditional_ratio_calc",
  22. _utils_dir / "conditional_ratio_calc.py",
  23. )
  24. _cond_mod = importlib.util.module_from_spec(_cond_spec)
  25. _cond_spec.loader.exec_module(_cond_mod)
  26. calc_node_conditional_ratio = _cond_mod.calc_node_conditional_ratio
  27. # 相对本文件:tools -> overall_derivation,input 在 overall_derivation 下
  28. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  29. def _tree_dir(account_name: str) -> Path:
  30. """人设树目录:../input/{account_name}/原始数据/tree/"""
  31. return _BASE_INPUT / account_name / "原始数据" / "tree"
  32. def _load_trees(account_name: str) -> list[tuple[str, dict]]:
  33. """加载该账号下所有维度的人设树。返回 [(维度名, 根节点 dict), ...]。"""
  34. td = _tree_dir(account_name)
  35. if not td.is_dir():
  36. return []
  37. result = []
  38. for p in td.glob("*.json"):
  39. try:
  40. with open(p, "r", encoding="utf-8") as f:
  41. data = json.load(f)
  42. for dim_name, root in data.items():
  43. if isinstance(root, dict):
  44. result.append((dim_name, root))
  45. break
  46. except Exception:
  47. continue
  48. return result
  49. def _iter_all_nodes(account_name: str):
  50. """遍历该账号下所有人设树节点,产出 (节点名称, 父节点名称, 节点 dict)。"""
  51. for dim_name, root in _load_trees(account_name):
  52. def walk(parent_name: str, node_dict: dict):
  53. for name, child in (node_dict.get("children") or {}).items():
  54. if not isinstance(child, dict):
  55. continue
  56. yield (name, parent_name, child)
  57. yield from walk(name, child)
  58. yield from walk(dim_name, root)
  59. # ---------------------------------------------------------------------------
  60. # 1. 获取人设树常量节点
  61. # ---------------------------------------------------------------------------
  62. def get_constant_nodes(account_name: str) -> list[dict[str, Any]]:
  63. """
  64. 获取人设树的常量节点。
  65. - 全局常量:_is_constant=True
  66. - 局部常量:_is_local_constant=True 且 _is_constant=False
  67. 返回列表项:节点名称、概率(_ratio)、常量类型。
  68. """
  69. result = []
  70. for node_name, _parent, node in _iter_all_nodes(account_name):
  71. is_const = node.get("_is_constant") is True
  72. is_local = node.get("_is_local_constant") is True
  73. if is_const:
  74. const_type = "全局常量"
  75. elif is_local and not is_const:
  76. const_type = "局部常量"
  77. else:
  78. continue
  79. ratio = node.get("_ratio")
  80. result.append({
  81. "节点名称": node_name,
  82. "概率": ratio,
  83. "常量类型": const_type,
  84. })
  85. result.sort(key=lambda x: (x["概率"] is None, -(x["概率"] or 0)))
  86. return result
  87. # ---------------------------------------------------------------------------
  88. # 2. 获取符合条件概率阈值的节点
  89. # ---------------------------------------------------------------------------
  90. def get_nodes_by_conditional_ratio(
  91. account_name: str,
  92. derived_list: list[tuple[str, str]],
  93. threshold: float,
  94. top_n: int,
  95. ) -> list[dict[str, Any]]:
  96. """
  97. 获取人设树中条件概率 >= threshold 的节点,按条件概率降序,返回前 top_n 个。
  98. derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点);为空时使用节点自身的 _ratio 作为条件概率。
  99. 返回列表项:节点名称、条件概率、父节点名称。
  100. """
  101. base_dir = _BASE_INPUT
  102. scored: list[tuple[str, float, str]] = []
  103. if not derived_list:
  104. # derived_items 为空:条件概率取节点本身的 _ratio
  105. for node_name, parent_name, node in _iter_all_nodes(account_name):
  106. ratio = node.get("_ratio")
  107. if ratio is None:
  108. ratio = 0.0
  109. else:
  110. ratio = float(ratio)
  111. if ratio >= threshold:
  112. scored.append((node_name, ratio, parent_name))
  113. else:
  114. node_to_parent: dict[str, str] = {}
  115. for node_name, parent_name, _ in _iter_all_nodes(account_name):
  116. node_to_parent[node_name] = parent_name
  117. for node_name, parent_name in node_to_parent.items():
  118. ratio = calc_node_conditional_ratio(
  119. account_name, derived_list, node_name, base_dir=base_dir
  120. )
  121. if ratio >= threshold:
  122. scored.append((node_name, ratio, parent_name))
  123. scored.sort(key=lambda x: x[1], reverse=True)
  124. top = scored[:top_n]
  125. return [
  126. {"节点名称": name, "条件概率": ratio, "父节点名称": parent}
  127. for name, ratio, parent in top
  128. ]
  129. def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]:
  130. """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。"""
  131. out = []
  132. for item in derived_items:
  133. if isinstance(item, dict):
  134. topic = item.get("topic") or item.get("已推导的选题点")
  135. source = item.get("source_node") or item.get("推导来源人设树节点")
  136. if topic is not None and source is not None:
  137. out.append((str(topic).strip(), str(source).strip()))
  138. elif isinstance(item, (list, tuple)) and len(item) >= 2:
  139. out.append((str(item[0]).strip(), str(item[1]).strip()))
  140. return out
  141. # ---------------------------------------------------------------------------
  142. # Agent Tools(参考 glob_tool 封装)
  143. # ---------------------------------------------------------------------------
  144. @tool(description="获取人设树的常量节点(全局常量、局部常量)。输入账号名,返回节点名称、概率、常量类型。")
  145. async def find_tree_constant_nodes(
  146. account_name: str,
  147. context: Optional[ToolContext] = None,
  148. ) -> ToolResult:
  149. """
  150. 获取人设树的常量节点。
  151. 读取该账号 input/{account_name}/原始数据/tree/ 下的人设树 JSON,
  152. 筛选 _is_constant=true(全局常量)或 _is_local_constant=true 且 _is_constant=false(局部常量)的节点,
  153. 返回:节点名称、概率(_ratio)、常量类型。
  154. """
  155. tree_dir = _tree_dir(account_name)
  156. if not tree_dir.is_dir():
  157. return ToolResult(
  158. title="人设树目录不存在",
  159. output=f"目录不存在: {tree_dir}",
  160. error="Directory not found",
  161. )
  162. try:
  163. items = get_constant_nodes(account_name)
  164. if not items:
  165. output = "未找到常量节点"
  166. else:
  167. lines = [f"- {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}" for x in items]
  168. output = "\n".join(lines)
  169. return ToolResult(
  170. title=f"常量节点 ({account_name})",
  171. output=output,
  172. metadata={"account_name": account_name, "count": len(items), "items": items},
  173. )
  174. except Exception as e:
  175. return ToolResult(
  176. title="获取常量节点失败",
  177. output=str(e),
  178. error=str(e),
  179. )
  180. @tool(
  181. description="获取人设树中条件概率不低于阈值的节点,按条件概率从高到低返回 topN。"
  182. "输入:账号名、已推导选题点列表(可为空)、条件概率阈值、topN。"
  183. "derived_items 为空时,条件概率使用节点自身的 _ratio。"
  184. )
  185. async def find_tree_nodes_by_conditional_ratio(
  186. account_name: str,
  187. derived_items: list[dict[str, str]],
  188. conditional_ratio_threshold: float,
  189. top_n: int = 20,
  190. context: Optional[ToolContext] = None,
  191. ) -> ToolResult:
  192. """
  193. 获取人设树中符合条件概率阈值的节点。
  194. derived_items:可为空;非空时每项为 {\"topic\": \"已推导选题点\", \"source_node\": \"推导来源人设树节点\"}。
  195. 当 derived_items 为空时,各节点的条件概率取其自身的 _ratio;非空时按已推导帖子集合计算条件概率。
  196. 返回:节点名称、条件概率、父节点名称,按条件概率降序最多 top_n 条。
  197. """
  198. tree_dir = _tree_dir(account_name)
  199. if not tree_dir.is_dir():
  200. return ToolResult(
  201. title="人设树目录不存在",
  202. output=f"目录不存在: {tree_dir}",
  203. error="Directory not found",
  204. )
  205. try:
  206. derived_list = _parse_derived_list(derived_items or [])
  207. items = get_nodes_by_conditional_ratio(
  208. account_name, derived_list, conditional_ratio_threshold, top_n
  209. )
  210. if not items:
  211. output = f"未找到条件概率 >= {conditional_ratio_threshold} 的节点"
  212. else:
  213. lines = [
  214. f"- {x['节点名称']}\t条件概率={x['条件概率']}\t父节点={x['父节点名称']}"
  215. for x in items
  216. ]
  217. output = "\n".join(lines)
  218. return ToolResult(
  219. title=f"条件概率节点 ({account_name}, 阈值={conditional_ratio_threshold})",
  220. output=output,
  221. metadata={
  222. "account_name": account_name,
  223. "threshold": conditional_ratio_threshold,
  224. "top_n": top_n,
  225. "count": len(items),
  226. "items": items,
  227. },
  228. )
  229. except Exception as e:
  230. return ToolResult(
  231. title="按条件概率查询节点失败",
  232. output=str(e),
  233. error=str(e),
  234. )
  235. def main() -> None:
  236. """本地测试:用家有大志账号测常量节点与条件概率节点,有 agent 时再跑一遍 tool 接口。"""
  237. import asyncio
  238. account_name = "家有大志"
  239. derived_items = [
  240. {"topic": "分享", "source_node": "分享"},
  241. ]
  242. conditional_ratio_threshold = 0.1
  243. top_n = 10
  244. # 1)常量节点
  245. constant_nodes = get_constant_nodes(account_name)
  246. print(f"账号: {account_name} — 常量节点共 {len(constant_nodes)} 个(前 50 个):")
  247. for x in constant_nodes[:50]:
  248. print(f" - {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}")
  249. print()
  250. # 2)条件概率节点(核心函数)
  251. derived_list = _parse_derived_list(derived_items)
  252. ratio_nodes = get_nodes_by_conditional_ratio(
  253. account_name, derived_list, conditional_ratio_threshold, top_n
  254. )
  255. print(f"条件概率节点 阈值={conditional_ratio_threshold}, top_n={top_n}, 共 {len(ratio_nodes)} 个:")
  256. for x in ratio_nodes:
  257. print(f" - {x['节点名称']}\t条件概率={x['条件概率']}\t父节点={x['父节点名称']}")
  258. print()
  259. # 3)有 agent 时通过 tool 接口再跑一遍
  260. if ToolResult is not None:
  261. async def run_tools():
  262. r1 = await find_tree_constant_nodes(account_name)
  263. print("--- find_tree_constant_nodes ---")
  264. print(r1.output[:200] + "..." if len(r1.output) > 200 else r1.output)
  265. r2 = await find_tree_nodes_by_conditional_ratio(
  266. account_name,
  267. derived_items=derived_items,
  268. conditional_ratio_threshold=conditional_ratio_threshold,
  269. top_n=top_n,
  270. )
  271. print("\n--- find_tree_nodes_by_conditional_ratio ---")
  272. print(r2.output)
  273. asyncio.run(run_tools())
  274. if __name__ == "__main__":
  275. main()