Ver código fonte

fix: 优化推导图谱输出格式

- 节点只保留基本信息,不添加额外属性
- 移除固有资产边(中间过程不需要)
- 组合节点ID包含成员类型以保证唯一性
- 成员按名称排序避免重复

🤖 Generated with [Claude Code](https://claude.ai/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
yangxiaohui 1 dia atrás
pai
commit
78fb4827c0
1 arquivos alterados com 160 adições e 1 exclusões
  1. 160 1
      script/data_processing/analyze_node_origin.py

+ 160 - 1
script/data_processing/analyze_node_origin.py

@@ -412,6 +412,158 @@ async def analyze_node_origin(
         }
 
 
+# ===== 图谱构建函数 =====
+
+def build_origin_graph(all_results: List[Dict], post_id: str) -> Dict:
+    """
+    将分析结果转换为图谱格式
+
+    Args:
+        all_results: 所有目标特征的分析结果
+        post_id: 帖子ID
+
+    Returns:
+        图谱数据,包含 nodes 和 edges
+    """
+    nodes = {}
+    edges = {}
+
+    # 从输入收集所有特征节点(不添加额外信息)
+    for result in all_results:
+        target_input = result.get("输入", {})
+
+        # 添加目标节点
+        target_info = target_input.get("目标特征", {})
+        target_name = target_info.get("特征名称", "")
+        target_type = target_info.get("特征类型", "关键点")
+        node_id = f"帖子:{target_type}:标签:{target_name}"
+        if node_id not in nodes:
+            nodes[node_id] = {
+                "name": target_name,
+                "type": "标签",
+                "dimension": target_type,
+                "domain": "帖子",
+                "detail": {}
+            }
+
+        # 添加候选特征节点
+        for candidate in target_input.get("候选特征", []):
+            c_name = candidate.get("特征名称", "")
+            c_type = candidate.get("特征类型", "关键点")
+            c_node_id = f"帖子:{c_type}:标签:{c_name}"
+            if c_node_id not in nodes:
+                nodes[c_node_id] = {
+                    "name": c_name,
+                    "type": "标签",
+                    "dimension": c_type,
+                    "domain": "帖子",
+                    "detail": {}
+                }
+
+    # 构建推导边
+    for result in all_results:
+        target_name = result.get("目标特征", "")
+        target_input = result.get("输入", {})
+        target_info = target_input.get("目标特征", {})
+        target_type = target_info.get("特征类型", "关键点")
+        target_node_id = f"帖子:{target_type}:标签:{target_name}"
+
+        reasoning = result.get("推理类型分类", {})
+
+        # 单独推理的边
+        for item in reasoning.get("单独推理", []):
+            source_name = item.get("特征名称", "")
+            source_type = item.get("特征类型", "关键点")
+            source_node_id = f"帖子:{source_type}:标签:{source_name}"
+            probability = item.get("可能性", 0)
+
+            edge_id = f"{source_node_id}|推导|{target_node_id}"
+            edges[edge_id] = {
+                "source": source_node_id,
+                "target": target_node_id,
+                "type": "推导",
+                "score": probability,
+                "detail": {
+                    "推理类型": "单独推理",
+                    "正向概率": item.get("正向概率", 0),
+                    "反向概率": item.get("反向概率", 0),
+                    "推理说明": item.get("推理说明", "")
+                }
+            }
+
+        # 组合推理的边(用虚拟节点表示组合)
+        for item in reasoning.get("组合推理", []):
+            members = item.get("组合成员", [])
+            member_types = item.get("成员类型", [])
+            probability = item.get("可能性", 0)
+
+            # 创建组合虚拟节点(排序成员以保证唯一性)
+            # 将成员和类型配对后排序
+            member_pairs = list(zip(members, member_types)) if len(member_types) == len(members) else [(m, "关键点") for m in members]
+            sorted_pairs = sorted(member_pairs, key=lambda x: x[0])
+            sorted_members = [p[0] for p in sorted_pairs]
+            sorted_types = [p[1] for p in sorted_pairs]
+
+            # 组合名称和ID包含类型信息
+            combo_parts = [f"{sorted_types[i]}:{m}" for i, m in enumerate(sorted_members)]
+            combo_name = " + ".join(combo_parts)
+            combo_node_id = f"帖子:组合:组合:{combo_name}"
+            if combo_node_id not in nodes:
+                nodes[combo_node_id] = {
+                    "name": combo_name,
+                    "type": "组合",
+                    "dimension": "组合",
+                    "domain": "帖子",
+                    "detail": {
+                        "成员": sorted_members,
+                        "成员类型": sorted_types
+                    }
+                }
+
+            # 组合节点到目标的边
+            edge_id = f"{combo_node_id}|推导|{target_node_id}"
+            edges[edge_id] = {
+                "source": combo_node_id,
+                "target": target_node_id,
+                "type": "推导",
+                "score": probability,
+                "detail": {
+                    "推理类型": "组合推理",
+                    "正向概率": item.get("正向概率", 0),
+                    "反向概率": item.get("反向概率", 0),
+                    "协同增益": item.get("协同效应分析", {}).get("协同增益", 0),
+                    "推理说明": item.get("推理说明", "")
+                }
+            }
+
+            # 成员到组合节点的边
+            for i, member in enumerate(sorted_members):
+                m_type = sorted_types[i]
+                m_node_id = f"帖子:{m_type}:标签:{member}"
+                m_edge_id = f"{m_node_id}|组成|{combo_node_id}"
+                if m_edge_id not in edges:
+                    edges[m_edge_id] = {
+                        "source": m_node_id,
+                        "target": combo_node_id,
+                        "type": "组成",
+                        "score": 1.0,
+                        "detail": {}
+                    }
+
+    return {
+        "meta": {
+            "postId": post_id,
+            "type": "推导图谱",
+            "stats": {
+                "nodeCount": len(nodes),
+                "edgeCount": len(edges)
+            }
+        },
+        "nodes": nodes,
+        "edges": edges
+    }
+
+
 # ===== 辅助函数 =====
 
 def get_all_target_names(post_graph: Dict) -> List[str]:
@@ -558,9 +710,16 @@ async def main(
     with open(output_file, "w", encoding="utf-8") as f:
         json.dump(merged_output, f, ensure_ascii=False, indent=2)
 
+    # 生成推导关系图谱
+    graph_output = build_origin_graph(all_results, actual_post_id)
+    graph_file = output_dir / f"{actual_post_id}_推导图谱.json"
+    with open(graph_file, "w", encoding="utf-8") as f:
+        json.dump(graph_output, f, ensure_ascii=False, indent=2)
+
     print("\n" + "=" * 60)
     print(f"完成! 共分析 {len(target_names)} 个目标特征")
-    print(f"结果已保存到: {output_file}")
+    print(f"分析结果: {output_file}")
+    print(f"推导图谱: {graph_file}")
     if log_url:
         print(f"Trace: {log_url}")