| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- """
- 查找 Pattern Tool - 从 pattern 库中获取符合条件概率阈值的 pattern
- 功能:读取账号的 pattern 库,合并去重后按条件概率筛选,返回 topN 条 pattern(含 pattern 名称、条件概率)。
- """
- import importlib.util
- import json
- from pathlib import Path
- from typing import Any, Optional
- try:
- from agent.tools import tool, ToolResult, ToolContext
- except ImportError:
- def tool(*args, **kwargs):
- return lambda f: f
- ToolResult = None # 仅用 main() 测核心逻辑时可无 agent
- ToolContext = None
- # 与 pattern_data_process 一致的 key 定义
- TOP_KEYS = [
- "depth_max_with_name",
- "depth_mixed",
- "depth_max_concrete",
- "depth2_medium",
- "depth1_abstract",
- ]
- SUB_KEYS = ["two_x", "one_x", "zero_x"]
- # 加载 conditional_ratio_calc(与 find_tree_node 一致)
- _utils_dir = Path(__file__).resolve().parent.parent / "utils"
- _cond_spec = importlib.util.spec_from_file_location(
- "conditional_ratio_calc",
- _utils_dir / "conditional_ratio_calc.py",
- )
- _cond_mod = importlib.util.module_from_spec(_cond_spec)
- _cond_spec.loader.exec_module(_cond_mod)
- calc_pattern_conditional_ratio = _cond_mod.calc_pattern_conditional_ratio
- _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
- def _pattern_file(account_name: str) -> Path:
- """pattern 库文件:../input/{account_name}/原始数据/pattern/processed_edge_data.json"""
- return _BASE_INPUT / account_name / "原始数据" / "pattern" / "processed_edge_data.json"
- def _slim_pattern(p: dict) -> tuple[float, int, list[str], int]:
- """提取 name 列表(去重保序)、support、length、post_count。"""
- names = [item["name"] for item in (p.get("items") or [])]
- seen = set()
- unique = []
- for n in names:
- if n not in seen:
- seen.add(n)
- unique.append(n)
- support = round(float(p.get("support", 0)), 4)
- length = int(p.get("length", 0))
- post_count = int(p.get("post_count", 0))
- return support, length, unique, post_count
- def _merge_and_dedupe(patterns: list[dict]) -> list[dict]:
- """
- 按 items 的 name 集合去重(不区分顺序),留 support 最大;
- 输出格式保留 s、l、i(nameA+nameB+nameC)及 post_count,供条件概率计算使用。
- """
- key_to_best: dict[tuple, tuple[float, int, int]] = {}
- for p in patterns:
- support, length, unique, post_count = _slim_pattern(p)
- if not unique:
- continue
- key = tuple(sorted(unique))
- if key not in key_to_best or support > key_to_best[key][0]:
- key_to_best[key] = (support, length, post_count)
- out = []
- for k, (s, l, post_count) in key_to_best.items():
- if s < 0.1:
- continue
- out.append({
- "s": s,
- "l": l,
- "i": "+".join(k),
- "post_count": post_count,
- })
- out.sort(key=lambda x: x["s"] * x["l"], reverse=True)
- return out
- def _load_and_merge_patterns(account_name: str) -> list[dict]:
- """读取 pattern 库 JSON,按 TOP_KEYS/SUB_KEYS 合并为列表并做合并、去重。"""
- path = _pattern_file(account_name)
- if not path.is_file():
- return []
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
- all_patterns = []
- for top in TOP_KEYS:
- if top not in data:
- continue
- block = data[top]
- for sub in SUB_KEYS:
- all_patterns.extend(block.get(sub) or [])
- return _merge_and_dedupe(all_patterns)
- def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]:
- """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。"""
- out = []
- for item in derived_items:
- if isinstance(item, dict):
- topic = item.get("topic") or item.get("已推导的选题点")
- source = item.get("source_node") or item.get("推导来源人设树节点")
- if topic is not None and source is not None:
- out.append((str(topic).strip(), str(source).strip()))
- elif isinstance(item, (list, tuple)) and len(item) >= 2:
- out.append((str(item[0]).strip(), str(item[1]).strip()))
- return out
- def get_patterns_by_conditional_ratio(
- account_name: str,
- derived_list: list[tuple[str, str]],
- conditional_ratio_threshold: float,
- top_n: int,
- ) -> list[dict[str, Any]]:
- """
- 从 pattern 库中获取条件概率 >= 阈值的 pattern,按条件概率降序(同分按 length 降序),返回 top_n 条。
- 返回每项:pattern名称(nameA+nameB+nameC)、条件概率。
- """
- merged = _load_and_merge_patterns(account_name)
- if not merged:
- return []
- base_dir = _BASE_INPUT
- scored: list[tuple[dict, float]] = []
- for p in merged:
- # calc_pattern_conditional_ratio 需要 pattern 含 "i" 与 "post_count"
- ratio = calc_pattern_conditional_ratio(
- account_name, derived_list, p, base_dir=base_dir
- )
- if ratio >= conditional_ratio_threshold:
- scored.append((p, ratio))
- # 条件概率从高到低;相等按 length 降序
- scored.sort(key=lambda x: (-x[1], -x[0]["l"]))
- result = []
- for p, ratio in scored[:top_n]:
- result.append({
- "pattern名称": p["i"],
- "条件概率": round(ratio, 6),
- })
- return result
- @tool(
- description="从 pattern 库中获取符合条件概率阈值的 pattern。"
- "输入:账号名、已推导选题点列表(DerivedItem)、条件概率阈值、topN。"
- "返回:pattern 名称(nameA+nameB+nameC)及条件概率,按条件概率从高到低最多 topN 条。"
- )
- async def find_pattern(
- account_name: str,
- derived_items: list[dict[str, str]],
- conditional_ratio_threshold: float,
- top_n: int = 20,
- context: Optional[ToolContext] = None,
- ) -> ToolResult:
- """
- 从 pattern 库中获取符合条件概率阈值的 pattern。
- 已推导选题点 derived_items:每项为 {"topic": "已推导选题点", "source_node": "推导来源人设树节点"}。
- 流程:读取 pattern 库 → 合并去重 → 计算条件概率 → 筛选 ≥ 阈值 → 按条件概率降序(同分按 length 降序)→ 返回 top_n 条。
- 返回每条:pattern名称(nameA+nameB+nameC)、条件概率。
- """
- pattern_path = _pattern_file(account_name)
- if not pattern_path.is_file():
- return ToolResult(
- title="Pattern 库不存在",
- output=f"pattern 文件不存在: {pattern_path}",
- error="Pattern file not found",
- )
- try:
- derived_list = _parse_derived_list(derived_items)
- if not derived_list:
- return ToolResult(
- title="参数无效",
- output="derived_items 不能为空,且每项需包含 topic 与 source_node(或 已推导的选题点 与 推导来源人设树节点)",
- error="Invalid derived_items",
- )
- items = get_patterns_by_conditional_ratio(
- account_name, derived_list, conditional_ratio_threshold, top_n
- )
- if not items:
- output = f"未找到条件概率 >= {conditional_ratio_threshold} 的 pattern"
- else:
- lines = [
- f"- {x['pattern名称']}\t条件概率={x['条件概率']}"
- for x in items
- ]
- output = "\n".join(lines)
- return ToolResult(
- title=f"符合条件概率的 Pattern ({account_name}, 阈值={conditional_ratio_threshold})",
- output=output,
- metadata={
- "account_name": account_name,
- "conditional_ratio_threshold": conditional_ratio_threshold,
- "top_n": top_n,
- "count": len(items),
- "items": items,
- },
- )
- except Exception as e:
- return ToolResult(
- title="查找 Pattern 失败",
- output=str(e),
- error=str(e),
- )
- def main() -> None:
- """本地测试:用家有大志账号、已推导选题点,查询符合条件概率阈值的 pattern。"""
- import asyncio
- account_name = "家有大志"
- # 已推导选题点,每项:已推导的选题点 + 推导来源人设树节点
- derived_items = [
- {"topic": "分享", "source_node": "分享"},
- {"topic": "柴犬", "source_node": "动物角色"},
- ]
- conditional_ratio_threshold = 0.01
- top_n = 10
- # 1)直接调用核心函数
- derived_list = _parse_derived_list(derived_items)
- items = get_patterns_by_conditional_ratio(
- account_name, derived_list, conditional_ratio_threshold, top_n
- )
- print(f"账号: {account_name}, 阈值: {conditional_ratio_threshold}, top_n: {top_n}")
- print(f"共 {len(items)} 条 pattern:\n")
- for x in items:
- print(f" - {x['pattern名称']}\t条件概率={x['条件概率']}")
- # 2)有 agent 时通过 tool 接口再跑一遍
- if ToolResult is not None:
- async def run_tool():
- result = await find_pattern(
- account_name=account_name,
- derived_items=derived_items,
- conditional_ratio_threshold=conditional_ratio_threshold,
- top_n=top_n,
- )
- print("\n--- Tool 返回 ---")
- print(result.output)
- asyncio.run(run_tool())
- if __name__ == "__main__":
- main()
|