""" 查找 Pattern Tool - 从 pattern 库中获取符合条件概率阈值的 pattern 功能: - 账号:读取 input/{账号}/处理后数据/pattern/pattern.json,条件概率基于账号人设树; 元素与帖子选题点匹配走账号 match_data / point_match,并支持人设树子节点、兄弟节点扩展。 - 平台库:读取 input/xiaohongshu/pattern/processed_edge_data.json,条件概率基于 xiaohongshu/tree; 元素匹配仅使用 input/xiaohongshu/match_data/{post_id}_匹配_all.json。 """ import json import sys from pathlib import Path from typing import Any # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转) _root = Path(__file__).resolve().parent.parent if str(_root) not in sys.path: sys.path.insert(0, str(_root)) from utils.conditional_ratio_calc import ( build_node_index_for_tree_dir, calc_pattern_conditional_ratio, calc_pattern_conditional_ratio_with_index, ) from tools.point_match import ( DEFAULT_MATCH_THRESHOLD, _load_match_data, match_derivation_to_post_points, ) from tools.find_tree_node import _load_trees try: from agent.tools import tool, ToolResult, ToolContext except ImportError: def tool(*args, **kwargs): return lambda f: f ToolResult = None # 仅用 main() 测核心逻辑时可无 agent ToolContext = None # 与 pattern_data_process 一致的 key 定义 TOP_KEYS = [ "depth_max_with_name", "depth_mixed", "depth_max_concrete", "depth2_medium", "depth1_abstract", "depth_max_minus_1", "depth_max_minus_2", "depth_3", "depth_4", ] SUB_KEYS = ["two_x", "one_x", "zero_x"] _BASE_INPUT = Path(__file__).resolve().parent.parent / "input" # 排序时「已推导选题点 ↔ pattern 元素」在 match_data 中的高分优先阈值(与账号段原逻辑一致) _MATCH_PRIOR_MIN_SCORE = 0.8 _PLATFORM_TREE_DIR = _BASE_INPUT / "xiaohongshu" / "tree" _PLATFORM_PATTERN_FILE = _BASE_INPUT / "xiaohongshu" / "pattern" / "processed_edge_data.json" def _build_node_info(account_name: str) -> dict[str, dict]: """ 构建人设树节点信息映射: node_name -> { "type": 节点 _type("class" / "ID" 等), "children": 子节点名称列表(仅分类节点有值), "siblings": 兄弟节点名称列表(不含自身), } """ node_info: dict[str, dict] = {} def _walk(node_dict: dict): children_dict = node_dict.get("children") or {} child_entries = [(n, c) for n, c in children_dict.items() if isinstance(c, dict)] child_names = [n for n, _ in child_entries] for name, child in child_entries: sub_children = child.get("children") or {} sub_child_names = [n for n, c in sub_children.items() if isinstance(c, dict)] node_info[name] = { "type": child.get("_type", ""), "children": sub_child_names, "siblings": [n for n in child_names if n != name], } _walk(child) for _dim_name, root in _load_trees(account_name): _walk(root) return node_info def _pattern_file(account_name: str) -> Path: """pattern 库文件:../input/{account_name}/处理后数据/pattern/pattern.json""" return _BASE_INPUT / account_name / "处理后数据" / "pattern" / "pattern.json" def _platform_pattern_file() -> Path: """平台库 pattern:../input/xiaohongshu/pattern/processed_edge_data.json""" return _PLATFORM_PATTERN_FILE def _slim_pattern(p: dict) -> tuple[float, int, list[str], int]: """提取 name 列表(去重保序)、support、length、post_count。""" names = [item["name"] for item in (p.get("items") or [])] seen = set() unique = [] for n in names: if n not in seen: seen.add(n) unique.append(n) support = round(float(p.get("support", 0)), 4) length = int(p.get("length", 0)) post_count = int(p.get("post_count", 0)) return support, length, unique, post_count def _merge_and_dedupe(patterns: list[dict]) -> list[dict]: """ 按 items 的 name 集合去重(不区分顺序),留 support 最大; 输出格式保留 s、l、i(nameA+nameB+nameC)及 post_count,供条件概率计算使用。 """ key_to_best: dict[tuple, tuple[float, int, int]] = {} for p in patterns: support, length, unique, post_count = _slim_pattern(p) if not unique: continue key = tuple(sorted(unique)) if key not in key_to_best or support > key_to_best[key][0]: key_to_best[key] = (support, length, post_count) out = [] for k, (s, l, post_count) in key_to_best.items(): out.append({ "s": s, "l": l, "i": "+".join(k), "post_count": post_count, }) out.sort(key=lambda x: x["s"] * x["l"], reverse=True) return out def _load_and_merge_patterns(account_name: str) -> list[dict]: """读取 pattern 库 JSON,按 TOP_KEYS/SUB_KEYS 合并为列表并做合并、去重。""" path = _pattern_file(account_name) if not path.is_file(): return [] with open(path, "r", encoding="utf-8") as f: data = json.load(f) all_patterns = [] for top in TOP_KEYS: if top not in data: continue block = data[top] for sub in SUB_KEYS: all_patterns.extend(block.get(sub) or []) return _merge_and_dedupe(all_patterns) def _load_and_merge_platform_patterns() -> list[dict]: """读取平台库 pattern JSON,结构与账号库相同,合并去重。""" path = _platform_pattern_file() if not path.is_file(): return [] with open(path, "r", encoding="utf-8") as f: data = json.load(f) all_patterns = [] for top in TOP_KEYS: if top not in data: continue block = data[top] for sub in SUB_KEYS: all_patterns.extend(block.get(sub) or []) return _merge_and_dedupe(all_patterns) def _load_platform_match_pair_lookup(post_id: str) -> dict[tuple[str, str], float]: """ xiaohongshu/match_data/{post_id}_匹配_all.json -> (帖子选题点, 人设树节点名) -> 最高 match_score(跨 dimension 合并)。 """ lookup: dict[tuple[str, str], float] = {} if not post_id: return lookup path = _BASE_INPUT / "xiaohongshu" / "match_data" / f"{post_id}_匹配_all.json" if not path.is_file(): return lookup try: with open(path, "r", encoding="utf-8") as f: data = json.load(f) except Exception: return lookup if not isinstance(data, list): return lookup for item in data: if not isinstance(item, dict): continue topic = item.get("name") personas = item.get("match_personas") if topic is None or not isinstance(personas, list): continue topic_s = str(topic).strip() if not topic_s: continue for mp in personas: if not isinstance(mp, dict): continue elem = mp.get("name") score = mp.get("match_score") if elem is None or score is None: continue elem_s = str(elem).strip() if not elem_s: continue try: sc = float(score) except (TypeError, ValueError): continue key = (topic_s, elem_s) if key not in lookup or sc > lookup[key]: lookup[key] = sc return lookup def _platform_element_post_match_map( post_id: str, match_score_threshold: float, ) -> dict[str, dict[str, float]]: """ 平台库:节点名称(不区分 dimension)-> {帖子选题点: 最高分}, 仅保留 match_score >= match_score_threshold 的对。 """ out: dict[str, dict[str, float]] = {} if not post_id: return out path = _BASE_INPUT / "xiaohongshu" / "match_data" / f"{post_id}_匹配_all.json" if not path.is_file(): return out try: with open(path, "r", encoding="utf-8") as f: data = json.load(f) except Exception: return out if not isinstance(data, list): return out thr = float(match_score_threshold) for item in data: if not isinstance(item, dict): continue topic = item.get("name") personas = item.get("match_personas") if topic is None or not isinstance(personas, list): continue topic_s = str(topic).strip() if not topic_s: continue for mp in personas: if not isinstance(mp, dict): continue elem = mp.get("name") score = mp.get("match_score") if elem is None or score is None: continue try: sc = float(score) except (TypeError, ValueError): continue if sc < thr: continue elem_s = str(elem).strip() if not elem_s: continue bucket = out.setdefault(elem_s, {}) prev = bucket.get(topic_s) if prev is None or sc > prev: bucket[topic_s] = sc return out def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]: """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。""" out = [] for item in derived_items: if isinstance(item, dict): topic = item.get("topic") or item.get("已推导的选题点") source = item.get("source_node") or item.get("推导来源人设树节点") if topic is not None and source is not None: out.append((str(topic).strip(), str(source).strip())) elif isinstance(item, (list, tuple)) and len(item) >= 2: out.append((str(item[0]).strip(), str(item[1]).strip())) return out def get_patterns_by_conditional_ratio( account_name: str, derived_list: list[tuple[str, str]], conditional_ratio_threshold: float, top_n: int, post_id: str = "", ) -> list[dict[str, Any]]: """ 从 pattern 库中获取条件概率 >= 阈值的 pattern,按以下优先级排序后返回 top_n 条: 1. pattern 元素中直接包含已推导选题点(topic)的排最前; 2. pattern 元素与任意已推导选题点的匹配分 >= 0.8 的次之(从 match_data 文件读取, key 为 (帖子选题点, 人设树节点),pattern 元素视为人设树节点); 3. 按条件概率降序; 4. 按 length 降序。 derived_list 为空时,条件概率使用 pattern 自身的 support(s)。 返回每项:pattern名称(nameA+nameB+nameC)、条件概率。 """ merged = _load_and_merge_patterns(account_name) if not merged: return [] base_dir = _BASE_INPUT scored: list[tuple[dict, float]] = [] if not derived_list: # derived_items 为空:条件概率取 pattern 本身的 support (s) for p in merged: ratio = float(p.get("s", 0)) if ratio >= conditional_ratio_threshold: scored.append((p, ratio)) else: for p in merged: ratio = calc_pattern_conditional_ratio( account_name, derived_list, p, base_dir=base_dir ) if ratio >= conditional_ratio_threshold: scored.append((p, ratio)) derived_topics = {topic for topic, _ in derived_list} if derived_list else set() # 次优先:从 match_data 文件加载 (帖子选题点, 人设树节点) -> 匹配分, # 用已推导选题点(topic)作为帖子选题点,pattern 元素作为人设树节点, # 检查是否存在匹配分 >= 0.8 的组合。 match_lookup: dict[tuple[str, str], float] = {} if derived_topics and post_id: match_lookup = _load_match_data(account_name, post_id) def _sort_key(x: tuple[dict, float]) -> tuple: p, ratio = x elements = set(p["i"].split("+")) has_derived = bool(elements & derived_topics) has_high_match = False if not has_derived and match_lookup: for elem in elements: for dt in derived_topics: if match_lookup.get((dt, elem), 0.0) >= _MATCH_PRIOR_MIN_SCORE: has_high_match = True break if has_high_match: break return (not has_derived, not has_high_match, -ratio, -p["l"]) scored.sort(key=_sort_key) result = [] for p, ratio in scored[:top_n]: result.append({ "pattern名称": p["i"], "条件概率": round(ratio, 6), }) return result def get_platform_patterns_by_conditional_ratio( derived_list: list[tuple[str, str]], conditional_ratio_threshold: float, top_n: int, post_id: str = "", ) -> list[dict[str, Any]]: """ 平台库 pattern:数据来自 xiaohongshu/pattern/processed_edge_data.json, 条件概率基于 xiaohongshu/tree 的节点索引(与账号侧 calc_pattern 规则一致)。 排序优先级规则与 get_patterns_by_conditional_ratio 一致,高分参照 xiaohongshu/match_data。 """ merged = _load_and_merge_platform_patterns() if not merged: return [] platform_index = build_node_index_for_tree_dir(_PLATFORM_TREE_DIR) scored: list[tuple[dict, float]] = [] if not derived_list: for p in merged: ratio = float(p.get("s", 0)) if ratio >= conditional_ratio_threshold: scored.append((p, ratio)) else: for p in merged: ratio = calc_pattern_conditional_ratio_with_index(derived_list, p, platform_index) if ratio >= conditional_ratio_threshold: scored.append((p, ratio)) derived_topics = {topic for topic, _ in derived_list} if derived_list else set() match_lookup: dict[tuple[str, str], float] = {} if derived_topics and post_id: match_lookup = _load_platform_match_pair_lookup(post_id) def _sort_key(x: tuple[dict, float]) -> tuple: p, ratio = x elements = set(p["i"].split("+")) has_derived = bool(elements & derived_topics) has_high_match = False if not has_derived and match_lookup: for elem in elements: for dt in derived_topics: if match_lookup.get((dt, elem), 0.0) >= _MATCH_PRIOR_MIN_SCORE: has_high_match = True break if has_high_match: break return (not has_derived, not has_high_match, -ratio, -p["l"]) scored.sort(key=_sort_key) result = [] for p, ratio in scored[:top_n]: result.append({ "pattern名称": p["i"], "条件概率": round(ratio, 6), }) return result def _attach_platform_pattern_post_matches( items: list[dict[str, Any]], post_id: str, match_score_threshold: float, ) -> None: """就地写入 帖子选题点匹配:仅使用 xiaohongshu/match_data,元素为节点名(跨 dimension 聚合)。""" if not items or not post_id: for it in items: it["帖子选题点匹配"] = "无" return elem_map = _platform_element_post_match_map(post_id, float(match_score_threshold)) for item in items: pattern_matches: list[dict[str, Any]] = [] for elem in item["pattern名称"].split("+"): elem = elem.strip() if not elem: continue for post_topic, sc in (elem_map.get(elem) or {}).items(): pattern_matches.append({ "pattern元素": elem, "帖子选题点": post_topic, "匹配分数": round(sc, 6), }) distinct_post_points = len({m["帖子选题点"] for m in pattern_matches}) item["帖子选题点匹配"] = ( pattern_matches if distinct_post_points >= 2 else "无" ) @tool() async def find_pattern( account_name: str, post_id: str, derived_items: list[dict[str, str]], conditional_ratio_threshold: float, top_n: int = 100, match_score_threshold: float = DEFAULT_MATCH_THRESHOLD, ) -> ToolResult: """ 按条件概率阈值从 pattern 库筛选:第一节为账号 pattern,第二节为平台库 pattern(xiaohongshu/pattern)。 账号段帖子匹配走账号 match_data + point_match;平台段元素匹配仅走 xiaohongshu/match_data。 Args: account_name : 账号名,用于定位该账号的 pattern 库。 post_id : 帖子ID。 derived_items : 已推导选题点列表,可为空。 conditional_ratio_threshold : 条件概率阈值。 top_n : 账号段与平台段各自最多返回条数(各自经匹配过滤后可能更少)。 match_score_threshold : 帖子选题点匹配分阈值。 Returns: ToolResult:output 分「账号 pattern」「平台库 pattern」两段;平台段已排除与账号段 pattern 名称完全相同的项。 """ def _split_by_post_match( items: list[dict[str, Any]], ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: matched: list[dict[str, Any]] = [] unmatched: list[dict[str, Any]] = [] for x in items: if isinstance(x.get("帖子选题点匹配"), list): matched.append(x) else: unmatched.append(x) return matched, unmatched def _pick_with_quota( items: list[dict[str, Any]], target_count: int, ) -> list[dict[str, Any]]: return items[:max(0, int(target_count))] def _mix_by_ratio( items: list[dict[str, Any]], target_count: int, ) -> list[dict[str, Any]]: if target_count <= 0: return [] matched, unmatched = _split_by_post_match(items) matched_quota = target_count // 2 unmatched_quota = target_count - matched_quota selected = _pick_with_quota(matched, matched_quota) selected.extend(_pick_with_quota(unmatched, unmatched_quota)) if len(selected) < target_count: selected_names = {str(x.get("pattern名称", "")) for x in selected} fallback_pool = [ x for x in items if str(x.get("pattern名称", "")) not in selected_names ] selected.extend(_pick_with_quota(fallback_pool, target_count - len(selected))) return selected pattern_path = _pattern_file(account_name) if not pattern_path.is_file(): return ToolResult( title="Pattern 库不存在", output=f"pattern 文件不存在: {pattern_path}", error="Pattern file not found", ) try: derived_list = _parse_derived_list(derived_items or []) thr = float(match_score_threshold) total_top_n = max(0, int(top_n)) account_top_n = int(total_top_n * 0.6) platform_top_n = total_top_n - account_top_n # 候选池适当放大,避免按“有/无匹配”分桶后数量不足 candidate_top_n = max(total_top_n * 4, total_top_n + 100) # ---------- 账号 pattern(原逻辑:match_data + 子节点/兄弟扩展)---------- items_account = get_patterns_by_conditional_ratio( account_name, derived_list, conditional_ratio_threshold, candidate_top_n, post_id ) if not post_id: for item in items_account: item["帖子选题点匹配"] = "无" if items_account and post_id: all_elements: list[str] = [] seen_elements: set[str] = set() for item in items_account: for elem in item["pattern名称"].split("+"): elem = elem.strip() if elem and elem not in seen_elements: all_elements.append(elem) seen_elements.add(elem) matched_results = await match_derivation_to_post_points( all_elements, account_name, post_id, match_threshold=thr ) elem_match_map: dict[str, list] = {} for m in matched_results: elem_match_map.setdefault(m["推导选题点"], []).append({ "帖子选题点": m["帖子选题点"], "匹配分数": m["匹配分数"], }) for item in items_account: pattern_matches = [] for elem in item["pattern名称"].split("+"): elem = elem.strip() for post_match in elem_match_map.get(elem, []): pattern_matches.append({ "pattern元素": elem, "帖子选题点": post_match["帖子选题点"], "匹配分数": post_match["匹配分数"], }) distinct_post_points = len({m["帖子选题点"] for m in pattern_matches}) item["帖子选题点匹配"] = ( pattern_matches if distinct_post_points >= 2 else "无" ) if items_account and post_id: node_info_map = _build_node_info(account_name) all_candidates_set: set[str] = set() item_unmatched_info: list[list[tuple[str, list[str], str]]] = [] for item in items_account: pattern_matches = item.get("帖子选题点匹配", []) matched_elems = ( {m["pattern元素"] for m in pattern_matches} if isinstance(pattern_matches, list) else set() ) all_elems = [e.strip() for e in item["pattern名称"].split("+")] unmatched = [e for e in all_elems if e not in matched_elems] elem_candidates: list[tuple[str, list[str], str]] = [] for elem in unmatched: info = node_info_map.get(elem) if not info: continue if info["type"] == "class" and info["children"]: candidates = info["children"] expand_type = "子节点" else: candidates = info["siblings"] expand_type = "兄弟节点" if candidates: elem_candidates.append((elem, candidates, expand_type)) all_candidates_set.update(candidates) item_unmatched_info.append(elem_candidates) if all_candidates_set: candidate_matches = await match_derivation_to_post_points( list(all_candidates_set), account_name, post_id, match_threshold=thr ) cand_match_map: dict[str, list[tuple[str, float]]] = {} for m in candidate_matches: cand_match_map.setdefault(m["推导选题点"], []).append( (m["帖子选题点"], m["匹配分数"]) ) for item, elem_cands in zip(items_account, item_unmatched_info): for elem, candidates, expand_type in elem_cands: best_cand, best_pp, best_sc = None, None, -1.0 for cand in candidates: for pp, sc in cand_match_map.get(cand, []): if sc > best_sc: best_cand, best_pp, best_sc = cand, pp, sc if best_cand is not None: if not isinstance(item.get("帖子选题点匹配"), list): item["帖子选题点匹配"] = [] item["帖子选题点匹配"].append({ "pattern元素": elem, "帖子选题点": best_pp, "匹配分数": best_sc, "扩展节点": best_cand, "扩展类型": expand_type, }) for item in items_account: matches = item.get("帖子选题点匹配") if not isinstance(matches, list): continue best_by_pp: dict[str, dict] = {} for m in matches: pp = m["帖子选题点"] if pp not in best_by_pp or m["匹配分数"] > best_by_pp[pp]["匹配分数"]: best_by_pp[pp] = m item["帖子选题点匹配"] = list(best_by_pp.values()) items_account = _mix_by_ratio(items_account, account_top_n) account_pattern_names = {str(x.get("pattern名称", "")).strip() for x in items_account} # ---------- 平台库 pattern(xiaohongshu/tree 条件概率 + xiaohongshu/match_data 匹配)---------- items_platform: list[dict[str, Any]] = [] items_platform = get_platform_patterns_by_conditional_ratio( derived_list, conditional_ratio_threshold / 5, candidate_top_n, post_id ) if post_id: _attach_platform_pattern_post_matches(items_platform, post_id, thr) else: for item in items_platform: item["帖子选题点匹配"] = "无" items_platform = [ x for x in items_platform if str(x.get("pattern名称", "")).strip() not in account_pattern_names ] for item in items_platform: matches = item.get("帖子选题点匹配") if not isinstance(matches, list): continue best_by_pp: dict[str, dict] = {} for m in matches: pp = m["帖子选题点"] if pp not in best_by_pp or m["匹配分数"] > best_by_pp[pp]["匹配分数"]: best_by_pp[pp] = m item["帖子选题点匹配"] = list(best_by_pp.values()) items_platform = _mix_by_ratio(items_platform, platform_top_n) def _format_pattern_block(xs: list[dict[str, Any]]) -> list[str]: lines: list[str] = [] for x in xs: match_info = x.get("帖子选题点匹配", "无") if isinstance(match_info, list): match_str = "、".join( ( f"{m['扩展节点']}({m['pattern元素']}的{m['扩展类型']})→{m['帖子选题点']}({m['匹配分数']})" if "扩展节点" in m else f"{m['pattern元素']}→{m['帖子选题点']}({m['匹配分数']})" ) for m in match_info ) else: match_str = str(match_info) lines.append( f"- {x['pattern名称']}\t条件概率={x['条件概率']}\t帖子选题点匹配={match_str}" ) return lines lines_out: list[str] = [] lines_out.append( "【优先使用】第一节为账号 pattern;第二节为平台库 pattern。" ) lines_out.append("") lines_out.append("—— 账号 pattern ——") if not items_account: lines_out.append( f"(无:未找到条件概率 >= {conditional_ratio_threshold} 的 pattern)" ) else: lines_out.extend(_format_pattern_block(items_account)) lines_out.append("") lines_out.append("—— 平台库 pattern ——") if not items_platform: lines_out.append( "(无:未找到达标 pattern)" ) else: lines_out.extend(_format_pattern_block(items_platform)) output = "\n".join(lines_out) return ToolResult( title=f"符合条件概率的 Pattern ({account_name}, 阈值={conditional_ratio_threshold})", output=output, metadata={ "account_name": account_name, "conditional_ratio_threshold": conditional_ratio_threshold, "match_score_threshold": thr, "top_n": top_n, "account_pattern_count": len(items_account), "platform_pattern_count": len(items_platform), "count": len(items_account) + len(items_platform), }, ) except Exception as e: return ToolResult( title="查找 Pattern 失败", output=str(e), error=str(e), ) def main() -> None: """本地测试:用家有大志账号、已推导选题点,查询符合条件概率阈值的 pattern(含帖子匹配)。""" import asyncio account_name = "家有大志" post_id = "68fb6a5c000000000302e5de" # 已推导选题点,每项:已推导的选题点 + 推导来源人设树节点 # derived_items = [ # {"topic": "分享", "source_node": "分享"}, # {"topic": "植入方式", "source_node": "植入方式"}, # {"topic": "叙事结构", "source_node": "叙事结构"}, # ] derived_items = derived_items = [] conditional_ratio_threshold = 0.2 top_n = 200 # 1)直接调用核心函数(不含帖子匹配,仅验证排序逻辑) # derived_list = _parse_derived_list(derived_items) # items = get_patterns_by_conditional_ratio( # account_name, derived_list, conditional_ratio_threshold, top_n, post_id # ) # print(f"账号: {account_name}, 阈值: {conditional_ratio_threshold}, top_n: {top_n}") # print(f"共 {len(items)} 条 pattern:\n") # for x in items: # print(f" - {x['pattern名称']}\t条件概率={x['条件概率']}") # 2)有 agent 时通过 tool 接口再跑一遍(含帖子选题点匹配) if ToolResult is not None: async def run_tool(): result = await find_pattern( account_name=account_name, post_id=post_id, derived_items=derived_items, conditional_ratio_threshold=conditional_ratio_threshold, top_n=top_n, ) print("\n--- Tool 返回 ---") print(result.output) asyncio.run(run_tool()) if __name__ == "__main__": main()