pattern_dimension_analyze.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032
  1. """
  2. Pattern 维度分析 Tool
  3. 功能概述:
  4. 1. 读取某次整体推导日志目录下各轮评估结果,累计 matched_post_point / derivation_output_point 等字段。
  5. 2. 每轮通过 derivation_output_point 在人设树中找到 cluster_level 层祖先节点(已推导维度节点集合)。
  6. 3. 从 deduped_patterns 中筛选包含已推导维度节点的 pattern,并对各元素标记是否已推导。
  7. 输入参数:
  8. - account_name: 账号名称
  9. - post_id: 帖子 ID
  10. - log_id: 推导日志目录名(形如 20260313210921)
  11. 已推导/未推导维度节点在结果中以对象列表表示,字段见 _analyze_single_round 返回说明。
  12. """
  13. import json
  14. import logging
  15. import sys
  16. from pathlib import Path
  17. from typing import Any, Dict, List, Optional, Tuple, Set
  18. logger = logging.getLogger(__name__)
  19. try:
  20. from agent.tools import tool, ToolResult, ToolContext
  21. except ImportError:
  22. def tool(*args, **kwargs):
  23. return lambda f: f
  24. ToolResult = None
  25. ToolContext = None
  26. # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
  27. _root = Path(__file__).resolve().parent.parent
  28. if str(_root) not in sys.path:
  29. sys.path.insert(0, str(_root))
  30. from tools.find_tree_node import _load_trees # 加载三棵人设树
  31. _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
  32. _BASE_OUTPUT = Path(__file__).resolve().parent.parent / "output"
  33. # pattern 库 key 定义(与 find_pattern 中保持一致)
  34. TOP_KEYS = [
  35. "depth_4",
  36. ]
  37. SUB_KEYS = ["two_x", "one_x", "zero_x"]
  38. # 在人设树中查找祖先节点的目标深度(root 为 0 层)
  39. CLUSTER_LEVEL = 3
  40. # ---------------------------------------------------------------------------
  41. # 1. 读取推导日志:按轮次累计 matched_post_point
  42. # ---------------------------------------------------------------------------
  43. def _round_eval_dir(account_name: str, post_id: str, log_id: str) -> Path:
  44. """
  45. 推导日志目录:
  46. ../output/{account_name}/推导日志/{post_id}/{log_id}/
  47. """
  48. return _BASE_OUTPUT / account_name / "推导日志" / post_id / log_id
  49. def _load_round_matched_points(
  50. account_name: str,
  51. post_id: str,
  52. log_id: str,
  53. max_round: Optional[int] = None,
  54. ) -> List[Dict[str, Any]]:
  55. """
  56. 读取指定日志目录下所有 {轮次}.评估.json,按轮次排序,生成:
  57. [
  58. {
  59. "round": 1,
  60. "round_points": [
  61. {
  62. "matched_post_point": "叙事结构",
  63. "derivation_output_point": "叙事编排",
  64. "matched_score": 0.9151,
  65. "is_fully_derived": true,
  66. },
  67. ...
  68. ],
  69. "cumulative_points": [
  70. ... 累计到本轮的去重列表(以 derivation_output_point 为去重 key) ...
  71. ],
  72. },
  73. ...
  74. ]
  75. """
  76. base_dir = _round_eval_dir(account_name, post_id, log_id)
  77. if not base_dir.is_dir():
  78. return []
  79. eval_files: List[Tuple[int, Path]] = []
  80. for p in base_dir.glob("*.json"):
  81. name = p.name
  82. # 只处理 *_评估.json
  83. if not name.endswith("评估.json"):
  84. continue
  85. try:
  86. round_str = name.split("_", 1)[0]
  87. r = int(round_str)
  88. except Exception:
  89. continue
  90. eval_files.append((r, p))
  91. if max_round is not None:
  92. eval_files = [(r, p) for r, p in eval_files if r <= max_round]
  93. eval_files.sort(key=lambda x: x[0])
  94. results: List[Dict[str, Any]] = []
  95. cumulative: List[Dict[str, Any]] = []
  96. cumulative_set: Set[str] = set() # 以 derivation_output_point 去重
  97. for r, path in eval_files:
  98. try:
  99. with open(path, "r", encoding="utf-8") as f:
  100. data = json.load(f)
  101. except Exception:
  102. continue
  103. eval_results = data.get("eval_results") or []
  104. round_points: List[Dict[str, Any]] = []
  105. seen_in_round: Set[str] = set()
  106. for item in eval_results:
  107. if not isinstance(item, dict):
  108. continue
  109. if not item.get("is_matched"):
  110. continue
  111. dop = item.get("derivation_output_point")
  112. if dop is None:
  113. continue
  114. dop = str(dop).strip()
  115. if not dop:
  116. continue
  117. # 本轮内按 derivation_output_point 去重
  118. if dop in seen_in_round:
  119. continue
  120. seen_in_round.add(dop)
  121. mpp = item.get("matched_post_point")
  122. entry: Dict[str, Any] = {
  123. "matched_post_point": str(mpp).strip() if mpp is not None else None,
  124. "derivation_output_point": dop,
  125. "matched_score": item.get("matched_score"),
  126. "is_fully_derived": item.get("is_fully_derived"),
  127. }
  128. round_points.append(entry)
  129. # 累加到累计列表(按 derivation_output_point 去重)
  130. for entry in round_points:
  131. dop = entry["derivation_output_point"]
  132. if dop not in cumulative_set:
  133. cumulative_set.add(dop)
  134. cumulative.append(entry)
  135. results.append(
  136. {
  137. "round": r,
  138. "round_points": round_points,
  139. "cumulative_points": list(cumulative),
  140. }
  141. )
  142. return results
  143. # ---------------------------------------------------------------------------
  144. # 2. 读取 pattern 库并按 matched_post_point 打分
  145. # ---------------------------------------------------------------------------
  146. def _pattern_file(account_name: str) -> Path:
  147. """pattern 库文件:../input/{account_name}/原始数据/pattern/processed_edge_data.json"""
  148. return _BASE_INPUT / account_name / "原始数据" / "pattern" / "processed_edge_data.json"
  149. def _load_raw_patterns(account_name: str) -> List[Dict[str, Any]]:
  150. """
  151. 读取 pattern 库中所有原始 pattern(保留 items 结构,不做合并)。
  152. 返回列表中每个元素形如原始 JSON 中的 pattern(此处不关心 item 的 point / dimension 字段)。
  153. """
  154. path = _pattern_file(account_name)
  155. if not path.is_file():
  156. return []
  157. with open(path, "r", encoding="utf-8") as f:
  158. data = json.load(f)
  159. patterns: List[Dict[str, Any]] = []
  160. for top in TOP_KEYS:
  161. block = data.get(top)
  162. if not isinstance(block, dict):
  163. continue
  164. for sub in SUB_KEYS:
  165. items = block.get(sub) or []
  166. if isinstance(items, list):
  167. for p in items:
  168. if isinstance(p, dict):
  169. patterns.append(p)
  170. return patterns
  171. def _slim_pattern_for_dedupe(p: Dict[str, Any]) -> Tuple[float, List[str]]:
  172. """
  173. 提取 pattern 的 support 与去重后的 item name 列表(按名称合并,不关心顺序),
  174. 用于与 find_pattern.py 中的去重逻辑对齐。
  175. """
  176. items = p.get("items") or []
  177. names = [str(it.get("name") or "").strip() for it in items if isinstance(it, dict)]
  178. seen: Set[str] = set()
  179. unique: List[str] = []
  180. for n in names:
  181. if n and n not in seen:
  182. seen.add(n)
  183. unique.append(n)
  184. try:
  185. support = float(p.get("support", 0.0))
  186. except (TypeError, ValueError):
  187. support = 0.0
  188. return support, unique
  189. def _dedupe_patterns(raw_patterns: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
  190. """
  191. 按 pattern 的 item name 集合去重(不区分顺序),与 find_pattern.py 的思路一致:
  192. - key 为 sorted(unique item names)
  193. - 同一个 key 仅保留 support 最大的 pattern(保留其原始 items 结构,方便后续打分)
  194. """
  195. key_to_best: Dict[Tuple[str, ...], Dict[str, Any]] = {}
  196. key_to_support: Dict[Tuple[str, ...], float] = {}
  197. for p in raw_patterns:
  198. support, unique = _slim_pattern_for_dedupe(p)
  199. if not unique:
  200. continue
  201. key = tuple(sorted(unique))
  202. best_support = key_to_support.get(key)
  203. if best_support is None or support > best_support:
  204. key_to_support[key] = support
  205. key_to_best[key] = p
  206. return list(key_to_best.values())
  207. # ---------------------------------------------------------------------------
  208. # 3. 人设树节点信息 & 聚类节点搜索
  209. # ---------------------------------------------------------------------------
  210. class TreeIndex:
  211. """
  212. 人设树索引:
  213. - node_info: 节点 -> { "parent": 父节点名称, "children": [子节点名称...], "depth": 深度, "dimension": 维度名 }
  214. - roots: 维度名 -> 根节点名称(即维度名本身)
  215. - merged_tree: 将实质/形式/意图三棵树合并后的单个 JSON(顶层 key 为实质/形式/意图)
  216. """
  217. def __init__(self, account_name: str) -> None:
  218. self.account_name = account_name
  219. self.node_info: Dict[str, Dict[str, Any]] = {}
  220. self.roots: Dict[str, str] = {}
  221. # 三棵树合并后的 JSON:{"实质": {...}, "形式": {...}, "意图": {...}}
  222. self.merged_tree: Dict[str, Dict[str, Any]] = {}
  223. self._build()
  224. def _build(self) -> None:
  225. trees = _load_trees(self.account_name)
  226. # 1)先将三棵树合并成一个 JSON:{"实质": {...}, "形式": {...}, "意图": {...}}
  227. merged: Dict[str, Dict[str, Any]] = {}
  228. for dim_name, root in trees:
  229. if isinstance(root, dict):
  230. merged[dim_name] = root
  231. self.merged_tree = merged
  232. # 2)基于合并后的 JSON 构建 parent/children 结构
  233. for dim_name, root in merged.items():
  234. root_name = dim_name
  235. self.roots[dim_name] = root_name
  236. if root_name not in self.node_info:
  237. self.node_info[root_name] = {
  238. "parent": None,
  239. "children": [],
  240. "dimension": dim_name,
  241. "depth": 0,
  242. }
  243. def walk(parent_name: str, node_dict: Dict[str, Any]):
  244. children = node_dict.get("children") or {}
  245. for name, child in children.items():
  246. if not isinstance(child, dict):
  247. continue
  248. if name not in self.node_info:
  249. self.node_info[name] = {
  250. "parent": parent_name,
  251. "children": [],
  252. "dimension": dim_name,
  253. "depth": None, # 稍后统一计算
  254. }
  255. else:
  256. # 仅当不会形成自引用时才更新 parent(树中可能存在同名的父子节点)
  257. if name != parent_name:
  258. self.node_info[name]["parent"] = parent_name
  259. self.node_info[name]["dimension"] = dim_name
  260. # 维护父节点的 children
  261. if parent_name not in self.node_info:
  262. self.node_info[parent_name] = {
  263. "parent": None,
  264. "children": [],
  265. "dimension": dim_name,
  266. "depth": 0,
  267. }
  268. if name not in self.node_info[parent_name]["children"]:
  269. self.node_info[parent_name]["children"].append(name)
  270. walk(name, child)
  271. walk(root_name, root)
  272. # 统一计算各节点深度(从根开始 BFS)
  273. from collections import deque
  274. q = deque()
  275. for dim_name, root_name in self.roots.items():
  276. if root_name not in self.node_info:
  277. continue
  278. self.node_info[root_name]["depth"] = 0
  279. q.append(root_name)
  280. while q:
  281. cur = q.popleft()
  282. cur_depth = self.node_info[cur].get("depth", 0) or 0
  283. for child in self.node_info[cur].get("children", []):
  284. self.node_info.setdefault(child, {})
  285. if self.node_info[child].get("depth") is None:
  286. self.node_info[child]["depth"] = cur_depth + 1
  287. # BFS 首次到达该节点时(即最短路径),同步修正 parent 指针,
  288. # 确保 parent 与 depth 始终保持一致。
  289. # 若同名节点在树中多处出现,walk() 会用最后一次遍历的父节点
  290. # 覆盖 parent,导致 parent 指向更深处的节点,
  291. # 进而使 find_ancestor_at_level 沿 parent 链爬升时出现深度
  292. # 倒退(越走越深)甚至返回错误祖先/None 的问题。
  293. # 在 BFS 阶段统一修正,可保证 parent 链单调递减至根节点。
  294. self.node_info[child]["parent"] = cur
  295. q.append(child)
  296. def find_ancestor_at_level(self, node_name: str, level: int) -> Optional[str]:
  297. """
  298. 在人设树中找到 node_name 的 depth == level 的祖先节点。
  299. - 若 node_name 自身 depth == level,直接返回自身。
  300. - 若 node_name depth < level(比目标层浅),返回自身。
  301. - 否则沿 parent 链向上查找,返回第一个 depth == level 的祖先节点。
  302. 说明:
  303. 早期实现中为了防止意外环路使用了 visited 集合,一旦检测到「重复节点」就直接
  304. 返回 None,导致在树中存在同名节点、且 parent 指针被覆盖的情况下,会错误返回
  305. None。这里改为**只沿 parent 链向上行走**,不再依赖 visited 截断:
  306. - 每一步仅查看当前节点的 depth 与 parent;
  307. - 一旦到达 depth <= level,直接返回当前节点;
  308. - 若 parent 为空,则返回当前已到达的最高节点。
  309. 在正常树结构下(parent 指针无环),该过程必然在有限步内结束;若底层数据意外
  310. 形成环,需在构建 node_info 时修复,祖先查找本身不再额外承担防御职责。
  311. """
  312. info = self.node_info.get(node_name)
  313. if not info:
  314. return None
  315. depth = info.get("depth")
  316. if depth is None:
  317. return None
  318. if depth <= level:
  319. return node_name
  320. # 只沿 parent 链向上查找,不再依赖 visited 截断;
  321. # 一旦到达 depth <= level 或 parent 为空即返回当前节点。
  322. cur = node_name
  323. while True:
  324. cur_info = self.node_info.get(cur) or {}
  325. cur_depth = cur_info.get("depth")
  326. if cur_depth is None:
  327. return cur
  328. if cur_depth <= level:
  329. return cur
  330. parent = cur_info.get("parent")
  331. if parent is None:
  332. return cur
  333. cur = parent
  334. # 聚类搜索(不再区分维度)
  335. def find_clusters(
  336. self,
  337. elements: List[str],
  338. cluster_level: int,
  339. ) -> List[Dict[str, Any]]:
  340. """
  341. 在所有人设树中,为给定元素列表寻找聚类节点(不再要求 dimension 一致)。
  342. 规则(固定聚类层级 cluster_level):
  343. - 仅在 depth == cluster_level 的节点上做聚类判断:
  344. * 若某节点子树中包含的元素数量 >= 2,
  345. 且在该路径上尚未存在更高层(深度更小)的聚类节点,则将其视为一个聚类节点。
  346. - 对无法向上形成聚类的元素,为其寻找 depth == cluster_level 的祖先节点,
  347. 若存在则作为该元素的「单元素聚类」节点。
  348. - 返回:
  349. [
  350. {
  351. "cluster_node": "节点名",
  352. "from_elements": ["元素A", "元素B", ...]
  353. },
  354. ...
  355. ]
  356. """
  357. # 过滤出真实存在于人设树中的元素
  358. elem_set: Set[str] = set()
  359. for e in elements:
  360. e = str(e).strip()
  361. if not e:
  362. continue
  363. info = self.node_info.get(e)
  364. if not info:
  365. continue
  366. elem_set.add(e)
  367. if not elem_set:
  368. return []
  369. # 先计算每个节点子树中包含的元素数量(跨所有维度的根)
  370. # 注意:人设树数据中可能存在意外的环或重复引用,这里通过 visited 集合避免递归死循环。
  371. subtree_count: Dict[str, int] = {}
  372. def dfs_count(node: str, visited: Set[str]) -> int:
  373. if node in visited:
  374. # 检测到环,直接返回 0,避免无限递归
  375. return 0
  376. visited.add(node)
  377. cnt = 1 if node in elem_set else 0
  378. for ch in self.node_info.get(node, {}).get("children", []):
  379. cnt += dfs_count(ch, visited)
  380. subtree_count[node] = cnt
  381. return cnt
  382. for root_name in self.roots.values():
  383. dfs_count(root_name, set())
  384. # 再自上而下优先选择「更上层」聚类节点(但仅在 cluster_level 层):
  385. # - 若当前节点已作为聚类节点,则其子孙不再作为聚类节点(保证尽量向上聚类);
  386. # 同样需要防止意外的环导致递归过深,这里使用 visited 集合。
  387. clusters: Set[str] = set()
  388. def dfs_select(node: str, ancestor_selected: bool, visited: Set[str]) -> None:
  389. if node in visited:
  390. return
  391. visited.add(node)
  392. info = self.node_info.get(node) or {}
  393. depth = info.get("depth", 0) or 0
  394. cnt = subtree_count.get(node, 0)
  395. selected_here = False
  396. # 仅当祖先尚未被选中、当前节点位于 cluster_level 层且满足条件时,选当前节点为聚类节点
  397. if (not ancestor_selected) and depth == cluster_level and cnt >= 2:
  398. clusters.add(node)
  399. selected_here = True
  400. # 祖先已经被选中或当前节点被选中,则子孙不再作为聚类节点
  401. for ch in info.get("children", []):
  402. dfs_select(ch, ancestor_selected or selected_here, visited)
  403. for root_name in self.roots.values():
  404. dfs_select(root_name, False, set())
  405. if not clusters:
  406. return []
  407. # 统计每个聚类节点下真实覆盖的元素列表
  408. cluster_to_elements: Dict[str, Set[str]] = {c: set() for c in clusters}
  409. for e in elem_set:
  410. cur = e
  411. visited: Set[str] = set()
  412. while cur and cur not in visited:
  413. visited.add(cur)
  414. if cur in clusters:
  415. cluster_to_elements[cur].add(e)
  416. parent = self.node_info.get(cur, {}).get("parent")
  417. if parent is None:
  418. break
  419. cur = parent
  420. out: List[Dict[str, Any]] = []
  421. # 1)多元素聚类:仅统计真正输出的聚类节点所覆盖的元素,
  422. # 避免把「元素数不足 2 的节点」也算作已覆盖,从而导致元素丢失。
  423. covered_elems: Set[str] = set()
  424. for node in clusters:
  425. elems = sorted(cluster_to_elements.get(node) or [])
  426. if len(elems) < 2:
  427. # 主聚类逻辑只考虑覆盖至少 2 个元素的节点
  428. continue
  429. out.append(
  430. {
  431. "cluster_node": node,
  432. "from_elements": elems,
  433. }
  434. )
  435. for e in elems:
  436. covered_elems.add(e)
  437. # 2)对无法向上形成聚类的元素,给一个「单元素聚类」
  438. uncovered = elem_set - covered_elems
  439. # 将未覆盖元素按「cluster_level 层级的祖先节点」分组,确保同一个祖先节点下的
  440. # 多个元素合并为一个聚类,而不是多个单元素聚类。
  441. single_clusters: Dict[str, Set[str]] = {}
  442. for e in uncovered:
  443. # 单元素聚类时,cluster_node 应为「祖先节点」,不直接使用元素自身。
  444. # 这里固定选择 depth == cluster_level 的祖先节点。
  445. info_e = self.node_info.get(e) or {}
  446. parent = info_e.get("parent")
  447. cur = parent
  448. best_ancestor: Optional[str] = None
  449. visited_chain: Set[str] = set()
  450. while cur and cur not in visited_chain:
  451. visited_chain.add(cur)
  452. info = self.node_info.get(cur) or {}
  453. depth = info.get("depth", 0) or 0
  454. if depth == cluster_level:
  455. best_ancestor = cur
  456. break
  457. parent = info.get("parent")
  458. if parent is None:
  459. break
  460. cur = parent
  461. if best_ancestor:
  462. single_clusters.setdefault(best_ancestor, set()).add(e)
  463. for anc, elems in single_clusters.items():
  464. out.append(
  465. {
  466. "cluster_node": anc,
  467. "from_elements": sorted(elems),
  468. }
  469. )
  470. # 为了输出更稳定,按 from_elements 的元素数量从大到小排序,数量相同再按节点名排序
  471. out.sort(key=lambda x: (-len(x["from_elements"]), x["cluster_node"]))
  472. return out
  473. # ---------------------------------------------------------------------------
  474. # 4. 对单轮数据执行 pattern & 聚类分析
  475. # ---------------------------------------------------------------------------
  476. def _dim_obj(
  477. tree_node_name: str,
  478. tree_index: TreeIndex,
  479. matched_point: Optional[str] = None,
  480. ) -> Dict[str, Any]:
  481. dim = (tree_index.node_info.get(tree_node_name) or {}).get("dimension") or ""
  482. o: Dict[str, Any] = {
  483. "tree_node_name": tree_node_name,
  484. "dimension": dim,
  485. }
  486. if matched_point is not None:
  487. o["matched_point"] = matched_point
  488. return o
  489. def _entry_to_matched_point(entry: Dict[str, Any]) -> str:
  490. """is_fully_derived 为 true 时用 matched_post_point,否则用 derivation_output_point。"""
  491. dop = entry.get("derivation_output_point")
  492. dop_s = str(dop).strip() if dop is not None else ""
  493. if entry.get("is_fully_derived") is True:
  494. mpp = entry.get("matched_post_point")
  495. return str(mpp).strip() if mpp is not None else ""
  496. return dop_s
  497. def _analyze_single_round(
  498. patterns: List[Dict[str, Any]],
  499. tree_index: TreeIndex,
  500. cumulative_points: List[Dict[str, Any]],
  501. cluster_level: int = CLUSTER_LEVEL,
  502. ) -> Dict[str, Any]:
  503. """
  504. 对某一轮(给定累计 point 列表)执行维度分析:
  505. 1. 从 cumulative_points 中提取 derivation_output_point,
  506. 在人设树中找到每个节点的 cluster_level 层祖先 → derived_ancestor_set(已推导维度节点集合)。
  507. 2. 从 deduped_patterns 中筛选出包含 derived_ancestor_set 中节点的 pattern。
  508. 3. 对筛选出 pattern 的每个元素标记是否已推导:
  509. - 元素在 derived_ancestor_set 中 → is_derived=True(已推导维度)
  510. - 其他 → is_derived=False(未推导维度)
  511. 4. 汇总 derived_dims / underived_dims 对象列表。
  512. 返回结构(节选):
  513. - derived_ancestor_nodes: [{ tree_node_name, dimension, matched_point }, ...]
  514. - derived_dims: [{ tree_node_name, dimension, matched_point }, ...]
  515. - underived_dims: [{ tree_node_name, dimension }, ...](无 matched_point)
  516. """
  517. # 1. 收集 derived_ancestor_set,同时按规则累计每个祖先的 matched_point
  518. derived_ancestor_set: Set[str] = set()
  519. ancestor_to_matched: Dict[str, List[str]] = {}
  520. for entry in cumulative_points:
  521. if not isinstance(entry, dict):
  522. continue
  523. dop = entry.get("derivation_output_point")
  524. if not dop:
  525. continue
  526. ancestor = tree_index.find_ancestor_at_level(str(dop).strip(), cluster_level)
  527. if not ancestor:
  528. continue
  529. derived_ancestor_set.add(ancestor)
  530. pt = _entry_to_matched_point(entry)
  531. if pt and pt not in ancestor_to_matched.get(ancestor, []):
  532. ancestor_to_matched.setdefault(ancestor, []).append(pt)
  533. # 2. 筛选 pattern:已推导维度节点占所有元素的比例 >= 50%
  534. filtered_patterns: List[Dict[str, Any]] = []
  535. for p in patterns:
  536. items = p.get("items") or []
  537. item_names = [
  538. str(it.get("name") or "").strip()
  539. for it in items
  540. if isinstance(it, dict)
  541. ]
  542. if not item_names:
  543. continue
  544. if len(item_names) < 5:
  545. continue
  546. derived_count = sum(1 for name in item_names if name in derived_ancestor_set)
  547. if derived_count / len(item_names) >= 0.5:
  548. filtered_patterns.append(p)
  549. print(
  550. f"filtered_patterns: {len(filtered_patterns)}, "
  551. f"derived_ancestor_set: {len(derived_ancestor_set)}"
  552. )
  553. def _matched_join(name: str) -> str:
  554. pts = ancestor_to_matched.get(name) or []
  555. return ", ".join(pts) if pts else ""
  556. derived_ancestor_nodes: List[Dict[str, Any]] = []
  557. for anc in sorted(derived_ancestor_set):
  558. derived_ancestor_nodes.append(
  559. _dim_obj(anc, tree_index, matched_point=_matched_join(anc) or "")
  560. )
  561. # 3. 对筛选 pattern 元素分类并汇总维度列表
  562. derived_dims: List[Dict[str, Any]] = []
  563. underived_dims: List[Dict[str, Any]] = []
  564. derived_dims_seen: Set[str] = set()
  565. underived_dims_seen: Set[str] = set()
  566. scored_patterns: List[Dict[str, Any]] = []
  567. for p in filtered_patterns:
  568. items = p.get("items") or []
  569. tagged_items: List[Dict[str, Any]] = []
  570. for it in items:
  571. if not isinstance(it, dict):
  572. continue
  573. name = str(it.get("name") or "").strip()
  574. is_derived = name in derived_ancestor_set
  575. tagged_items.append(
  576. {
  577. "name": name,
  578. "is_derived": is_derived,
  579. }
  580. )
  581. if is_derived:
  582. if name and name not in derived_dims_seen:
  583. derived_dims_seen.add(name)
  584. derived_dims.append(
  585. _dim_obj(
  586. name,
  587. tree_index,
  588. matched_point=_matched_join(name) or "",
  589. )
  590. )
  591. else:
  592. if name and name not in underived_dims_seen:
  593. underived_dims_seen.add(name)
  594. underived_dims.append(_dim_obj(name, tree_index))
  595. scored_patterns.append(
  596. {
  597. "id": p.get("id"),
  598. "support": p.get("support"),
  599. "items": tagged_items,
  600. }
  601. )
  602. # 从 underived_dims 中排除与 derived_dims 重叠的节点
  603. underived_dims = [d for d in underived_dims if d["tree_node_name"] not in derived_dims_seen]
  604. # 按 is_derived=True 的元素数量从高到低排序,数量相同再按元素总数从高到低
  605. scored_patterns.sort(
  606. key=lambda x: (
  607. sum(1 for it in x.get("items", []) if it.get("is_derived")),
  608. len(x.get("items", [])),
  609. ),
  610. reverse=True,
  611. )
  612. return {
  613. "cumulative_points": list(cumulative_points),
  614. "derived_ancestor_nodes": derived_ancestor_nodes,
  615. "patterns": scored_patterns,
  616. "derived_dims": derived_dims,
  617. "underived_dims": underived_dims,
  618. "patterns_count": len(scored_patterns),
  619. "derived_dim_count": len(derived_dims),
  620. "underived_dim_count": len(underived_dims),
  621. }
  622. def _format_round_dimension_text(
  623. derived_dims: List[Dict[str, Any]],
  624. underived_dims: List[Dict[str, Any]],
  625. ) -> str:
  626. """已推导/未推导维度,每行:维度:tree_node_name,匹配点:matched_point"""
  627. lines: List[str] = ["【已推导的维度】"]
  628. for d in derived_dims:
  629. name = d.get("tree_node_name") or ""
  630. mp = d.get("matched_point") or "-"
  631. lines.append(f"维度:{name},匹配点:{mp}")
  632. if not derived_dims:
  633. lines.append("(无)")
  634. lines.append("")
  635. lines.append("【未推导的维度】")
  636. for d in underived_dims:
  637. name = d.get("tree_node_name") or ""
  638. lines.append(f"维度:{name}")
  639. if not underived_dims:
  640. lines.append("(无)")
  641. return "\n".join(lines)
  642. def pattern_dimension_analyze(
  643. account_name: str,
  644. post_id: str,
  645. log_id: str,
  646. ) -> Dict[str, Any]:
  647. """
  648. Pattern 维度分析主入口。
  649. 参数
  650. -------
  651. account_name : 账号名(用于定位 input / output 下的数据目录)
  652. post_id : 帖子 ID(用于定位推导日志)
  653. log_id : 推导日志目录名(../output/{account_name}/推导日志/{post_id}/{log_id}/)
  654. 逻辑概述
  655. --------
  656. 聚类层级固定为 CLUSTER_LEVEL(默认 3)。每一轮:
  657. 1. 从 derivation_output_point 在人设树中找到该层祖先节点 → 已推导维度节点集合。
  658. 2. 筛选包含已推导维度节点的 pattern。
  659. 3. 标记每个 pattern 元素是否已推导,汇总 derived_dims / underived_dims(对象列表)。
  660. """
  661. eval_dir = _round_eval_dir(account_name, post_id, log_id)
  662. if not eval_dir.is_dir():
  663. raise FileNotFoundError(f"推导日志目录不存在: {eval_dir}")
  664. round_infos = _load_round_matched_points(account_name, post_id, log_id)
  665. if not round_infos:
  666. return {
  667. "account_name": account_name,
  668. "post_id": post_id,
  669. "log_id": log_id,
  670. "cluster_level": CLUSTER_LEVEL,
  671. "rounds": [],
  672. "message": "未在指定日志目录下找到任何评估结果文件(*_评估.json)",
  673. }
  674. tree_index = TreeIndex(account_name)
  675. # pattern 库只在整体分析时读取 & 去重一次,避免每一轮重复 IO 与解析
  676. raw_patterns = _load_raw_patterns(account_name)
  677. deduped_patterns = _dedupe_patterns(raw_patterns)
  678. print(f"deduped_patterns len: {len(deduped_patterns)}")
  679. rounds_output: List[Dict[str, Any]] = []
  680. for info in round_infos:
  681. r = info["round"]
  682. cumulative_points = info["cumulative_points"]
  683. analyzed = _analyze_single_round(
  684. patterns=deduped_patterns,
  685. tree_index=tree_index,
  686. cumulative_points=cumulative_points,
  687. )
  688. analyzed["round"] = r
  689. rounds_output.append(analyzed)
  690. return {
  691. "account_name": account_name,
  692. "post_id": post_id,
  693. "log_id": log_id,
  694. "cluster_level": CLUSTER_LEVEL,
  695. "rounds": rounds_output,
  696. }
  697. def round_pattern_dimension_analyze_core(
  698. account_name: str,
  699. post_id: str,
  700. log_id: str,
  701. round: int,
  702. ) -> Dict[str, Any]:
  703. """
  704. 仅使用第 round 轮及之前的评估文件,得到该轮结束时的累计选题点状态并做维度分析。
  705. 返回 analyzed 单轮结构(含 derived_dims / underived_dims 等),失败时含 error 字段。
  706. """
  707. eval_dir = _round_eval_dir(account_name, post_id, log_id)
  708. if not eval_dir.is_dir():
  709. return {"error": f"推导日志目录不存在: {eval_dir}"}
  710. round_infos = _load_round_matched_points(
  711. account_name, post_id, log_id, max_round=round
  712. )
  713. if not round_infos:
  714. return {
  715. "error": f"在 {eval_dir} 下未找到第 {round} 轮及之前的 *_评估.json",
  716. }
  717. last = round_infos[-1]
  718. if last.get("round") != round:
  719. return {
  720. "error": (
  721. f"指定轮次 {round} 的评估文件不存在;"
  722. f"当前仅加载到第 {last.get('round')} 轮"
  723. ),
  724. }
  725. tree_index = TreeIndex(account_name)
  726. raw_patterns = _load_raw_patterns(account_name)
  727. deduped_patterns = _dedupe_patterns(raw_patterns)
  728. analyzed = _analyze_single_round(
  729. patterns=deduped_patterns,
  730. tree_index=tree_index,
  731. cumulative_points=last["cumulative_points"],
  732. )
  733. analyzed["round"] = round
  734. return analyzed
  735. @tool()
  736. async def round_pattern_dimension_analyze(
  737. account_name: str,
  738. post_id: str,
  739. log_id: str,
  740. round: int,
  741. ) -> Any:
  742. """
  743. 推导维度分析,返回当前轮次已推导的维度和可能的未推导维度数据
  744. Args:
  745. account_name: 账号名称
  746. post_id: 帖子 ID
  747. log_id: 推导日志目录名
  748. round: 推导轮次(正整数)
  749. Returns:
  750. ToolResult:output 为可读文本,含「已推导的维度」「未推导的维度」两段,
  751. 每行格式为「维度:tree_node_name,匹配点:matched_point」
  752. (未推导行固定为「-」;matched_point 规则:is_fully_derived 为真取选题点否则取推导输出点)。
  753. """
  754. if ToolResult is None:
  755. return None
  756. logger.info(
  757. "round_pattern_dimension_analyze: account=%s post_id=%s log_id=%s round=%s",
  758. account_name,
  759. post_id,
  760. log_id,
  761. round,
  762. )
  763. try:
  764. r = int(round)
  765. if r < 1:
  766. return ToolResult(
  767. title="维度分析: 轮次无效",
  768. output="",
  769. error="round 须为 >= 1 的整数",
  770. )
  771. except (TypeError, ValueError):
  772. return ToolResult(
  773. title="维度分析: 轮次无效",
  774. output="",
  775. error="round 须为整数",
  776. )
  777. try:
  778. analyzed = round_pattern_dimension_analyze_core(
  779. account_name, post_id, log_id, r
  780. )
  781. if analyzed.get("error"):
  782. return ToolResult(
  783. title=f"维度分析 第{r}轮 失败",
  784. output="",
  785. error=str(analyzed["error"]),
  786. )
  787. # 保存到与 {轮次}_评估.json 同级目录
  788. out_dir = _round_eval_dir(account_name, post_id, log_id)
  789. out_dir.mkdir(parents=True, exist_ok=True)
  790. out_path = out_dir / f"{r}_维度分析.json"
  791. with open(out_path, "w", encoding="utf-8") as f:
  792. json.dump(analyzed, f, ensure_ascii=False, indent=2)
  793. derived = analyzed.get("derived_dims") or []
  794. underived = analyzed.get("underived_dims") or []
  795. text = _format_round_dimension_text(derived, underived)
  796. meta = (
  797. f"round={r}, derived={len(derived)}, underived={len(underived)}, "
  798. f"patterns={analyzed.get('patterns_count', 0)}"
  799. )
  800. return ToolResult(
  801. title=f"第 {r} 轮维度分析(已推导 {len(derived)} / 未推导 {len(underived)})",
  802. output=text,
  803. metadata={"round_pattern_dimension_analyze": meta},
  804. )
  805. except Exception as e:
  806. logger.exception("round_pattern_dimension_analyze failed: %s", e)
  807. return ToolResult(
  808. title="维度分析失败",
  809. output="",
  810. error=str(e),
  811. )
  812. def main(account_name, post_id, log_id) -> None:
  813. """本地简单测试:以家有大志账号的一次推导日志做分析,并将结果写入输出目录。"""
  814. result = pattern_dimension_analyze(
  815. account_name=account_name,
  816. post_id=post_id,
  817. log_id=log_id,
  818. )
  819. # 控制台打印前 4000 字符,便于快速查看
  820. # print(json.dumps(result, ensure_ascii=False, indent=2)[:4000] + "...")
  821. # 写入输出文件 1:../output/{account_name}/推导日志/{post_id}/{log_id}/pattern_dimension_analyze.json
  822. out_dir = _round_eval_dir(account_name, post_id, log_id)
  823. out_dir.mkdir(parents=True, exist_ok=True)
  824. output_file_name = f"{post_id}_pattern_dimension_analyze.json"
  825. out_path = out_dir / output_file_name
  826. with open(out_path, "w", encoding="utf-8") as f:
  827. json.dump(result, f, ensure_ascii=False, indent=2)
  828. # 写入输出文件 2:../output/{account_name}/整体推导维度分析/{post_id}_pattern_dimension_analyze.json
  829. overall_dir = _BASE_OUTPUT / account_name / "整体推导维度分析"
  830. overall_dir.mkdir(parents=True, exist_ok=True)
  831. overall_output_file_name = f"{post_id}_pattern_dimension_analyze.json"
  832. overall_out_path = overall_dir / overall_output_file_name
  833. with open(overall_out_path, "w", encoding="utf-8") as f:
  834. json.dump(result, f, ensure_ascii=False, indent=2)
  835. print(f"\n分析结果已写入: {out_path}")
  836. def main_round_pattern_dimension_analyze(
  837. account_name: str,
  838. post_id: str,
  839. log_id: str,
  840. round_num: int,
  841. ) -> None:
  842. """本地测试:直接调用 round_pattern_dimension_analyze,打印 ToolResult。"""
  843. import asyncio
  844. async def _run() -> None:
  845. result = await round_pattern_dimension_analyze(
  846. account_name=account_name,
  847. post_id=post_id,
  848. log_id=log_id,
  849. round=round_num,
  850. )
  851. if result is None:
  852. print(
  853. "round_pattern_dimension_analyze 返回 None:请先将 Agent 项目根目录加入 PYTHONPATH,"
  854. "或在 __main__ 中保证能 import agent.tools.ToolResult"
  855. )
  856. return
  857. if result.error:
  858. print(f"错误: {result.error}")
  859. else:
  860. print(result.title)
  861. print(result.output)
  862. asyncio.run(_run())
  863. if __name__ == "__main__":
  864. import asyncio
  865. import importlib.util
  866. # 直接加载 ToolResult,避免 import agent 时拉全量依赖(如 langchain)
  867. _agent_root = Path(__file__).resolve().parents[3]
  868. _models_py = _agent_root / "agent" / "tools" / "models.py"
  869. if _models_py.is_file():
  870. _spec = importlib.util.spec_from_file_location(
  871. "_pattern_dim_tool_models", _models_py
  872. )
  873. if _spec and _spec.loader:
  874. _m = importlib.util.module_from_spec(_spec)
  875. _spec.loader.exec_module(_m)
  876. globals()["ToolResult"] = _m.ToolResult
  877. if getattr(_m, "ToolContext", None) is not None:
  878. globals()["ToolContext"] = _m.ToolContext
  879. # ---------- 开关与参数(改这里即可) ----------
  880. run_round_pattern_test = False
  881. run_full_pattern_analyze = True
  882. test_account_name = "家有大志"
  883. test_post_id = "69185d49000000000d00f94e"
  884. test_log_id = "20260318221136"
  885. test_round = 1
  886. items = [
  887. {"post_id": "68fb6a5c000000000302e5de", "log_id": "20260319134630"},
  888. {"post_id": "69185d49000000000d00f94e", "log_id": "20260319140603"},
  889. {"post_id": "6921937a000000001b0278d1", "log_id": "20260319141843"}
  890. ]
  891. if run_round_pattern_test:
  892. main_round_pattern_dimension_analyze(
  893. test_account_name,
  894. test_post_id,
  895. test_log_id,
  896. test_round,
  897. )
  898. elif run_full_pattern_analyze:
  899. for item in items:
  900. main(test_account_name, item["post_id"], item["log_id"])