find_pattern.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. """
  2. 查找 Pattern Tool - 从 pattern 库中获取符合条件概率阈值的 pattern
  3. 功能:读取账号的 pattern 库,合并去重后按条件概率筛选,返回 topN 条 pattern(含 pattern 名称、条件概率)。
  4. """
  5. import json
  6. import sys
  7. from pathlib import Path
  8. from typing import Any, Optional
  9. # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
  10. _root = Path(__file__).resolve().parent.parent
  11. if str(_root) not in sys.path:
  12. sys.path.insert(0, str(_root))
  13. from utils.conditional_ratio_calc import calc_pattern_conditional_ratio
  14. from tools.point_match import _load_match_data, match_derivation_to_post_points
  15. from tools.find_tree_node import _load_trees
  16. try:
  17. from agent.tools import tool, ToolResult, ToolContext
  18. except ImportError:
  19. def tool(*args, **kwargs):
  20. return lambda f: f
  21. ToolResult = None # 仅用 main() 测核心逻辑时可无 agent
  22. ToolContext = None
  23. # 与 pattern_data_process 一致的 key 定义
  24. TOP_KEYS = [
  25. "depth_max_with_name",
  26. "depth_mixed",
  27. "depth_max_concrete",
  28. "depth2_medium",
  29. "depth1_abstract",
  30. "depth_max_minus_1",
  31. "depth_max_minus_2",
  32. "depth_3",
  33. "depth_4",
  34. ]
  35. SUB_KEYS = ["two_x", "one_x", "zero_x"]
  36. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  37. def _build_node_info(account_name: str) -> dict[str, dict]:
  38. """
  39. 构建人设树节点信息映射: node_name -> {
  40. "type": 节点 _type("class" / "ID" 等),
  41. "children": 子节点名称列表(仅分类节点有值),
  42. "siblings": 兄弟节点名称列表(不含自身),
  43. }
  44. """
  45. node_info: dict[str, dict] = {}
  46. def _walk(node_dict: dict):
  47. children_dict = node_dict.get("children") or {}
  48. child_entries = [(n, c) for n, c in children_dict.items() if isinstance(c, dict)]
  49. child_names = [n for n, _ in child_entries]
  50. for name, child in child_entries:
  51. sub_children = child.get("children") or {}
  52. sub_child_names = [n for n, c in sub_children.items() if isinstance(c, dict)]
  53. node_info[name] = {
  54. "type": child.get("_type", ""),
  55. "children": sub_child_names,
  56. "siblings": [n for n in child_names if n != name],
  57. }
  58. _walk(child)
  59. for _dim_name, root in _load_trees(account_name):
  60. _walk(root)
  61. return node_info
  62. def _pattern_file(account_name: str) -> Path:
  63. """pattern 库文件:../input/{account_name}/原始数据/pattern/processed_edge_data.json"""
  64. return _BASE_INPUT / account_name / "原始数据" / "pattern" / "processed_edge_data.json"
  65. def _slim_pattern(p: dict) -> tuple[float, int, list[str], int]:
  66. """提取 name 列表(去重保序)、support、length、post_count。"""
  67. names = [item["name"] for item in (p.get("items") or [])]
  68. seen = set()
  69. unique = []
  70. for n in names:
  71. if n not in seen:
  72. seen.add(n)
  73. unique.append(n)
  74. support = round(float(p.get("support", 0)), 4)
  75. length = int(p.get("length", 0))
  76. post_count = int(p.get("post_count", 0))
  77. return support, length, unique, post_count
  78. def _merge_and_dedupe(patterns: list[dict]) -> list[dict]:
  79. """
  80. 按 items 的 name 集合去重(不区分顺序),留 support 最大;
  81. 输出格式保留 s、l、i(nameA+nameB+nameC)及 post_count,供条件概率计算使用。
  82. """
  83. key_to_best: dict[tuple, tuple[float, int, int]] = {}
  84. for p in patterns:
  85. support, length, unique, post_count = _slim_pattern(p)
  86. if not unique:
  87. continue
  88. key = tuple(sorted(unique))
  89. if key not in key_to_best or support > key_to_best[key][0]:
  90. key_to_best[key] = (support, length, post_count)
  91. out = []
  92. for k, (s, l, post_count) in key_to_best.items():
  93. out.append({
  94. "s": s,
  95. "l": l,
  96. "i": "+".join(k),
  97. "post_count": post_count,
  98. })
  99. out.sort(key=lambda x: x["s"] * x["l"], reverse=True)
  100. return out
  101. def _load_and_merge_patterns(account_name: str) -> list[dict]:
  102. """读取 pattern 库 JSON,按 TOP_KEYS/SUB_KEYS 合并为列表并做合并、去重。"""
  103. path = _pattern_file(account_name)
  104. if not path.is_file():
  105. return []
  106. with open(path, "r", encoding="utf-8") as f:
  107. data = json.load(f)
  108. all_patterns = []
  109. for top in TOP_KEYS:
  110. if top not in data:
  111. continue
  112. block = data[top]
  113. for sub in SUB_KEYS:
  114. all_patterns.extend(block.get(sub) or [])
  115. return _merge_and_dedupe(all_patterns)
  116. def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]:
  117. """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。"""
  118. out = []
  119. for item in derived_items:
  120. if isinstance(item, dict):
  121. topic = item.get("topic") or item.get("已推导的选题点")
  122. source = item.get("source_node") or item.get("推导来源人设树节点")
  123. if topic is not None and source is not None:
  124. out.append((str(topic).strip(), str(source).strip()))
  125. elif isinstance(item, (list, tuple)) and len(item) >= 2:
  126. out.append((str(item[0]).strip(), str(item[1]).strip()))
  127. return out
  128. def get_patterns_by_conditional_ratio(
  129. account_name: str,
  130. derived_list: list[tuple[str, str]],
  131. conditional_ratio_threshold: float,
  132. top_n: int,
  133. post_id: str = "",
  134. ) -> list[dict[str, Any]]:
  135. """
  136. 从 pattern 库中获取条件概率 >= 阈值的 pattern,按以下优先级排序后返回 top_n 条:
  137. 1. pattern 元素中直接包含已推导选题点(topic)的排最前;
  138. 2. pattern 元素与任意已推导选题点的匹配分 >= 0.8 的次之(从 match_data 文件读取,
  139. key 为 (帖子选题点, 人设树节点),pattern 元素视为人设树节点);
  140. 3. 按条件概率降序;
  141. 4. 按 length 降序。
  142. derived_list 为空时,条件概率使用 pattern 自身的 support(s)。
  143. 返回每项:pattern名称(nameA+nameB+nameC)、条件概率。
  144. """
  145. merged = _load_and_merge_patterns(account_name)
  146. print(f"_load_and_merge_patterns,patterns: {len(merged)}")
  147. if not merged:
  148. return []
  149. base_dir = _BASE_INPUT
  150. scored: list[tuple[dict, float]] = []
  151. if not derived_list:
  152. # derived_items 为空:条件概率取 pattern 本身的 support (s)
  153. for p in merged:
  154. ratio = float(p.get("s", 0))
  155. if ratio >= conditional_ratio_threshold:
  156. scored.append((p, ratio))
  157. else:
  158. for p in merged:
  159. ratio = calc_pattern_conditional_ratio(
  160. account_name, derived_list, p, base_dir=base_dir
  161. )
  162. if ratio >= conditional_ratio_threshold:
  163. scored.append((p, ratio))
  164. derived_topics = {topic for topic, _ in derived_list} if derived_list else set()
  165. # 次优先:从 match_data 文件加载 (帖子选题点, 人设树节点) -> 匹配分,
  166. # 用已推导选题点(topic)作为帖子选题点,pattern 元素作为人设树节点,
  167. # 检查是否存在匹配分 >= 0.8 的组合。
  168. match_lookup: dict[tuple[str, str], float] = {}
  169. if derived_topics and post_id:
  170. match_lookup = _load_match_data(account_name, post_id)
  171. def _sort_key(x: tuple[dict, float]) -> tuple:
  172. p, ratio = x
  173. elements = set(p["i"].split("+"))
  174. has_derived = bool(elements & derived_topics)
  175. has_high_match = False
  176. if not has_derived and match_lookup:
  177. for elem in elements:
  178. for dt in derived_topics:
  179. if match_lookup.get((dt, elem), 0.0) >= 0.8:
  180. has_high_match = True
  181. break
  182. if has_high_match:
  183. break
  184. return (not has_derived, not has_high_match, -ratio, -p["l"])
  185. scored.sort(key=_sort_key)
  186. result = []
  187. for p, ratio in scored[:top_n]:
  188. result.append({
  189. "pattern名称": p["i"],
  190. "条件概率": round(ratio, 6),
  191. })
  192. return result
  193. @tool(
  194. description="按条件概率从 pattern 库中筛选 pattern,优先返回包含已推导选题点的 pattern,并检查每个 pattern 的元素是否与帖子选题点匹配。"
  195. "功能:根据账号与已推导选题点(可选),筛选条件概率不低于阈值的 pattern;当 derived_items 非空时,优先返回 pattern 元素中包含已推导选题点的 pattern;同时对每个 pattern 的所有元素做帖子选题点匹配,匹配结果直接包含在返回数据中。"
  196. "参数:account_name 为账号名;post_id 为帖子ID,用于加载帖子选题点并做匹配判断;derived_items 为已推导选题点列表,每项含 topic(或已推导的选题点)与 source_node(或推导来源人设树节点),可为空,为空时条件概率使用 pattern 自身的 support;conditional_ratio_threshold 为条件概率阈值;top_n 为返回条数上限,默认 100。"
  197. "返回:ToolResult,output 为可读的 pattern 列表文本"
  198. )
  199. async def find_pattern(
  200. account_name: str,
  201. post_id: str,
  202. derived_items: list[dict[str, str]],
  203. conditional_ratio_threshold: float,
  204. top_n: int = 100,
  205. context: Optional[ToolContext] = None,
  206. ) -> ToolResult:
  207. """
  208. 按条件概率阈值从 pattern 库筛选 pattern,返回最多 top_n 条(按条件概率降序)。
  209. 当 derived_items 非空时,优先返回元素中包含已推导选题点的 pattern。
  210. 返回前对每个 pattern 的所有元素做帖子选题点匹配,匹配结果直接包含在返回数据中。
  211. 参数
  212. -------
  213. account_name : 账号名,用于定位该账号的 pattern 库。
  214. post_id : 帖子ID,用于加载帖子选题点并与 pattern 元素做匹配判断。
  215. derived_items : 已推导选题点列表,可为空。非空时每项为字典,需含 topic(或「已推导的选题点」)与 source_node(或「推导来源人设树节点」);为空时各 pattern 的条件概率取其自身 support。
  216. conditional_ratio_threshold : 条件概率阈值,仅返回条件概率 >= 该值的 pattern。
  217. top_n : 返回条数上限,默认 100。
  218. context : 可选,Agent 工具上下文。
  219. 返回
  220. -------
  221. ToolResult:
  222. - title: 结果标题。
  223. - output: 可读的 pattern 列表文本(每行:pattern名称、条件概率、帖子匹配情况)。
  224. "帖子选题点匹配": 无匹配时为 "无",有匹配时为 list[{"pattern元素", "帖子选题点", "匹配分数"}]}。
  225. - 出错时 error 为错误信息。
  226. """
  227. pattern_path = _pattern_file(account_name)
  228. if not pattern_path.is_file():
  229. return ToolResult(
  230. title="Pattern 库不存在",
  231. output=f"pattern 文件不存在: {pattern_path}",
  232. error="Pattern file not found",
  233. )
  234. try:
  235. derived_list = _parse_derived_list(derived_items or [])
  236. items = get_patterns_by_conditional_ratio(
  237. account_name, derived_list, conditional_ratio_threshold, top_n, post_id
  238. )
  239. # 批量收集所有 pattern 元素,统一做一次帖子选题点匹配
  240. if items and post_id:
  241. all_elements: list[str] = []
  242. seen_elements: set[str] = set()
  243. for item in items:
  244. for elem in item["pattern名称"].split("+"):
  245. elem = elem.strip()
  246. if elem and elem not in seen_elements:
  247. all_elements.append(elem)
  248. seen_elements.add(elem)
  249. matched_results = await match_derivation_to_post_points(all_elements, account_name, post_id)
  250. elem_match_map: dict[str, list] = {}
  251. for m in matched_results:
  252. elem_match_map.setdefault(m["推导选题点"], []).append({
  253. "帖子选题点": m["帖子选题点"],
  254. "匹配分数": m["匹配分数"],
  255. })
  256. for item in items:
  257. pattern_matches = []
  258. for elem in item["pattern名称"].split("+"):
  259. elem = elem.strip()
  260. for post_match in elem_match_map.get(elem, []):
  261. pattern_matches.append({
  262. "pattern元素": elem,
  263. "帖子选题点": post_match["帖子选题点"],
  264. "匹配分数": post_match["匹配分数"],
  265. })
  266. # 仅当 pattern 元素匹配到至少 2 个不同帖子选题点时才返回匹配信息,否则为无
  267. distinct_post_points = len({m["帖子选题点"] for m in pattern_matches})
  268. item["帖子选题点匹配"] = (
  269. pattern_matches if distinct_post_points >= 2 else "无"
  270. )
  271. # [临时] 仅保留有帖子选题点匹配的记录(distinct_post_points>=2),方便后续删除
  272. items = [x for x in items if isinstance(x.get("帖子选题点匹配"), list)]
  273. # 对未匹配帖子选题点的 pattern 元素,通过人设树子节点/兄弟节点扩展匹配
  274. if items and post_id:
  275. node_info_map = _build_node_info(account_name)
  276. all_candidates_set: set[str] = set()
  277. item_unmatched_info: list[list[tuple[str, list[str]]]] = []
  278. for item in items:
  279. pattern_matches = item.get("帖子选题点匹配", [])
  280. matched_elems = (
  281. {m["pattern元素"] for m in pattern_matches}
  282. if isinstance(pattern_matches, list) else set()
  283. )
  284. all_elems = [e.strip() for e in item["pattern名称"].split("+")]
  285. unmatched = [e for e in all_elems if e not in matched_elems]
  286. elem_candidates: list[tuple[str, list[str], str]] = []
  287. for elem in unmatched:
  288. info = node_info_map.get(elem)
  289. if not info:
  290. continue
  291. if info["type"] == "class" and info["children"]:
  292. candidates = info["children"]
  293. expand_type = "子节点"
  294. else:
  295. candidates = info["siblings"]
  296. expand_type = "兄弟节点"
  297. if candidates:
  298. elem_candidates.append((elem, candidates, expand_type))
  299. all_candidates_set.update(candidates)
  300. item_unmatched_info.append(elem_candidates)
  301. if all_candidates_set:
  302. candidate_matches = await match_derivation_to_post_points(
  303. list(all_candidates_set), account_name, post_id
  304. )
  305. cand_match_map: dict[str, list[tuple[str, float]]] = {}
  306. for m in candidate_matches:
  307. cand_match_map.setdefault(m["推导选题点"], []).append(
  308. (m["帖子选题点"], m["匹配分数"])
  309. )
  310. for item, elem_cands in zip(items, item_unmatched_info):
  311. for elem, candidates, expand_type in elem_cands:
  312. best_cand, best_pp, best_sc = None, None, -1.0
  313. for cand in candidates:
  314. for pp, sc in cand_match_map.get(cand, []):
  315. if sc > best_sc:
  316. best_cand, best_pp, best_sc = cand, pp, sc
  317. if best_cand is not None:
  318. item["帖子选题点匹配"].append({
  319. "pattern元素": elem,
  320. "帖子选题点": best_pp,
  321. "匹配分数": best_sc,
  322. "扩展节点": best_cand,
  323. "扩展类型": expand_type,
  324. })
  325. # 同一 pattern 内帖子选题点去重:同一帖子选题点出现多次时只保留分数最高的
  326. for item in items:
  327. matches = item.get("帖子选题点匹配")
  328. if not isinstance(matches, list):
  329. continue
  330. best_by_pp: dict[str, dict] = {}
  331. for m in matches:
  332. pp = m["帖子选题点"]
  333. if pp not in best_by_pp or m["匹配分数"] > best_by_pp[pp]["匹配分数"]:
  334. best_by_pp[pp] = m
  335. item["帖子选题点匹配"] = list(best_by_pp.values())
  336. if not items:
  337. output = f"未找到条件概率 >= {conditional_ratio_threshold} 的 pattern"
  338. else:
  339. lines = []
  340. for x in items:
  341. match_info = x.get("帖子选题点匹配", "无")
  342. if isinstance(match_info, list):
  343. match_str = "、".join(
  344. (
  345. f"{m['扩展节点']}({m['pattern元素']}的{m['扩展类型']})→{m['帖子选题点']}({m['匹配分数']})"
  346. if "扩展节点" in m else
  347. f"{m['pattern元素']}→{m['帖子选题点']}({m['匹配分数']})"
  348. )
  349. for m in match_info
  350. )
  351. else:
  352. match_str = str(match_info)
  353. lines.append(f"- {x['pattern名称']}\t条件概率={x['条件概率']}\t帖子选题点匹配={match_str}")
  354. output = "\n".join(lines)
  355. return ToolResult(
  356. title=f"符合条件概率的 Pattern ({account_name}, 阈值={conditional_ratio_threshold})",
  357. output=output,
  358. metadata={
  359. "account_name": account_name,
  360. "conditional_ratio_threshold": conditional_ratio_threshold,
  361. "top_n": top_n,
  362. "count": len(items),
  363. },
  364. )
  365. except Exception as e:
  366. return ToolResult(
  367. title="查找 Pattern 失败",
  368. output=str(e),
  369. error=str(e),
  370. )
  371. def main() -> None:
  372. """本地测试:用家有大志账号、已推导选题点,查询符合条件概率阈值的 pattern(含帖子匹配)。"""
  373. import asyncio
  374. account_name = "家有大志"
  375. post_id = "68fb6a5c000000000302e5de"
  376. # 已推导选题点,每项:已推导的选题点 + 推导来源人设树节点
  377. # derived_items = [
  378. # {"topic": "分享", "source_node": "分享"},
  379. # {"topic": "植入方式", "source_node": "植入方式"},
  380. # {"topic": "叙事结构", "source_node": "叙事结构"},
  381. # ]
  382. derived_items = derived_items = [{"source_node":"分享","topic":"分享"},{"source_node":"叙事结构","topic":"叙事结构"},{"source_node":"图片文字","topic":"图片文字"},{"source_node":"补充说明式","topic":"补充说明式"},{"source_node":"幽默化标题","topic":"幽默化标题"},{"source_node":"标题","topic":"标题"}]
  383. conditional_ratio_threshold = 0.2
  384. top_n = 2000
  385. # 1)直接调用核心函数(不含帖子匹配,仅验证排序逻辑)
  386. # derived_list = _parse_derived_list(derived_items)
  387. # items = get_patterns_by_conditional_ratio(
  388. # account_name, derived_list, conditional_ratio_threshold, top_n, post_id
  389. # )
  390. # print(f"账号: {account_name}, 阈值: {conditional_ratio_threshold}, top_n: {top_n}")
  391. # print(f"共 {len(items)} 条 pattern:\n")
  392. # for x in items:
  393. # print(f" - {x['pattern名称']}\t条件概率={x['条件概率']}")
  394. # 2)有 agent 时通过 tool 接口再跑一遍(含帖子选题点匹配)
  395. if ToolResult is not None:
  396. async def run_tool():
  397. result = await find_pattern(
  398. account_name=account_name,
  399. post_id=post_id,
  400. derived_items=derived_items,
  401. conditional_ratio_threshold=conditional_ratio_threshold,
  402. top_n=top_n,
  403. )
  404. print("\n--- Tool 返回 ---")
  405. print(result.output)
  406. asyncio.run(run_tool())
  407. if __name__ == "__main__":
  408. main()