Procházet zdrojové kódy

feat: 优化匹配过滤和边类型区分

- filter_how_results: 阈值从0.6改为0.5
- build_match_graph: 匹配边分为匹配_相同(≥0.8)和匹配_相似(0.5-0.8)
- visualize_match_graph: 相同边实线,相似边虚线
- extract_nodes_and_edges: 递归收集中间层级分类的特征来源,修复缺失节点问题

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
yangxiaohui před 5 dny
rodič
revize
0e9eb86f42

+ 10 - 4
script/data_processing/build_match_graph.py

@@ -151,13 +151,19 @@ def extract_matched_nodes_and_edges(filtered_data: Dict) -> tuple:
                             )
                             persona_node_ids.add(persona_node_id)
 
-                            # 创建匹配边
+                            # 创建匹配边(根据相似度区分类型)
+                            similarity = match_detail.get("相似度", 0)
+                            if similarity >= 0.8:
+                                edge_type = "匹配_相同"
+                            else:
+                                edge_type = "匹配_相似"
+
                             match_edge = {
                                 "源节点ID": tag_node_id,
                                 "目标节点ID": persona_node_id,
-                                "边类型": "匹配",
+                                "边类型": edge_type,
                                 "边详情": {
-                                    "相似度": match_detail.get("相似度", 0),
+                                    "相似度": similarity,
                                     "说明": match_detail.get("说明", "")
                                 }
                             }
