find_pattern.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. """
  2. 查找 Pattern Tool - 从 pattern 库中获取符合条件概率阈值的 pattern
  3. 功能:
  4. - 账号:读取 input/{账号}/处理后数据/pattern/pattern.json,条件概率基于账号人设树。
  5. - 平台库:读取 input/xiaohongshu/pattern/processed_edge_data.json,条件概率基于 xiaohongshu/tree。
  6. 所有 pattern 按 条件概率 * pattern元素长度 降序;账号占 60% 配额,平台库占 40% 配额。
  7. """
  8. import json
  9. import sys
  10. from pathlib import Path
  11. from typing import Any
  12. # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
  13. _root = Path(__file__).resolve().parent.parent
  14. if str(_root) not in sys.path:
  15. sys.path.insert(0, str(_root))
  16. from utils.conditional_ratio_calc import (
  17. build_node_index_for_tree_dir,
  18. calc_pattern_conditional_ratio,
  19. calc_pattern_conditional_ratio_with_index,
  20. )
  21. from tools.point_match import (
  22. DEFAULT_MATCH_THRESHOLD,
  23. )
  24. try:
  25. from agent.tools import tool, ToolResult, ToolContext
  26. except ImportError:
  27. def tool(*args, **kwargs):
  28. return lambda f: f
  29. ToolResult = None # 仅用 main() 测核心逻辑时可无 agent
  30. ToolContext = None
  31. # 与 pattern_data_process 一致的 key 定义
  32. TOP_KEYS = [
  33. "depth_max_with_name",
  34. "depth_mixed",
  35. "depth_max_concrete",
  36. "depth2_medium",
  37. "depth1_abstract",
  38. "depth_max_minus_1",
  39. "depth_max_minus_2",
  40. "depth_3",
  41. "depth_4",
  42. ]
  43. SUB_KEYS = ["two_x", "one_x", "zero_x"]
  44. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  45. _PLATFORM_TREE_DIR = _BASE_INPUT / "xiaohongshu" / "tree"
  46. _PLATFORM_PATTERN_FILE = _BASE_INPUT / "xiaohongshu" / "pattern" / "processed_edge_data.json"
  47. def _pattern_file(account_name: str) -> Path:
  48. """pattern 库文件:../input/{account_name}/处理后数据/pattern/pattern.json"""
  49. return _BASE_INPUT / account_name / "处理后数据" / "pattern" / "pattern.json"
  50. def _platform_pattern_file() -> Path:
  51. """平台库 pattern:../input/xiaohongshu/pattern/processed_edge_data.json"""
  52. return _PLATFORM_PATTERN_FILE
  53. def _slim_pattern(p: dict) -> tuple[float, int, list[str], int]:
  54. """提取 name 列表(去重保序)、support、length、post_count。"""
  55. names = [item["name"] for item in (p.get("items") or [])]
  56. seen = set()
  57. unique = []
  58. for n in names:
  59. if n not in seen:
  60. seen.add(n)
  61. unique.append(n)
  62. support = round(float(p.get("support", 0)), 4)
  63. length = int(p.get("length", 0))
  64. post_count = int(p.get("post_count", 0))
  65. return support, length, unique, post_count
  66. def _merge_and_dedupe(patterns: list[dict]) -> list[dict]:
  67. """
  68. 按 items 的 name 集合去重(不区分顺序),留 support 最大;
  69. 输出格式保留 s、l、i(nameA+nameB+nameC)及 post_count,供条件概率计算使用。
  70. """
  71. key_to_best: dict[tuple, tuple[float, int, int]] = {}
  72. for p in patterns:
  73. support, length, unique, post_count = _slim_pattern(p)
  74. if not unique:
  75. continue
  76. key = tuple(sorted(unique))
  77. if key not in key_to_best or support > key_to_best[key][0]:
  78. key_to_best[key] = (support, length, post_count)
  79. out = []
  80. for k, (s, l, post_count) in key_to_best.items():
  81. out.append({
  82. "s": s,
  83. "l": l,
  84. "i": "+".join(k),
  85. "post_count": post_count,
  86. })
  87. out.sort(key=lambda x: x["s"] * x["l"], reverse=True)
  88. return out
  89. def _load_and_merge_patterns(account_name: str) -> list[dict]:
  90. """读取 pattern 库 JSON,按 TOP_KEYS/SUB_KEYS 合并为列表并做合并、去重。"""
  91. path = _pattern_file(account_name)
  92. if not path.is_file():
  93. return []
  94. with open(path, "r", encoding="utf-8") as f:
  95. data = json.load(f)
  96. all_patterns = []
  97. for top in TOP_KEYS:
  98. if top not in data:
  99. continue
  100. block = data[top]
  101. for sub in SUB_KEYS:
  102. all_patterns.extend(block.get(sub) or [])
  103. return _merge_and_dedupe(all_patterns)
  104. def _load_and_merge_platform_patterns() -> list[dict]:
  105. """读取平台库 pattern JSON,结构与账号库相同,合并去重。"""
  106. path = _platform_pattern_file()
  107. if not path.is_file():
  108. return []
  109. with open(path, "r", encoding="utf-8") as f:
  110. data = json.load(f)
  111. all_patterns = []
  112. for top in TOP_KEYS:
  113. if top not in data:
  114. continue
  115. block = data[top]
  116. for sub in SUB_KEYS:
  117. all_patterns.extend(block.get(sub) or [])
  118. return _merge_and_dedupe(all_patterns)
  119. def _load_match_lookup(file_path: Path) -> dict[tuple[str, str], float]:
  120. """
  121. 读取 match_data 文件,返回 (帖子选题点, 人设树节点) -> 最高匹配分。
  122. 文件格式:[{"name": 帖子选题点, "match_personas": [{"name": 节点名, "match_score": float}]}]
  123. """
  124. lookup: dict[tuple[str, str], float] = {}
  125. if not file_path.is_file():
  126. return lookup
  127. try:
  128. with open(file_path, "r", encoding="utf-8") as f:
  129. data = json.load(f)
  130. except Exception:
  131. return lookup
  132. if not isinstance(data, list):
  133. return lookup
  134. for item in data:
  135. if not isinstance(item, dict):
  136. continue
  137. topic = item.get("name")
  138. personas = item.get("match_personas")
  139. if topic is None or not isinstance(personas, list):
  140. continue
  141. topic_s = str(topic).strip()
  142. if not topic_s:
  143. continue
  144. for mp in personas:
  145. if not isinstance(mp, dict):
  146. continue
  147. node = mp.get("name")
  148. score = mp.get("match_score")
  149. if node is None or score is None:
  150. continue
  151. try:
  152. sc = float(score)
  153. except (TypeError, ValueError):
  154. continue
  155. key = (topic_s, str(node).strip())
  156. if key not in lookup or sc > lookup[key]:
  157. lookup[key] = sc
  158. return lookup
  159. def _pattern_has_derived_match(
  160. pattern_name: str,
  161. derived_topics: set[str],
  162. match_lookup: dict[tuple[str, str], float],
  163. threshold: float,
  164. ) -> bool:
  165. """pattern 中至少有一个元素与任意 derived_topic 的匹配分 >= threshold。"""
  166. for elem in (e.strip() for e in pattern_name.split("+")):
  167. if not elem:
  168. continue
  169. for topic in derived_topics:
  170. if match_lookup.get((topic, elem), 0.0) >= threshold:
  171. return True
  172. return False
  173. def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]:
  174. """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。"""
  175. out = []
  176. for item in derived_items:
  177. if isinstance(item, dict):
  178. topic = item.get("topic") or item.get("已推导的选题点")
  179. source = item.get("source_node") or item.get("推导来源人设树节点")
  180. if topic is not None and source is not None:
  181. out.append((str(topic).strip(), str(source).strip()))
  182. elif isinstance(item, (list, tuple)) and len(item) >= 2:
  183. out.append((str(item[0]).strip(), str(item[1]).strip()))
  184. return out
  185. def get_patterns_by_conditional_ratio(
  186. account_name: str,
  187. derived_list: list[tuple[str, str]],
  188. conditional_ratio_threshold: float,
  189. top_n: int,
  190. ) -> list[dict[str, Any]]:
  191. """
  192. 从 pattern 库中获取条件概率 >= 阈值的 pattern,按 条件概率 * pattern元素长度 降序返回 top_n 条。
  193. derived_list 为空时,条件概率使用 pattern 自身的 support(s)。
  194. 返回每项:pattern名称(nameA+nameB+nameC)、条件概率。
  195. """
  196. merged = _load_and_merge_patterns(account_name)
  197. if not merged:
  198. return []
  199. base_dir = _BASE_INPUT
  200. scored: list[tuple[dict, float]] = []
  201. if not derived_list:
  202. for p in merged:
  203. ratio = float(p.get("s", 0))
  204. if ratio >= conditional_ratio_threshold:
  205. scored.append((p, ratio))
  206. else:
  207. for p in merged:
  208. ratio = calc_pattern_conditional_ratio(
  209. account_name, derived_list, p, base_dir=base_dir
  210. )
  211. if ratio >= conditional_ratio_threshold:
  212. scored.append((p, ratio))
  213. scored.sort(key=lambda x: -(x[1] * x[0]["l"]))
  214. result = []
  215. for p, ratio in scored[:top_n]:
  216. result.append({
  217. "pattern名称": p["i"],
  218. "条件概率": round(ratio, 6),
  219. })
  220. return result
  221. def get_platform_patterns_by_conditional_ratio(
  222. derived_list: list[tuple[str, str]],
  223. conditional_ratio_threshold: float,
  224. top_n: int,
  225. ) -> list[dict[str, Any]]:
  226. """
  227. 平台库 pattern:数据来自 xiaohongshu/pattern/processed_edge_data.json,
  228. 条件概率基于 xiaohongshu/tree 的节点索引(与账号侧 calc_pattern 规则一致)。
  229. 按 条件概率 * pattern元素长度 降序返回 top_n 条。
  230. """
  231. merged = _load_and_merge_platform_patterns()
  232. if not merged:
  233. return []
  234. platform_index = build_node_index_for_tree_dir(_PLATFORM_TREE_DIR)
  235. scored: list[tuple[dict, float]] = []
  236. if not derived_list:
  237. for p in merged:
  238. ratio = float(p.get("s", 0))
  239. if ratio >= conditional_ratio_threshold:
  240. scored.append((p, ratio))
  241. else:
  242. for p in merged:
  243. ratio = calc_pattern_conditional_ratio_with_index(derived_list, p, platform_index)
  244. if ratio >= conditional_ratio_threshold:
  245. scored.append((p, ratio))
  246. scored.sort(key=lambda x: -(x[1] * x[0]["l"]))
  247. result = []
  248. for p, ratio in scored[:top_n]:
  249. result.append({
  250. "pattern名称": p["i"],
  251. "条件概率": round(ratio, 6),
  252. })
  253. return result
  254. # ---------------------------------------------------------------------------
  255. # Agent Tool
  256. # ---------------------------------------------------------------------------
  257. @tool()
  258. async def find_pattern(
  259. account_name: str,
  260. post_id: str,
  261. derived_items: list[dict[str, str]],
  262. conditional_ratio_threshold: float,
  263. top_n: int = 100,
  264. match_score_threshold: float = DEFAULT_MATCH_THRESHOLD,
  265. ) -> ToolResult:
  266. """
  267. 按条件概率阈值从 pattern 库筛选:第一节为账号 pattern(优先使用),第二节为平台库 pattern。
  268. 所有 pattern 按 条件概率 * pattern元素长度 降序排列。
  269. Args:
  270. account_name : 账号名,用于定位该账号的 pattern 库。
  271. post_id : 帖子ID,用于加载 match_data 过滤(derived_items 非空时生效)。
  272. derived_items : 已推导选题点列表,可为空。
  273. conditional_ratio_threshold : 条件概率阈值。
  274. top_n : 最终返回总条数上限。
  275. match_score_threshold : pattern 元素与帖子选题点的匹配分阈值。
  276. Returns:
  277. ToolResult:output 分「账号 pattern」「平台库 pattern」两段;平台段已排除与账号段 pattern 名称完全相同的项。
  278. """
  279. pattern_path = _pattern_file(account_name)
  280. if not pattern_path.is_file():
  281. return ToolResult(
  282. title="Pattern 库不存在",
  283. output=f"pattern 文件不存在: {pattern_path}",
  284. error="Pattern file not found",
  285. )
  286. try:
  287. derived_list = _parse_derived_list(derived_items or [])
  288. derived_topics = {topic for topic, _ in derived_list}
  289. thr = float(match_score_threshold)
  290. total_top_n = max(0, int(top_n))
  291. account_top_n = int(total_top_n * 0.6)
  292. platform_top_n = total_top_n - account_top_n
  293. # 有过滤时候选池放大,以保证过滤后仍有足够数量
  294. candidate_mult = max(total_top_n * 5, 500) if derived_topics and post_id else 0
  295. # 预加载 match_lookup(仅当 derived_topics 非空且有 post_id 时)
  296. account_match_lookup: dict[tuple[str, str], float] = {}
  297. platform_match_lookup: dict[tuple[str, str], float] = {}
  298. if derived_topics and post_id:
  299. account_match_file = (
  300. _BASE_INPUT / account_name / "处理后数据" / "match_data"
  301. / f"{post_id}_匹配_all.json"
  302. )
  303. platform_match_file = (
  304. _BASE_INPUT / "xiaohongshu" / "match_data" / f"{post_id}_匹配_all.json"
  305. )
  306. account_match_lookup = _load_match_lookup(account_match_file)
  307. platform_match_lookup = _load_match_lookup(platform_match_file)
  308. def _filter_by_derived_match(
  309. items: list[dict],
  310. match_lookup: dict[tuple[str, str], float],
  311. ) -> list[dict]:
  312. """derived_topics 非空时过滤:pattern 至少有一个元素与任意 topic 匹配分 >= thr。"""
  313. if not derived_topics or not post_id:
  314. return items
  315. return [
  316. x for x in items
  317. if _pattern_has_derived_match(
  318. str(x.get("pattern名称", "")), derived_topics, match_lookup, thr
  319. )
  320. ]
  321. # ---------- 账号 pattern ----------
  322. account_candidate_n = candidate_mult if candidate_mult else account_top_n
  323. items_account_raw = get_patterns_by_conditional_ratio(
  324. account_name, derived_list, conditional_ratio_threshold, account_candidate_n
  325. )
  326. items_account = _filter_by_derived_match(items_account_raw, account_match_lookup)[:account_top_n]
  327. account_pattern_names = {str(x.get("pattern名称", "")).strip() for x in items_account}
  328. # ---------- 平台库 pattern ----------
  329. platform_candidate_n = (candidate_mult + len(account_pattern_names)) if candidate_mult else (platform_top_n + len(account_pattern_names))
  330. items_platform_raw = get_platform_patterns_by_conditional_ratio(
  331. derived_list,
  332. conditional_ratio_threshold / 5,
  333. platform_candidate_n,
  334. )
  335. items_platform = _filter_by_derived_match(
  336. [x for x in items_platform_raw if str(x.get("pattern名称", "")).strip() not in account_pattern_names],
  337. platform_match_lookup,
  338. )[:platform_top_n]
  339. def _format_pattern_block(xs: list[dict[str, Any]]) -> list[str]:
  340. return [f"- {x['pattern名称']}\t条件概率={x['条件概率']}" for x in xs]
  341. lines_out: list[str] = []
  342. lines_out.append(
  343. "【优先使用】第一节为账号 pattern(优先使用);第二节为平台库 pattern。"
  344. )
  345. lines_out.append("")
  346. lines_out.append("—— 账号 pattern ——")
  347. if not items_account:
  348. lines_out.append(
  349. f"(无:未找到条件概率 >= {conditional_ratio_threshold} 的 pattern)"
  350. )
  351. else:
  352. lines_out.extend(_format_pattern_block(items_account))
  353. lines_out.append("")
  354. lines_out.append("—— 平台库 pattern ——")
  355. if not items_platform:
  356. lines_out.append("(无:未找到达标 pattern)")
  357. else:
  358. lines_out.extend(_format_pattern_block(items_platform))
  359. output = "\n".join(lines_out)
  360. return ToolResult(
  361. title=f"符合条件概率的 Pattern ({account_name}, 阈值={conditional_ratio_threshold})",
  362. output=output,
  363. metadata={
  364. "account_name": account_name,
  365. "conditional_ratio_threshold": conditional_ratio_threshold,
  366. "top_n": top_n,
  367. "quota": {
  368. "account_top_n": account_top_n,
  369. "platform_top_n": platform_top_n,
  370. },
  371. "account_pattern_count": len(items_account),
  372. "platform_pattern_count": len(items_platform),
  373. "count": len(items_account) + len(items_platform),
  374. },
  375. )
  376. except Exception as e:
  377. return ToolResult(
  378. title="查找 Pattern 失败",
  379. output=str(e),
  380. error=str(e),
  381. )
  382. def main() -> None:
  383. """本地测试:用家有大志账号、已推导选题点,查询符合条件概率阈值的 pattern。"""
  384. import asyncio
  385. account_name = "家有大志"
  386. post_id = "68fb6a5c000000000302e5de"
  387. derived_items = [
  388. {"topic": "分享", "source_node": "分享"},
  389. {"topic": "植入方式", "source_node": "植入方式"},
  390. {"topic": "叙事结构", "source_node": "叙事结构"},
  391. ]
  392. derived_items: list[dict[str, str]] = []
  393. conditional_ratio_threshold = 0.2
  394. top_n = 500
  395. # 1)直接调用核心函数(仅验证排序逻辑)
  396. # derived_list = _parse_derived_list(derived_items)
  397. # items = get_patterns_by_conditional_ratio(
  398. # account_name, derived_list, conditional_ratio_threshold, top_n
  399. # )
  400. # print(f"账号: {account_name}, 阈值: {conditional_ratio_threshold}, top_n: {top_n}")
  401. # print(f"共 {len(items)} 条 pattern:\n")
  402. # for x in items:
  403. # print(f" - {x['pattern名称']}\t条件概率={x['条件概率']}")
  404. # 2)有 agent 时通过 tool 接口再跑一遍
  405. if ToolResult is not None:
  406. async def run_tool():
  407. result = await find_pattern(
  408. account_name=account_name,
  409. post_id=post_id,
  410. derived_items=derived_items,
  411. conditional_ratio_threshold=conditional_ratio_threshold,
  412. top_n=top_n,
  413. )
  414. print("\n--- Tool 返回 ---")
  415. print(result.output)
  416. asyncio.run(run_tool())
  417. if __name__ == "__main__":
  418. main()