""" 查找树节点 Tool - 人设树节点查询 功能: 1. 获取人设树的常量节点(全局常量、局部常量) 2. 获取符合条件概率阈值的节点(按条件概率排序返回 topN) """ import json import sys from pathlib import Path from typing import Any, Optional # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转) _root = Path(__file__).resolve().parent.parent if str(_root) not in sys.path: sys.path.insert(0, str(_root)) from utils.conditional_ratio_calc import calc_node_conditional_ratio from tools.point_match import match_derivation_to_post_points try: from agent.tools import tool, ToolResult, ToolContext except ImportError: def tool(*args, **kwargs): return lambda f: f ToolResult = None # 仅用 main() 测核心逻辑时可无 agent ToolContext = None # 相对本文件:tools -> overall_derivation,input 在 overall_derivation 下 _BASE_INPUT = Path(__file__).resolve().parent.parent / "input" def _tree_dir(account_name: str) -> Path: """人设树目录:../input/{account_name}/原始数据/tree/""" return _BASE_INPUT / account_name / "原始数据" / "tree" def _load_trees(account_name: str) -> list[tuple[str, dict]]: """加载该账号下所有维度的人设树。返回 [(维度名, 根节点 dict), ...]。""" td = _tree_dir(account_name) if not td.is_dir(): return [] result = [] for p in td.glob("*.json"): try: with open(p, "r", encoding="utf-8") as f: data = json.load(f) for dim_name, root in data.items(): if isinstance(root, dict): result.append((dim_name, root)) break except Exception: continue return result def _iter_all_nodes(account_name: str): """遍历该账号下所有人设树节点,产出 (节点名称, 父节点名称, 节点 dict)。""" for dim_name, root in _load_trees(account_name): def walk(parent_name: str, node_dict: dict): for name, child in (node_dict.get("children") or {}).items(): if not isinstance(child, dict): continue yield (name, parent_name, child) yield from walk(name, child) yield from walk(dim_name, root) # --------------------------------------------------------------------------- # 1. 获取人设树常量节点 # --------------------------------------------------------------------------- def get_constant_nodes(account_name: str) -> list[dict[str, Any]]: """ 获取人设树的常量节点。 - 全局常量:_is_constant=True - 局部常量:_is_local_constant=True 且 _is_constant=False 返回列表项:节点名称、概率(_ratio)、常量类型。 """ result = [] for node_name, _parent, node in _iter_all_nodes(account_name): is_const = node.get("_is_constant") is True is_local = node.get("_is_local_constant") is True if is_const: const_type = "全局常量" elif is_local and not is_const: const_type = "局部常量" else: continue ratio = node.get("_ratio") result.append({ "节点名称": node_name, "概率": ratio, "常量类型": const_type, }) result.sort(key=lambda x: (x["概率"] is None, -(x["概率"] or 0))) return result # --------------------------------------------------------------------------- # 2. 获取符合条件概率阈值的节点 # --------------------------------------------------------------------------- def get_nodes_by_conditional_ratio( account_name: str, derived_list: list[tuple[str, str]], threshold: float, top_n: int, ) -> list[dict[str, Any]]: """ 获取人设树中条件概率 >= threshold 的节点,按条件概率降序,返回前 top_n 个。 derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点);为空时使用节点自身的 _ratio 作为条件概率。 返回列表项:节点名称、条件概率、父节点名称。 """ base_dir = _BASE_INPUT scored: list[tuple[str, float, str]] = [] if not derived_list: # derived_items 为空:条件概率取节点本身的 _ratio for node_name, parent_name, node in _iter_all_nodes(account_name): ratio = node.get("_ratio") if ratio is None: ratio = 0.0 else: ratio = float(ratio) if ratio >= threshold: scored.append((node_name, ratio, parent_name)) else: node_to_parent: dict[str, str] = {} for node_name, parent_name, _ in _iter_all_nodes(account_name): node_to_parent[node_name] = parent_name for node_name, parent_name in node_to_parent.items(): ratio = calc_node_conditional_ratio( account_name, derived_list, node_name, base_dir=base_dir ) if ratio >= threshold: scored.append((node_name, ratio, parent_name)) scored.sort(key=lambda x: x[1], reverse=True) top = scored[:top_n] return [ {"节点名称": name, "条件概率": ratio, "父节点名称": parent} for name, ratio, parent in top ] def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]: """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。""" out = [] for item in derived_items: if isinstance(item, dict): topic = item.get("topic") or item.get("已推导的选题点") source = item.get("source_node") or item.get("推导来源人设树节点") if topic is not None and source is not None: out.append((str(topic).strip(), str(source).strip())) elif isinstance(item, (list, tuple)) and len(item) >= 2: out.append((str(item[0]).strip(), str(item[1]).strip())) return out # --------------------------------------------------------------------------- # Agent Tools(参考 glob_tool 封装) # --------------------------------------------------------------------------- @tool( description="获取指定账号人设树中的常量节点(全局常量、局部常量),并检查每个节点与帖子选题点的匹配情况。" "功能:根据账号名查询该账号人设树中所有常量节点,同时对每个节点判断是否匹配帖子选题点,匹配结果直接包含在返回数据中。" "参数:account_name 为账号名;post_id 为帖子ID,用于加载帖子选题点并做匹配判断。" "返回:ToolResult,output 为可读的节点列表文本,metadata.items 为列表,每项含「节点名称」「概率」「常量类型」「帖子选题点匹配」=无/匹配结果(无匹配时为「无」,有匹配时为匹配列表,每项含帖子选题点与匹配分数)。" ) async def find_tree_constant_nodes( account_name: str, post_id: str, context: Optional[ToolContext] = None, ) -> ToolResult: """ 获取人设树中的常量节点列表(全局常量与局部常量),并检查每个节点与帖子选题点的匹配情况。 参数 ------- account_name : 账号名,用于定位该账号的人设树数据。 post_id : 帖子ID,用于加载帖子选题点并与各常量节点做匹配判断。 context : 可选,Agent 工具上下文。 返回 ------- ToolResult: - title: 结果标题。 - output: 可读的节点列表文本(每行:节点名称、概率、常量类型、帖子匹配情况)。 - metadata: 含 account_name、count、items;items 为列表,每项为 {"节点名称": str, "概率": 数值或 None, "常量类型": "全局常量"|"局部常量", "帖子选题点匹配": 无匹配时为 "无",有匹配时为 list[{"帖子选题点": str, "匹配分数": float}]}。 - 出错时 error 为错误信息。 """ tree_dir = _tree_dir(account_name) if not tree_dir.is_dir(): return ToolResult( title="人设树目录不存在", output=f"目录不存在: {tree_dir}", error="Directory not found", ) try: items = get_constant_nodes(account_name) # 批量匹配所有节点与帖子选题点 if items and post_id: node_names = [x["节点名称"] for x in items] matched_results = await match_derivation_to_post_points(node_names, account_name, post_id) node_match_map: dict[str, list] = {} for m in matched_results: node_match_map.setdefault(m["推导选题点"], []).append({ "帖子选题点": m["帖子选题点"], "匹配分数": m["匹配分数"], }) for item in items: matches = node_match_map.get(item["节点名称"], []) item["帖子选题点匹配"] = matches if matches else "无" if not items: output = "未找到常量节点" else: lines = [] for x in items: match_info = x.get("帖子选题点匹配", "无") if isinstance(match_info, list): match_str = "、".join(f"{m['帖子选题点']}({m['匹配分数']})" for m in match_info) else: match_str = str(match_info) lines.append(f"- {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}\t帖子选题点匹配={match_str}") output = "\n".join(lines) return ToolResult( title=f"常量节点 ({account_name})", output=output, metadata={"account_name": account_name, "count": len(items)}, ) except Exception as e: return ToolResult( title="获取常量节点失败", output=str(e), error=str(e), ) @tool( description="按条件概率从人设树中筛选节点,返回达到阈值且按条件概率排序的前 topN 条,并检查每个节点与帖子选题点的匹配情况。" "功能:根据账号与已推导选题点(可选),筛选人设树中条件概率不低于阈值的节点,同时对每个节点判断是否匹配帖子选题点,匹配结果直接包含在返回数据中。" "参数:account_name 为账号名;post_id 为帖子ID,用于加载帖子选题点并做匹配判断;derived_items 为已推导选题点列表,每项含 topic(或已推导的选题点)与 source_node(或推导来源人设树节点),可为空,为空时条件概率使用节点自身的 _ratio;conditional_ratio_threshold 为条件概率阈值;top_n 为返回条数上限,默认 100。" "返回:ToolResult,output 为可读的节点列表文本,metadata.items 为列表,每项含「节点名称」「条件概率」「父节点名称」「帖子选题点匹配」=无/匹配结果(无匹配时为「无」,有匹配时为匹配列表,每项含帖子选题点与匹配分数)。" ) async def find_tree_nodes_by_conditional_ratio( account_name: str, post_id: str, derived_items: list[dict[str, str]], conditional_ratio_threshold: float, top_n: int = 100, context: Optional[ToolContext] = None, ) -> ToolResult: """ 按条件概率阈值从人设树筛选节点,返回最多 top_n 条(按条件概率降序),并检查每个节点与帖子选题点的匹配情况。 参数 ------- account_name : 账号名,用于定位该账号的人设树数据。 post_id : 帖子ID,用于加载帖子选题点并与各节点做匹配判断。 derived_items : 已推导选题点列表,可为空。非空时每项为字典,需含 topic(或「已推导的选题点」)与 source_node(或「推导来源人设树节点」);为空时各节点的条件概率取其自身 _ratio。 conditional_ratio_threshold : 条件概率阈值,仅返回条件概率 >= 该值的节点。 top_n : 返回条数上限,默认 100。 context : 可选,Agent 工具上下文。 返回 ------- ToolResult: - title: 结果标题。 - output: 可读的节点列表文本(每行:节点名称、条件概率、父节点名称、帖子匹配情况)。 - metadata: 含 account_name、threshold、top_n、count、items; items 为列表,每项为 {"节点名称": str, "条件概率": float, "父节点名称": str, "帖子选题点匹配": 无匹配时为 "无",有匹配时为 list[{"帖子选题点": str, "匹配分数": float}]}。 - 出错时 error 为错误信息。 """ tree_dir = _tree_dir(account_name) if not tree_dir.is_dir(): return ToolResult( title="人设树目录不存在", output=f"目录不存在: {tree_dir}", error="Directory not found", ) try: derived_list = _parse_derived_list(derived_items or []) items = get_nodes_by_conditional_ratio( account_name, derived_list, conditional_ratio_threshold, top_n ) # 批量匹配所有节点与帖子选题点 if items and post_id: node_names = [x["节点名称"] for x in items] matched_results = await match_derivation_to_post_points(node_names, account_name, post_id) node_match_map: dict[str, list] = {} for m in matched_results: node_match_map.setdefault(m["推导选题点"], []).append({ "帖子选题点": m["帖子选题点"], "匹配分数": m["匹配分数"], }) for item in items: matches = node_match_map.get(item["节点名称"], []) item["帖子选题点匹配"] = matches if matches else "无" # [临时] 仅保留有帖子选题点匹配的记录(过滤掉「无」),方便后续删除 items = [x for x in items if isinstance(x.get("帖子选题点匹配"), list)] if not items: output = f"未找到条件概率 >= {conditional_ratio_threshold} 的节点" else: lines = [] for x in items: match_info = x.get("帖子选题点匹配", "无") if isinstance(match_info, list): match_str = "、".join(f"{m['帖子选题点']}({m['匹配分数']})" for m in match_info) else: match_str = str(match_info) lines.append( f"- {x['节点名称']}\t条件概率={x['条件概率']}\t父节点={x['父节点名称']}\t帖子选题点匹配={match_str}" ) output = "\n".join(lines) return ToolResult( title=f"条件概率节点 ({account_name}, 阈值={conditional_ratio_threshold})", output=output, metadata={ "account_name": account_name, "threshold": conditional_ratio_threshold, "top_n": top_n, "count": len(items), }, ) except Exception as e: return ToolResult( title="按条件概率查询节点失败", output=str(e), error=str(e), ) def main() -> None: """本地测试:用家有大志账号测常量节点与条件概率节点,有 agent 时再跑一遍 tool 接口。""" import asyncio account_name = "家有大志" post_id = "68fb6a5c000000000302e5de" # derived_items = [ # {"topic": "分享", "source_node": "分享"}, # {"topic": "叙事结构", "source_node": "叙事结构"}, # ] derived_items = [{"source_node":"分享","topic":"分享"},{"source_node":"叙事结构","topic":"叙事结构"},{"source_node":"图片文字","topic":"图片文字"},{"source_node":"补充说明式","topic":"补充说明式"},{"source_node":"幽默化标题","topic":"幽默化标题"},{"source_node":"标题","topic":"标题"}] conditional_ratio_threshold = 0.01 top_n = 1000 # # 1)常量节点(核心函数,无匹配) # constant_nodes = get_constant_nodes(account_name) # print(f"账号: {account_name} — 常量节点共 {len(constant_nodes)} 个(前 50 个):") # for x in constant_nodes[:50]: # print(f" - {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}") # print() # # # 2)条件概率节点(核心函数) # derived_list = _parse_derived_list(derived_items) # ratio_nodes = get_nodes_by_conditional_ratio( # account_name, derived_list, conditional_ratio_threshold, top_n # ) # print(f"条件概率节点 阈值={conditional_ratio_threshold}, top_n={top_n}, 共 {len(ratio_nodes)} 个:") # for x in ratio_nodes: # print(f" - {x['节点名称']}\t条件概率={x['条件概率']}\t父节点={x['父节点名称']}") # print() # 3)有 agent 时通过 tool 接口再跑一遍(含帖子选题点匹配) if ToolResult is not None: async def run_tools(): r1 = await find_tree_constant_nodes(account_name, post_id=post_id) print("--- find_tree_constant_nodes ---") print(r1.output[:2000] + "..." if len(r1.output) > 2000 else r1.output) r2 = await find_tree_nodes_by_conditional_ratio( account_name, post_id=post_id, derived_items=derived_items, conditional_ratio_threshold=conditional_ratio_threshold, top_n=top_n, ) print("\n--- find_tree_nodes_by_conditional_ratio ---") print(r2.output) asyncio.run(run_tools()) if __name__ == "__main__": main()