""" 查找 Pattern Tool - 从 pattern 库中获取符合条件概率阈值的 pattern 功能:读取账号的 pattern 库,合并去重后按条件概率筛选,返回 topN 条 pattern(含 pattern 名称、条件概率)。 """ import importlib.util import json from pathlib import Path from typing import Any, Optional try: from agent.tools import tool, ToolResult, ToolContext except ImportError: def tool(*args, **kwargs): return lambda f: f ToolResult = None # 仅用 main() 测核心逻辑时可无 agent ToolContext = None # 与 pattern_data_process 一致的 key 定义 TOP_KEYS = [ "depth_max_with_name", "depth_mixed", "depth_max_concrete", "depth2_medium", "depth1_abstract", ] SUB_KEYS = ["two_x", "one_x", "zero_x"] # 加载 conditional_ratio_calc(与 find_tree_node 一致) _utils_dir = Path(__file__).resolve().parent.parent / "utils" _cond_spec = importlib.util.spec_from_file_location( "conditional_ratio_calc", _utils_dir / "conditional_ratio_calc.py", ) _cond_mod = importlib.util.module_from_spec(_cond_spec) _cond_spec.loader.exec_module(_cond_mod) calc_pattern_conditional_ratio = _cond_mod.calc_pattern_conditional_ratio _BASE_INPUT = Path(__file__).resolve().parent.parent / "input" def _pattern_file(account_name: str) -> Path: """pattern 库文件:../input/{account_name}/原始数据/pattern/processed_edge_data.json""" return _BASE_INPUT / account_name / "原始数据" / "pattern" / "processed_edge_data.json" def _slim_pattern(p: dict) -> tuple[float, int, list[str], int]: """提取 name 列表(去重保序)、support、length、post_count。""" names = [item["name"] for item in (p.get("items") or [])] seen = set() unique = [] for n in names: if n not in seen: seen.add(n) unique.append(n) support = round(float(p.get("support", 0)), 4) length = int(p.get("length", 0)) post_count = int(p.get("post_count", 0)) return support, length, unique, post_count def _merge_and_dedupe(patterns: list[dict]) -> list[dict]: """ 按 items 的 name 集合去重(不区分顺序),留 support 最大; 输出格式保留 s、l、i(nameA+nameB+nameC)及 post_count,供条件概率计算使用。 """ key_to_best: dict[tuple, tuple[float, int, int]] = {} for p in patterns: support, length, unique, post_count = _slim_pattern(p) if not unique: continue key = tuple(sorted(unique)) if key not in key_to_best or support > key_to_best[key][0]: key_to_best[key] = (support, length, post_count) out = [] for k, (s, l, post_count) in key_to_best.items(): if s < 0.1: continue out.append({ "s": s, "l": l, "i": "+".join(k), "post_count": post_count, }) out.sort(key=lambda x: x["s"] * x["l"], reverse=True) return out def _load_and_merge_patterns(account_name: str) -> list[dict]: """读取 pattern 库 JSON,按 TOP_KEYS/SUB_KEYS 合并为列表并做合并、去重。""" path = _pattern_file(account_name) if not path.is_file(): return [] with open(path, "r", encoding="utf-8") as f: data = json.load(f) all_patterns = [] for top in TOP_KEYS: if top not in data: continue block = data[top] for sub in SUB_KEYS: all_patterns.extend(block.get(sub) or []) return _merge_and_dedupe(all_patterns) def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]: """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。""" out = [] for item in derived_items: if isinstance(item, dict): topic = item.get("topic") or item.get("已推导的选题点") source = item.get("source_node") or item.get("推导来源人设树节点") if topic is not None and source is not None: out.append((str(topic).strip(), str(source).strip())) elif isinstance(item, (list, tuple)) and len(item) >= 2: out.append((str(item[0]).strip(), str(item[1]).strip())) return out def get_patterns_by_conditional_ratio( account_name: str, derived_list: list[tuple[str, str]], conditional_ratio_threshold: float, top_n: int, ) -> list[dict[str, Any]]: """ 从 pattern 库中获取条件概率 >= 阈值的 pattern,按条件概率降序(同分按 length 降序),返回 top_n 条。 返回每项:pattern名称(nameA+nameB+nameC)、条件概率。 """ merged = _load_and_merge_patterns(account_name) if not merged: return [] base_dir = _BASE_INPUT scored: list[tuple[dict, float]] = [] for p in merged: # calc_pattern_conditional_ratio 需要 pattern 含 "i" 与 "post_count" ratio = calc_pattern_conditional_ratio( account_name, derived_list, p, base_dir=base_dir ) if ratio >= conditional_ratio_threshold: scored.append((p, ratio)) # 条件概率从高到低;相等按 length 降序 scored.sort(key=lambda x: (-x[1], -x[0]["l"])) result = [] for p, ratio in scored[:top_n]: result.append({ "pattern名称": p["i"], "条件概率": round(ratio, 6), }) return result @tool( description="从 pattern 库中获取符合条件概率阈值的 pattern。" "输入:账号名、已推导选题点列表(DerivedItem)、条件概率阈值、topN。" "返回:pattern 名称(nameA+nameB+nameC)及条件概率,按条件概率从高到低最多 topN 条。" ) async def find_pattern( account_name: str, derived_items: list[dict[str, str]], conditional_ratio_threshold: float, top_n: int = 20, context: Optional[ToolContext] = None, ) -> ToolResult: """ 从 pattern 库中获取符合条件概率阈值的 pattern。 已推导选题点 derived_items:每项为 {"topic": "已推导选题点", "source_node": "推导来源人设树节点"}。 流程:读取 pattern 库 → 合并去重 → 计算条件概率 → 筛选 ≥ 阈值 → 按条件概率降序(同分按 length 降序)→ 返回 top_n 条。 返回每条:pattern名称(nameA+nameB+nameC)、条件概率。 """ pattern_path = _pattern_file(account_name) if not pattern_path.is_file(): return ToolResult( title="Pattern 库不存在", output=f"pattern 文件不存在: {pattern_path}", error="Pattern file not found", ) try: derived_list = _parse_derived_list(derived_items) if not derived_list: return ToolResult( title="参数无效", output="derived_items 不能为空,且每项需包含 topic 与 source_node(或 已推导的选题点 与 推导来源人设树节点)", error="Invalid derived_items", ) items = get_patterns_by_conditional_ratio( account_name, derived_list, conditional_ratio_threshold, top_n ) if not items: output = f"未找到条件概率 >= {conditional_ratio_threshold} 的 pattern" else: lines = [ f"- {x['pattern名称']}\t条件概率={x['条件概率']}" for x in items ] output = "\n".join(lines) return ToolResult( title=f"符合条件概率的 Pattern ({account_name}, 阈值={conditional_ratio_threshold})", output=output, metadata={ "account_name": account_name, "conditional_ratio_threshold": conditional_ratio_threshold, "top_n": top_n, "count": len(items), "items": items, }, ) except Exception as e: return ToolResult( title="查找 Pattern 失败", output=str(e), error=str(e), ) def main() -> None: """本地测试:用家有大志账号、已推导选题点,查询符合条件概率阈值的 pattern。""" import asyncio account_name = "家有大志" # 已推导选题点,每项:已推导的选题点 + 推导来源人设树节点 derived_items = [ {"topic": "分享", "source_node": "分享"}, {"topic": "柴犬", "source_node": "动物角色"}, ] conditional_ratio_threshold = 0.01 top_n = 10 # 1)直接调用核心函数 derived_list = _parse_derived_list(derived_items) items = get_patterns_by_conditional_ratio( account_name, derived_list, conditional_ratio_threshold, top_n ) print(f"账号: {account_name}, 阈值: {conditional_ratio_threshold}, top_n: {top_n}") print(f"共 {len(items)} 条 pattern:\n") for x in items: print(f" - {x['pattern名称']}\t条件概率={x['条件概率']}") # 2)有 agent 时通过 tool 接口再跑一遍 if ToolResult is not None: async def run_tool(): result = await find_pattern( account_name=account_name, derived_items=derived_items, conditional_ratio_threshold=conditional_ratio_threshold, top_n=top_n, ) print("\n--- Tool 返回 ---") print(result.output) asyncio.run(run_tool()) if __name__ == "__main__": main()