Prechádzať zdrojové kódy

feat: 优化人设图谱节点和边的数据结构

- 节点添加 postIds 字段,父分类递归包含所有子分类的帖子ID
- 节点添加 probGlobal(全局概率)和 probToParent(相对父节点概率)
- 根节点和维度节点也添加 postIds 和 postCount
- 包含边的 score 使用子节点的 probToParent
- 共现边统一使用 Jaccard 作为 score(包括点内分类共现边)
- 统一边的 detail 字段:postIds、postCount、sourcePostIds、targetPostIds

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
yangxiaohui 9 hodín pred
rodič
commit
56ee14c3f4
1 zmenil súbory, kde vykonal 174 pridanie a 38 odobranie
  1. 174 38
      script/data_processing/build_persona_graph.py

+ 174 - 38
script/data_processing/build_persona_graph.py

@@ -207,10 +207,7 @@ def extract_category_nodes_from_pattern(
             if isinstance(value, dict):
                 current_path = parent_path + [key]
 
-                # 获取帖子列表
-                post_ids = value.get("帖子列表", [])
-
-                # 构建节点来源
+                # 构建节点来源(只收集当前节点的特征)
                 node_sources = []
                 if "特征列表" in value:
                     for feature in value["特征列表"]:
@@ -220,14 +217,10 @@ def extract_category_nodes_from_pattern(
                             "postId": feature.get("帖子id", "")
                         }
                         node_sources.append(source)
-                else:
-                    node_sources = collect_sources_recursively(value)
 
-                # 计算帖子数
-                if post_ids:
-                    post_count = len(post_ids)
-                else:
-                    post_count = len(set(s.get("postId", "") for s in node_sources if s.get("postId")))
+                # 收集帖子ID列表(递归收集当前节点及所有子节点的帖子ID,去重)
+                all_sources = collect_sources_recursively(value)
+                unique_post_ids = list(set(s.get("postId", "") for s in all_sources if s.get("postId")))
 
                 # 构建节点
                 node_id = build_node_id("人设", dimension_name, "分类", key)
@@ -238,7 +231,8 @@ def extract_category_nodes_from_pattern(
                     name=key,
                     detail={
                         "parentPath": parent_path.copy(),
-                        "postCount": post_count,
+                        "postIds": unique_post_ids,
+                        "postCount": len(unique_post_ids),
                         "sources": node_sources
                     }
                 )
@@ -318,6 +312,7 @@ def extract_tag_nodes_from_pattern(
             name=tag_info["name"],
             detail={
                 "parentPath": tag_info["parentPath"],
+                "postIds": list(tag_info["postIds"]),
                 "postCount": len(tag_info["postIds"]),
                 "sources": tag_info["sources"]
             }
@@ -365,6 +360,10 @@ def extract_belong_contain_edges(
         parent_id = category_name_to_id.get(parent_name)
 
         if parent_id:
+            # 获取 source 和 target 的 postIds
+            child_post_ids = node_data["detail"].get("postIds", [])
+            parent_post_ids = nodes.get(parent_id, {}).get("detail", {}).get("postIds", [])
+
             # 属于边:子 → 父
             edge_id = build_edge_id(node_id, "属于", parent_id)
             edges[edge_id] = create_edge(
@@ -372,7 +371,10 @@ def extract_belong_contain_edges(
                 target=parent_id,
                 edge_type="属于",
                 score=1.0,
-                detail={}
+                detail={
+                    "sourcePostIds": child_post_ids,
+                    "targetPostIds": parent_post_ids
+                }
             )
 
             # 包含边:父 → 子
@@ -382,7 +384,10 @@ def extract_belong_contain_edges(
                 target=node_id,
                 edge_type="包含",
                 score=1.0,
-                detail={}
+                detail={
+                    "sourcePostIds": parent_post_ids,
+                    "targetPostIds": child_post_ids
+                }
             )
 
     return edges
@@ -390,10 +395,14 @@ def extract_belong_contain_edges(
 
 # ==================== 从关联分析提取分类共现边(跨点)====================
 
-def extract_category_cooccur_edges(associations_data: Dict) -> Dict[str, Dict]:
+def extract_category_cooccur_edges(associations_data: Dict, nodes: Dict[str, Dict]) -> Dict[str, Dict]:
     """
     从 dimension_associations_analysis.json 中提取分类共现边(跨点)
 
+    Args:
+        associations_data: 关联分析数据
+        nodes: 已构建的节点数据(用于获取节点的 postIds)
+
     Returns:
         { edgeId: edgeData }
     """
@@ -449,6 +458,10 @@ def extract_category_cooccur_edges(associations_data: Dict) -> Dict[str, Dict]:
                         # 使用 Jaccard 作为 score
                         jaccard = assoc.get("Jaccard相似度", 0)
 
+                        # 获取 source 和 target 的 postIds
+                        source_post_ids = nodes.get(source_node_id, {}).get("detail", {}).get("postIds", [])
+                        target_post_ids = nodes.get(target_node_id, {}).get("detail", {}).get("postIds", [])
+
                         edge_id = build_edge_id(source_node_id, "分类共现", target_node_id)
                         edges[edge_id] = create_edge(
                             source=source_node_id,
@@ -456,10 +469,12 @@ def extract_category_cooccur_edges(associations_data: Dict) -> Dict[str, Dict]:
                             edge_type="分类共现",
                             score=jaccard,
                             detail={
+                                "postIds": assoc.get("共同帖子ID", []),
+                                "postCount": assoc.get("共同帖子数", 0),
                                 "jaccard": jaccard,
                                 "overlapCoef": assoc.get("重叠系数", 0),
-                                "cooccurCount": assoc.get("共同帖子数", 0),
-                                "cooccurPosts": assoc.get("共同帖子ID", [])
+                                "sourcePostIds": source_post_ids,
+                                "targetPostIds": target_post_ids
                             }
                         )
 
@@ -468,10 +483,14 @@ def extract_category_cooccur_edges(associations_data: Dict) -> Dict[str, Dict]:
 
 # ==================== 从关联分析提取分类共现边(点内)====================
 
-def extract_intra_category_cooccur_edges(intra_data: Dict) -> Dict[str, Dict]:
+def extract_intra_category_cooccur_edges(intra_data: Dict, nodes: Dict[str, Dict]) -> Dict[str, Dict]:
     """
     从 intra_dimension_associations_analysis.json 中提取点内分类共现边
 
+    Args:
+        intra_data: 点内关联分析数据
+        nodes: 已构建的节点数据(用于获取节点的 postIds)
+
     Returns:
         { edgeId: edgeData }
     """
@@ -514,14 +533,30 @@ def extract_intra_category_cooccur_edges(intra_data: Dict) -> Dict[str, Dict]:
                         edges[edge_id]["detail"]["pointCount"] += point_count
                         edges[edge_id]["detail"]["pointNames"].extend(point_names)
                     else:
+                        # 获取 source 和 target 的 postIds
+                        cat1_post_ids = nodes.get(cat1_id, {}).get("detail", {}).get("postIds", [])
+                        cat2_post_ids = nodes.get(cat2_id, {}).get("detail", {}).get("postIds", [])
+
+                        # 计算 Jaccard(基于帖子)
+                        cat1_set = set(cat1_post_ids)
+                        cat2_set = set(cat2_post_ids)
+                        intersection = cat1_set & cat2_set
+                        union = cat1_set | cat2_set
+                        jaccard = round(len(intersection) / len(union), 4) if union else 0
+
                         edges[edge_id] = create_edge(
                             source=cat1_id,
                             target=cat2_id,
                             edge_type="分类共现",
-                            score=point_count,  # 先用点数作为 score,后续可归一化
+                            score=jaccard,
                             detail={
+                                "postIds": list(intersection),
+                                "postCount": len(intersection),
+                                "jaccard": jaccard,
                                 "pointCount": point_count,
-                                "pointNames": point_names.copy()
+                                "pointNames": point_names.copy(),
+                                "sourcePostIds": cat1_post_ids,
+                                "targetPostIds": cat2_post_ids
                             }
                         )
 
@@ -530,15 +565,19 @@ def extract_intra_category_cooccur_edges(intra_data: Dict) -> Dict[str, Dict]:
 
 # ==================== 从历史帖子提取标签共现边 ====================
 
-def extract_tag_cooccur_edges(historical_posts_dir: Path) -> Dict[str, Dict]:
+def extract_tag_cooccur_edges(historical_posts_dir: Path, nodes: Dict[str, Dict]) -> Dict[str, Dict]:
     """
     从历史帖子解构结果中提取标签共现边
 
+    Args:
+        historical_posts_dir: 历史帖子目录
+        nodes: 已构建的节点数据(用于获取标签的 postIds 计算 Jaccard)
+
     Returns:
         { edgeId: edgeData }
     """
     edges = {}
-    cooccur_map = {}  # (tag1_id, tag2_id, dimension) -> { cooccurPosts: set() }
+    cooccur_map = {}  # (tag1_id, tag2_id) -> { postIds: set() }
 
     if not historical_posts_dir.exists():
         print(f"  警告: 历史帖子目录不存在: {historical_posts_dir}")
@@ -632,27 +671,37 @@ def extract_tag_cooccur_edges(historical_posts_dir: Path) -> Dict[str, Dict]:
                         key = (tag1_id, tag2_id)
 
                         if key not in cooccur_map:
-                            cooccur_map[key] = {"cooccurPosts": set()}
+                            cooccur_map[key] = {"postIds": set()}
 
-                        cooccur_map[key]["cooccurPosts"].add(post_id)
+                        cooccur_map[key]["postIds"].add(post_id)
 
         except Exception as e:
             print(f"  警告: 处理文件 {file_path.name} 时出错: {e}")
 
     # 转换为边
     for (tag1_id, tag2_id), info in cooccur_map.items():
-        cooccur_posts = list(info["cooccurPosts"])
-        cooccur_count = len(cooccur_posts)
+        cooccur_post_ids = list(info["postIds"])
+        cooccur_count = len(cooccur_post_ids)
+
+        # 获取两个标签的帖子集合,计算 Jaccard
+        tag1_post_ids = nodes.get(tag1_id, {}).get("detail", {}).get("postIds", [])
+        tag2_post_ids = nodes.get(tag2_id, {}).get("detail", {}).get("postIds", [])
+
+        union_count = len(set(tag1_post_ids) | set(tag2_post_ids))
+        jaccard = round(cooccur_count / union_count, 4) if union_count > 0 else 0
 
         edge_id = build_edge_id(tag1_id, "标签共现", tag2_id)
         edges[edge_id] = create_edge(
             source=tag1_id,
             target=tag2_id,
             edge_type="标签共现",
-            score=cooccur_count,  # 先用共现次数,后续可归一化
+            score=jaccard,
             detail={
-                "cooccurCount": cooccur_count,
-                "cooccurPosts": cooccur_posts
+                "postIds": cooccur_post_ids,
+                "postCount": cooccur_count,
+                "jaccard": jaccard,
+                "sourcePostIds": tag1_post_ids,
+                "targetPostIds": tag2_post_ids
             }
         )
 
@@ -937,45 +986,66 @@ def main():
 
     # 分类共现边(跨点)
     print("\n提取分类共现边(跨点):")
-    category_cooccur_edges = extract_category_cooccur_edges(associations_data)
+    category_cooccur_edges = extract_category_cooccur_edges(associations_data, all_nodes)
     all_edges.update(category_cooccur_edges)
     print(f"  分类共现边: {len(category_cooccur_edges)}")
 
     # 分类共现边(点内)
     print("\n提取分类共现边(点内):")
-    intra_category_edges = extract_intra_category_cooccur_edges(intra_associations_data)
+    intra_category_edges = extract_intra_category_cooccur_edges(intra_associations_data, all_nodes)
     all_edges.update(intra_category_edges)
     print(f"  分类共现边: {len(intra_category_edges)}")
 
     # 标签共现边
     print("\n提取标签共现边:")
-    tag_cooccur_edges = extract_tag_cooccur_edges(historical_posts_dir)
+    tag_cooccur_edges = extract_tag_cooccur_edges(historical_posts_dir, all_nodes)
     all_edges.update(tag_cooccur_edges)
     print(f"  标签共现边: {len(tag_cooccur_edges)}")
 
     # ===== 添加根节点和维度节点 =====
     print("\n添加根节点和维度节点:")
 
+    # 收集所有帖子ID(用于根节点)
+    all_post_ids_for_root = set()
+    for node in all_nodes.values():
+        post_ids = node["detail"].get("postIds", [])
+        all_post_ids_for_root.update(post_ids)
+
     # 根节点
     root_id = "人设:人设:人设:人设"
+    root_post_ids = list(all_post_ids_for_root)
     all_nodes[root_id] = create_node(
         domain="人设",
         dimension="人设",
         node_type="人设",
         name="人设",
-        detail={}
+        detail={
+            "postIds": root_post_ids,
+            "postCount": len(root_post_ids)
+        }
     )
 
     # 维度节点 + 边
     dimensions = ["灵感点", "目的点", "关键点"]
     for dim in dimensions:
+        # 收集该维度下所有节点的帖子ID
+        dim_post_ids = set()
+        for node in all_nodes.values():
+            if node["dimension"] == dim:
+                post_ids = node["detail"].get("postIds", [])
+                dim_post_ids.update(post_ids)
+        dim_post_ids_list = list(dim_post_ids)
+
         dim_id = f"人设:{dim}:{dim}:{dim}"
         all_nodes[dim_id] = create_node(
             domain="人设",
             dimension=dim,
             node_type=dim,
             name=dim,
-            detail={}
+            detail={
+                "postIds": dim_post_ids_list,
+                "postCount": len(dim_post_ids_list)
+            }
         )
 
         # 维度 -> 根 的属于边
@@ -985,7 +1055,10 @@ def main():
             target=root_id,
             edge_type="属于",
             score=1.0,
-            detail={}
+            detail={
+                "sourcePostIds": dim_post_ids_list,
+                "targetPostIds": root_post_ids
+            }
         )
 
         # 根 -> 维度 的包含边
@@ -995,7 +1068,10 @@ def main():
             target=dim_id,
             edge_type="包含",
             score=1.0,
-            detail={}
+            detail={
+                "sourcePostIds": root_post_ids,
+                "targetPostIds": dim_post_ids_list
+            }
         )
 
         # 找该维度下的顶级分类(没有父节点的分类),添加边
@@ -1006,6 +1082,8 @@ def main():
         ]
 
         for cat_id, cat_data in dim_categories:
+            cat_post_ids = cat_data["detail"].get("postIds", [])
+
             # 顶级分类 -> 维度 的属于边
             edge_id = build_edge_id(cat_id, "属于", dim_id)
             all_edges[edge_id] = create_edge(
@@ -1013,7 +1091,10 @@ def main():
                 target=dim_id,
                 edge_type="属于",
                 score=1.0,
-                detail={}
+                detail={
+                    "sourcePostIds": cat_post_ids,
+                    "targetPostIds": dim_post_ids_list
+                }
             )
 
             # 维度 -> 顶级分类 的包含边
@@ -1023,7 +1104,10 @@ def main():
                 target=cat_id,
                 edge_type="包含",
                 score=1.0,
-                detail={}
+                detail={
+                    "sourcePostIds": dim_post_ids_list,
+                    "targetPostIds": cat_post_ids
+                }
             )
 
     print(f"  添加节点: 1 根节点 + 3 维度节点 = 4")
@@ -1039,6 +1123,58 @@ def main():
     for t, count in sorted(edge_type_counts.items(), key=lambda x: -x[1]):
         print(f"  {t}: {count}")
 
+    # ===== 计算节点概率 =====
+    print("\n" + "=" * 60)
+    print("计算节点概率...")
+
+    # 1. 计算总帖子数(所有帖子ID的并集)
+    all_post_ids = set()
+    for node in all_nodes.values():
+        post_ids = node["detail"].get("postIds", [])
+        all_post_ids.update(post_ids)
+    total_post_count = len(all_post_ids)
+    print(f"  总帖子数: {total_post_count}")
+
+    # 2. 为每个节点计算概率
+    for node_id, node in all_nodes.items():
+        post_count = node["detail"].get("postCount", 0)
+
+        # 全局概率
+        if total_post_count > 0:
+            node["detail"]["probGlobal"] = round(post_count / total_post_count, 4)
+        else:
+            node["detail"]["probGlobal"] = 0
+
+        # 相对父节点的概率
+        # 通过"属于"边找父节点
+        parent_edge_id = None
+        for edge_id, edge in all_edges.items():
+            if edge["source"] == node_id and edge["type"] == "属于":
+                parent_node_id = edge["target"]
+                parent_node = all_nodes.get(parent_node_id)
+                if parent_node:
+                    parent_post_count = parent_node["detail"].get("postCount", 0)
+                    if parent_post_count > 0:
+                        node["detail"]["probToParent"] = round(post_count / parent_post_count, 4)
+                    else:
+                        node["detail"]["probToParent"] = 0
+                break
+        else:
+            # 没有父节点(根节点)
+            node["detail"]["probToParent"] = 1.0
+
+    print(f"  已为 {len(all_nodes)} 个节点计算概率")
+
+    # 3. 更新"包含"边的分数(使用子节点的 probToParent)
+    contain_edge_updated = 0
+    for edge_id, edge in all_edges.items():
+        if edge["type"] == "包含":
+            target_node = all_nodes.get(edge["target"])
+            if target_node:
+                edge["score"] = target_node["detail"].get("probToParent", 1.0)
+                contain_edge_updated += 1
+    print(f"  已更新 {contain_edge_updated} 条包含边的分数")
+
     # ===== 构建索引 =====
     print("\n" + "=" * 60)
     print("构建索引...")