find_pattern.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772
  1. """
  2. 查找 Pattern Tool - 从 pattern 库中获取符合条件概率阈值的 pattern
  3. 功能:
  4. - 账号:读取 input/{账号}/处理后数据/pattern/pattern.json,条件概率基于账号人设树;
  5. 元素与帖子选题点匹配走账号 match_data / point_match,并支持人设树子节点、兄弟节点扩展。
  6. - 平台库:读取 input/xiaohongshu/pattern/processed_edge_data.json,条件概率基于 xiaohongshu/tree;
  7. 元素匹配仅使用 input/xiaohongshu/match_data/{post_id}_匹配_all.json。
  8. """
  9. import json
  10. import sys
  11. from pathlib import Path
  12. from typing import Any
  13. # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
  14. _root = Path(__file__).resolve().parent.parent
  15. if str(_root) not in sys.path:
  16. sys.path.insert(0, str(_root))
  17. from utils.conditional_ratio_calc import (
  18. build_node_index_for_tree_dir,
  19. calc_pattern_conditional_ratio,
  20. calc_pattern_conditional_ratio_with_index,
  21. )
  22. from tools.point_match import (
  23. DEFAULT_MATCH_THRESHOLD,
  24. _load_match_data,
  25. match_derivation_to_post_points,
  26. )
  27. from tools.find_tree_node import _load_trees
  28. try:
  29. from agent.tools import tool, ToolResult, ToolContext
  30. except ImportError:
  31. def tool(*args, **kwargs):
  32. return lambda f: f
  33. ToolResult = None # 仅用 main() 测核心逻辑时可无 agent
  34. ToolContext = None
  35. # 与 pattern_data_process 一致的 key 定义
  36. TOP_KEYS = [
  37. "depth_max_with_name",
  38. "depth_mixed",
  39. "depth_max_concrete",
  40. "depth2_medium",
  41. "depth1_abstract",
  42. "depth_max_minus_1",
  43. "depth_max_minus_2",
  44. "depth_3",
  45. "depth_4",
  46. ]
  47. SUB_KEYS = ["two_x", "one_x", "zero_x"]
  48. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  49. # 排序时「已推导选题点 ↔ pattern 元素」在 match_data 中的高分优先阈值(与账号段原逻辑一致)
  50. _MATCH_PRIOR_MIN_SCORE = 0.8
  51. _PLATFORM_TREE_DIR = _BASE_INPUT / "xiaohongshu" / "tree"
  52. _PLATFORM_PATTERN_FILE = _BASE_INPUT / "xiaohongshu" / "pattern" / "processed_edge_data.json"
  53. def _build_node_info(account_name: str) -> dict[str, dict]:
  54. """
  55. 构建人设树节点信息映射: node_name -> {
  56. "type": 节点 _type("class" / "ID" 等),
  57. "children": 子节点名称列表(仅分类节点有值),
  58. "siblings": 兄弟节点名称列表(不含自身),
  59. }
  60. """
  61. node_info: dict[str, dict] = {}
  62. def _walk(node_dict: dict):
  63. children_dict = node_dict.get("children") or {}
  64. child_entries = [(n, c) for n, c in children_dict.items() if isinstance(c, dict)]
  65. child_names = [n for n, _ in child_entries]
  66. for name, child in child_entries:
  67. sub_children = child.get("children") or {}
  68. sub_child_names = [n for n, c in sub_children.items() if isinstance(c, dict)]
  69. node_info[name] = {
  70. "type": child.get("_type", ""),
  71. "children": sub_child_names,
  72. "siblings": [n for n in child_names if n != name],
  73. }
  74. _walk(child)
  75. for _dim_name, root in _load_trees(account_name):
  76. _walk(root)
  77. return node_info
  78. def _pattern_file(account_name: str) -> Path:
  79. """pattern 库文件:../input/{account_name}/处理后数据/pattern/pattern.json"""
  80. return _BASE_INPUT / account_name / "处理后数据" / "pattern" / "pattern.json"
  81. def _platform_pattern_file() -> Path:
  82. """平台库 pattern:../input/xiaohongshu/pattern/processed_edge_data.json"""
  83. return _PLATFORM_PATTERN_FILE
  84. def _slim_pattern(p: dict) -> tuple[float, int, list[str], int]:
  85. """提取 name 列表(去重保序)、support、length、post_count。"""
  86. names = [item["name"] for item in (p.get("items") or [])]
  87. seen = set()
  88. unique = []
  89. for n in names:
  90. if n not in seen:
  91. seen.add(n)
  92. unique.append(n)
  93. support = round(float(p.get("support", 0)), 4)
  94. length = int(p.get("length", 0))
  95. post_count = int(p.get("post_count", 0))
  96. return support, length, unique, post_count
  97. def _merge_and_dedupe(patterns: list[dict]) -> list[dict]:
  98. """
  99. 按 items 的 name 集合去重(不区分顺序),留 support 最大;
  100. 输出格式保留 s、l、i(nameA+nameB+nameC)及 post_count,供条件概率计算使用。
  101. """
  102. key_to_best: dict[tuple, tuple[float, int, int]] = {}
  103. for p in patterns:
  104. support, length, unique, post_count = _slim_pattern(p)
  105. if not unique:
  106. continue
  107. key = tuple(sorted(unique))
  108. if key not in key_to_best or support > key_to_best[key][0]:
  109. key_to_best[key] = (support, length, post_count)
  110. out = []
  111. for k, (s, l, post_count) in key_to_best.items():
  112. out.append({
  113. "s": s,
  114. "l": l,
  115. "i": "+".join(k),
  116. "post_count": post_count,
  117. })
  118. out.sort(key=lambda x: x["s"] * x["l"], reverse=True)
  119. return out
  120. def _load_and_merge_patterns(account_name: str) -> list[dict]:
  121. """读取 pattern 库 JSON,按 TOP_KEYS/SUB_KEYS 合并为列表并做合并、去重。"""
  122. path = _pattern_file(account_name)
  123. if not path.is_file():
  124. return []
  125. with open(path, "r", encoding="utf-8") as f:
  126. data = json.load(f)
  127. all_patterns = []
  128. for top in TOP_KEYS:
  129. if top not in data:
  130. continue
  131. block = data[top]
  132. for sub in SUB_KEYS:
  133. all_patterns.extend(block.get(sub) or [])
  134. return _merge_and_dedupe(all_patterns)
  135. def _load_and_merge_platform_patterns() -> list[dict]:
  136. """读取平台库 pattern JSON,结构与账号库相同,合并去重。"""
  137. path = _platform_pattern_file()
  138. if not path.is_file():
  139. return []
  140. with open(path, "r", encoding="utf-8") as f:
  141. data = json.load(f)
  142. all_patterns = []
  143. for top in TOP_KEYS:
  144. if top not in data:
  145. continue
  146. block = data[top]
  147. for sub in SUB_KEYS:
  148. all_patterns.extend(block.get(sub) or [])
  149. return _merge_and_dedupe(all_patterns)
  150. def _load_platform_match_pair_lookup(post_id: str) -> dict[tuple[str, str], float]:
  151. """
  152. xiaohongshu/match_data/{post_id}_匹配_all.json
  153. -> (帖子选题点, 人设树节点名) -> 最高 match_score(跨 dimension 合并)。
  154. """
  155. lookup: dict[tuple[str, str], float] = {}
  156. if not post_id:
  157. return lookup
  158. path = _BASE_INPUT / "xiaohongshu" / "match_data" / f"{post_id}_匹配_all.json"
  159. if not path.is_file():
  160. return lookup
  161. try:
  162. with open(path, "r", encoding="utf-8") as f:
  163. data = json.load(f)
  164. except Exception:
  165. return lookup
  166. if not isinstance(data, list):
  167. return lookup
  168. for item in data:
  169. if not isinstance(item, dict):
  170. continue
  171. topic = item.get("name")
  172. personas = item.get("match_personas")
  173. if topic is None or not isinstance(personas, list):
  174. continue
  175. topic_s = str(topic).strip()
  176. if not topic_s:
  177. continue
  178. for mp in personas:
  179. if not isinstance(mp, dict):
  180. continue
  181. elem = mp.get("name")
  182. score = mp.get("match_score")
  183. if elem is None or score is None:
  184. continue
  185. elem_s = str(elem).strip()
  186. if not elem_s:
  187. continue
  188. try:
  189. sc = float(score)
  190. except (TypeError, ValueError):
  191. continue
  192. key = (topic_s, elem_s)
  193. if key not in lookup or sc > lookup[key]:
  194. lookup[key] = sc
  195. return lookup
  196. def _platform_element_post_match_map(
  197. post_id: str,
  198. match_score_threshold: float,
  199. ) -> dict[str, dict[str, float]]:
  200. """
  201. 平台库:节点名称(不区分 dimension)-> {帖子选题点: 最高分},
  202. 仅保留 match_score >= match_score_threshold 的对。
  203. """
  204. out: dict[str, dict[str, float]] = {}
  205. if not post_id:
  206. return out
  207. path = _BASE_INPUT / "xiaohongshu" / "match_data" / f"{post_id}_匹配_all.json"
  208. if not path.is_file():
  209. return out
  210. try:
  211. with open(path, "r", encoding="utf-8") as f:
  212. data = json.load(f)
  213. except Exception:
  214. return out
  215. if not isinstance(data, list):
  216. return out
  217. thr = float(match_score_threshold)
  218. for item in data:
  219. if not isinstance(item, dict):
  220. continue
  221. topic = item.get("name")
  222. personas = item.get("match_personas")
  223. if topic is None or not isinstance(personas, list):
  224. continue
  225. topic_s = str(topic).strip()
  226. if not topic_s:
  227. continue
  228. for mp in personas:
  229. if not isinstance(mp, dict):
  230. continue
  231. elem = mp.get("name")
  232. score = mp.get("match_score")
  233. if elem is None or score is None:
  234. continue
  235. try:
  236. sc = float(score)
  237. except (TypeError, ValueError):
  238. continue
  239. if sc < thr:
  240. continue
  241. elem_s = str(elem).strip()
  242. if not elem_s:
  243. continue
  244. bucket = out.setdefault(elem_s, {})
  245. prev = bucket.get(topic_s)
  246. if prev is None or sc > prev:
  247. bucket[topic_s] = sc
  248. return out
  249. def _parse_derived_list(derived_items: list[dict[str, str]]) -> list[tuple[str, str]]:
  250. """将 agent 传入的 [{"topic": "x", "source_node": "y"}, ...] 转为 DerivedItem 列表。"""
  251. out = []
  252. for item in derived_items:
  253. if isinstance(item, dict):
  254. topic = item.get("topic") or item.get("已推导的选题点")
  255. source = item.get("source_node") or item.get("推导来源人设树节点")
  256. if topic is not None and source is not None:
  257. out.append((str(topic).strip(), str(source).strip()))
  258. elif isinstance(item, (list, tuple)) and len(item) >= 2:
  259. out.append((str(item[0]).strip(), str(item[1]).strip()))
  260. return out
  261. def get_patterns_by_conditional_ratio(
  262. account_name: str,
  263. derived_list: list[tuple[str, str]],
  264. conditional_ratio_threshold: float,
  265. top_n: int,
  266. post_id: str = "",
  267. ) -> list[dict[str, Any]]:
  268. """
  269. 从 pattern 库中获取条件概率 >= 阈值的 pattern,按以下优先级排序后返回 top_n 条:
  270. 1. pattern 元素中直接包含已推导选题点(topic)的排最前;
  271. 2. pattern 元素与任意已推导选题点的匹配分 >= 0.8 的次之(从 match_data 文件读取,
  272. key 为 (帖子选题点, 人设树节点),pattern 元素视为人设树节点);
  273. 3. 按条件概率降序;
  274. 4. 按 length 降序。
  275. derived_list 为空时,条件概率使用 pattern 自身的 support(s)。
  276. 返回每项:pattern名称(nameA+nameB+nameC)、条件概率。
  277. """
  278. merged = _load_and_merge_patterns(account_name)
  279. if not merged:
  280. return []
  281. base_dir = _BASE_INPUT
  282. scored: list[tuple[dict, float]] = []
  283. if not derived_list:
  284. # derived_items 为空:条件概率取 pattern 本身的 support (s)
  285. for p in merged:
  286. ratio = float(p.get("s", 0))
  287. if ratio >= conditional_ratio_threshold:
  288. scored.append((p, ratio))
  289. else:
  290. for p in merged:
  291. ratio = calc_pattern_conditional_ratio(
  292. account_name, derived_list, p, base_dir=base_dir
  293. )
  294. if ratio >= conditional_ratio_threshold:
  295. scored.append((p, ratio))
  296. derived_topics = {topic for topic, _ in derived_list} if derived_list else set()
  297. # 次优先:从 match_data 文件加载 (帖子选题点, 人设树节点) -> 匹配分,
  298. # 用已推导选题点(topic)作为帖子选题点,pattern 元素作为人设树节点,
  299. # 检查是否存在匹配分 >= 0.8 的组合。
  300. match_lookup: dict[tuple[str, str], float] = {}
  301. if derived_topics and post_id:
  302. match_lookup = _load_match_data(account_name, post_id)
  303. def _sort_key(x: tuple[dict, float]) -> tuple:
  304. p, ratio = x
  305. elements = set(p["i"].split("+"))
  306. has_derived = bool(elements & derived_topics)
  307. has_high_match = False
  308. if not has_derived and match_lookup:
  309. for elem in elements:
  310. for dt in derived_topics:
  311. if match_lookup.get((dt, elem), 0.0) >= _MATCH_PRIOR_MIN_SCORE:
  312. has_high_match = True
  313. break
  314. if has_high_match:
  315. break
  316. return (not has_derived, not has_high_match, -ratio, -p["l"])
  317. scored.sort(key=_sort_key)
  318. result = []
  319. for p, ratio in scored[:top_n]:
  320. result.append({
  321. "pattern名称": p["i"],
  322. "条件概率": round(ratio, 6),
  323. })
  324. return result
  325. def get_platform_patterns_by_conditional_ratio(
  326. derived_list: list[tuple[str, str]],
  327. conditional_ratio_threshold: float,
  328. top_n: int,
  329. post_id: str = "",
  330. ) -> list[dict[str, Any]]:
  331. """
  332. 平台库 pattern:数据来自 xiaohongshu/pattern/processed_edge_data.json,
  333. 条件概率基于 xiaohongshu/tree 的节点索引(与账号侧 calc_pattern 规则一致)。
  334. 排序优先级规则与 get_patterns_by_conditional_ratio 一致,高分参照 xiaohongshu/match_data。
  335. """
  336. merged = _load_and_merge_platform_patterns()
  337. if not merged:
  338. return []
  339. platform_index = build_node_index_for_tree_dir(_PLATFORM_TREE_DIR)
  340. scored: list[tuple[dict, float]] = []
  341. if not derived_list:
  342. for p in merged:
  343. ratio = float(p.get("s", 0))
  344. if ratio >= conditional_ratio_threshold:
  345. scored.append((p, ratio))
  346. else:
  347. for p in merged:
  348. ratio = calc_pattern_conditional_ratio_with_index(derived_list, p, platform_index)
  349. if ratio >= conditional_ratio_threshold:
  350. scored.append((p, ratio))
  351. derived_topics = {topic for topic, _ in derived_list} if derived_list else set()
  352. match_lookup: dict[tuple[str, str], float] = {}
  353. if derived_topics and post_id:
  354. match_lookup = _load_platform_match_pair_lookup(post_id)
  355. def _sort_key(x: tuple[dict, float]) -> tuple:
  356. p, ratio = x
  357. elements = set(p["i"].split("+"))
  358. has_derived = bool(elements & derived_topics)
  359. has_high_match = False
  360. if not has_derived and match_lookup:
  361. for elem in elements:
  362. for dt in derived_topics:
  363. if match_lookup.get((dt, elem), 0.0) >= _MATCH_PRIOR_MIN_SCORE:
  364. has_high_match = True
  365. break
  366. if has_high_match:
  367. break
  368. return (not has_derived, not has_high_match, -ratio, -p["l"])
  369. scored.sort(key=_sort_key)
  370. result = []
  371. for p, ratio in scored[:top_n]:
  372. result.append({
  373. "pattern名称": p["i"],
  374. "条件概率": round(ratio, 6),
  375. })
  376. return result
  377. def _attach_platform_pattern_post_matches(
  378. items: list[dict[str, Any]],
  379. post_id: str,
  380. match_score_threshold: float,
  381. ) -> None:
  382. """就地写入 帖子选题点匹配:仅使用 xiaohongshu/match_data,元素为节点名(跨 dimension 聚合)。"""
  383. if not items or not post_id:
  384. for it in items:
  385. it["帖子选题点匹配"] = "无"
  386. return
  387. elem_map = _platform_element_post_match_map(post_id, float(match_score_threshold))
  388. for item in items:
  389. pattern_matches: list[dict[str, Any]] = []
  390. for elem in item["pattern名称"].split("+"):
  391. elem = elem.strip()
  392. if not elem:
  393. continue
  394. for post_topic, sc in (elem_map.get(elem) or {}).items():
  395. pattern_matches.append({
  396. "pattern元素": elem,
  397. "帖子选题点": post_topic,
  398. "匹配分数": round(sc, 6),
  399. })
  400. distinct_post_points = len({m["帖子选题点"] for m in pattern_matches})
  401. item["帖子选题点匹配"] = (
  402. pattern_matches if distinct_post_points >= 2 else "无"
  403. )
  404. @tool()
  405. async def find_pattern(
  406. account_name: str,
  407. post_id: str,
  408. derived_items: list[dict[str, str]],
  409. conditional_ratio_threshold: float,
  410. top_n: int = 100,
  411. match_score_threshold: float = DEFAULT_MATCH_THRESHOLD,
  412. ) -> ToolResult:
  413. """
  414. 按条件概率阈值从 pattern 库筛选:第一节为账号 pattern,第二节为平台库 pattern(xiaohongshu/pattern)。
  415. 账号段帖子匹配走账号 match_data + point_match;平台段元素匹配仅走 xiaohongshu/match_data。
  416. Args:
  417. account_name : 账号名,用于定位该账号的 pattern 库。
  418. post_id : 帖子ID。
  419. derived_items : 已推导选题点列表,可为空。
  420. conditional_ratio_threshold : 条件概率阈值。
  421. top_n : 账号段与平台段各自最多返回条数(各自经匹配过滤后可能更少)。
  422. match_score_threshold : 帖子选题点匹配分阈值。
  423. Returns:
  424. ToolResult:output 分「账号 pattern」「平台库 pattern」两段;平台段已排除与账号段 pattern 名称完全相同的项。
  425. """
  426. def _split_by_post_match(
  427. items: list[dict[str, Any]],
  428. ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
  429. matched: list[dict[str, Any]] = []
  430. unmatched: list[dict[str, Any]] = []
  431. for x in items:
  432. if isinstance(x.get("帖子选题点匹配"), list):
  433. matched.append(x)
  434. else:
  435. unmatched.append(x)
  436. return matched, unmatched
  437. def _pick_with_quota(
  438. items: list[dict[str, Any]],
  439. target_count: int,
  440. ) -> list[dict[str, Any]]:
  441. return items[:max(0, int(target_count))]
  442. def _mix_by_ratio(
  443. items: list[dict[str, Any]],
  444. target_count: int,
  445. ) -> list[dict[str, Any]]:
  446. if target_count <= 0:
  447. return []
  448. matched, unmatched = _split_by_post_match(items)
  449. matched_quota = target_count // 2
  450. unmatched_quota = target_count - matched_quota
  451. selected = _pick_with_quota(matched, matched_quota)
  452. selected.extend(_pick_with_quota(unmatched, unmatched_quota))
  453. if len(selected) < target_count:
  454. selected_names = {str(x.get("pattern名称", "")) for x in selected}
  455. fallback_pool = [
  456. x for x in items
  457. if str(x.get("pattern名称", "")) not in selected_names
  458. ]
  459. selected.extend(_pick_with_quota(fallback_pool, target_count - len(selected)))
  460. return selected
  461. pattern_path = _pattern_file(account_name)
  462. if not pattern_path.is_file():
  463. return ToolResult(
  464. title="Pattern 库不存在",
  465. output=f"pattern 文件不存在: {pattern_path}",
  466. error="Pattern file not found",
  467. )
  468. try:
  469. derived_list = _parse_derived_list(derived_items or [])
  470. thr = float(match_score_threshold)
  471. total_top_n = max(0, int(top_n))
  472. account_top_n = int(total_top_n * 0.6)
  473. platform_top_n = total_top_n - account_top_n
  474. # 候选池适当放大,避免按“有/无匹配”分桶后数量不足
  475. candidate_top_n = max(total_top_n * 4, total_top_n + 100)
  476. # ---------- 账号 pattern(原逻辑:match_data + 子节点/兄弟扩展)----------
  477. items_account = get_patterns_by_conditional_ratio(
  478. account_name, derived_list, conditional_ratio_threshold, candidate_top_n, post_id
  479. )
  480. if not post_id:
  481. for item in items_account:
  482. item["帖子选题点匹配"] = "无"
  483. if items_account and post_id:
  484. all_elements: list[str] = []
  485. seen_elements: set[str] = set()
  486. for item in items_account:
  487. for elem in item["pattern名称"].split("+"):
  488. elem = elem.strip()
  489. if elem and elem not in seen_elements:
  490. all_elements.append(elem)
  491. seen_elements.add(elem)
  492. matched_results = await match_derivation_to_post_points(
  493. all_elements, account_name, post_id, match_threshold=thr
  494. )
  495. elem_match_map: dict[str, list] = {}
  496. for m in matched_results:
  497. elem_match_map.setdefault(m["推导选题点"], []).append({
  498. "帖子选题点": m["帖子选题点"],
  499. "匹配分数": m["匹配分数"],
  500. })
  501. for item in items_account:
  502. pattern_matches = []
  503. for elem in item["pattern名称"].split("+"):
  504. elem = elem.strip()
  505. for post_match in elem_match_map.get(elem, []):
  506. pattern_matches.append({
  507. "pattern元素": elem,
  508. "帖子选题点": post_match["帖子选题点"],
  509. "匹配分数": post_match["匹配分数"],
  510. })
  511. distinct_post_points = len({m["帖子选题点"] for m in pattern_matches})
  512. item["帖子选题点匹配"] = (
  513. pattern_matches if distinct_post_points >= 2 else "无"
  514. )
  515. if items_account and post_id:
  516. node_info_map = _build_node_info(account_name)
  517. all_candidates_set: set[str] = set()
  518. item_unmatched_info: list[list[tuple[str, list[str], str]]] = []
  519. for item in items_account:
  520. pattern_matches = item.get("帖子选题点匹配", [])
  521. matched_elems = (
  522. {m["pattern元素"] for m in pattern_matches}
  523. if isinstance(pattern_matches, list) else set()
  524. )
  525. all_elems = [e.strip() for e in item["pattern名称"].split("+")]
  526. unmatched = [e for e in all_elems if e not in matched_elems]
  527. elem_candidates: list[tuple[str, list[str], str]] = []
  528. for elem in unmatched:
  529. info = node_info_map.get(elem)
  530. if not info:
  531. continue
  532. if info["type"] == "class" and info["children"]:
  533. candidates = info["children"]
  534. expand_type = "子节点"
  535. else:
  536. candidates = info["siblings"]
  537. expand_type = "兄弟节点"
  538. if candidates:
  539. elem_candidates.append((elem, candidates, expand_type))
  540. all_candidates_set.update(candidates)
  541. item_unmatched_info.append(elem_candidates)
  542. if all_candidates_set:
  543. candidate_matches = await match_derivation_to_post_points(
  544. list(all_candidates_set), account_name, post_id, match_threshold=thr
  545. )
  546. cand_match_map: dict[str, list[tuple[str, float]]] = {}
  547. for m in candidate_matches:
  548. cand_match_map.setdefault(m["推导选题点"], []).append(
  549. (m["帖子选题点"], m["匹配分数"])
  550. )
  551. for item, elem_cands in zip(items_account, item_unmatched_info):
  552. for elem, candidates, expand_type in elem_cands:
  553. best_cand, best_pp, best_sc = None, None, -1.0
  554. for cand in candidates:
  555. for pp, sc in cand_match_map.get(cand, []):
  556. if sc > best_sc:
  557. best_cand, best_pp, best_sc = cand, pp, sc
  558. if best_cand is not None:
  559. if not isinstance(item.get("帖子选题点匹配"), list):
  560. item["帖子选题点匹配"] = []
  561. item["帖子选题点匹配"].append({
  562. "pattern元素": elem,
  563. "帖子选题点": best_pp,
  564. "匹配分数": best_sc,
  565. "扩展节点": best_cand,
  566. "扩展类型": expand_type,
  567. })
  568. for item in items_account:
  569. matches = item.get("帖子选题点匹配")
  570. if not isinstance(matches, list):
  571. continue
  572. best_by_pp: dict[str, dict] = {}
  573. for m in matches:
  574. pp = m["帖子选题点"]
  575. if pp not in best_by_pp or m["匹配分数"] > best_by_pp[pp]["匹配分数"]:
  576. best_by_pp[pp] = m
  577. item["帖子选题点匹配"] = list(best_by_pp.values())
  578. items_account = _mix_by_ratio(items_account, account_top_n)
  579. account_pattern_names = {str(x.get("pattern名称", "")).strip() for x in items_account}
  580. # ---------- 平台库 pattern(xiaohongshu/tree 条件概率 + xiaohongshu/match_data 匹配)----------
  581. items_platform: list[dict[str, Any]] = []
  582. items_platform = get_platform_patterns_by_conditional_ratio(
  583. derived_list, conditional_ratio_threshold / 5, candidate_top_n, post_id
  584. )
  585. if post_id:
  586. _attach_platform_pattern_post_matches(items_platform, post_id, thr)
  587. else:
  588. for item in items_platform:
  589. item["帖子选题点匹配"] = "无"
  590. items_platform = [
  591. x for x in items_platform
  592. if str(x.get("pattern名称", "")).strip() not in account_pattern_names
  593. ]
  594. for item in items_platform:
  595. matches = item.get("帖子选题点匹配")
  596. if not isinstance(matches, list):
  597. continue
  598. best_by_pp: dict[str, dict] = {}
  599. for m in matches:
  600. pp = m["帖子选题点"]
  601. if pp not in best_by_pp or m["匹配分数"] > best_by_pp[pp]["匹配分数"]:
  602. best_by_pp[pp] = m
  603. item["帖子选题点匹配"] = list(best_by_pp.values())
  604. items_platform = _mix_by_ratio(items_platform, platform_top_n)
  605. def _format_pattern_block(xs: list[dict[str, Any]]) -> list[str]:
  606. lines: list[str] = []
  607. for x in xs:
  608. match_info = x.get("帖子选题点匹配", "无")
  609. if isinstance(match_info, list):
  610. match_str = "、".join(
  611. (
  612. f"{m['扩展节点']}({m['pattern元素']}的{m['扩展类型']})→{m['帖子选题点']}({m['匹配分数']})"
  613. if "扩展节点" in m else
  614. f"{m['pattern元素']}→{m['帖子选题点']}({m['匹配分数']})"
  615. )
  616. for m in match_info
  617. )
  618. else:
  619. match_str = str(match_info)
  620. lines.append(
  621. f"- {x['pattern名称']}\t条件概率={x['条件概率']}\t帖子选题点匹配={match_str}"
  622. )
  623. return lines
  624. lines_out: list[str] = []
  625. lines_out.append(
  626. "【优先使用】第一节为账号 pattern;第二节为平台库 pattern。"
  627. )
  628. lines_out.append("")
  629. lines_out.append("—— 账号 pattern ——")
  630. if not items_account:
  631. lines_out.append(
  632. f"(无:未找到条件概率 >= {conditional_ratio_threshold} 的 pattern)"
  633. )
  634. else:
  635. lines_out.extend(_format_pattern_block(items_account))
  636. lines_out.append("")
  637. lines_out.append("—— 平台库 pattern ——")
  638. if not items_platform:
  639. lines_out.append(
  640. "(无:未找到达标 pattern)"
  641. )
  642. else:
  643. lines_out.extend(_format_pattern_block(items_platform))
  644. output = "\n".join(lines_out)
  645. return ToolResult(
  646. title=f"符合条件概率的 Pattern ({account_name}, 阈值={conditional_ratio_threshold})",
  647. output=output,
  648. metadata={
  649. "account_name": account_name,
  650. "conditional_ratio_threshold": conditional_ratio_threshold,
  651. "match_score_threshold": thr,
  652. "top_n": top_n,
  653. "account_pattern_count": len(items_account),
  654. "platform_pattern_count": len(items_platform),
  655. "count": len(items_account) + len(items_platform),
  656. },
  657. )
  658. except Exception as e:
  659. return ToolResult(
  660. title="查找 Pattern 失败",
  661. output=str(e),
  662. error=str(e),
  663. )
  664. def main() -> None:
  665. """本地测试:用家有大志账号、已推导选题点,查询符合条件概率阈值的 pattern(含帖子匹配)。"""
  666. import asyncio
  667. account_name = "家有大志"
  668. post_id = "68fb6a5c000000000302e5de"
  669. # 已推导选题点,每项:已推导的选题点 + 推导来源人设树节点
  670. # derived_items = [
  671. # {"topic": "分享", "source_node": "分享"},
  672. # {"topic": "植入方式", "source_node": "植入方式"},
  673. # {"topic": "叙事结构", "source_node": "叙事结构"},
  674. # ]
  675. derived_items = derived_items = []
  676. conditional_ratio_threshold = 0.2
  677. top_n = 200
  678. # 1)直接调用核心函数(不含帖子匹配,仅验证排序逻辑)
  679. # derived_list = _parse_derived_list(derived_items)
  680. # items = get_patterns_by_conditional_ratio(
  681. # account_name, derived_list, conditional_ratio_threshold, top_n, post_id
  682. # )
  683. # print(f"账号: {account_name}, 阈值: {conditional_ratio_threshold}, top_n: {top_n}")
  684. # print(f"共 {len(items)} 条 pattern:\n")
  685. # for x in items:
  686. # print(f" - {x['pattern名称']}\t条件概率={x['条件概率']}")
  687. # 2)有 agent 时通过 tool 接口再跑一遍(含帖子选题点匹配)
  688. if ToolResult is not None:
  689. async def run_tool():
  690. result = await find_pattern(
  691. account_name=account_name,
  692. post_id=post_id,
  693. derived_items=derived_items,
  694. conditional_ratio_threshold=conditional_ratio_threshold,
  695. top_n=top_n,
  696. )
  697. print("\n--- Tool 返回 ---")
  698. print(result.output)
  699. asyncio.run(run_tool())
  700. if __name__ == "__main__":
  701. main()