@@ -513,7 +519,7 @@ def process_filtered_result(
 
     # 分离帖子侧的边:属于边(标签→点)和匹配边(标签→人设)
     post_belong_edges = [e for e in post_edges_raw if e["边类型"] == "属于"]
-    match_edges = [e for e in post_edges_raw if e["边类型"] == "匹配"]
+    match_edges = [e for e in post_edges_raw if e["边类型"].startswith("匹配_")]
 
     # 统计帖子点节点和标签节点
     post_point_nodes = [n for n in post_nodes if n["节点类型"] == "点"]

+ 28 - 2
script/data_processing/extract_nodes_and_edges.py

@@ -232,6 +232,29 @@ def extract_category_nodes_from_pattern(
     if dimension_key not in pattern_data:
         return nodes
 
+    def collect_sources_recursively(node: Dict) -> List[Dict]:
+        """递归收集节点及其所有子节点的特征来源"""
+        sources = []
+
+        # 收集当前节点的特征
+        if "特征列表" in node:
+            for feature in node["特征列表"]:
+                source = {
+                    "点的名称": feature.get("所属点", ""),
+                    "点的描述": feature.get("点描述", ""),
+                    "帖子ID": feature.get("帖子id", "")
+                }
+                sources.append(source)
+
+        # 递归收集子节点的特征
+        for key, value in node.items():
+            if key in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
+                continue
+            if isinstance(value, dict):
+                sources.extend(collect_sources_recursively(value))
+
+        return sources
+
     def traverse_node(node: Dict, parent_categories: List[str]):
         """递归遍历节点"""
         for key, value in node.items():
@@ -245,7 +268,7 @@ def extract_category_nodes_from_pattern(
                 # 获取帖子列表
                 post_ids = value.get("帖子列表", [])
 
-                # 构建节点来源(从特征列表中获取)
+                # 构建节点来源(从特征列表中获取,如果没有则递归收集子分类的
                 node_sources = []
                 if "特征列表" in value:
                     for feature in value["特征列表"]:
@@ -255,6 +278,9 @@ def extract_category_nodes_from_pattern(
                             "帖子ID": feature.get("帖子id", "")
                         }
                         node_sources.append(source)
+                else:
+                    # 没有直接特征,递归收集子分类的特征来源
+                    node_sources = collect_sources_recursively(value)
 
                 node_info = {
                     "节点ID": build_node_id(dimension_name, "分类", key),
@@ -262,7 +288,7 @@ def extract_category_nodes_from_pattern(
                     "节点类型": "分类",
                     "节点层级": dimension_name,
                     "所属分类": parent_categories.copy(),
-                    "帖子数": len(post_ids),
+                    "帖子数": len(post_ids) if post_ids else len(set(s.get("帖子ID", "") for s in node_sources if s.get("帖子ID"))),
                     "节点来源": node_sources
                 }
                 nodes.append(node_info)

+ 7 - 7
script/data_processing/filter_how_results.py

@@ -5,7 +5,7 @@ How解构结果过滤脚本
 
 从 how 解构结果中过滤出高质量的匹配结果:
 1. 移除 what解构结果 字段
-2. 只保留相似度 >= 0.6 的匹配结果
+2. 只保留相似度 >= 0.5 的 top1 匹配结果
 3. 保留特征即使其匹配结果为空
 """
 
@@ -23,7 +23,7 @@ sys.path.insert(0, str(project_root))
 from script.data_processing.path_config import PathConfig
 
 
-def filter_match_results(feature_list: List[Dict], threshold: float = 0.6) -> List[Dict]:
+def filter_match_results(feature_list: List[Dict], threshold: float = 0.5) -> List[Dict]:
     """
     过滤特征列表中的匹配结果
 
@@ -65,7 +65,7 @@ def filter_match_results(feature_list: List[Dict], threshold: float = 0.6) -> Li
     return filtered_features
 
 
-def filter_how_steps(how_steps: List[Dict], threshold: float = 0.6) -> List[Dict]:
+def filter_how_steps(how_steps: List[Dict], threshold: float = 0.5) -> List[Dict]:
     """
     过滤 how 步骤列表
 
@@ -88,7 +88,7 @@ def filter_how_steps(how_steps: List[Dict], threshold: float = 0.6) -> List[Dict
     return filtered_steps
 
 
-def filter_point_list(point_list: List[Dict], threshold: float = 0.6) -> List[Dict]:
+def filter_point_list(point_list: List[Dict], threshold: float = 0.5) -> List[Dict]:
     """
     过滤点列表(灵感点/关键点/目的点)
 
@@ -148,7 +148,7 @@ def calculate_statistics(original_point_list: List[Dict], filtered_point_list: L
     }
 
 
-def process_single_file(input_file: Path, output_file: Path, threshold: float = 0.6) -> Dict:
+def process_single_file(input_file: Path, output_file: Path, threshold: float = 0.5) -> Dict:
     """
     处理单个文件
 
@@ -215,8 +215,8 @@ def main():
     parser.add_argument(
         "--threshold",
         type=float,
-        default=0.6,
-        help="相似度阈值(默认 0.6)"
+        default=0.5,
+        help="相似度阈值(默认 0.5)"
     )
 
     args = parser.parse_args()

+ 20 - 21
script/data_processing/visualize_match_graph.py

@@ -437,7 +437,10 @@ HTML_TEMPLATE = '''<!DOCTYPE html>
         .edge-label-bg {{
             fill: rgba(0,0,0,0.7);
         }}
-        .link.match {{
+        .link.match-same {{
+            stroke: #e94560;
+        }}
+        .link.match-similar {{
             stroke: #e94560;
             stroke-dasharray: 5,5;
         }}
@@ -623,7 +626,11 @@ HTML_TEMPLATE = '''<!DOCTYPE html>
                     <div class="legend-grid">
                         <div class="legend-item">
                             <div class="legend-line" style="background: #e94560;"></div>
-                            <span>匹配</span>
+                            <span>相同(≥0.8)</span>
+                        </div>
+                        <div class="legend-item">
+                            <div class="legend-line" style="background: linear-gradient(90deg, #e94560 0%, #e94560 40%, transparent 40%, transparent 60%, #e94560 60%);"></div>
+                            <span>相似(0.5-0.8)</span>
                         </div>
                         <div class="legend-item">
                             <div class="legend-line" style="background: #9b59b6;"></div>
@@ -898,7 +905,7 @@ HTML_TEMPLATE = '''<!DOCTYPE html>
             const postNodes = nodes.filter(n => n.source === "帖子");
             const personaNodes = nodes.filter(n => n.source === "人设" && !n.是否扩展);
             const expandedNodes = nodes.filter(n => n.source === "人设" && n.是否扩展);
-            const matchLinks = links.filter(l => l.type === "匹配");
+            const matchLinks = links.filter(l => l.type.startsWith("匹配_"));
 
             // 更新匹配列表(按分数降序)
             updateMatchList(matchLinks, nodes);
@@ -1346,14 +1353,14 @@ HTML_TEMPLATE = '''<!DOCTYPE html>
                 .force("link", d3.forceLink(links).id(d => d.id)
                     .distance(d => {{
                         // 跨层连线距离
-                        if (d.type === "匹配" || d.type === "属于") {{
+                        if (d.type.startsWith("匹配_") || d.type === "属于") {{
                             return 150;  // 跨层边
                         }}
                         return 60;  // 同层边
                     }})
                     .strength(d => {{
                         // 跨层边力度弱一些,不要拉扯节点出层
-                        if (d.type === "匹配" || d.type === "属于") {{
+                        if (d.type.startsWith("匹配_") || d.type === "属于") {{
                             return 0.03;
                         }}
                         return 0.1;
@@ -1373,7 +1380,8 @@ HTML_TEMPLATE = '''<!DOCTYPE html>
 
             // 边类型到CSS类的映射
             const edgeTypeClass = {{
-                "匹配": "match",
+                "匹配_相同": "match-same",
+                "匹配_相似": "match-similar",
                 "分类共现(跨点)": "category-cross",
                 "分类共现(点内)": "category-intra",
                 "标签共现": "tag-cooccur",
@@ -1695,16 +1703,7 @@ HTML_TEMPLATE = '''<!DOCTYPE html>
             // 绘制可见的边
             const link = linkG.append("line")
                 .attr("class", d => "link " + getEdgeClass(d.type))
-                .attr("stroke-width", d => d.type === "匹配" ? 2.5 : 1.5)
-                .attr("stroke-dasharray", d => {{
-                    // 匹配边根据相似度设置虚实线
-                    if (d.type === "匹配" && d.边详情 && d.边详情.相似度 !== undefined) {{
-                        const score = d.边详情.相似度;
-                        if (score >= 0.8) return null;  // >= 0.8 实线
-                        if (score >= 0.5) return "6,4";  // 0.5-0.8 虚线
-                    }}
-                    return null;  // 默认实线
-                }});
+                .attr("stroke-width", d => d.type.startsWith("匹配_") ? 2.5 : 1.5);
 
             // 判断是否为跨层边(根据源和目标节点的层级)- 赋值给全局变量
             isCrossLayerEdge = function(d) {{
@@ -1716,7 +1715,7 @@ HTML_TEMPLATE = '''<!DOCTYPE html>
 
             // 设置跨层边的初始可见性(匹配边始终显示,其他跨层边默认隐藏)
             linkG.each(function(d) {{
-                if (d.type === "匹配") {{
+                if (d.type.startsWith("匹配_")) {{
                     d3.select(this).style("display", "block");  // 匹配边始终显示
                 }} else if (isCrossLayerEdge(d) && !showCrossLayerEdges) {{
                     d3.select(this).style("display", "none");
@@ -1724,7 +1723,7 @@ HTML_TEMPLATE = '''<!DOCTYPE html>
             }});
 
             // 为匹配边添加分数标签
-            const edgeLabels = linkG.filter(d => d.type === "匹配" && d.边详情 && d.边详情.相似度)
+            const edgeLabels = linkG.filter(d => d.type.startsWith("匹配_") && d.边详情 && d.边详情.相似度)
                 .append("g")
                 .attr("class", "edge-label-group");
 
@@ -1763,7 +1762,7 @@ HTML_TEMPLATE = '''<!DOCTYPE html>
             .on("mouseout", function(event, d) {{
                 d3.select(this.parentNode).select(".link")
                     .attr("stroke-opacity", 0.7)
-                    .attr("stroke-width", d.type === "匹配" ? 2.5 : 1.5);
+                    .attr("stroke-width", d.type.startsWith("匹配_") ? 2.5 : 1.5);
             }});
 
             // 绘制节点
@@ -2025,7 +2024,7 @@ HTML_TEMPLATE = '''<!DOCTYPE html>
                         const lTgt = typeof link.target === "object" ? link.target.id : link.target;
 
                         // 帖子->标签 的匹配边
-                        if (link.type === "匹配") {{
+                        if (link.type.startsWith("匹配_")) {{
                             if ((lSrc === sourceId && lTgt === detail.标签节点1) ||
                                 (lSrc === targetId && lTgt === detail.标签节点2)) {{
                                 highlightLinkIndices.add(i);
@@ -2057,7 +2056,7 @@ HTML_TEMPLATE = '''<!DOCTYPE html>
                         const lTgt = typeof link.target === "object" ? link.target.id : link.target;
 
                         // 匹配边
-                        if (link.type === "匹配") {{
+                        if (link.type.startsWith("匹配_")) {{
                             if ((lSrc === sourceId && lTgt === detail.源人设节点) ||
                                 (lSrc === targetId && lTgt === detail.目标人设节点)) {{
                                 highlightLinkIndices.add(i);