pattern_dimension_analyze.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761
  1. """
  2. Pattern 维度聚类分析 Tool
  3. 功能概述:
  4. 1. 读取某次整体推导日志目录下各轮评估结果,累计每轮已匹配的 matched_post_point。
  5. 2. 基于帖子与 pattern 库的匹配结果,对 pattern 元素做打分与分类(已推导/未推导)。
  6. 3. 在账号人设树(实质/形式/意图)中,分别为「已推导元素」「未推导元素」寻找聚类节点。
  7. 输入参数:
  8. - account_name: 账号名称
  9. - post_id: 帖子 ID
  10. - log_id: 推导日志目录名(形如 20260313210921)
  11. """
  12. import json
  13. import sys
  14. from pathlib import Path
  15. from typing import Any, Dict, List, Optional, Tuple, Set
  16. # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
  17. _root = Path(__file__).resolve().parent.parent
  18. if str(_root) not in sys.path:
  19. sys.path.insert(0, str(_root))
  20. from tools.point_match import _load_match_data # 帖子选题点与人设树节点匹配分
  21. from tools.find_tree_node import _load_trees # 加载三棵人设树
  22. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  23. _BASE_OUTPUT = Path(__file__).resolve().parent.parent / "output"
  24. # pattern 库 key 定义(与 find_pattern 中保持一致)
  25. TOP_KEYS = [
  26. "depth_max_with_name",
  27. "depth_mixed",
  28. "depth_max_concrete",
  29. "depth2_medium",
  30. "depth1_abstract",
  31. ]
  32. SUB_KEYS = ["two_x", "one_x", "zero_x"]
  33. # ---------------------------------------------------------------------------
  34. # 1. 读取推导日志:按轮次累计 matched_post_point
  35. # ---------------------------------------------------------------------------
  36. def _round_eval_dir(account_name: str, post_id: str, log_id: str) -> Path:
  37. """
  38. 推导日志目录:
  39. ../output/{account_name}/推导日志/{post_id}/{log_id}/
  40. """
  41. return _BASE_OUTPUT / account_name / "推导日志" / post_id / log_id
  42. def _load_round_matched_points(
  43. account_name: str,
  44. post_id: str,
  45. log_id: str,
  46. ) -> List[Dict[str, Any]]:
  47. """
  48. 读取指定日志目录下所有 {轮次}.评估.json,按轮次排序,生成:
  49. [
  50. {
  51. "round": 1,
  52. "round_points": [... 本轮 matched_post_point 去重 ...],
  53. "cumulative_points": [... 累计到本轮的 matched_post_point 去重 ...],
  54. },
  55. ...
  56. ]
  57. """
  58. base_dir = _round_eval_dir(account_name, post_id, log_id)
  59. if not base_dir.is_dir():
  60. return []
  61. eval_files: List[Tuple[int, Path]] = []
  62. for p in base_dir.glob("*.json"):
  63. name = p.name
  64. # 只处理 *_评估.json
  65. if not name.endswith("评估.json"):
  66. continue
  67. try:
  68. round_str = name.split("_", 1)[0]
  69. r = int(round_str)
  70. except Exception:
  71. continue
  72. eval_files.append((r, p))
  73. eval_files.sort(key=lambda x: x[0])
  74. results: List[Dict[str, Any]] = []
  75. cumulative: List[str] = []
  76. cumulative_set: Set[str] = set()
  77. for r, path in eval_files:
  78. try:
  79. with open(path, "r", encoding="utf-8") as f:
  80. data = json.load(f)
  81. except Exception:
  82. continue
  83. eval_results = data.get("eval_results") or []
  84. round_points: List[str] = []
  85. for item in eval_results:
  86. if not isinstance(item, dict):
  87. continue
  88. if not item.get("is_matched"):
  89. continue
  90. # 根据是否已完全推导,选择不同的帖子选题点字段:
  91. # - is_fully_derived 为 False 时,使用 derivation_output_point
  92. # - 其他情况(True 或缺失)使用 matched_post_point(兼容旧数据)
  93. if item.get("is_fully_derived") is False:
  94. mp = item.get("derivation_output_point")
  95. else:
  96. mp = item.get("matched_post_point")
  97. if mp is None:
  98. continue
  99. mp = str(mp).strip()
  100. if not mp:
  101. continue
  102. if mp not in round_points:
  103. round_points.append(mp)
  104. # 累加到本轮
  105. for mp in round_points:
  106. if mp not in cumulative_set:
  107. cumulative_set.add(mp)
  108. cumulative.append(mp)
  109. results.append(
  110. {
  111. "round": r,
  112. "round_points": round_points,
  113. "cumulative_points": list(cumulative),
  114. }
  115. )
  116. return results
  117. # ---------------------------------------------------------------------------
  118. # 2. 读取 pattern 库并按 matched_post_point 打分
  119. # ---------------------------------------------------------------------------
  120. def _pattern_file(account_name: str) -> Path:
  121. """pattern 库文件:../input/{account_name}/原始数据/pattern/processed_edge_data.json"""
  122. return _BASE_INPUT / account_name / "原始数据" / "pattern" / "processed_edge_data.json"
  123. def _load_raw_patterns(account_name: str) -> List[Dict[str, Any]]:
  124. """
  125. 读取 pattern 库中所有原始 pattern(保留 items 结构,不做合并)。
  126. 返回列表中每个元素形如原始 JSON 中的 pattern(此处不关心 item 的 point / dimension 字段)。
  127. """
  128. path = _pattern_file(account_name)
  129. if not path.is_file():
  130. return []
  131. with open(path, "r", encoding="utf-8") as f:
  132. data = json.load(f)
  133. patterns: List[Dict[str, Any]] = []
  134. for top in TOP_KEYS:
  135. block = data.get(top)
  136. if not isinstance(block, dict):
  137. continue
  138. for sub in SUB_KEYS:
  139. items = block.get(sub) or []
  140. if isinstance(items, list):
  141. for p in items:
  142. if isinstance(p, dict):
  143. patterns.append(p)
  144. return patterns
  145. def _slim_pattern_for_dedupe(p: Dict[str, Any]) -> Tuple[float, List[str]]:
  146. """
  147. 提取 pattern 的 support 与去重后的 item name 列表(按名称合并,不关心顺序),
  148. 用于与 find_pattern.py 中的去重逻辑对齐。
  149. """
  150. items = p.get("items") or []
  151. names = [str(it.get("name") or "").strip() for it in items if isinstance(it, dict)]
  152. seen: Set[str] = set()
  153. unique: List[str] = []
  154. for n in names:
  155. if n and n not in seen:
  156. seen.add(n)
  157. unique.append(n)
  158. try:
  159. support = float(p.get("support", 0.0))
  160. except (TypeError, ValueError):
  161. support = 0.0
  162. return support, unique
  163. def _dedupe_patterns(raw_patterns: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
  164. """
  165. 按 pattern 的 item name 集合去重(不区分顺序),与 find_pattern.py 的思路一致:
  166. - key 为 sorted(unique item names)
  167. - 同一个 key 仅保留 support 最大的 pattern(保留其原始 items 结构,方便后续打分)
  168. """
  169. key_to_best: Dict[Tuple[str, ...], Dict[str, Any]] = {}
  170. key_to_support: Dict[Tuple[str, ...], float] = {}
  171. for p in raw_patterns:
  172. support, unique = _slim_pattern_for_dedupe(p)
  173. if not unique:
  174. continue
  175. key = tuple(sorted(unique))
  176. best_support = key_to_support.get(key)
  177. if best_support is None or support > best_support:
  178. key_to_support[key] = support
  179. key_to_best[key] = p
  180. return list(key_to_best.values())
  181. def _score_patterns_by_matched_points(
  182. patterns: List[Dict[str, Any]],
  183. account_name: str,
  184. post_id: str,
  185. matched_post_points: List[str],
  186. match_threshold: float,
  187. ) -> List[Dict[str, Any]]:
  188. """
  189. 对传入的 pattern 列表计算其元素与 matched_post_point 列表的匹配分:
  190. - 匹配分来源:../input/{account_name}/match_data/{post_id}_匹配_all.json
  191. lookup key 为 (帖子选题点, 人设树节点)
  192. - 对于每个 pattern 的每个元素(item):
  193. * 以 item["name"] 视为人设树节点名称
  194. * 对每个 matched_post_point 查找匹配分,取最大值
  195. - 仅保留「至少有一个元素匹配分 >= match_threshold」的 pattern。
  196. 返回的 pattern 结构(item 不再保留 point / dimension 字段):
  197. {
  198. "id": xxx,
  199. "support": xxx,
  200. "items": [
  201. {
  202. "name": "xxx",
  203. "type": "xxx",
  204. "matched_post_point": "xxx" | null,
  205. "matched_score": float,
  206. },
  207. ...
  208. ],
  209. }
  210. """
  211. if not patterns or not matched_post_points:
  212. return []
  213. match_lookup = _load_match_data(account_name, post_id)
  214. matched_post_points = [str(x).strip() for x in matched_post_points if str(x).strip()]
  215. if not matched_post_points:
  216. return []
  217. results: List[Dict[str, Any]] = []
  218. for p in patterns:
  219. items = p.get("items") or []
  220. if not isinstance(items, list):
  221. continue
  222. scored_items: List[Dict[str, Any]] = []
  223. max_item_score = 0.0
  224. for it in items:
  225. if not isinstance(it, dict):
  226. continue
  227. name = str(it.get("name") or "").strip()
  228. _type = str(it.get("type") or "").strip()
  229. best_score = 0.0
  230. best_post_point: Optional[str] = None
  231. if name:
  232. for post_point in matched_post_points:
  233. # 如果帖子选题点与节点名称完全一致,直接视为满分匹配
  234. if post_point == name:
  235. s = 1.0
  236. else:
  237. score = match_lookup.get((post_point, name))
  238. if score is None:
  239. continue
  240. try:
  241. s = float(score)
  242. except (TypeError, ValueError):
  243. continue
  244. if s > best_score:
  245. best_score = s
  246. best_post_point = post_point
  247. if best_score > max_item_score:
  248. max_item_score = best_score
  249. scored_items.append(
  250. {
  251. "name": name,
  252. "type": _type,
  253. "matched_post_point": best_post_point,
  254. "matched_score": round(best_score, 6),
  255. }
  256. )
  257. if not scored_items:
  258. continue
  259. if max_item_score < match_threshold:
  260. # 该 pattern 在本轮未与帖子形成足够强的匹配
  261. continue
  262. results.append(
  263. {
  264. "id": p.get("id"),
  265. "support": p.get("support"),
  266. "items": scored_items,
  267. }
  268. )
  269. return results
  270. # ---------------------------------------------------------------------------
  271. # 3. 人设树节点信息 & 聚类节点搜索
  272. # ---------------------------------------------------------------------------
  273. class TreeIndex:
  274. """
  275. 人设树索引:
  276. - node_info: 节点 -> { "parent": 父节点名称, "children": [子节点名称...], "depth": 深度, "dimension": 维度名 }
  277. - roots: 维度名 -> 根节点名称(即维度名本身)
  278. - merged_tree: 将实质/形式/意图三棵树合并后的单个 JSON(顶层 key 为实质/形式/意图)
  279. """
  280. def __init__(self, account_name: str) -> None:
  281. self.account_name = account_name
  282. self.node_info: Dict[str, Dict[str, Any]] = {}
  283. self.roots: Dict[str, str] = {}
  284. # 三棵树合并后的 JSON:{"实质": {...}, "形式": {...}, "意图": {...}}
  285. self.merged_tree: Dict[str, Dict[str, Any]] = {}
  286. self._build()
  287. def _build(self) -> None:
  288. trees = _load_trees(self.account_name)
  289. # 1)先将三棵树合并成一个 JSON:{"实质": {...}, "形式": {...}, "意图": {...}}
  290. merged: Dict[str, Dict[str, Any]] = {}
  291. for dim_name, root in trees:
  292. if isinstance(root, dict):
  293. merged[dim_name] = root
  294. self.merged_tree = merged
  295. # 2)基于合并后的 JSON 构建 parent/children 结构
  296. for dim_name, root in merged.items():
  297. root_name = dim_name
  298. self.roots[dim_name] = root_name
  299. if root_name not in self.node_info:
  300. self.node_info[root_name] = {
  301. "parent": None,
  302. "children": [],
  303. "dimension": dim_name,
  304. "depth": 0,
  305. }
  306. def walk(parent_name: str, node_dict: Dict[str, Any]):
  307. children = node_dict.get("children") or {}
  308. for name, child in children.items():
  309. if not isinstance(child, dict):
  310. continue
  311. if name not in self.node_info:
  312. self.node_info[name] = {
  313. "parent": parent_name,
  314. "children": [],
  315. "dimension": dim_name,
  316. "depth": None, # 稍后统一计算
  317. }
  318. else:
  319. # 仅当不会形成自引用时才更新 parent(树中可能存在同名的父子节点)
  320. if name != parent_name:
  321. self.node_info[name]["parent"] = parent_name
  322. self.node_info[name]["dimension"] = dim_name
  323. # 维护父节点的 children
  324. if parent_name not in self.node_info:
  325. self.node_info[parent_name] = {
  326. "parent": None,
  327. "children": [],
  328. "dimension": dim_name,
  329. "depth": 0,
  330. }
  331. if name not in self.node_info[parent_name]["children"]:
  332. self.node_info[parent_name]["children"].append(name)
  333. walk(name, child)
  334. walk(root_name, root)
  335. # 统一计算各节点深度(从根开始 BFS)
  336. from collections import deque
  337. q = deque()
  338. for dim_name, root_name in self.roots.items():
  339. if root_name not in self.node_info:
  340. continue
  341. self.node_info[root_name]["depth"] = 0
  342. q.append(root_name)
  343. while q:
  344. cur = q.popleft()
  345. cur_depth = self.node_info[cur].get("depth", 0) or 0
  346. for child in self.node_info[cur].get("children", []):
  347. self.node_info.setdefault(child, {})
  348. if self.node_info[child].get("depth") is None:
  349. self.node_info[child]["depth"] = cur_depth + 1
  350. q.append(child)
  351. # 聚类搜索(不再区分维度)
  352. def find_clusters(
  353. self,
  354. elements: List[str],
  355. cluster_level: int,
  356. ) -> List[Dict[str, Any]]:
  357. """
  358. 在所有人设树中,为给定元素列表寻找聚类节点(不再要求 dimension 一致)。
  359. 规则(固定聚类层级 cluster_level):
  360. - 仅在 depth == cluster_level 的节点上做聚类判断:
  361. * 若某节点子树中包含的元素数量 >= 2,
  362. 且在该路径上尚未存在更高层(深度更小)的聚类节点,则将其视为一个聚类节点。
  363. - 对无法向上形成聚类的元素,为其寻找 depth == cluster_level 的祖先节点,
  364. 若存在则作为该元素的「单元素聚类」节点。
  365. - 返回:
  366. [
  367. {
  368. "cluster_node": "节点名",
  369. "from_elements": ["元素A", "元素B", ...]
  370. },
  371. ...
  372. ]
  373. """
  374. # 过滤出真实存在于人设树中的元素
  375. elem_set: Set[str] = set()
  376. for e in elements:
  377. e = str(e).strip()
  378. if not e:
  379. continue
  380. info = self.node_info.get(e)
  381. if not info:
  382. continue
  383. elem_set.add(e)
  384. if not elem_set:
  385. return []
  386. # 先计算每个节点子树中包含的元素数量(跨所有维度的根)
  387. # 注意:人设树数据中可能存在意外的环或重复引用,这里通过 visited 集合避免递归死循环。
  388. subtree_count: Dict[str, int] = {}
  389. def dfs_count(node: str, visited: Set[str]) -> int:
  390. if node in visited:
  391. # 检测到环,直接返回 0,避免无限递归
  392. return 0
  393. visited.add(node)
  394. cnt = 1 if node in elem_set else 0
  395. for ch in self.node_info.get(node, {}).get("children", []):
  396. cnt += dfs_count(ch, visited)
  397. subtree_count[node] = cnt
  398. return cnt
  399. for root_name in self.roots.values():
  400. dfs_count(root_name, set())
  401. # 再自上而下优先选择「更上层」聚类节点(但仅在 cluster_level 层):
  402. # - 若当前节点已作为聚类节点,则其子孙不再作为聚类节点(保证尽量向上聚类);
  403. # 同样需要防止意外的环导致递归过深,这里使用 visited 集合。
  404. clusters: Set[str] = set()
  405. def dfs_select(node: str, ancestor_selected: bool, visited: Set[str]) -> None:
  406. if node in visited:
  407. return
  408. visited.add(node)
  409. info = self.node_info.get(node) or {}
  410. depth = info.get("depth", 0) or 0
  411. cnt = subtree_count.get(node, 0)
  412. selected_here = False
  413. # 仅当祖先尚未被选中、当前节点位于 cluster_level 层且满足条件时,选当前节点为聚类节点
  414. if (not ancestor_selected) and depth == cluster_level and cnt >= 2:
  415. clusters.add(node)
  416. selected_here = True
  417. # 祖先已经被选中或当前节点被选中,则子孙不再作为聚类节点
  418. for ch in info.get("children", []):
  419. dfs_select(ch, ancestor_selected or selected_here, visited)
  420. for root_name in self.roots.values():
  421. dfs_select(root_name, False, set())
  422. if not clusters:
  423. return []
  424. # 统计每个聚类节点下真实覆盖的元素列表
  425. cluster_to_elements: Dict[str, Set[str]] = {c: set() for c in clusters}
  426. for e in elem_set:
  427. cur = e
  428. visited: Set[str] = set()
  429. while cur and cur not in visited:
  430. visited.add(cur)
  431. if cur in clusters:
  432. cluster_to_elements[cur].add(e)
  433. parent = self.node_info.get(cur, {}).get("parent")
  434. if parent is None:
  435. break
  436. cur = parent
  437. out: List[Dict[str, Any]] = []
  438. # 1)多元素聚类:仅统计真正输出的聚类节点所覆盖的元素,
  439. # 避免把「元素数不足 2 的节点」也算作已覆盖,从而导致元素丢失。
  440. covered_elems: Set[str] = set()
  441. for node in clusters:
  442. elems = sorted(cluster_to_elements.get(node) or [])
  443. if len(elems) < 2:
  444. # 主聚类逻辑只考虑覆盖至少 2 个元素的节点
  445. continue
  446. out.append(
  447. {
  448. "cluster_node": node,
  449. "from_elements": elems,
  450. }
  451. )
  452. for e in elems:
  453. covered_elems.add(e)
  454. # 2)对无法向上形成聚类的元素,给一个「单元素聚类」
  455. uncovered = elem_set - covered_elems
  456. # 将未覆盖元素按「cluster_level 层级的祖先节点」分组,确保同一个祖先节点下的
  457. # 多个元素合并为一个聚类,而不是多个单元素聚类。
  458. single_clusters: Dict[str, Set[str]] = {}
  459. for e in uncovered:
  460. # 单元素聚类时,cluster_node 应为「祖先节点」,不直接使用元素自身。
  461. # 这里固定选择 depth == cluster_level 的祖先节点。
  462. info_e = self.node_info.get(e) or {}
  463. parent = info_e.get("parent")
  464. cur = parent
  465. best_ancestor: Optional[str] = None
  466. visited_chain: Set[str] = set()
  467. while cur and cur not in visited_chain:
  468. visited_chain.add(cur)
  469. info = self.node_info.get(cur) or {}
  470. depth = info.get("depth", 0) or 0
  471. if depth == cluster_level:
  472. best_ancestor = cur
  473. break
  474. parent = info.get("parent")
  475. if parent is None:
  476. break
  477. cur = parent
  478. if best_ancestor:
  479. single_clusters.setdefault(best_ancestor, set()).add(e)
  480. for anc, elems in single_clusters.items():
  481. out.append(
  482. {
  483. "cluster_node": anc,
  484. "from_elements": sorted(elems),
  485. }
  486. )
  487. # 为了输出更稳定,按 from_elements 的元素数量从大到小排序,数量相同再按节点名排序
  488. out.sort(key=lambda x: (-len(x["from_elements"]), x["cluster_node"]))
  489. return out
  490. # ---------------------------------------------------------------------------
  491. # 4. 对单轮数据执行 pattern & 聚类分析
  492. # ---------------------------------------------------------------------------
  493. def _analyze_single_round(
  494. account_name: str,
  495. post_id: str,
  496. patterns: List[Dict[str, Any]],
  497. tree_index: TreeIndex,
  498. cumulative_points: List[str],
  499. match_threshold: float,
  500. cluster_level: int,
  501. ) -> Dict[str, Any]:
  502. """
  503. 对某一轮(给定累计 matched_post_point 列表)执行分析:
  504. - 筛选与帖子匹配度 >= match_threshold 的 pattern
  505. - 将 pattern 元素按 matched_score 分为「已推导元素」与「未推导元素」
  506. - 在三棵人设树中(不区分维度)为两组元素分别寻找聚类节点
  507. """
  508. patterns = _score_patterns_by_matched_points(
  509. patterns=patterns,
  510. account_name=account_name,
  511. post_id=post_id,
  512. matched_post_points=cumulative_points,
  513. match_threshold=match_threshold,
  514. )
  515. print(f"_score_patterns_by_matched_points len: {len(patterns)}")
  516. # 已推导 / 未推导 元素列表(不再按维度拆分)
  517. derived_elems: List[str] = []
  518. underived_elems: List[str] = []
  519. for p in patterns:
  520. for it in p.get("items", []):
  521. if not isinstance(it, dict):
  522. continue
  523. node_name = str(it.get("name") or "").strip()
  524. if not node_name:
  525. continue
  526. score = float(it.get("matched_score") or 0.0)
  527. if score >= match_threshold:
  528. derived_elems.append(node_name)
  529. else:
  530. underived_elems.append(node_name)
  531. # 为避免重复元素干扰统计与聚类,先做去重
  532. derived_set: List[str] = list(dict.fromkeys(derived_elems))
  533. underived_set: List[str] = list(dict.fromkeys(underived_elems))
  534. clusters: Dict[str, Any] = {
  535. "derived": [],
  536. "underived": [],
  537. }
  538. # 已推导元素聚类
  539. if derived_set:
  540. c = tree_index.find_clusters(derived_set, cluster_level=cluster_level)
  541. clusters["derived"] = c or []
  542. # 未推导元素聚类
  543. if underived_set:
  544. c = tree_index.find_clusters(underived_set, cluster_level=cluster_level)
  545. clusters["underived"] = c or []
  546. # 在同一轮中,如果某个 cluster_node 已经在 derived 聚类里出现过,
  547. # 则从 underived 聚类中剔除该 cluster_node,避免重复展示。
  548. if isinstance(clusters.get("derived"), list) and isinstance(clusters.get("underived"), list):
  549. derived_nodes = {
  550. str(item.get("cluster_node"))
  551. for item in clusters["derived"]
  552. if isinstance(item, dict) and item.get("cluster_node") is not None
  553. }
  554. if derived_nodes:
  555. filtered_underived = []
  556. for item in clusters["underived"]:
  557. if not isinstance(item, dict):
  558. continue
  559. node = str(item.get("cluster_node"))
  560. if node in derived_nodes:
  561. continue
  562. filtered_underived.append(item)
  563. clusters["underived"] = filtered_underived
  564. return {
  565. "matched_post_points": list(cumulative_points),
  566. "patterns": patterns,
  567. "clusters": clusters,
  568. # 统计信息:
  569. # - patterns_count: 本轮参与分析的 pattern 数量
  570. # - derived_cluster_count: 已推导元素聚类节点数量
  571. # - underived_cluster_count: 未推导元素聚类节点数量
  572. "patterns_count": len(patterns),
  573. "derived_cluster_count": len(clusters["derived"]) if isinstance(clusters.get("derived"), list) else 0,
  574. "underived_cluster_count": len(clusters["underived"]) if isinstance(clusters.get("underived"), list) else 0,
  575. }
  576. def pattern_dimension_analyze(
  577. account_name: str,
  578. post_id: str,
  579. log_id: str,
  580. match_threshold: float = 0.6,
  581. cluster_level: int = 2,
  582. ) -> Dict[str, Any]:
  583. """
  584. Pattern 维度分析主入口。
  585. 参数
  586. -------
  587. account_name : 账号名(用于定位 input / output 下的数据目录)
  588. post_id : 帖子 ID(用于定位推导日志与帖子匹配数据)
  589. log_id : 推导日志目录名(../output/{account_name}/推导日志/{post_id}/{log_id}/)
  590. match_threshold : pattern 元素与 matched_post_point 的最小匹配分,默认 0.6
  591. cluster_level : 在人设树中搜索聚类节点的聚类层级(root 为 0 层),默认 2
  592. """
  593. eval_dir = _round_eval_dir(account_name, post_id, log_id)
  594. if not eval_dir.is_dir():
  595. raise FileNotFoundError(f"推导日志目录不存在: {eval_dir}")
  596. round_infos = _load_round_matched_points(account_name, post_id, log_id)
  597. if not round_infos:
  598. return {
  599. "account_name": account_name,
  600. "post_id": post_id,
  601. "log_id": log_id,
  602. "match_threshold": match_threshold,
  603. "cluster_level": cluster_level,
  604. "rounds": [],
  605. "message": "未在指定日志目录下找到任何评估结果文件(*_评估.json)",
  606. }
  607. tree_index = TreeIndex(account_name)
  608. # pattern 库只在整体分析时读取 & 去重一次,避免每一轮重复 IO 与解析
  609. raw_patterns = _load_raw_patterns(account_name)
  610. deduped_patterns = _dedupe_patterns(raw_patterns)
  611. print(f"deduped_patterns len: {len(deduped_patterns)}")
  612. rounds_output: List[Dict[str, Any]] = []
  613. for info in round_infos:
  614. r = info["round"]
  615. cumulative_points = info["cumulative_points"]
  616. analyzed = _analyze_single_round(
  617. account_name=account_name,
  618. post_id=post_id,
  619. patterns=deduped_patterns,
  620. tree_index=tree_index,
  621. cumulative_points=cumulative_points,
  622. match_threshold=match_threshold,
  623. cluster_level=cluster_level,
  624. )
  625. analyzed["round"] = r
  626. rounds_output.append(analyzed)
  627. result = {
  628. "account_name": account_name,
  629. "post_id": post_id,
  630. "log_id": log_id,
  631. "match_threshold": match_threshold,
  632. "cluster_level": cluster_level,
  633. "rounds": rounds_output,
  634. }
  635. return result
  636. def main() -> None:
  637. """本地简单测试:以家有大志账号的一次推导日志做分析,并将结果写入输出目录。"""
  638. account_name = "家有大志"
  639. post_id = "68fb6a5c000000000302e5de"
  640. # 需要根据实际运行结果修改为最新的 log_id
  641. log_id = "20260317112639"
  642. result = pattern_dimension_analyze(
  643. account_name=account_name,
  644. post_id=post_id,
  645. log_id=log_id,
  646. match_threshold=0.5,
  647. cluster_level=2,
  648. )
  649. # 控制台打印前 4000 字符,便于快速查看
  650. # print(json.dumps(result, ensure_ascii=False, indent=2)[:4000] + "...")
  651. # 写入输出文件:../output/{account_name}/推导日志/{post_id}/{log_id}/pattern_dimension_analyze.json
  652. out_dir = _round_eval_dir(account_name, post_id, log_id)
  653. out_dir.mkdir(parents=True, exist_ok=True)
  654. output_file_name = f"{post_id}_pattern_dimension_analyze.json"
  655. out_path = out_dir / output_file_name
  656. with open(out_path, "w", encoding="utf-8") as f:
  657. json.dump(result, f, ensure_ascii=False, indent=2)
  658. print(f"\n分析结果已写入: {out_path}")
  659. if __name__ == "__main__":
  660. main()