find_pattern.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. """
  2. 查找 Pattern Tool - 从 pattern 库中获取符合条件概率阈值的 pattern
  3. 功能:读取账号的 pattern 库,合并去重后按条件概率筛选,返回 topN 条 pattern(含 pattern 名称、条件概率)。
  4. """
  5. import importlib.util
  6. import json
  7. from pathlib import Path
  8. from typing import Any, Optional
  9. try:
  10. from agent.tools import tool, ToolResult, ToolContext
  11. except ImportError:
  12. def tool(*args, **kwargs):
  13. return lambda f: f
  14. ToolResult = None # 仅用 main() 测核心逻辑时可无 agent
  15. ToolContext = None
  16. # 与 pattern_data_process 一致的 key 定义
  17. TOP_KEYS = [
  18. "depth_max_with_name",
  19. "depth_mixed",
  20. "depth_max_concrete",
  21. "depth2_medium",
  22. "depth1_abstract",
  23. ]
  24. SUB_KEYS = ["two_x", "one_x", "zero_x"]
  25. # 加载 conditional_ratio_calc(与 find_tree_node 一致)
  26. _utils_dir = Path(__file__).resolve().parent.parent / "utils"
  27. _cond_spec = importlib.util.spec_from_file_location(
  28. "conditional_ratio_calc",
  29. _utils_dir / "conditional_ratio_calc.py",
  30. )
  31. _cond_mod = importlib.util.module_from_spec(_cond_spec)
  32. _cond_spec.loader.exec_module(_cond_mod)
  33. calc_pattern_conditional_ratio = _cond_mod.calc_pattern_conditional_ratio
  34. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  35. # 加载 point_match(用于检查 pattern 元素是否匹配帖子选题点)
  36. _point_match_spec = importlib.util.spec_from_file_location(
  37. "point_match",
  38. Path(__file__).resolve().parent / "point_match.py",
  39. )
  40. _point_match_mod = importlib.util.module_from_spec(_point_match_spec)
  41. _point_match_spec.loader.exec_module(_point_match_mod)
  42. _match_derivation_to_post_points = _point_match_mod.match_derivation_to_post_points
  43. _load_match_data = _point_match_mod._load_match_data
  44. def _pattern_file(account_name: str) -> Path:
  45. """pattern 库文件:../input/{account_name}/原始数据/pattern/processed_edge_data.json"""
  46. return _BASE_INPUT / account_name / "原始数据" / "pattern" / "processed_edge_data.json"
  47. def _slim_pattern(p: dict) -> tuple[float, int, list[str], int]:
  48. """提取 name 列表(去重保序)、support、length、post_count。"""
  49. names = [item["name"] for item in (p.get("items") or [])]
  50. seen = set()
  51. unique = []
  52. for n in names:
  53. if n not in seen:
  54. seen.add(n)
  55. unique.append(n)
  56. support = round(float(p.get("support", 0)), 4)
  57. length = int(p.get("length", 0))
  58. post_count = int(p.get("post_count", 0))
  59. return support, length, unique, post_count
  60. def _merge_and_dedupe(patterns: list[dict]) -> list[dict]:
  61. """
  62. 按 items 的 name 集合去重(不区分顺序),留 support 最大;
  63. 输出格式保留 s、l、i(nameA+nameB+nameC)及 post_count,供条件概率计算使用。
  64. """
  65. key_to_best: dict[tuple, tuple[float, int, int]] = {}
  66. for p in patterns:
  67. support, length, unique, post_count = _slim_pattern(p)
  68. if not unique:
  69. continue
  70. key = tuple(sorted(unique))
  71. if key not in key_to_best or support > key_to_best[key][0]:
  72. key_to_best[key] = (support, length, post_count)
  73. out = []
  74. for k, (s, l, post_count) in key_to_best.items():
  75. if s < 0.1:
  76. continue
  77. out.append({
  78. "s": s,
  79. "l": l,
  80. "i": "+".join(k),
  81. "post_count": post_count,
  82. })
  83. out.sort(key=lambda x: x["s"] * x["l"], reverse=True)
  84. return out
  85. def _load_and_merge_patterns(account_name: str) -> list[dict]:
  86. """读取 pattern 库 JSON,按 TOP_KEYS/SUB_KEYS 合并为列表并做合并、去重。"""
  87. path = _pattern_file(account_name)
  88. if not path.is_file():
  89. return []
  90. with open(path, "r", encoding="utf-8") as f:
  91. data = json.load(f)
  92. all_patterns = []
  93. for top in TOP_KEYS:
  94. if top not in data:
  95. continue
  96. block = data[top]
  97. for sub in SUB_KEYS:
  98. all_patterns.extend(block.get(sub) or [])
  99. return _merge_and_dedupe(all_patterns)
  100. def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]:
  101. """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。"""
  102. out = []
  103. for item in derived_items:
  104. if isinstance(item, dict):
  105. topic = item.get("topic") or item.get("已推导的选题点")
  106. source = item.get("source_node") or item.get("推导来源人设树节点")
  107. if topic is not None and source is not None:
  108. out.append((str(topic).strip(), str(source).strip()))
  109. elif isinstance(item, (list, tuple)) and len(item) >= 2:
  110. out.append((str(item[0]).strip(), str(item[1]).strip()))
  111. return out
  112. def get_patterns_by_conditional_ratio(
  113. account_name: str,
  114. derived_list: list[tuple[str, str]],
  115. conditional_ratio_threshold: float,
  116. top_n: int,
  117. post_id: str = "",
  118. ) -> list[dict[str, Any]]:
  119. """
  120. 从 pattern 库中获取条件概率 >= 阈值的 pattern,按以下优先级排序后返回 top_n 条:
  121. 1. pattern 元素中直接包含已推导选题点(topic)的排最前;
  122. 2. pattern 元素与任意已推导选题点的匹配分 >= 0.8 的次之(从 match_data 文件读取,
  123. key 为 (帖子选题点, 人设树节点),pattern 元素视为人设树节点);
  124. 3. 按条件概率降序;
  125. 4. 按 length 降序。
  126. derived_list 为空时,条件概率使用 pattern 自身的 support(s)。
  127. 返回每项:pattern名称(nameA+nameB+nameC)、条件概率。
  128. """
  129. merged = _load_and_merge_patterns(account_name)
  130. if not merged:
  131. return []
  132. base_dir = _BASE_INPUT
  133. scored: list[tuple[dict, float]] = []
  134. if not derived_list:
  135. # derived_items 为空:条件概率取 pattern 本身的 support (s)
  136. for p in merged:
  137. ratio = float(p.get("s", 0))
  138. if ratio >= conditional_ratio_threshold:
  139. scored.append((p, ratio))
  140. else:
  141. for p in merged:
  142. ratio = calc_pattern_conditional_ratio(
  143. account_name, derived_list, p, base_dir=base_dir
  144. )
  145. if ratio >= conditional_ratio_threshold:
  146. scored.append((p, ratio))
  147. derived_topics = {topic for topic, _ in derived_list} if derived_list else set()
  148. # 次优先:从 match_data 文件加载 (帖子选题点, 人设树节点) -> 匹配分,
  149. # 用已推导选题点(topic)作为帖子选题点,pattern 元素作为人设树节点,
  150. # 检查是否存在匹配分 >= 0.8 的组合。
  151. match_lookup: dict[tuple[str, str], float] = {}
  152. if derived_topics and post_id:
  153. match_lookup = _load_match_data(account_name, post_id)
  154. def _sort_key(x: tuple[dict, float]) -> tuple:
  155. p, ratio = x
  156. elements = set(p["i"].split("+"))
  157. has_derived = bool(elements & derived_topics)
  158. has_high_match = False
  159. if not has_derived and match_lookup:
  160. for elem in elements:
  161. for dt in derived_topics:
  162. if match_lookup.get((dt, elem), 0.0) >= 0.8:
  163. has_high_match = True
  164. break
  165. if has_high_match:
  166. break
  167. return (not has_derived, not has_high_match, -ratio, -p["l"])
  168. scored.sort(key=_sort_key)
  169. result = []
  170. for p, ratio in scored[:top_n]:
  171. result.append({
  172. "pattern名称": p["i"],
  173. "条件概率": round(ratio, 6),
  174. })
  175. return result
  176. @tool(
  177. description="按条件概率从 pattern 库中筛选 pattern,优先返回包含已推导选题点的 pattern,并检查每个 pattern 的元素是否与帖子选题点匹配。"
  178. "功能:根据账号与已推导选题点(可选),筛选条件概率不低于阈值的 pattern;当 derived_items 非空时,优先返回 pattern 元素中包含已推导选题点的 pattern;同时对每个 pattern 的所有元素做帖子选题点匹配,匹配结果直接包含在返回数据中。"
  179. "参数:account_name 为账号名;post_id 为帖子ID,用于加载帖子选题点并做匹配判断;derived_items 为已推导选题点列表,每项含 topic(或已推导的选题点)与 source_node(或推导来源人设树节点),可为空,为空时条件概率使用 pattern 自身的 support;conditional_ratio_threshold 为条件概率阈值;top_n 为返回条数上限,默认 100。"
  180. "返回:ToolResult,output 为可读的 pattern 列表文本,metadata.items 为列表,每项含「pattern名称」(nameA+nameB+nameC 形式)、「条件概率」、「帖子选题点匹配」(匹配到帖子选题点的元素列表,每项含 pattern元素、帖子选题点与匹配分数;若无匹配则为字符串'无匹配帖子选题点')。"
  181. )
  182. async def find_pattern(
  183. account_name: str,
  184. post_id: str,
  185. derived_items: list[dict[str, str]],
  186. conditional_ratio_threshold: float,
  187. top_n: int = 100,
  188. context: Optional[ToolContext] = None,
  189. ) -> ToolResult:
  190. """
  191. 按条件概率阈值从 pattern 库筛选 pattern,返回最多 top_n 条(按条件概率降序)。
  192. 当 derived_items 非空时,优先返回元素中包含已推导选题点的 pattern。
  193. 返回前对每个 pattern 的所有元素做帖子选题点匹配,匹配结果直接包含在返回数据中。
  194. 参数
  195. -------
  196. account_name : 账号名,用于定位该账号的 pattern 库。
  197. post_id : 帖子ID,用于加载帖子选题点并与 pattern 元素做匹配判断。
  198. derived_items : 已推导选题点列表,可为空。非空时每项为字典,需含 topic(或「已推导的选题点」)与 source_node(或「推导来源人设树节点」);为空时各 pattern 的条件概率取其自身 support。
  199. conditional_ratio_threshold : 条件概率阈值,仅返回条件概率 >= 该值的 pattern。
  200. top_n : 返回条数上限,默认 100。
  201. context : 可选,Agent 工具上下文。
  202. 返回
  203. -------
  204. ToolResult:
  205. - title: 结果标题。
  206. - output: 可读的 pattern 列表文本(每行:pattern名称、条件概率、帖子匹配情况)。
  207. - metadata: 含 account_name、conditional_ratio_threshold、top_n、count、items;
  208. items 为列表,每项为 {"pattern名称": str, "条件概率": float,
  209. "帖子选题点匹配": list[{"pattern元素": str, "帖子选题点": str, "匹配分数": float}] 或 "无匹配帖子选题点"}。
  210. - 出错时 error 为错误信息。
  211. """
  212. pattern_path = _pattern_file(account_name)
  213. if not pattern_path.is_file():
  214. return ToolResult(
  215. title="Pattern 库不存在",
  216. output=f"pattern 文件不存在: {pattern_path}",
  217. error="Pattern file not found",
  218. )
  219. try:
  220. derived_list = _parse_derived_list(derived_items or [])
  221. items = get_patterns_by_conditional_ratio(
  222. account_name, derived_list, conditional_ratio_threshold, top_n, post_id
  223. )
  224. # 批量收集所有 pattern 元素,统一做一次帖子选题点匹配
  225. if items and post_id:
  226. all_elements: list[str] = []
  227. seen_elements: set[str] = set()
  228. for item in items:
  229. for elem in item["pattern名称"].split("+"):
  230. elem = elem.strip()
  231. if elem and elem not in seen_elements:
  232. all_elements.append(elem)
  233. seen_elements.add(elem)
  234. matched_results = await _match_derivation_to_post_points(all_elements, account_name, post_id)
  235. elem_match_map: dict[str, list] = {}
  236. for m in matched_results:
  237. elem_match_map.setdefault(m["推导选题点"], []).append({
  238. "帖子选题点": m["帖子选题点"],
  239. "匹配分数": m["匹配分数"],
  240. })
  241. for item in items:
  242. pattern_matches = []
  243. for elem in item["pattern名称"].split("+"):
  244. elem = elem.strip()
  245. for post_match in elem_match_map.get(elem, []):
  246. pattern_matches.append({
  247. "pattern元素": elem,
  248. "帖子选题点": post_match["帖子选题点"],
  249. "匹配分数": post_match["匹配分数"],
  250. })
  251. item["帖子选题点匹配"] = pattern_matches if pattern_matches else "无匹配帖子选题点"
  252. if not items:
  253. output = f"未找到条件概率 >= {conditional_ratio_threshold} 的 pattern"
  254. else:
  255. lines = []
  256. for x in items:
  257. match_info = x.get("帖子选题点匹配", "未查询")
  258. if isinstance(match_info, list):
  259. match_str = "、".join(
  260. f"{m['pattern元素']}→{m['帖子选题点']}({m['匹配分数']})" for m in match_info
  261. )
  262. else:
  263. match_str = str(match_info)
  264. lines.append(f"- {x['pattern名称']}\t条件概率={x['条件概率']}\t帖子匹配={match_str}")
  265. output = "\n".join(lines)
  266. return ToolResult(
  267. title=f"符合条件概率的 Pattern ({account_name}, 阈值={conditional_ratio_threshold})",
  268. output=output,
  269. metadata={
  270. "account_name": account_name,
  271. "conditional_ratio_threshold": conditional_ratio_threshold,
  272. "top_n": top_n,
  273. "count": len(items),
  274. "items": items,
  275. },
  276. )
  277. except Exception as e:
  278. return ToolResult(
  279. title="查找 Pattern 失败",
  280. output=str(e),
  281. error=str(e),
  282. )
  283. def main() -> None:
  284. """本地测试:用家有大志账号、已推导选题点,查询符合条件概率阈值的 pattern(含帖子匹配)。"""
  285. import asyncio
  286. account_name = "家有大志"
  287. post_id = "68fb6a5c000000000302e5de"
  288. # 已推导选题点,每项:已推导的选题点 + 推导来源人设树节点
  289. derived_items = [
  290. # {"topic": "分享", "source_node": "分享"},
  291. {"topic": "柴犬", "source_node": "动物角色"},
  292. {"topic": "叙事结构", "source_node": "叙事逻辑"},
  293. ]
  294. conditional_ratio_threshold = 0.01
  295. top_n = 100
  296. # 1)直接调用核心函数(不含帖子匹配,仅验证排序逻辑)
  297. derived_list = _parse_derived_list(derived_items)
  298. items = get_patterns_by_conditional_ratio(
  299. account_name, derived_list, conditional_ratio_threshold, top_n, post_id
  300. )
  301. print(f"账号: {account_name}, 阈值: {conditional_ratio_threshold}, top_n: {top_n}")
  302. print(f"共 {len(items)} 条 pattern:\n")
  303. for x in items:
  304. print(f" - {x['pattern名称']}\t条件概率={x['条件概率']}")
  305. # 2)有 agent 时通过 tool 接口再跑一遍(含帖子选题点匹配)
  306. if ToolResult is not None:
  307. async def run_tool():
  308. result = await find_pattern(
  309. account_name=account_name,
  310. post_id=post_id,
  311. derived_items=derived_items,
  312. conditional_ratio_threshold=conditional_ratio_threshold,
  313. top_n=top_n,
  314. )
  315. print("\n--- Tool 返回 ---")
  316. print(result.output)
  317. asyncio.run(run_tool())
  318. if __name__ == "__main__":
  319. main()