| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032 |
- """
- Pattern 维度分析 Tool
- 功能概述:
- 1. 读取某次整体推导日志目录下各轮评估结果,累计 matched_post_point / derivation_output_point 等字段。
- 2. 每轮通过 derivation_output_point 在人设树中找到 cluster_level 层祖先节点(已推导维度节点集合)。
- 3. 从 deduped_patterns 中筛选包含已推导维度节点的 pattern,并对各元素标记是否已推导。
- 输入参数:
- - account_name: 账号名称
- - post_id: 帖子 ID
- - log_id: 推导日志目录名(形如 20260313210921)
- 已推导/未推导维度节点在结果中以对象列表表示,字段见 _analyze_single_round 返回说明。
- """
- import json
- import logging
- import sys
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Tuple, Set
- logger = logging.getLogger(__name__)
- try:
- from agent.tools import tool, ToolResult, ToolContext
- except ImportError:
- def tool(*args, **kwargs):
- return lambda f: f
- ToolResult = None
- ToolContext = None
- # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
- _root = Path(__file__).resolve().parent.parent
- if str(_root) not in sys.path:
- sys.path.insert(0, str(_root))
- from tools.find_tree_node import _load_trees # 加载三棵人设树
- _BASE_INPUT = Path(__file__).resolve().parent.parent / "input"
- _BASE_OUTPUT = Path(__file__).resolve().parent.parent / "output"
- # pattern 库 key 定义(与 find_pattern 中保持一致)
- TOP_KEYS = [
- "depth_4",
- ]
- SUB_KEYS = ["two_x", "one_x", "zero_x"]
- # 在人设树中查找祖先节点的目标深度(root 为 0 层)
- CLUSTER_LEVEL = 3
- # ---------------------------------------------------------------------------
- # 1. 读取推导日志:按轮次累计 matched_post_point
- # ---------------------------------------------------------------------------
- def _round_eval_dir(account_name: str, post_id: str, log_id: str) -> Path:
- """
- 推导日志目录:
- ../output/{account_name}/推导日志/{post_id}/{log_id}/
- """
- return _BASE_OUTPUT / account_name / "推导日志" / post_id / log_id
- def _load_round_matched_points(
- account_name: str,
- post_id: str,
- log_id: str,
- max_round: Optional[int] = None,
- ) -> List[Dict[str, Any]]:
- """
- 读取指定日志目录下所有 {轮次}.评估.json,按轮次排序,生成:
- [
- {
- "round": 1,
- "round_points": [
- {
- "matched_post_point": "叙事结构",
- "derivation_output_point": "叙事编排",
- "matched_score": 0.9151,
- "is_fully_derived": true,
- },
- ...
- ],
- "cumulative_points": [
- ... 累计到本轮的去重列表(以 derivation_output_point 为去重 key) ...
- ],
- },
- ...
- ]
- """
- base_dir = _round_eval_dir(account_name, post_id, log_id)
- if not base_dir.is_dir():
- return []
- eval_files: List[Tuple[int, Path]] = []
- for p in base_dir.glob("*.json"):
- name = p.name
- # 只处理 *_评估.json
- if not name.endswith("评估.json"):
- continue
- try:
- round_str = name.split("_", 1)[0]
- r = int(round_str)
- except Exception:
- continue
- eval_files.append((r, p))
- if max_round is not None:
- eval_files = [(r, p) for r, p in eval_files if r <= max_round]
- eval_files.sort(key=lambda x: x[0])
- results: List[Dict[str, Any]] = []
- cumulative: List[Dict[str, Any]] = []
- cumulative_set: Set[str] = set() # 以 derivation_output_point 去重
- for r, path in eval_files:
- try:
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
- except Exception:
- continue
- eval_results = data.get("eval_results") or []
- round_points: List[Dict[str, Any]] = []
- seen_in_round: Set[str] = set()
- for item in eval_results:
- if not isinstance(item, dict):
- continue
- if not item.get("is_matched"):
- continue
- dop = item.get("derivation_output_point")
- if dop is None:
- continue
- dop = str(dop).strip()
- if not dop:
- continue
- # 本轮内按 derivation_output_point 去重
- if dop in seen_in_round:
- continue
- seen_in_round.add(dop)
- mpp = item.get("matched_post_point")
- entry: Dict[str, Any] = {
- "matched_post_point": str(mpp).strip() if mpp is not None else None,
- "derivation_output_point": dop,
- "matched_score": item.get("matched_score"),
- "is_fully_derived": item.get("is_fully_derived"),
- }
- round_points.append(entry)
- # 累加到累计列表(按 derivation_output_point 去重)
- for entry in round_points:
- dop = entry["derivation_output_point"]
- if dop not in cumulative_set:
- cumulative_set.add(dop)
- cumulative.append(entry)
- results.append(
- {
- "round": r,
- "round_points": round_points,
- "cumulative_points": list(cumulative),
- }
- )
- return results
- # ---------------------------------------------------------------------------
- # 2. 读取 pattern 库并按 matched_post_point 打分
- # ---------------------------------------------------------------------------
- def _pattern_file(account_name: str) -> Path:
- """pattern 库文件:../input/{account_name}/原始数据/pattern/processed_edge_data.json"""
- return _BASE_INPUT / account_name / "原始数据" / "pattern" / "processed_edge_data.json"
- def _load_raw_patterns(account_name: str) -> List[Dict[str, Any]]:
- """
- 读取 pattern 库中所有原始 pattern(保留 items 结构,不做合并)。
- 返回列表中每个元素形如原始 JSON 中的 pattern(此处不关心 item 的 point / dimension 字段)。
- """
- path = _pattern_file(account_name)
- if not path.is_file():
- return []
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
- patterns: List[Dict[str, Any]] = []
- for top in TOP_KEYS:
- block = data.get(top)
- if not isinstance(block, dict):
- continue
- for sub in SUB_KEYS:
- items = block.get(sub) or []
- if isinstance(items, list):
- for p in items:
- if isinstance(p, dict):
- patterns.append(p)
- return patterns
- def _slim_pattern_for_dedupe(p: Dict[str, Any]) -> Tuple[float, List[str]]:
- """
- 提取 pattern 的 support 与去重后的 item name 列表(按名称合并,不关心顺序),
- 用于与 find_pattern.py 中的去重逻辑对齐。
- """
- items = p.get("items") or []
- names = [str(it.get("name") or "").strip() for it in items if isinstance(it, dict)]
- seen: Set[str] = set()
- unique: List[str] = []
- for n in names:
- if n and n not in seen:
- seen.add(n)
- unique.append(n)
- try:
- support = float(p.get("support", 0.0))
- except (TypeError, ValueError):
- support = 0.0
- return support, unique
- def _dedupe_patterns(raw_patterns: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """
- 按 pattern 的 item name 集合去重(不区分顺序),与 find_pattern.py 的思路一致:
- - key 为 sorted(unique item names)
- - 同一个 key 仅保留 support 最大的 pattern(保留其原始 items 结构,方便后续打分)
- """
- key_to_best: Dict[Tuple[str, ...], Dict[str, Any]] = {}
- key_to_support: Dict[Tuple[str, ...], float] = {}
- for p in raw_patterns:
- support, unique = _slim_pattern_for_dedupe(p)
- if not unique:
- continue
- key = tuple(sorted(unique))
- best_support = key_to_support.get(key)
- if best_support is None or support > best_support:
- key_to_support[key] = support
- key_to_best[key] = p
- return list(key_to_best.values())
- # ---------------------------------------------------------------------------
- # 3. 人设树节点信息 & 聚类节点搜索
- # ---------------------------------------------------------------------------
- class TreeIndex:
- """
- 人设树索引:
- - node_info: 节点 -> { "parent": 父节点名称, "children": [子节点名称...], "depth": 深度, "dimension": 维度名 }
- - roots: 维度名 -> 根节点名称(即维度名本身)
- - merged_tree: 将实质/形式/意图三棵树合并后的单个 JSON(顶层 key 为实质/形式/意图)
- """
- def __init__(self, account_name: str) -> None:
- self.account_name = account_name
- self.node_info: Dict[str, Dict[str, Any]] = {}
- self.roots: Dict[str, str] = {}
- # 三棵树合并后的 JSON:{"实质": {...}, "形式": {...}, "意图": {...}}
- self.merged_tree: Dict[str, Dict[str, Any]] = {}
- self._build()
- def _build(self) -> None:
- trees = _load_trees(self.account_name)
- # 1)先将三棵树合并成一个 JSON:{"实质": {...}, "形式": {...}, "意图": {...}}
- merged: Dict[str, Dict[str, Any]] = {}
- for dim_name, root in trees:
- if isinstance(root, dict):
- merged[dim_name] = root
- self.merged_tree = merged
- # 2)基于合并后的 JSON 构建 parent/children 结构
- for dim_name, root in merged.items():
- root_name = dim_name
- self.roots[dim_name] = root_name
- if root_name not in self.node_info:
- self.node_info[root_name] = {
- "parent": None,
- "children": [],
- "dimension": dim_name,
- "depth": 0,
- }
- def walk(parent_name: str, node_dict: Dict[str, Any]):
- children = node_dict.get("children") or {}
- for name, child in children.items():
- if not isinstance(child, dict):
- continue
- if name not in self.node_info:
- self.node_info[name] = {
- "parent": parent_name,
- "children": [],
- "dimension": dim_name,
- "depth": None, # 稍后统一计算
- }
- else:
- # 仅当不会形成自引用时才更新 parent(树中可能存在同名的父子节点)
- if name != parent_name:
- self.node_info[name]["parent"] = parent_name
- self.node_info[name]["dimension"] = dim_name
- # 维护父节点的 children
- if parent_name not in self.node_info:
- self.node_info[parent_name] = {
- "parent": None,
- "children": [],
- "dimension": dim_name,
- "depth": 0,
- }
- if name not in self.node_info[parent_name]["children"]:
- self.node_info[parent_name]["children"].append(name)
- walk(name, child)
- walk(root_name, root)
- # 统一计算各节点深度(从根开始 BFS)
- from collections import deque
- q = deque()
- for dim_name, root_name in self.roots.items():
- if root_name not in self.node_info:
- continue
- self.node_info[root_name]["depth"] = 0
- q.append(root_name)
- while q:
- cur = q.popleft()
- cur_depth = self.node_info[cur].get("depth", 0) or 0
- for child in self.node_info[cur].get("children", []):
- self.node_info.setdefault(child, {})
- if self.node_info[child].get("depth") is None:
- self.node_info[child]["depth"] = cur_depth + 1
- # BFS 首次到达该节点时(即最短路径),同步修正 parent 指针,
- # 确保 parent 与 depth 始终保持一致。
- # 若同名节点在树中多处出现,walk() 会用最后一次遍历的父节点
- # 覆盖 parent,导致 parent 指向更深处的节点,
- # 进而使 find_ancestor_at_level 沿 parent 链爬升时出现深度
- # 倒退(越走越深)甚至返回错误祖先/None 的问题。
- # 在 BFS 阶段统一修正,可保证 parent 链单调递减至根节点。
- self.node_info[child]["parent"] = cur
- q.append(child)
- def find_ancestor_at_level(self, node_name: str, level: int) -> Optional[str]:
- """
- 在人设树中找到 node_name 的 depth == level 的祖先节点。
- - 若 node_name 自身 depth == level,直接返回自身。
- - 若 node_name depth < level(比目标层浅),返回自身。
- - 否则沿 parent 链向上查找,返回第一个 depth == level 的祖先节点。
- 说明:
- 早期实现中为了防止意外环路使用了 visited 集合,一旦检测到「重复节点」就直接
- 返回 None,导致在树中存在同名节点、且 parent 指针被覆盖的情况下,会错误返回
- None。这里改为**只沿 parent 链向上行走**,不再依赖 visited 截断:
- - 每一步仅查看当前节点的 depth 与 parent;
- - 一旦到达 depth <= level,直接返回当前节点;
- - 若 parent 为空,则返回当前已到达的最高节点。
- 在正常树结构下(parent 指针无环),该过程必然在有限步内结束;若底层数据意外
- 形成环,需在构建 node_info 时修复,祖先查找本身不再额外承担防御职责。
- """
- info = self.node_info.get(node_name)
- if not info:
- return None
- depth = info.get("depth")
- if depth is None:
- return None
- if depth <= level:
- return node_name
- # 只沿 parent 链向上查找,不再依赖 visited 截断;
- # 一旦到达 depth <= level 或 parent 为空即返回当前节点。
- cur = node_name
- while True:
- cur_info = self.node_info.get(cur) or {}
- cur_depth = cur_info.get("depth")
- if cur_depth is None:
- return cur
- if cur_depth <= level:
- return cur
- parent = cur_info.get("parent")
- if parent is None:
- return cur
- cur = parent
- # 聚类搜索(不再区分维度)
- def find_clusters(
- self,
- elements: List[str],
- cluster_level: int,
- ) -> List[Dict[str, Any]]:
- """
- 在所有人设树中,为给定元素列表寻找聚类节点(不再要求 dimension 一致)。
- 规则(固定聚类层级 cluster_level):
- - 仅在 depth == cluster_level 的节点上做聚类判断:
- * 若某节点子树中包含的元素数量 >= 2,
- 且在该路径上尚未存在更高层(深度更小)的聚类节点,则将其视为一个聚类节点。
- - 对无法向上形成聚类的元素,为其寻找 depth == cluster_level 的祖先节点,
- 若存在则作为该元素的「单元素聚类」节点。
- - 返回:
- [
- {
- "cluster_node": "节点名",
- "from_elements": ["元素A", "元素B", ...]
- },
- ...
- ]
- """
- # 过滤出真实存在于人设树中的元素
- elem_set: Set[str] = set()
- for e in elements:
- e = str(e).strip()
- if not e:
- continue
- info = self.node_info.get(e)
- if not info:
- continue
- elem_set.add(e)
- if not elem_set:
- return []
- # 先计算每个节点子树中包含的元素数量(跨所有维度的根)
- # 注意:人设树数据中可能存在意外的环或重复引用,这里通过 visited 集合避免递归死循环。
- subtree_count: Dict[str, int] = {}
- def dfs_count(node: str, visited: Set[str]) -> int:
- if node in visited:
- # 检测到环,直接返回 0,避免无限递归
- return 0
- visited.add(node)
- cnt = 1 if node in elem_set else 0
- for ch in self.node_info.get(node, {}).get("children", []):
- cnt += dfs_count(ch, visited)
- subtree_count[node] = cnt
- return cnt
- for root_name in self.roots.values():
- dfs_count(root_name, set())
- # 再自上而下优先选择「更上层」聚类节点(但仅在 cluster_level 层):
- # - 若当前节点已作为聚类节点,则其子孙不再作为聚类节点(保证尽量向上聚类);
- # 同样需要防止意外的环导致递归过深,这里使用 visited 集合。
- clusters: Set[str] = set()
- def dfs_select(node: str, ancestor_selected: bool, visited: Set[str]) -> None:
- if node in visited:
- return
- visited.add(node)
- info = self.node_info.get(node) or {}
- depth = info.get("depth", 0) or 0
- cnt = subtree_count.get(node, 0)
- selected_here = False
- # 仅当祖先尚未被选中、当前节点位于 cluster_level 层且满足条件时,选当前节点为聚类节点
- if (not ancestor_selected) and depth == cluster_level and cnt >= 2:
- clusters.add(node)
- selected_here = True
- # 祖先已经被选中或当前节点被选中,则子孙不再作为聚类节点
- for ch in info.get("children", []):
- dfs_select(ch, ancestor_selected or selected_here, visited)
- for root_name in self.roots.values():
- dfs_select(root_name, False, set())
- if not clusters:
- return []
- # 统计每个聚类节点下真实覆盖的元素列表
- cluster_to_elements: Dict[str, Set[str]] = {c: set() for c in clusters}
- for e in elem_set:
- cur = e
- visited: Set[str] = set()
- while cur and cur not in visited:
- visited.add(cur)
- if cur in clusters:
- cluster_to_elements[cur].add(e)
- parent = self.node_info.get(cur, {}).get("parent")
- if parent is None:
- break
- cur = parent
- out: List[Dict[str, Any]] = []
- # 1)多元素聚类:仅统计真正输出的聚类节点所覆盖的元素,
- # 避免把「元素数不足 2 的节点」也算作已覆盖,从而导致元素丢失。
- covered_elems: Set[str] = set()
- for node in clusters:
- elems = sorted(cluster_to_elements.get(node) or [])
- if len(elems) < 2:
- # 主聚类逻辑只考虑覆盖至少 2 个元素的节点
- continue
- out.append(
- {
- "cluster_node": node,
- "from_elements": elems,
- }
- )
- for e in elems:
- covered_elems.add(e)
- # 2)对无法向上形成聚类的元素,给一个「单元素聚类」
- uncovered = elem_set - covered_elems
- # 将未覆盖元素按「cluster_level 层级的祖先节点」分组,确保同一个祖先节点下的
- # 多个元素合并为一个聚类,而不是多个单元素聚类。
- single_clusters: Dict[str, Set[str]] = {}
- for e in uncovered:
- # 单元素聚类时,cluster_node 应为「祖先节点」,不直接使用元素自身。
- # 这里固定选择 depth == cluster_level 的祖先节点。
- info_e = self.node_info.get(e) or {}
- parent = info_e.get("parent")
- cur = parent
- best_ancestor: Optional[str] = None
- visited_chain: Set[str] = set()
- while cur and cur not in visited_chain:
- visited_chain.add(cur)
- info = self.node_info.get(cur) or {}
- depth = info.get("depth", 0) or 0
- if depth == cluster_level:
- best_ancestor = cur
- break
- parent = info.get("parent")
- if parent is None:
- break
- cur = parent
- if best_ancestor:
- single_clusters.setdefault(best_ancestor, set()).add(e)
- for anc, elems in single_clusters.items():
- out.append(
- {
- "cluster_node": anc,
- "from_elements": sorted(elems),
- }
- )
- # 为了输出更稳定,按 from_elements 的元素数量从大到小排序,数量相同再按节点名排序
- out.sort(key=lambda x: (-len(x["from_elements"]), x["cluster_node"]))
- return out
- # ---------------------------------------------------------------------------
- # 4. 对单轮数据执行 pattern & 聚类分析
- # ---------------------------------------------------------------------------
- def _dim_obj(
- tree_node_name: str,
- tree_index: TreeIndex,
- matched_point: Optional[str] = None,
- ) -> Dict[str, Any]:
- dim = (tree_index.node_info.get(tree_node_name) or {}).get("dimension") or ""
- o: Dict[str, Any] = {
- "tree_node_name": tree_node_name,
- "dimension": dim,
- }
- if matched_point is not None:
- o["matched_point"] = matched_point
- return o
- def _entry_to_matched_point(entry: Dict[str, Any]) -> str:
- """is_fully_derived 为 true 时用 matched_post_point,否则用 derivation_output_point。"""
- dop = entry.get("derivation_output_point")
- dop_s = str(dop).strip() if dop is not None else ""
- if entry.get("is_fully_derived") is True:
- mpp = entry.get("matched_post_point")
- return str(mpp).strip() if mpp is not None else ""
- return dop_s
- def _analyze_single_round(
- patterns: List[Dict[str, Any]],
- tree_index: TreeIndex,
- cumulative_points: List[Dict[str, Any]],
- cluster_level: int = CLUSTER_LEVEL,
- ) -> Dict[str, Any]:
- """
- 对某一轮(给定累计 point 列表)执行维度分析:
- 1. 从 cumulative_points 中提取 derivation_output_point,
- 在人设树中找到每个节点的 cluster_level 层祖先 → derived_ancestor_set(已推导维度节点集合)。
- 2. 从 deduped_patterns 中筛选出包含 derived_ancestor_set 中节点的 pattern。
- 3. 对筛选出 pattern 的每个元素标记是否已推导:
- - 元素在 derived_ancestor_set 中 → is_derived=True(已推导维度)
- - 其他 → is_derived=False(未推导维度)
- 4. 汇总 derived_dims / underived_dims 对象列表。
- 返回结构(节选):
- - derived_ancestor_nodes: [{ tree_node_name, dimension, matched_point }, ...]
- - derived_dims: [{ tree_node_name, dimension, matched_point }, ...]
- - underived_dims: [{ tree_node_name, dimension }, ...](无 matched_point)
- """
- # 1. 收集 derived_ancestor_set,同时按规则累计每个祖先的 matched_point
- derived_ancestor_set: Set[str] = set()
- ancestor_to_matched: Dict[str, List[str]] = {}
- for entry in cumulative_points:
- if not isinstance(entry, dict):
- continue
- dop = entry.get("derivation_output_point")
- if not dop:
- continue
- ancestor = tree_index.find_ancestor_at_level(str(dop).strip(), cluster_level)
- if not ancestor:
- continue
- derived_ancestor_set.add(ancestor)
- pt = _entry_to_matched_point(entry)
- if pt and pt not in ancestor_to_matched.get(ancestor, []):
- ancestor_to_matched.setdefault(ancestor, []).append(pt)
- # 2. 筛选 pattern:已推导维度节点占所有元素的比例 >= 50%
- filtered_patterns: List[Dict[str, Any]] = []
- for p in patterns:
- items = p.get("items") or []
- item_names = [
- str(it.get("name") or "").strip()
- for it in items
- if isinstance(it, dict)
- ]
- if not item_names:
- continue
- if len(item_names) < 5:
- continue
- derived_count = sum(1 for name in item_names if name in derived_ancestor_set)
- if derived_count / len(item_names) >= 0.5:
- filtered_patterns.append(p)
- print(
- f"filtered_patterns: {len(filtered_patterns)}, "
- f"derived_ancestor_set: {len(derived_ancestor_set)}"
- )
- def _matched_join(name: str) -> str:
- pts = ancestor_to_matched.get(name) or []
- return ", ".join(pts) if pts else ""
- derived_ancestor_nodes: List[Dict[str, Any]] = []
- for anc in sorted(derived_ancestor_set):
- derived_ancestor_nodes.append(
- _dim_obj(anc, tree_index, matched_point=_matched_join(anc) or "")
- )
- # 3. 对筛选 pattern 元素分类并汇总维度列表
- derived_dims: List[Dict[str, Any]] = []
- underived_dims: List[Dict[str, Any]] = []
- derived_dims_seen: Set[str] = set()
- underived_dims_seen: Set[str] = set()
- scored_patterns: List[Dict[str, Any]] = []
- for p in filtered_patterns:
- items = p.get("items") or []
- tagged_items: List[Dict[str, Any]] = []
- for it in items:
- if not isinstance(it, dict):
- continue
- name = str(it.get("name") or "").strip()
- is_derived = name in derived_ancestor_set
- tagged_items.append(
- {
- "name": name,
- "is_derived": is_derived,
- }
- )
- if is_derived:
- if name and name not in derived_dims_seen:
- derived_dims_seen.add(name)
- derived_dims.append(
- _dim_obj(
- name,
- tree_index,
- matched_point=_matched_join(name) or "",
- )
- )
- else:
- if name and name not in underived_dims_seen:
- underived_dims_seen.add(name)
- underived_dims.append(_dim_obj(name, tree_index))
- scored_patterns.append(
- {
- "id": p.get("id"),
- "support": p.get("support"),
- "items": tagged_items,
- }
- )
- # 从 underived_dims 中排除与 derived_dims 重叠的节点
- underived_dims = [d for d in underived_dims if d["tree_node_name"] not in derived_dims_seen]
- # 按 is_derived=True 的元素数量从高到低排序,数量相同再按元素总数从高到低
- scored_patterns.sort(
- key=lambda x: (
- sum(1 for it in x.get("items", []) if it.get("is_derived")),
- len(x.get("items", [])),
- ),
- reverse=True,
- )
- return {
- "cumulative_points": list(cumulative_points),
- "derived_ancestor_nodes": derived_ancestor_nodes,
- "patterns": scored_patterns,
- "derived_dims": derived_dims,
- "underived_dims": underived_dims,
- "patterns_count": len(scored_patterns),
- "derived_dim_count": len(derived_dims),
- "underived_dim_count": len(underived_dims),
- }
- def _format_round_dimension_text(
- derived_dims: List[Dict[str, Any]],
- underived_dims: List[Dict[str, Any]],
- ) -> str:
- """已推导/未推导维度,每行:维度:tree_node_name,匹配点:matched_point"""
- lines: List[str] = ["【已推导的维度】"]
- for d in derived_dims:
- name = d.get("tree_node_name") or ""
- mp = d.get("matched_point") or "-"
- lines.append(f"维度:{name},匹配点:{mp}")
- if not derived_dims:
- lines.append("(无)")
- lines.append("")
- lines.append("【未推导的维度】")
- for d in underived_dims:
- name = d.get("tree_node_name") or ""
- lines.append(f"维度:{name}")
- if not underived_dims:
- lines.append("(无)")
- return "\n".join(lines)
- def pattern_dimension_analyze(
- account_name: str,
- post_id: str,
- log_id: str,
- ) -> Dict[str, Any]:
- """
- Pattern 维度分析主入口。
- 参数
- -------
- account_name : 账号名(用于定位 input / output 下的数据目录)
- post_id : 帖子 ID(用于定位推导日志)
- log_id : 推导日志目录名(../output/{account_name}/推导日志/{post_id}/{log_id}/)
- 逻辑概述
- --------
- 聚类层级固定为 CLUSTER_LEVEL(默认 3)。每一轮:
- 1. 从 derivation_output_point 在人设树中找到该层祖先节点 → 已推导维度节点集合。
- 2. 筛选包含已推导维度节点的 pattern。
- 3. 标记每个 pattern 元素是否已推导,汇总 derived_dims / underived_dims(对象列表)。
- """
- eval_dir = _round_eval_dir(account_name, post_id, log_id)
- if not eval_dir.is_dir():
- raise FileNotFoundError(f"推导日志目录不存在: {eval_dir}")
- round_infos = _load_round_matched_points(account_name, post_id, log_id)
- if not round_infos:
- return {
- "account_name": account_name,
- "post_id": post_id,
- "log_id": log_id,
- "cluster_level": CLUSTER_LEVEL,
- "rounds": [],
- "message": "未在指定日志目录下找到任何评估结果文件(*_评估.json)",
- }
- tree_index = TreeIndex(account_name)
- # pattern 库只在整体分析时读取 & 去重一次,避免每一轮重复 IO 与解析
- raw_patterns = _load_raw_patterns(account_name)
- deduped_patterns = _dedupe_patterns(raw_patterns)
- print(f"deduped_patterns len: {len(deduped_patterns)}")
- rounds_output: List[Dict[str, Any]] = []
- for info in round_infos:
- r = info["round"]
- cumulative_points = info["cumulative_points"]
- analyzed = _analyze_single_round(
- patterns=deduped_patterns,
- tree_index=tree_index,
- cumulative_points=cumulative_points,
- )
- analyzed["round"] = r
- rounds_output.append(analyzed)
- return {
- "account_name": account_name,
- "post_id": post_id,
- "log_id": log_id,
- "cluster_level": CLUSTER_LEVEL,
- "rounds": rounds_output,
- }
- def round_pattern_dimension_analyze_core(
- account_name: str,
- post_id: str,
- log_id: str,
- round: int,
- ) -> Dict[str, Any]:
- """
- 仅使用第 round 轮及之前的评估文件,得到该轮结束时的累计选题点状态并做维度分析。
- 返回 analyzed 单轮结构(含 derived_dims / underived_dims 等),失败时含 error 字段。
- """
- eval_dir = _round_eval_dir(account_name, post_id, log_id)
- if not eval_dir.is_dir():
- return {"error": f"推导日志目录不存在: {eval_dir}"}
- round_infos = _load_round_matched_points(
- account_name, post_id, log_id, max_round=round
- )
- if not round_infos:
- return {
- "error": f"在 {eval_dir} 下未找到第 {round} 轮及之前的 *_评估.json",
- }
- last = round_infos[-1]
- if last.get("round") != round:
- return {
- "error": (
- f"指定轮次 {round} 的评估文件不存在;"
- f"当前仅加载到第 {last.get('round')} 轮"
- ),
- }
- tree_index = TreeIndex(account_name)
- raw_patterns = _load_raw_patterns(account_name)
- deduped_patterns = _dedupe_patterns(raw_patterns)
- analyzed = _analyze_single_round(
- patterns=deduped_patterns,
- tree_index=tree_index,
- cumulative_points=last["cumulative_points"],
- )
- analyzed["round"] = round
- return analyzed
- @tool()
- async def round_pattern_dimension_analyze(
- account_name: str,
- post_id: str,
- log_id: str,
- round: int,
- ) -> Any:
- """
- 推导维度分析,返回当前轮次已推导的维度和可能的未推导维度数据
- Args:
- account_name: 账号名称
- post_id: 帖子 ID
- log_id: 推导日志目录名
- round: 推导轮次(正整数)
- Returns:
- ToolResult:output 为可读文本,含「已推导的维度」「未推导的维度」两段,
- 每行格式为「维度:tree_node_name,匹配点:matched_point」
- (未推导行固定为「-」;matched_point 规则:is_fully_derived 为真取选题点否则取推导输出点)。
- """
- if ToolResult is None:
- return None
- logger.info(
- "round_pattern_dimension_analyze: account=%s post_id=%s log_id=%s round=%s",
- account_name,
- post_id,
- log_id,
- round,
- )
- try:
- r = int(round)
- if r < 1:
- return ToolResult(
- title="维度分析: 轮次无效",
- output="",
- error="round 须为 >= 1 的整数",
- )
- except (TypeError, ValueError):
- return ToolResult(
- title="维度分析: 轮次无效",
- output="",
- error="round 须为整数",
- )
- try:
- analyzed = round_pattern_dimension_analyze_core(
- account_name, post_id, log_id, r
- )
- if analyzed.get("error"):
- return ToolResult(
- title=f"维度分析 第{r}轮 失败",
- output="",
- error=str(analyzed["error"]),
- )
- # 保存到与 {轮次}_评估.json 同级目录
- out_dir = _round_eval_dir(account_name, post_id, log_id)
- out_dir.mkdir(parents=True, exist_ok=True)
- out_path = out_dir / f"{r}_维度分析.json"
- with open(out_path, "w", encoding="utf-8") as f:
- json.dump(analyzed, f, ensure_ascii=False, indent=2)
- derived = analyzed.get("derived_dims") or []
- underived = analyzed.get("underived_dims") or []
- text = _format_round_dimension_text(derived, underived)
- meta = (
- f"round={r}, derived={len(derived)}, underived={len(underived)}, "
- f"patterns={analyzed.get('patterns_count', 0)}"
- )
- return ToolResult(
- title=f"第 {r} 轮维度分析(已推导 {len(derived)} / 未推导 {len(underived)})",
- output=text,
- metadata={"round_pattern_dimension_analyze": meta},
- )
- except Exception as e:
- logger.exception("round_pattern_dimension_analyze failed: %s", e)
- return ToolResult(
- title="维度分析失败",
- output="",
- error=str(e),
- )
- def main(account_name, post_id, log_id) -> None:
- """本地简单测试:以家有大志账号的一次推导日志做分析,并将结果写入输出目录。"""
- result = pattern_dimension_analyze(
- account_name=account_name,
- post_id=post_id,
- log_id=log_id,
- )
- # 控制台打印前 4000 字符,便于快速查看
- # print(json.dumps(result, ensure_ascii=False, indent=2)[:4000] + "...")
- # 写入输出文件 1:../output/{account_name}/推导日志/{post_id}/{log_id}/pattern_dimension_analyze.json
- out_dir = _round_eval_dir(account_name, post_id, log_id)
- out_dir.mkdir(parents=True, exist_ok=True)
- output_file_name = f"{post_id}_pattern_dimension_analyze.json"
- out_path = out_dir / output_file_name
- with open(out_path, "w", encoding="utf-8") as f:
- json.dump(result, f, ensure_ascii=False, indent=2)
- # 写入输出文件 2:../output/{account_name}/整体推导维度分析/{post_id}_pattern_dimension_analyze.json
- overall_dir = _BASE_OUTPUT / account_name / "整体推导维度分析"
- overall_dir.mkdir(parents=True, exist_ok=True)
- overall_output_file_name = f"{post_id}_pattern_dimension_analyze.json"
- overall_out_path = overall_dir / overall_output_file_name
- with open(overall_out_path, "w", encoding="utf-8") as f:
- json.dump(result, f, ensure_ascii=False, indent=2)
- print(f"\n分析结果已写入: {out_path}")
- def main_round_pattern_dimension_analyze(
- account_name: str,
- post_id: str,
- log_id: str,
- round_num: int,
- ) -> None:
- """本地测试:直接调用 round_pattern_dimension_analyze,打印 ToolResult。"""
- import asyncio
- async def _run() -> None:
- result = await round_pattern_dimension_analyze(
- account_name=account_name,
- post_id=post_id,
- log_id=log_id,
- round=round_num,
- )
- if result is None:
- print(
- "round_pattern_dimension_analyze 返回 None:请先将 Agent 项目根目录加入 PYTHONPATH,"
- "或在 __main__ 中保证能 import agent.tools.ToolResult"
- )
- return
- if result.error:
- print(f"错误: {result.error}")
- else:
- print(result.title)
- print(result.output)
- asyncio.run(_run())
- if __name__ == "__main__":
- import asyncio
- import importlib.util
- # 直接加载 ToolResult,避免 import agent 时拉全量依赖(如 langchain)
- _agent_root = Path(__file__).resolve().parents[3]
- _models_py = _agent_root / "agent" / "tools" / "models.py"
- if _models_py.is_file():
- _spec = importlib.util.spec_from_file_location(
- "_pattern_dim_tool_models", _models_py
- )
- if _spec and _spec.loader:
- _m = importlib.util.module_from_spec(_spec)
- _spec.loader.exec_module(_m)
- globals()["ToolResult"] = _m.ToolResult
- if getattr(_m, "ToolContext", None) is not None:
- globals()["ToolContext"] = _m.ToolContext
- # ---------- 开关与参数(改这里即可) ----------
- run_round_pattern_test = False
- run_full_pattern_analyze = True
- test_account_name = "家有大志"
- test_post_id = "69185d49000000000d00f94e"
- test_log_id = "20260318221136"
- test_round = 1
- items = [
- {"post_id": "68fb6a5c000000000302e5de", "log_id": "20260318220540"},
- {"post_id": "69185d49000000000d00f94e", "log_id": "20260318221136"},
- {"post_id": "6921937a000000001b0278d1", "log_id": "20260318221538"}
- ]
- if run_round_pattern_test:
- main_round_pattern_dimension_analyze(
- test_account_name,
- test_post_id,
- test_log_id,
- test_round,
- )
- elif run_full_pattern_analyze:
- for item in items:
- main(test_account_name, item["post_id"], item["log_id"])
|