Explorar o código

how agent 维度分析

liuzhiheng hai 1 mes
pai
achega
35e5f0d97b

+ 4 - 5
examples_how/overall_derivation/generate_visualize_data.py

@@ -120,7 +120,7 @@ def build_derivation_result(
     all_keys = {_topic_point_key(t) for t in topic_points}
     topic_by_key = {_topic_point_key(t): t for t in topic_points}
 
-    # 分轮次收集 (round_num, name) -> (matched_score, is_fully_derived),同一轮同名取首次出现
+    # 分轮次收集 (round_num, name) -> (matched_score, is_fully_derived),同一轮同名保留 matched_score 最高的
     score_by_round_name: dict[tuple[int, str], tuple[float, bool]] = {}
     for round_idx, eval_data in enumerate(evals):
         round_num = eval_data.get("round", round_idx + 1)
@@ -130,9 +130,6 @@ def build_derivation_result(
             mp = (er.get("matched_post_point") or er.get("matched_post_topic") or er.get("match_post_point") or "").strip()
             if not mp:
                 continue
-            key = (round_num, mp)
-            if key in score_by_round_name:
-                continue
             score = er.get("matched_score")
             if score is None:
                 score = 1.0
@@ -142,7 +139,9 @@ def build_derivation_result(
                 except (TypeError, ValueError):
                     score = 1.0
             is_fully = er.get("is_fully_derived", True)
-            score_by_round_name[key] = (score, bool(is_fully))
+            key = (round_num, mp)
+            if key not in score_by_round_name or score > score_by_round_name[key][0]:
+                score_by_round_name[key] = (score, bool(is_fully))
 
     result = []
     derived_names_so_far: set[str] = set()

+ 82 - 42
examples_how/overall_derivation/tools/pattern_dimension_analyze.py

@@ -17,8 +17,6 @@ import sys
 from pathlib import Path
 from typing import Any, Dict, List, Optional, Tuple, Set
 
-from Cython.Includes.libc.stdio import printf
-
 # 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
 _root = Path(__file__).resolve().parent.parent
 if str(_root) not in sys.path:
@@ -105,7 +103,15 @@ def _load_round_matched_points(
                 continue
             if not item.get("is_matched"):
                 continue
-            mp = item.get("matched_post_point")
+
+            # 根据是否已完全推导,选择不同的帖子选题点字段:
+            # - is_fully_derived 为 False 时,使用 derivation_output_point
+            # - 其他情况(True 或缺失)使用 matched_post_point(兼容旧数据)
+            if item.get("is_fully_derived") is False:
+                mp = item.get("derivation_output_point")
+            else:
+                mp = item.get("matched_post_point")
+
             if mp is None:
                 continue
             mp = str(mp).strip()
@@ -261,13 +267,17 @@ def _score_patterns_by_matched_points(
             best_post_point: Optional[str] = None
             if name:
                 for post_point in matched_post_points:
-                    score = match_lookup.get((post_point, name))
-                    if score is None:
-                        continue
-                    try:
-                        s = float(score)
-                    except (TypeError, ValueError):
-                        continue
+                    # 如果帖子选题点与节点名称完全一致,直接视为满分匹配
+                    if post_point == name:
+                        s = 1.0
+                    else:
+                        score = match_lookup.get((post_point, name))
+                        if score is None:
+                            continue
+                        try:
+                            s = float(score)
+                        except (TypeError, ValueError):
+                            continue
                     if s > best_score:
                         best_score = s
                         best_post_point = post_point
@@ -397,15 +407,17 @@ class TreeIndex:
     def find_clusters(
         self,
         elements: List[str],
-        min_level: int,
+        cluster_level: int,
     ) -> List[Dict[str, Any]]:
         """
         在所有人设树中,为给定元素列表寻找聚类节点(不再要求 dimension 一致)。
 
-        规则:
-        - 从 min_level 开始,自上而下扫描
+        规则(固定聚类层级 cluster_level)
+        - 仅在 depth == cluster_level 的节点上做聚类判断
             * 若某节点子树中包含的元素数量 >= 2,
               且在该路径上尚未存在更高层(深度更小)的聚类节点,则将其视为一个聚类节点。
+        - 对无法向上形成聚类的元素,为其寻找 depth == cluster_level 的祖先节点,
+          若存在则作为该元素的「单元素聚类」节点。
         - 返回:
         [
           {
@@ -447,7 +459,7 @@ class TreeIndex:
         for root_name in self.roots.values():
             dfs_count(root_name, set())
 
-        # 再自上而下优先选择「更上层」聚类节点(但不低于 min_level):
+        # 再自上而下优先选择「更上层」聚类节点(但仅在 cluster_level 层):
         # - 若当前节点已作为聚类节点,则其子孙不再作为聚类节点(保证尽量向上聚类);
         # 同样需要防止意外的环导致递归过深,这里使用 visited 集合。
         clusters: Set[str] = set()
@@ -461,8 +473,8 @@ class TreeIndex:
             cnt = subtree_count.get(node, 0)
 
             selected_here = False
-            # 仅当祖先尚未被选中、当前节点满足条件时,选当前节点为聚类节点
-            if (not ancestor_selected) and depth >= min_level and cnt >= 2:
+            # 仅当祖先尚未被选中、当前节点位于 cluster_level 层且满足条件时,选当前节点为聚类节点
+            if (not ancestor_selected) and depth == cluster_level and cnt >= 2:
                 clusters.add(node)
                 selected_here = True
 
@@ -492,6 +504,9 @@ class TreeIndex:
                 cur = parent
 
         out: List[Dict[str, Any]] = []
+        # 1)多元素聚类:仅统计真正输出的聚类节点所覆盖的元素,
+        #    避免把「元素数不足 2 的节点」也算作已覆盖,从而导致元素丢失。
+        covered_elems: Set[str] = set()
         for node in clusters:
             elems = sorted(cluster_to_elements.get(node) or [])
             if len(elems) < 2:
@@ -503,15 +518,19 @@ class TreeIndex:
                     "from_elements": elems,
                 }
             )
+            for e in elems:
+                covered_elems.add(e)
 
         # 2)对无法向上形成聚类的元素,给一个「单元素聚类」
-        covered_elems: Set[str] = set()
-        for elems in cluster_to_elements.values():
-            covered_elems.update(elems)
         uncovered = elem_set - covered_elems
 
+        # 将未覆盖元素按「cluster_level 层级的祖先节点」分组,确保同一个祖先节点下的
+        # 多个元素合并为一个聚类,而不是多个单元素聚类。
+        single_clusters: Dict[str, Set[str]] = {}
+
         for e in uncovered:
-            # 单元素聚类时,cluster_node 应为「祖先节点」,不直接使用元素自身
+            # 单元素聚类时,cluster_node 应为「祖先节点」,不直接使用元素自身。
+            # 这里固定选择 depth == cluster_level 的祖先节点。
             info_e = self.node_info.get(e) or {}
             parent = info_e.get("parent")
             cur = parent
@@ -521,22 +540,23 @@ class TreeIndex:
                 visited_chain.add(cur)
                 info = self.node_info.get(cur) or {}
                 depth = info.get("depth", 0) or 0
-                if depth >= min_level:
-                    # 遍历方向为从元素向上(depth 单调递减),
-                    # 每次满足 >= min_level 就更新,循环结束后得到的是
-                    # depth 最小(离元素最远)且不低于 min_level 的祖先节点。
+                if depth == cluster_level:
                     best_ancestor = cur
+                    break
                 parent = info.get("parent")
                 if parent is None:
                     break
                 cur = parent
             if best_ancestor:
-                out.append(
-                    {
-                        "cluster_node": best_ancestor,
-                        "from_elements": [e],
-                    }
-                )
+                single_clusters.setdefault(best_ancestor, set()).add(e)
+
+        for anc, elems in single_clusters.items():
+            out.append(
+                {
+                    "cluster_node": anc,
+                    "from_elements": sorted(elems),
+                }
+            )
 
         # 为了输出更稳定,按 from_elements 的元素数量从大到小排序,数量相同再按节点名排序
         out.sort(key=lambda x: (-len(x["from_elements"]), x["cluster_node"]))
@@ -554,7 +574,7 @@ def _analyze_single_round(
     tree_index: TreeIndex,
     cumulative_points: List[str],
     match_threshold: float,
-    min_level: int,
+    cluster_level: int,
 ) -> Dict[str, Any]:
     """
     对某一轮(给定累计 matched_post_point 列表)执行分析:
@@ -569,7 +589,7 @@ def _analyze_single_round(
         matched_post_points=cumulative_points,
         match_threshold=match_threshold,
     )
-    printf(f"_score_patterns_by_matched_points len: {len(patterns)}")
+    print(f"_score_patterns_by_matched_points len: {len(patterns)}")
 
     # 已推导 / 未推导 元素列表(不再按维度拆分)
     derived_elems: List[str] = []
@@ -599,14 +619,33 @@ def _analyze_single_round(
 
     # 已推导元素聚类
     if derived_set:
-        c = tree_index.find_clusters(derived_set, min_level=min_level)
+        c = tree_index.find_clusters(derived_set, cluster_level=cluster_level)
         clusters["derived"] = c or []
 
     # 未推导元素聚类
     if underived_set:
-        c = tree_index.find_clusters(underived_set, min_level=min_level)
+        c = tree_index.find_clusters(underived_set, cluster_level=cluster_level)
         clusters["underived"] = c or []
 
+    # 在同一轮中,如果某个 cluster_node 已经在 derived 聚类里出现过,
+    # 则从 underived 聚类中剔除该 cluster_node,避免重复展示。
+    if isinstance(clusters.get("derived"), list) and isinstance(clusters.get("underived"), list):
+        derived_nodes = {
+            str(item.get("cluster_node"))
+            for item in clusters["derived"]
+            if isinstance(item, dict) and item.get("cluster_node") is not None
+        }
+        if derived_nodes:
+            filtered_underived = []
+            for item in clusters["underived"]:
+                if not isinstance(item, dict):
+                    continue
+                node = str(item.get("cluster_node"))
+                if node in derived_nodes:
+                    continue
+                filtered_underived.append(item)
+            clusters["underived"] = filtered_underived
+
     return {
         "matched_post_points": list(cumulative_points),
         "patterns": patterns,
@@ -627,7 +666,7 @@ def pattern_dimension_analyze(
     post_id: str,
     log_id: str,
     match_threshold: float = 0.6,
-    min_level: int = 2,
+    cluster_level: int = 2,
 ) -> Dict[str, Any]:
     """
     Pattern 维度分析主入口。
@@ -638,7 +677,7 @@ def pattern_dimension_analyze(
     post_id : 帖子 ID(用于定位推导日志与帖子匹配数据)
     log_id : 推导日志目录名(../output/{account_name}/推导日志/{post_id}/{log_id}/)
     match_threshold : pattern 元素与 matched_post_point 的最小匹配分,默认 0.6
-    min_level : 在人设树中搜索聚类节点时的最小层级(root 为 0 层),默认 2
+    cluster_level : 在人设树中搜索聚类节点的聚类层级(root 为 0 层),默认 2
 
     """
     eval_dir = _round_eval_dir(account_name, post_id, log_id)
@@ -652,7 +691,7 @@ def pattern_dimension_analyze(
             "post_id": post_id,
             "log_id": log_id,
             "match_threshold": match_threshold,
-            "min_level": min_level,
+            "cluster_level": cluster_level,
             "rounds": [],
             "message": "未在指定日志目录下找到任何评估结果文件(*_评估.json)",
         }
@@ -661,7 +700,7 @@ def pattern_dimension_analyze(
     # pattern 库只在整体分析时读取 & 去重一次,避免每一轮重复 IO 与解析
     raw_patterns = _load_raw_patterns(account_name)
     deduped_patterns = _dedupe_patterns(raw_patterns)
-    printf(f"deduped_patterns len: {len(deduped_patterns)}")
+    print(f"deduped_patterns len: {len(deduped_patterns)}")
 
     rounds_output: List[Dict[str, Any]] = []
     for info in round_infos:
@@ -674,7 +713,7 @@ def pattern_dimension_analyze(
             tree_index=tree_index,
             cumulative_points=cumulative_points,
             match_threshold=match_threshold,
-            min_level=min_level,
+            cluster_level=cluster_level,
         )
         analyzed["round"] = r
         rounds_output.append(analyzed)
@@ -684,7 +723,7 @@ def pattern_dimension_analyze(
         "post_id": post_id,
         "log_id": log_id,
         "match_threshold": match_threshold,
-        "min_level": min_level,
+        "cluster_level": cluster_level,
         "rounds": rounds_output,
     }
     return result
@@ -702,7 +741,7 @@ def main() -> None:
         post_id=post_id,
         log_id=log_id,
         match_threshold=0.5,
-        min_level=2,
+        cluster_level=2,
     )
     # 控制台打印前 4000 字符,便于快速查看
     # print(json.dumps(result, ensure_ascii=False, indent=2)[:4000] + "...")
@@ -710,7 +749,8 @@ def main() -> None:
     # 写入输出文件:../output/{account_name}/推导日志/{post_id}/{log_id}/pattern_dimension_analyze.json
     out_dir = _round_eval_dir(account_name, post_id, log_id)
     out_dir.mkdir(parents=True, exist_ok=True)
-    out_path = out_dir / "pattern_dimension_analyze.json"
+    output_file_name = f"{post_id}_pattern_dimension_analyze.json"
+    out_path = out_dir / output_file_name
     with open(out_path, "w", encoding="utf-8") as f:
         json.dump(result, f, ensure_ascii=False, indent=2)
     print(f"\n分析结果已写入: {out_path}")