""" 查找树节点 Tool - 人设树节点查询 功能: 1. 获取人设树的常量节点(全局常量、局部常量) 2. 获取符合条件概率阈值的节点(按条件概率排序返回 topN) 平台库人设树(第二节输出)流水线(由 build_platform_tree_section_items 聚合): xiaohongshu/tree → 与账号相同的条件概率计算 → xiaohongshu/match_data 按匹配分过滤选题点 → 剔除与账号段同名的节点。 """ import json import sys from pathlib import Path from typing import Any, Optional # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转) _root = Path(__file__).resolve().parent.parent if str(_root) not in sys.path: sys.path.insert(0, str(_root)) from utils.conditional_ratio_calc import ( # noqa: E402 build_node_post_index, build_node_post_index_from_tree_dir, calc_node_conditional_ratio, load_persona_trees_from_dir, ) from tools.point_match import ( # noqa: E402 DEFAULT_MATCH_THRESHOLD, match_derivation_to_post_points, ) try: from agent.tools import tool, ToolResult, ToolContext except ImportError: def tool(*args, **kwargs): return lambda f: f ToolResult = None # 仅用 main() 测核心逻辑时可无 agent ToolContext = None # 相对本文件:tools -> overall_derivation,input / output 在 overall_derivation 下 _BASE_INPUT = Path(__file__).resolve().parent.parent / "input" _BASE_OUTPUT = Path(__file__).resolve().parent.parent / "output" def _dimension_analysis_log_dir(account_name: str, post_id: str, log_id: str) -> Path: """推导日志目录:output/{account_name}/推导日志/{post_id}/{log_id}/""" return _BASE_OUTPUT / account_name / "推导日志" / post_id / log_id def _load_derived_dim_tree_node_names( account_name: str, post_id: str, log_id: str, round: int ) -> list[str]: """ 读取当前轮次对应的维度分析 JSON(优先 {round}_维度分析.json,不存在则 {round-1}_维度分析.json), 返回 derived_dims 中每项的 tree_node_name(已推导出的维度节点,人设树中层次较高)。 无可用文件时返回空列表。 """ if not log_id or not str(log_id).strip(): return [] log_dir = _dimension_analysis_log_dir(account_name, post_id, str(log_id).strip()) for r in (round, round - 1): if r < 1: continue path = log_dir / f"{r}_维度分析.json" if not path.is_file(): continue try: with open(path, "r", encoding="utf-8") as f: data = json.load(f) except Exception: continue dims = data.get("derived_dims") or [] names: list[str] = [] for d in dims: if isinstance(d, dict): tn = d.get("tree_node_name") if tn is not None and str(tn).strip(): names.append(str(tn).strip()) return names return [] def _descendant_names_under_tree_nodes( account_name: str, anchor_node_names: list[str] ) -> tuple[set[str], dict[str, str]]: """ 在每个人设维度树根上 DFS,收集所有锚点(derived_dims.tree_node_name)之下的**全部后代**(不含锚点自身)。 同时记录「所属维度」:对路径上每个后代节点,取从维度根到该节点路径上**最深的**那个锚点 (与原先沿父链向上找最近 derived_dim 一致;多个锚点呈祖孙时取更深者)。 """ if not anchor_node_names: return set(), {} S = set(anchor_node_names) allowed: set[str] = set() dim_map: dict[str, str] = {} for dim_root_name, root in _load_trees(account_name): def dfs(node_name: str, node_dict: dict, parent_deepest_s: Optional[str]) -> None: d_self = node_name if node_name in S else parent_deepest_s for cname, cnode in (node_dict.get("children") or {}).items(): if not isinstance(cnode, dict): continue if cname not in S and d_self is not None: allowed.add(cname) dim_map[cname] = d_self dfs(cname, cnode, d_self) dfs(dim_root_name, root, None) return allowed, dim_map def _tree_dir(account_name: str) -> Path: """人设树目录:../input/{account_name}/处理后数据/tree/""" return _BASE_INPUT / account_name / "处理后数据" / "tree" def _load_trees(account_name: str) -> list[tuple[str, dict]]: """加载该账号下所有维度的人设树。返回 [(维度名, 根节点 dict), ...]。""" td = _tree_dir(account_name) if not td.is_dir(): return [] result = [] for p in td.glob("*.json"): try: with open(p, "r", encoding="utf-8") as f: data = json.load(f) for dim_name, root in data.items(): if isinstance(root, dict): result.append((dim_name, root)) break except Exception: continue return result def _iter_all_nodes(account_name: str): """遍历该账号下所有人设树节点,产出 (节点名称, 父节点名称, 节点 dict)。""" for dim_name, root in _load_trees(account_name): def walk(parent_name: str, node_dict: dict): for name, child in (node_dict.get("children") or {}).items(): if not isinstance(child, dict): continue yield (name, parent_name, child) yield from walk(name, child) yield from walk(dim_name, root) # --------------------------------------------------------------------------- # 1. 获取人设树常量节点 # --------------------------------------------------------------------------- def get_constant_nodes(account_name: str) -> list[dict[str, Any]]: """ 获取人设树的常量节点。 - 全局常量:_is_constant=True - 局部常量:_is_local_constant=True 且 _is_constant=False 返回列表项:节点名称、概率(_ratio)、常量类型。 """ result = [] for node_name, _parent, node in _iter_all_nodes(account_name): is_const = node.get("_is_constant") is True is_local = node.get("_is_local_constant") is True if is_const: const_type = "全局常量" elif is_local and not is_const: const_type = "局部常量" else: continue ratio = node.get("_ratio") result.append({ "节点名称": node_name, "概率": ratio, "常量类型": const_type, }) result.sort(key=lambda x: (x["概率"] is None, -(x["概率"] or 0))) return result # --------------------------------------------------------------------------- # 2. 获取符合条件概率阈值的节点 # --------------------------------------------------------------------------- def get_nodes_by_conditional_ratio( account_name: str, derived_list: list[tuple[str, str]], threshold: float, top_n: int, allowed_node_names: Optional[set[str]] = None, node_belonging_dim: Optional[dict[str, str]] = None, ) -> list[dict[str, Any]]: """ 获取人设树中条件概率 >= threshold 的节点,按条件概率降序,返回前 top_n 个。 derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点);为空时使用节点自身的 _ratio 作为条件概率。 allowed_node_names: 若给定,仅保留节点名称在该集合内的结果。 node_belonging_dim: 与 allowed 同步生成(见 _descendant_names_under_tree_nodes),节点名 -> 所属已推导维度;不传则所属维度均为「—」。 返回列表项:节点名称、条件概率、父节点名称、所属维度。 """ base_dir = _BASE_INPUT node_to_parent: dict[str, str] = {} if derived_list: for n, p, _ in _iter_all_nodes(account_name): node_to_parent[n] = p def dim_for(node_name: str) -> str: if not node_belonging_dim: return "—" return node_belonging_dim.get(node_name) or "—" scored: list[tuple[str, float, str, str]] = [] if not derived_list: for node_name, parent_name, node in _iter_all_nodes(account_name): if allowed_node_names is not None and node_name not in allowed_node_names: continue ratio = node.get("_ratio") if ratio is None: ratio = 0.0 else: ratio = float(ratio) if ratio >= threshold: scored.append((node_name, ratio, parent_name, dim_for(node_name))) else: node_post_index = build_node_post_index(account_name, base_dir) for node_name, parent_name in node_to_parent.items(): if allowed_node_names is not None and node_name not in allowed_node_names: continue ratio = calc_node_conditional_ratio( account_name, derived_list, node_name, base_dir=base_dir, node_post_index=node_post_index, target_ratio=threshold, ) if ratio >= threshold: scored.append((node_name, ratio, parent_name, dim_for(node_name))) scored.sort(key=lambda x: x[1], reverse=True) top = scored[:top_n] return [ { "节点名称": name, "条件概率": ratio, "父节点名称": parent, "所属维度": dim, } for name, ratio, parent, dim in top ] def _platform_tree_dir() -> Path: """平台库人设树目录:../input/xiaohongshu/tree/""" return _BASE_INPUT / "xiaohongshu" / "tree" def _collect_platform_scored_tuples( derived_list: list[tuple[str, str]], threshold: float, max_nodes: int = 12000, ) -> list[tuple[str, float, str, str]]: """ 平台库人设树:条件概率 >= threshold 的节点全量收集,按条件概率降序。 max_nodes 防止极端大树占满内存;截断发生在全局排序之后(保留高分段)。 """ tree_dir = _platform_tree_dir() if not tree_dir.is_dir(): return [] thr = float(threshold) scored: list[tuple[str, float, str, str]] = [] if not derived_list: for dim_name, root in load_persona_trees_from_dir(tree_dir): def walk(parent_name: str, node_dict: dict) -> None: for name, child in (node_dict.get("children") or {}).items(): if not isinstance(child, dict): continue ratio = child.get("_ratio") r = 0.0 if ratio is None else float(ratio) if r >= thr: scored.append((name, r, parent_name, dim_name)) walk(name, child) walk(dim_name, root) else: node_post_index = build_node_post_index_from_tree_dir(tree_dir) node_to_parent_dim: dict[str, tuple[str, str]] = {} for dim_name, root in load_persona_trees_from_dir(tree_dir): def walk2(parent_name: str, node_dict: dict) -> None: for name, child in (node_dict.get("children") or {}).items(): if not isinstance(child, dict): continue node_to_parent_dim[name] = (parent_name, dim_name) walk2(name, child) walk2(dim_name, root) for node_name, (parent_name, dim_name) in node_to_parent_dim.items(): ratio = calc_node_conditional_ratio( "", derived_list, node_name, base_dir=_BASE_INPUT, node_post_index=node_post_index, target_ratio=thr, ) if ratio >= thr: scored.append((node_name, ratio, parent_name, dim_name)) scored.sort(key=lambda x: x[1], reverse=True) if max_nodes > 0 and len(scored) > max_nodes: scored = scored[:max_nodes] return scored def get_platform_nodes_by_conditional_ratio( derived_list: list[tuple[str, str]], threshold: float, top_n: int, ) -> list[dict[str, Any]]: """ 平台库人设树节点条件概率筛选,计算方式与 get_nodes_by_conditional_ratio 一致 (同一套 calc_node_conditional_ratio / _post_ids 规则,索引来自 xiaohongshu/tree)。 derived_list 为空时用节点 _ratio。 """ n = max(0, int(top_n)) scored = _collect_platform_scored_tuples(derived_list, threshold) top = scored[:n] return [ { "节点名称": name, "条件概率": ratio, "父节点名称": parent, "所属维度": dim, } for name, ratio, parent, dim in top ] def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]: """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。""" out = [] for item in derived_items: if isinstance(item, dict): topic = item.get("topic") or item.get("已推导的选题点") source = item.get("source_node") or item.get("推导来源人设树节点") if topic is not None and source is not None: out.append((str(topic).strip(), str(source).strip())) elif isinstance(item, (list, tuple)) and len(item) >= 2: out.append((str(item[0]).strip(), str(item[1]).strip())) return out # --------------------------------------------------------------------------- # 3. 平台库人设树辅助节点(基于帖子与平台库人设树匹配结果) # --------------------------------------------------------------------------- def _platform_match_topics_by_node( post_id: str, match_score_threshold: float, ) -> dict[tuple[str, str], dict[str, float]]: """ 读取 xiaohongshu/match_data/{post_id}_匹配_all.json, 返回 (dimension, 人设树节点名) -> {帖子选题点: 最高分},仅收录 match_score >= match_score_threshold 的对。 """ out: dict[tuple[str, str], dict[str, float]] = {} if not post_id: return out path = _BASE_INPUT / "xiaohongshu" / "match_data" / f"{post_id}_匹配_all.json" if not path.is_file(): return out try: with open(path, "r", encoding="utf-8") as f: data = json.load(f) except Exception: return out if not isinstance(data, list): return out thr = float(match_score_threshold) for item in data: if not isinstance(item, dict): continue topic = item.get("name") matches = item.get("match_personas") if topic is None or not isinstance(matches, list): continue topic_s = str(topic).strip() if not topic_s: continue for m in matches: if not isinstance(m, dict): continue name = m.get("name") dim = m.get("dimension") score = m.get("match_score") if name is None or dim is None or score is None: continue try: s = float(score) except Exception: continue if s < thr: continue key = (str(dim).strip(), str(name).strip()) bucket = out.setdefault(key, {}) prev = bucket.get(topic_s) if prev is None or s > prev: bucket[topic_s] = s return out def _platform_node_belonging_dim_from_anchor_nodes( anchor_node_names: list[str], ) -> dict[str, str]: """ 计算平台库人设树中:节点名 -> 所属最深 derived_dim 锚点节点名。 逻辑与账号段 _descendant_names_under_tree_nodes 保持一致(但树结构来自 xiaohongshu/tree)。 """ if not anchor_node_names: return {} S = set(anchor_node_names) dim_map: dict[str, str] = {} tree_dir = _platform_tree_dir() if not tree_dir.is_dir(): return {} for dim_root_name, root in load_persona_trees_from_dir(tree_dir): def dfs(node_name: str, node_dict: dict, parent_deepest_s: Optional[str]) -> None: d_self = node_name if node_name in S else parent_deepest_s for cname, cnode in (node_dict.get("children") or {}).items(): if not isinstance(cnode, dict): continue if cname not in S and d_self is not None: dim_map[cname] = d_self dfs(cname, cnode, d_self) dfs(dim_root_name, root, None) return dim_map def _load_platform_nodes_split( post_id: str, derived_list: list[tuple[str, str]], conditional_ratio_threshold: float, match_score_threshold: float, top_n: int, node_belonging_dim_platform: Optional[dict[str, str]] = None, ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ 平台库人设树:用 _collect_platform_scored_tuples 得到条件概率达标的节点, 再按 xiaohongshu/match_data 分为「有帖子选题点匹配 / 无匹配」两类,**两类各自按条件概率取 Top 池**(同一全局 TopN 不会挤掉另一类), 最后分别组装返回: - matched:有 match_score >= match_score_threshold 的帖子选题点匹配的节点 - unmatched:无达标帖子选题点匹配的节点 两组均要求节点在 node_belonging_dim_platform 中有有效的所属维度(不为「—」)。 """ matched: list[dict[str, Any]] = [] unmatched: list[dict[str, Any]] = [] topic_map: dict[tuple[str, str], dict[str, float]] = {} if post_id: topic_map = _platform_match_topics_by_node(post_id, float(match_score_threshold)) # 维度标签可能与树侧不完全一致:保留一个按节点名聚合的兜底索引,避免误判为“无匹配”。 topic_map_by_name: dict[str, dict[str, float]] = {} for (_dim, n), topics in topic_map.items(): bucket = topic_map_by_name.setdefault(str(n).strip(), {}) for t, sc in (topics or {}).items(): prev = bucket.get(t) if prev is None or sc > prev: bucket[t] = sc # 有 match_data 命中与无命中两类分开按条件概率取 Top,避免混在一个全局 TopN 里挤掉某一类。 all_scored = _collect_platform_scored_tuples( derived_list, float(conditional_ratio_threshold), ) if not all_scored: return matched, unmatched matched_tuples: list[tuple[str, float, str, str]] = [] unmatched_tuples: list[tuple[str, float, str, str]] = [] for name, ratio, parent, dim in all_scored: lookup_dim = str(dim).strip() key = (lookup_dim, str(name).strip()) topics = topic_map.get(key) or topic_map_by_name.get(str(name).strip()) or {} if topics: matched_tuples.append((name, ratio, parent, dim)) else: unmatched_tuples.append((name, ratio, parent, dim)) _pool = max(int(top_n), min(2000, max(500, int(top_n) * 5))) matched_tuples = matched_tuples[:_pool] unmatched_tuples = unmatched_tuples[:_pool] def _emit_tuple_rows( tuples: list[tuple[str, float, str, str]], *, has_topics: bool, ) -> None: for name, ratio, parent, dim in tuples: row = { "节点名称": name, "条件概率": ratio, "父节点名称": parent, "所属维度": dim, } name_s = str(row.get("节点名称") or "").strip() out_dim = "—" if node_belonging_dim_platform is not None: out_dim = node_belonging_dim_platform.get(name_s) or "—" if node_belonging_dim_platform is not None and out_dim == "—": continue row_out = dict(row) row_out["所属维度"] = out_dim lookup_dim = str(row.get("所属维度") or "").strip() key2 = (lookup_dim, name_s) topics = topic_map.get(key2) or topic_map_by_name.get(name_s) or {} if has_topics: if not topics: continue topic_items = sorted(topics.items(), key=lambda x: x[1], reverse=True) row_out["帖子选题点匹配"] = [{"帖子选题点": t, "匹配分数": sc} for t, sc in topic_items] matched.append(row_out) else: if topics: continue row_out["帖子选题点匹配"] = "无" unmatched.append(row_out) _emit_tuple_rows(matched_tuples, has_topics=True) _emit_tuple_rows(unmatched_tuples, has_topics=False) return matched, unmatched def build_platform_tree_section_items_split( post_id: str, derived_list: list[tuple[str, str]], conditional_ratio_threshold: float, match_score_threshold: float, top_n: int, exclude_node_names: set[str], node_belonging_dim_platform: Optional[dict[str, str]] = None, ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """ 平台库人设树节点:条件概率 + xiaohongshu/match_data 匹配,排除与账号段重复的节点名称, 返回 (有帖子选题点匹配的节点列表, 无帖子选题点匹配的节点列表)。 供 find_tree_nodes_by_conditional_ratio 聚合输出使用。 """ if not post_id: return [], [] ex = {str(n).strip() for n in exclude_node_names} matched, unmatched = _load_platform_nodes_split( post_id=post_id, derived_list=derived_list, conditional_ratio_threshold=float(conditional_ratio_threshold), match_score_threshold=float(match_score_threshold), top_n=int(top_n), node_belonging_dim_platform=node_belonging_dim_platform, ) matched_filtered = [p for p in matched if str(p.get("节点名称", "")).strip() not in ex] unmatched_filtered = [p for p in unmatched if str(p.get("节点名称", "")).strip() not in ex] return matched_filtered, unmatched_filtered # --------------------------------------------------------------------------- # Agent Tools(参考 glob_tool 封装) # --------------------------------------------------------------------------- @tool() async def find_tree_constant_nodes( account_name: str, post_id: str, ) -> ToolResult: """ 获取人设树中的常量节点列表(全局常量与局部常量),并检查每个节点与帖子选题点的匹配情况。 Args: account_name : 账号名,用于定位该账号的人设树数据。 post_id : 帖子ID,用于加载帖子选题点并与各常量节点做匹配判断。 Returns: ToolResult: - title: 结果标题。 - output: 可读的节点列表文本(每行:节点名称、概率、常量类型、帖子匹配情况)。 - 出错时 error 为错误信息。 """ tree_dir = _tree_dir(account_name) if not tree_dir.is_dir(): return ToolResult( title="人设树目录不存在", output=f"目录不存在: {tree_dir}", error="Directory not found", ) try: items = get_constant_nodes(account_name) # 批量匹配所有节点与帖子选题点 if items and post_id: node_names = [x["节点名称"] for x in items] matched_results = await match_derivation_to_post_points( node_names, account_name, post_id, match_threshold=float(DEFAULT_MATCH_THRESHOLD) ) node_match_map: dict[str, list] = {} for m in matched_results: node_match_map.setdefault(m["推导选题点"], []).append({ "帖子选题点": m["帖子选题点"], "匹配分数": m["匹配分数"], }) for item in items: matches = node_match_map.get(item["节点名称"], []) item["帖子选题点匹配"] = matches if matches else "无" if not items: output = "未找到常量节点" else: lines = [] for x in items: match_info = x.get("帖子选题点匹配", "无") if isinstance(match_info, list): match_str = "、".join(f"{m['帖子选题点']}({m['匹配分数']})" for m in match_info) else: match_str = str(match_info) lines.append(f"- {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}\t帖子选题点匹配={match_str}") output = "\n".join(lines) return ToolResult( title=f"常量节点 ({account_name})", output=output, metadata={"account_name": account_name, "count": len(items)}, ) except Exception as e: return ToolResult( title="获取常量节点失败", output=str(e), error=str(e), ) @tool() async def find_tree_nodes_by_conditional_ratio( account_name: str, post_id: str, derived_items: list[dict[str, str]], conditional_ratio_threshold: float, top_n: int = 100, round: int = 1, log_id: str = "", match_score_threshold: float = DEFAULT_MATCH_THRESHOLD, ) -> ToolResult: """ 按条件概率阈值筛选节点:先账号人设树(优先使用),再平台库人设树;两段不合并。 条件概率计算对两棵树使用同一套规则(calc_node_conditional_ratio / 节点 _post_ids)。 返回结果按以下配额分配(合计 top_n 条): - 账号人设树节点占 60%,其中有帖子选题点匹配的记录和无帖子选题点匹配的记录各占一半; - 平台库人设树节点占 40%,其中有帖子选题点匹配的记录和无帖子选题点匹配的记录各占一半。 「帖子选题点匹配」仅收录匹配分 >= match_score_threshold 的选题点。 Args: account_name : 账号名,用于定位该账号的人设树数据。 post_id : 帖子ID,用于加载帖子选题点并与各节点做匹配判断。 derived_items : 已推导选题点列表,可为空。非空时每项为字典,需含 topic(或「已推导的选题点」)与 source_node(或「推导来源人设树节点」) conditional_ratio_threshold : 条件概率阈值,仅返回条件概率 >= 该值的节点。 top_n : 最终返回总条数上限,按 账号60%/平台40%、有匹配/无匹配各半 分配。 round : 推导轮次。 log_id : 推导日志ID match_score_threshold : 帖子选题点匹配分阈值,与 point_match 默认一致。 Returns: ToolResult: - title: 结果标题。 - output: 两段文本——先账号人设树,后平台库人设树; 账号侧匹配来自 input/{账号}/match_data;平台侧条件概率基于 input/xiaohongshu/tree,匹配来自 input/xiaohongshu/match_data。 - 出错时 error 为错误信息。 """ tree_dir = _tree_dir(account_name) if not tree_dir.is_dir(): return ToolResult( title="人设树目录不存在", output=f"目录不存在: {tree_dir}", error="Directory not found", ) try: derived_list = _parse_derived_list(derived_items or []) allowed: Optional[set[str]] = None node_belonging_dim: dict[str, str] = {} node_belonging_dim_platform: Optional[dict[str, str]] = None dim_source = "" derived_dim_names: list[str] = [] derived_items_len = len(derived_items or []) if log_id and str(log_id).strip(): derived_dim_names = _load_derived_dim_tree_node_names( account_name, post_id, str(log_id).strip(), int(round) ) if derived_dim_names: allowed, node_belonging_dim = _descendant_names_under_tree_nodes( account_name, derived_dim_names ) node_belonging_dim_platform = _platform_node_belonging_dim_from_anchor_nodes( derived_dim_names ) # 记录实际用到的维度分析文件(与读取逻辑一致) log_dir = _dimension_analysis_log_dir(account_name, post_id, str(log_id).strip()) for r in (int(round), int(round) - 1): if r >= 1 and (log_dir / f"{r}_维度分析.json").is_file(): dim_source = f"{r}_维度分析.json (derived_dims -> 全部后代)" break else: dim_source = "未读到 derived_dims(无对应维度分析文件或为空),未收窄" # 当 derived_items 太多时,用 derived_dim_names 作为条件概率计算锚点: # 将每个 derived_dim_names 的 name 都映射为 (topic=name, source_node=name)。 if derived_items_len > 15 and derived_dim_names: derived_list = [(n, n) for n in derived_dim_names] # 1)账号人设树:按条件概率筛选;帖子选题点匹配仅走账号 match_data(match_derivation_to_post_points) items = get_nodes_by_conditional_ratio( account_name, derived_list, conditional_ratio_threshold, top_n, allowed_node_names=allowed, node_belonging_dim=node_belonging_dim if node_belonging_dim else None, ) if items and post_id: node_names = [x["节点名称"] for x in items] matched_results = await match_derivation_to_post_points( node_names, account_name, post_id, match_threshold=float(match_score_threshold) ) node_match_map: dict[str, list] = {} for m in matched_results: node_match_map.setdefault(m["推导选题点"], []).append({ "帖子选题点": m["帖子选题点"], "匹配分数": m["匹配分数"], }) for item in items: matches = node_match_map.get(item["节点名称"], []) item["帖子选题点匹配"] = matches if matches else "无" # 账号配额:占 top_n 的 60%,有/无匹配各一半 account_quota = int(top_n * 0.6 + 0.5) account_with_n = account_quota // 2 account_without_n = account_quota - account_with_n items_with_match = [x for x in items if isinstance(x.get("帖子选题点匹配"), list)] items_without_match = [x for x in items if not isinstance(x.get("帖子选题点匹配"), list)] items = items_with_match[:account_with_n] + items_without_match[:account_without_n] # 2)平台库人设树(条件概率 + xiaohongshu 匹配文件) # 平台配额:占 top_n 的 40%,有/无匹配各一半 platform_quota = top_n - account_quota platform_with_n = platform_quota // 2 platform_without_n = platform_quota - platform_with_n # 平台「有匹配」排除账号侧已有帖子选题点匹配的节点名(与账号段去重)。 # 平台「无匹配」排除已在账号段输出里出现过的节点名(避免重复罗列无新信息的同名节点)。 account_matched_names = {str(x.get("节点名称", "")).strip() for x in items if isinstance(x.get("帖子选题点匹配"), list)} account_all_names = {str(x.get("节点名称", "")).strip() for x in items} platform_items: list[dict[str, Any]] = [] if post_id: p_matched_raw, p_unmatched_raw = _load_platform_nodes_split( post_id=post_id, derived_list=derived_list, conditional_ratio_threshold=float(conditional_ratio_threshold), match_score_threshold=float(match_score_threshold), top_n=top_n, node_belonging_dim_platform=node_belonging_dim_platform, ) p_matched = [p for p in p_matched_raw if str(p.get("节点名称", "")).strip() not in account_matched_names] p_unmatched = [p for p in p_unmatched_raw if str(p.get("节点名称", "")).strip() not in account_all_names] platform_items = p_matched[:platform_with_n] + p_unmatched[:platform_without_n] def _format_node_line(x: dict[str, Any]) -> str: match_info = x.get("帖子选题点匹配", "无") if isinstance(match_info, list): match_str = "、".join(f"{m['帖子选题点']}({m['匹配分数']})" for m in match_info) else: match_str = str(match_info) dim_label = x.get("所属维度", "—") return ( f"- {x['节点名称']}\t条件概率={x['条件概率']}\t所属维度={dim_label}" f"\t帖子选题点匹配={match_str}" ) lines: list[str] = [] lines.append( "【优先使用】第一节为账号人设树中条件概率达标的节点(占60%配额,有/无帖子匹配各半);" "第二节为平台库人设树中条件概率达标的节点(占40%配额,有/无帖子匹配各半);" ) lines.append("") lines.append("—— 账号人设树节点 ——") if not items: lines.append(f"(无:未找到条件概率 >= {conditional_ratio_threshold} 的节点)") else: lines.extend(_format_node_line(x) for x in items) lines.append("") lines.append("—— 平台库人设树节点 ——") if not platform_items: lines.append( "(无:未找到条件概率达标的节点)" ) else: lines.extend(_format_node_line(x) for x in platform_items) output = "\n".join(lines) return ToolResult( title=f"条件概率节点 ({account_name}, 阈值={conditional_ratio_threshold})", output=output, metadata={ "account_name": account_name, "threshold": conditional_ratio_threshold, "match_score_threshold": float(match_score_threshold), "top_n": top_n, "quota": { "account_quota": account_quota, "account_with_match": len([x for x in items if isinstance(x.get("帖子选题点匹配"), list)]), "account_without_match": len([x for x in items if not isinstance(x.get("帖子选题点匹配"), list)]), "platform_quota": platform_quota, "platform_with_match": len([x for x in platform_items if isinstance(x.get("帖子选题点匹配"), list)]), "platform_without_match": len([x for x in platform_items if not isinstance(x.get("帖子选题点匹配"), list)]), }, "account_tree_count": len(items), "platform_tree_count": len(platform_items), "count": len(items) + len(platform_items), "round": int(round), "log_id": str(log_id).strip() if log_id else "", "dimension_filter": { "derived_dim_nodes": derived_dim_names, "allowed_descendant_count": len(allowed) if allowed is not None else None, "source": dim_source or ("未提供 log_id,未按维度收窄" if not (log_id and str(log_id).strip()) else ""), }, }, ) except Exception as e: return ToolResult( title="按条件概率查询节点失败", output=str(e), error=str(e), ) def main() -> None: """本地测试:用家有大志账号测常量节点与条件概率节点,有 agent 时再跑一遍 tool 接口。""" import asyncio account_name = "家有大志" post_id = "68fb6a5c000000000302e5de" log_id = "20260319134630" round = 4 # derived_items = [ # {"topic": "分享", "source_node": "分享"}, # {"topic": "叙事结构", "source_node": "叙事结构"}, # ] derived_items = [{"topic":"推广","source_node":"推广"},{"topic":"视觉调性","source_node":"视觉调性"}] conditional_ratio_threshold = 0.2 top_n = 200 # # 1)常量节点(核心函数,无匹配) # constant_nodes = get_constant_nodes(account_name) # print(f"账号: {account_name} — 常量节点共 {len(constant_nodes)} 个(前 50 个):") # for x in constant_nodes[:50]: # print(f" - {x['节点名称']}\t概率={x['概率']}\t{x['常量类型']}") # print() # # # 2)条件概率节点(核心函数) # derived_list = _parse_derived_list(derived_items) # ratio_nodes = get_nodes_by_conditional_ratio( # account_name, derived_list, conditional_ratio_threshold, top_n # ) # print(f"条件概率节点 阈值={conditional_ratio_threshold}, top_n={top_n}, 共 {len(ratio_nodes)} 个:") # for x in ratio_nodes: # print(f" - {x['节点名称']}\t条件概率={x['条件概率']}\t父节点={x['父节点名称']}") # print() # 3)有 agent 时通过 tool 接口再跑一遍(含帖子选题点匹配) if ToolResult is not None: async def run_tools(): r1 = await find_tree_constant_nodes(account_name, post_id=post_id) print("--- find_tree_constant_nodes ---") print(r1.output[:2000] + "..." if len(r1.output) > 2000 else r1.output) r2 = await find_tree_nodes_by_conditional_ratio( account_name, post_id=post_id, derived_items=derived_items, conditional_ratio_threshold=conditional_ratio_threshold, top_n=top_n, round=round, log_id=log_id, ) print("\n--- find_tree_nodes_by_conditional_ratio ---") print(r2.output) asyncio.run(run_tools()) if __name__ == "__main__": main()