Преглед на файлове

生成可视化数据更新

liuzhiheng преди 1 седмица
родител
ревизия
22e57a004c
променени са 1 файла, в които са добавени 55 реда и са изтрити 21 реда
  1. 55 21
      examples_how/overall_derivation/generate_visualize_data.py

+ 55 - 21
examples_how/overall_derivation/generate_visualize_data.py

@@ -123,11 +123,12 @@ def build_derivation_result(
         eval_results = eval_data.get("eval_results") or []
         matched_post_points = set()
         for er in eval_results:
-            if er.get("match_result") != "匹配":
+            # 新格式: is_matched;旧格式: match_result == "匹配"
+            if not (er.get("is_matched") is True or er.get("match_result") == "匹配"):
                 continue
-            mp = (er.get("match_post_point") or "").strip()
-            if mp:
-                matched_post_points.add(mp)
+            mp = er.get("matched_post_point") or er.get("matched_post_topic") or er.get("match_post_point") or ""
+            if mp and str(mp).strip():
+                matched_post_points.add(str(mp).strip())
 
         new_derived_names = matched_post_points - derived_names_so_far
         derived_names_so_far |= matched_post_points
@@ -150,6 +151,14 @@ def build_derivation_result(
     return result
 
 
+def _tree_node_display_name(raw: str) -> str:
+    """人设节点可能是 a.b.c 路径形式,实际需要的是最后一段节点名 c。"""
+    s = (raw or "").strip()
+    if "." in s:
+        return s.rsplit(".", 1)[-1].strip() or s
+    return s
+
+
 def _to_tree_node(name: str, extra: dict | None = None) -> dict:
     d = {"name": name}
     if extra:
@@ -173,42 +182,59 @@ def build_visualize_edges(
 ) -> tuple[list[dict], list[dict]]:
     """
     生成 node_list(所有评估通过的帖子选题点)和 edge_list(只保留评估通过的推导路径)。
+    按轮次从小到大处理,保证每个输出节点最多只出现在一条边的 output_nodes 里,且保留的是前面轮次的数据。
     """
+    # 按轮次从小到大排序,确保优先使用前面轮次的输出节点
+    derivations = sorted(derivations, key=lambda d: d.get("round", 0))
+    evals = sorted(evals, key=lambda e: e.get("round", 0))
+
     topic_by_name = {}
     for t in topic_points:
         name = t["name"]
         if name not in topic_by_name:
             topic_by_name[name] = t
 
-    derivation_output_to_match = {}
-    for eval_data in evals:
+    # 按 (round, id, derivation_output_point) 建立评估匹配表;新格式用 is_matched+id,旧格式用 match_result 且无 id
+    match_by_round_id_output: dict[tuple[int, int, str], str] = {}
+    # 按 (round, id) 收集该条推导在评估中且 is_matched 为 true 的 derivation_output_point 列表,供 detail「待比对的推导选题点」使用
+    round_id_to_output_points: dict[tuple[int, int], list[str]] = {}
+    for round_idx, eval_data in enumerate(evals):
+        round_num = eval_data.get("round", round_idx + 1)
         for er in eval_data.get("eval_results") or []:
-            if er.get("match_result") != "匹配":
+            if not (er.get("is_matched") is True or er.get("match_result") == "匹配"):
                 continue
             out_point = (er.get("derivation_output_point") or "").strip()
-            match_point = (er.get("match_post_point") or "").strip()
-            if out_point and match_point:
-                derivation_output_to_match[out_point] = {
-                    "match_post_point": match_point,
-                    "match_reason": er.get("match_reason", ""),
-                    "eval": er,
-                }
+            dr_id = er.get("id") if er.get("id") is not None else -1
+            key = (round_num, dr_id)
+            if key not in round_id_to_output_points:
+                round_id_to_output_points[key] = []
+            if out_point:
+                round_id_to_output_points[key].append(out_point)
+            mp = er.get("matched_post_point") or er.get("matched_post_topic") or er.get("match_post_point") or ""
+            if out_point and mp and str(mp).strip():
+                mp = str(mp).strip()
+                if dr_id != -1:
+                    match_by_round_id_output[(round_num, dr_id, out_point)] = mp
+                else:
+                    match_by_round_id_output[(round_num, -1, out_point)] = mp
 
     node_list = []
     seen_nodes = set()
     edge_list = []
     level_by_name = {}
+    output_nodes_seen: set[str] = set()  # 已在之前边的 output_nodes 中出现过的节点,避免同一输出节点对应多条边
 
     for round_idx, derivation in enumerate(derivations):
         round_num = derivation.get("round", round_idx + 1)
         for dr in derivation.get("derivation_results") or []:
+            dr_id = dr.get("id")
             output_list = dr.get("output") or []
             matched_outputs = []
             for out_item in output_list:
-                info = derivation_output_to_match.get(out_item)
-                if not info:
-                    continue
-                mp = info["match_post_point"]
+                if dr_id is not None:
+                    mp = match_by_round_id_output.get((round_num, dr_id, out_item))
+                else:
+                    mp = match_by_round_id_output.get((round_num, -1, out_item))
                 if not mp:
                     continue
                 matched_outputs.append(mp)
@@ -225,13 +251,19 @@ def build_visualize_edges(
             if not matched_outputs:
                 continue
 
+            # 只保留尚未在之前边的 output_nodes 中出现过的节点,避免同一输出节点对应多条边
+            output_names_this_edge = [x for x in matched_outputs if x not in output_nodes_seen]
+            if not output_names_this_edge:
+                continue
+            output_nodes_seen.update(output_names_this_edge)
+
             input_data = dr.get("input") or {}
             derived_nodes = input_data.get("derived_nodes") or []
             tree_nodes = input_data.get("tree_nodes") or []
             patterns = input_data.get("patterns") or []
 
             input_post_nodes = [{"name": x} for x in derived_nodes]
-            input_tree_nodes = [_to_tree_node(x) for x in tree_nodes]
+            input_tree_nodes = [_to_tree_node(_tree_node_display_name(x)) for x in tree_nodes]
             if patterns and isinstance(patterns[0], str):
                 input_pattern_nodes = [_to_pattern_node(p) for p in patterns]
             elif patterns and isinstance(patterns[0], dict):
@@ -239,11 +271,13 @@ def build_visualize_edges(
             else:
                 input_pattern_nodes = []
 
-            output_nodes = [{"name": x} for x in matched_outputs]
+            output_nodes = [{"name": x} for x in output_names_this_edge]
             detail = {
                 "reason": dr.get("reason", ""),
                 "评估结果": "匹配成功",
             }
+            key_dr = (round_num, dr_id if dr_id is not None else -1)
+            detail["待比对的推导选题点"] = round_id_to_output_points.get(key_dr, [])
             if dr.get("tools"):
                 detail["tools"] = dr["tools"]
             edge_list.append({
@@ -306,5 +340,5 @@ def main(account_name, post_id, log_id):
 if __name__ == "__main__":
     account_name="家有大志"
     post_id = "68fb6a5c000000000302e5de"
-    log_id="20260304161832"
+    log_id="20260305102218"
     main(account_name, post_id, log_id)