|
|
@@ -84,7 +84,8 @@ def extract_matched_nodes_and_edges(filtered_data: Dict) -> tuple:
|
|
|
"节点名称": feature_name,
|
|
|
"节点类型": "标签",
|
|
|
"节点层级": dimension,
|
|
|
- "权重": weight
|
|
|
+ "权重": weight,
|
|
|
+ "source": "帖子"
|
|
|
}
|
|
|
|
|
|
# 避免重复添加
|
|
|
@@ -298,11 +299,146 @@ def expand_one_layer(
|
|
|
# 标记为扩展节点
|
|
|
node_copy = node.copy()
|
|
|
node_copy["是否扩展"] = True
|
|
|
+ node_copy["source"] = "人设"
|
|
|
expanded_nodes.append(node_copy)
|
|
|
|
|
|
return expanded_nodes, expanded_edges, expanded_node_ids
|
|
|
|
|
|
|
|
|
+def expand_and_filter_useful_nodes(
|
|
|
+ matched_persona_ids: Set[str],
|
|
|
+ match_edges: List[Dict],
|
|
|
+ edges_data: Dict,
|
|
|
+ nodes_data: Dict,
|
|
|
+ exclude_edge_types: List[str] = None
|
|
|
+) -> tuple:
|
|
|
+ """
|
|
|
+ 扩展人设节点一层,只保留能产生新帖子连线的扩展节点
|
|
|
+
|
|
|
+ 逻辑:如果扩展节点E连接了2个以上的已匹配人设节点,
|
|
|
+ 那么通过E可以产生新的帖子间连线,保留E
|
|
|
+
|
|
|
+ Args:
|
|
|
+ matched_persona_ids: 已匹配的人设节点ID集合
|
|
|
+ match_edges: 匹配边列表
|
|
|
+ edges_data: 边关系数据
|
|
|
+ nodes_data: 节点列表数据
|
|
|
+ exclude_edge_types: 要排除的边类型列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (有效扩展节点列表, 扩展边列表, 通过扩展节点的帖子镜像边列表)
|
|
|
+ """
|
|
|
+ if exclude_edge_types is None:
|
|
|
+ exclude_edge_types = []
|
|
|
+
|
|
|
+ all_edges = edges_data.get("边列表", [])
|
|
|
+
|
|
|
+ # 构建人设节点到帖子节点的映射
|
|
|
+ persona_to_posts = {}
|
|
|
+ for edge in match_edges:
|
|
|
+ post_id = edge["源节点ID"]
|
|
|
+ persona_id = edge["目标节点ID"]
|
|
|
+ if persona_id not in persona_to_posts:
|
|
|
+ persona_to_posts[persona_id] = []
|
|
|
+ if post_id not in persona_to_posts[persona_id]:
|
|
|
+ persona_to_posts[persona_id].append(post_id)
|
|
|
+
|
|
|
+ # 找出所有扩展节点及其连接的已匹配人设节点
|
|
|
+ # expanded_node_id -> [(matched_persona_id, edge), ...]
|
|
|
+ expanded_connections = {}
|
|
|
+
|
|
|
+ for edge in all_edges:
|
|
|
+ # 跳过排除的边类型
|
|
|
+ if edge["边类型"] in exclude_edge_types:
|
|
|
+ continue
|
|
|
+
|
|
|
+ source_id = edge["源节点ID"]
|
|
|
+ target_id = edge["目标节点ID"]
|
|
|
+
|
|
|
+ # 源节点是已匹配的,目标节点是扩展候选
|
|
|
+ if source_id in matched_persona_ids and target_id not in matched_persona_ids:
|
|
|
+ if target_id not in expanded_connections:
|
|
|
+ expanded_connections[target_id] = []
|
|
|
+ expanded_connections[target_id].append((source_id, edge))
|
|
|
+
|
|
|
+ # 目标节点是已匹配的,源节点是扩展候选
|
|
|
+ if target_id in matched_persona_ids and source_id not in matched_persona_ids:
|
|
|
+ if source_id not in expanded_connections:
|
|
|
+ expanded_connections[source_id] = []
|
|
|
+ expanded_connections[source_id].append((target_id, edge))
|
|
|
+
|
|
|
+ # 过滤:只保留连接2个以上已匹配人设节点的扩展节点
|
|
|
+ useful_expanded_ids = set()
|
|
|
+ useful_edges = []
|
|
|
+ post_mirror_edges = []
|
|
|
+ seen_mirror_edges = set()
|
|
|
+
|
|
|
+ for expanded_id, connections in expanded_connections.items():
|
|
|
+ connected_personas = list(set([c[0] for c in connections]))
|
|
|
+
|
|
|
+ if len(connected_personas) >= 2:
|
|
|
+ useful_expanded_ids.add(expanded_id)
|
|
|
+
|
|
|
+ # 收集边
|
|
|
+ for persona_id, edge in connections:
|
|
|
+ useful_edges.append(edge)
|
|
|
+
|
|
|
+ # 为通过此扩展节点连接的每对人设节点,创建帖子镜像边
|
|
|
+ for i, p1 in enumerate(connected_personas):
|
|
|
+ for p2 in connected_personas[i+1:]:
|
|
|
+ posts1 = persona_to_posts.get(p1, [])
|
|
|
+ posts2 = persona_to_posts.get(p2, [])
|
|
|
+
|
|
|
+ # 找出连接p1和p2的边类型
|
|
|
+ edge_types_p1 = [c[1]["边类型"] for c in connections if c[0] == p1]
|
|
|
+ edge_types_p2 = [c[1]["边类型"] for c in connections if c[0] == p2]
|
|
|
+ # 用第一个边类型作为代表
|
|
|
+ edge_type = edge_types_p1[0] if edge_types_p1 else (edge_types_p2[0] if edge_types_p2 else "扩展")
|
|
|
+
|
|
|
+ for post1 in posts1:
|
|
|
+ for post2 in posts2:
|
|
|
+ if post1 == post2:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 避免重复
|
|
|
+ edge_key = tuple(sorted([post1, post2])) + (f"二阶_{edge_type}",)
|
|
|
+ if edge_key in seen_mirror_edges:
|
|
|
+ continue
|
|
|
+ seen_mirror_edges.add(edge_key)
|
|
|
+
|
|
|
+ post_mirror_edges.append({
|
|
|
+ "源节点ID": post1,
|
|
|
+ "目标节点ID": post2,
|
|
|
+ "边类型": f"二阶_{edge_type}",
|
|
|
+ "边详情": {
|
|
|
+ "原始边类型": edge_type,
|
|
|
+ "扩展节点": expanded_id,
|
|
|
+ "源人设节点": p1,
|
|
|
+ "目标人设节点": p2
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ # 获取扩展节点详情
|
|
|
+ useful_expanded_nodes = []
|
|
|
+ all_nodes = nodes_data.get("节点列表", [])
|
|
|
+ for node in all_nodes:
|
|
|
+ if node["节点ID"] in useful_expanded_ids:
|
|
|
+ node_copy = node.copy()
|
|
|
+ node_copy["是否扩展"] = True
|
|
|
+ useful_expanded_nodes.append(node_copy)
|
|
|
+
|
|
|
+ # 边去重
|
|
|
+ seen_edges = set()
|
|
|
+ unique_edges = []
|
|
|
+ for edge in useful_edges:
|
|
|
+ edge_key = (edge["源节点ID"], edge["目标节点ID"], edge["边类型"])
|
|
|
+ if edge_key not in seen_edges:
|
|
|
+ seen_edges.add(edge_key)
|
|
|
+ unique_edges.append(edge)
|
|
|
+
|
|
|
+ return useful_expanded_nodes, unique_edges, post_mirror_edges
|
|
|
+
|
|
|
+
|
|
|
def process_filtered_result(
|
|
|
filtered_file: Path,
|
|
|
nodes_data: Dict,
|
|
|
@@ -336,18 +472,120 @@ def process_filtered_result(
|
|
|
persona_nodes = get_persona_nodes_details(persona_node_ids, nodes_data)
|
|
|
for node in persona_nodes:
|
|
|
node["是否扩展"] = False
|
|
|
+ node["source"] = "人设"
|
|
|
|
|
|
# 获取人设节点之间的边
|
|
|
persona_edges = get_edges_between_nodes(persona_node_ids, edges_data)
|
|
|
|
|
|
- # 创建帖子节点之间的镜像边(基于人设边的投影)
|
|
|
+ # 创建帖子节点之间的镜像边(基于直接人设边的投影)
|
|
|
post_edges = create_mirrored_post_edges(match_edges, persona_edges)
|
|
|
|
|
|
- # 合并节点列表(不扩展,只保留直接匹配的节点)
|
|
|
- all_nodes = post_nodes + persona_nodes
|
|
|
+ # 扩展人设节点一层,只对标签类型的节点通过"属于"边扩展到分类
|
|
|
+ # 过滤出标签类型的人设节点(只有标签才能"属于"分类)
|
|
|
+ tag_persona_ids = {pid for pid in persona_node_ids if "_标签_" in pid}
|
|
|
+ expanded_nodes, expanded_edges, _ = expand_one_layer(
|
|
|
+ tag_persona_ids, edges_data, nodes_data,
|
|
|
+ edge_types=["属于"],
|
|
|
+ direction="outgoing" # 只向外扩展:标签->分类
|
|
|
+ )
|
|
|
+
|
|
|
+ # 创建通过扩展节点的帖子镜像边(正确逻辑)
|
|
|
+ # 逻辑:帖子->标签->分类,分类之间有边,则对应帖子产生二阶边
|
|
|
+
|
|
|
+ # 1. 构建 标签 -> 帖子列表 的映射
|
|
|
+ tag_to_posts = {}
|
|
|
+ for edge in match_edges:
|
|
|
+ post_node_id = edge["源节点ID"]
|
|
|
+ tag_id = edge["目标节点ID"]
|
|
|
+ if tag_id not in tag_to_posts:
|
|
|
+ tag_to_posts[tag_id] = []
|
|
|
+ if post_node_id not in tag_to_posts[tag_id]:
|
|
|
+ tag_to_posts[tag_id].append(post_node_id)
|
|
|
+
|
|
|
+ # 2. 构建 分类 -> 标签列表 的映射(通过属于边)
|
|
|
+ expanded_node_ids = set(n["节点ID"] for n in expanded_nodes)
|
|
|
+ category_to_tags = {} # 分类 -> [连接的标签]
|
|
|
+ for edge in expanded_edges:
|
|
|
+ src, tgt = edge["源节点ID"], edge["目标节点ID"]
|
|
|
+ # 属于边:标签 -> 分类
|
|
|
+ if tgt in expanded_node_ids and src in persona_node_ids:
|
|
|
+ if tgt not in category_to_tags:
|
|
|
+ category_to_tags[tgt] = []
|
|
|
+ if src not in category_to_tags[tgt]:
|
|
|
+ category_to_tags[tgt].append(src)
|
|
|
+
|
|
|
+ # 3. 获取扩展节点(分类)之间的边
|
|
|
+ category_edges = []
|
|
|
+ for edge in edges_data.get("边列表", []):
|
|
|
+ src, tgt = edge["源节点ID"], edge["目标节点ID"]
|
|
|
+ # 两端都是扩展节点(分类)
|
|
|
+ if src in expanded_node_ids and tgt in expanded_node_ids:
|
|
|
+ category_edges.append(edge)
|
|
|
+
|
|
|
+ # 4. 基于分类之间的边,生成帖子之间的二阶镜像边
|
|
|
+ post_edges_via_expanded = []
|
|
|
+ seen_mirror = set()
|
|
|
+ for cat_edge in category_edges:
|
|
|
+ cat1, cat2 = cat_edge["源节点ID"], cat_edge["目标节点ID"]
|
|
|
+ edge_type = cat_edge["边类型"]
|
|
|
+
|
|
|
+ # 获取连接到这两个分类的标签
|
|
|
+ tags1 = category_to_tags.get(cat1, [])
|
|
|
+ tags2 = category_to_tags.get(cat2, [])
|
|
|
+
|
|
|
+ # 通过标签找到对应的帖子,产生二阶边
|
|
|
+ for tag1 in tags1:
|
|
|
+ for tag2 in tags2:
|
|
|
+ posts1 = tag_to_posts.get(tag1, [])
|
|
|
+ posts2 = tag_to_posts.get(tag2, [])
|
|
|
+ for post1 in posts1:
|
|
|
+ for post2 in posts2:
|
|
|
+ if post1 == post2:
|
|
|
+ continue
|
|
|
+ edge_key = tuple(sorted([post1, post2])) + (f"二阶_{edge_type}",)
|
|
|
+ if edge_key in seen_mirror:
|
|
|
+ continue
|
|
|
+ seen_mirror.add(edge_key)
|
|
|
+ post_edges_via_expanded.append({
|
|
|
+ "源节点ID": post1,
|
|
|
+ "目标节点ID": post2,
|
|
|
+ "边类型": f"二阶_{edge_type}",
|
|
|
+ "边详情": {
|
|
|
+ "原始边类型": edge_type,
|
|
|
+ "分类节点1": cat1,
|
|
|
+ "分类节点2": cat2,
|
|
|
+ "标签节点1": tag1,
|
|
|
+ "标签节点2": tag2
|
|
|
+ }
|
|
|
+ })
|
|
|
+
|
|
|
+ # 只保留对帖子连接有帮助的扩展节点和边
|
|
|
+ # 1. 找出产生了二阶帖子边的扩展节点(分类)
|
|
|
+ useful_expanded_ids = set()
|
|
|
+ for edge in post_edges_via_expanded:
|
|
|
+ cat1 = edge.get("边详情", {}).get("分类节点1")
|
|
|
+ cat2 = edge.get("边详情", {}).get("分类节点2")
|
|
|
+ if cat1:
|
|
|
+ useful_expanded_ids.add(cat1)
|
|
|
+ if cat2:
|
|
|
+ useful_expanded_ids.add(cat2)
|
|
|
+
|
|
|
+ # 2. 只保留有用的扩展节点
|
|
|
+ useful_expanded_nodes = [n for n in expanded_nodes if n["节点ID"] in useful_expanded_ids]
|
|
|
+
|
|
|
+ # 3. 只保留连接到有用扩展节点的属于边
|
|
|
+ useful_expanded_edges = [e for e in expanded_edges
|
|
|
+ if e["目标节点ID"] in useful_expanded_ids or e["源节点ID"] in useful_expanded_ids]
|
|
|
+
|
|
|
+ # 4. 只保留有用的分类之间的边(产生了二阶帖子边的)
|
|
|
+ useful_category_edges = [e for e in category_edges
|
|
|
+ if e["源节点ID"] in useful_expanded_ids and e["目标节点ID"] in useful_expanded_ids]
|
|
|
+
|
|
|
+ # 合并节点列表
|
|
|
+ all_nodes = post_nodes + persona_nodes + useful_expanded_nodes
|
|
|
|
|
|
# 合并边列表
|
|
|
- all_edges = match_edges + persona_edges + post_edges
|
|
|
+ all_edges = match_edges + persona_edges + post_edges + useful_expanded_edges + useful_category_edges + post_edges_via_expanded
|
|
|
# 去重边
|
|
|
seen_edges = set()
|
|
|
unique_edges = []
|
|
|
@@ -379,19 +617,25 @@ def process_filtered_result(
|
|
|
"描述": "帖子与人设的节点匹配关系",
|
|
|
"统计": {
|
|
|
"帖子节点数": len(post_nodes),
|
|
|
- "人设节点数": len(persona_nodes),
|
|
|
+ "人设节点数(直接匹配)": len(persona_nodes),
|
|
|
+ "扩展节点数(有效)": len(useful_expanded_nodes),
|
|
|
"匹配边数": len(match_edges),
|
|
|
"人设节点间边数": len(persona_edges),
|
|
|
- "帖子节点间边数": len(post_edges),
|
|
|
+ "扩展边数(有效)": len(useful_expanded_edges),
|
|
|
+ "帖子镜像边数(直接)": len(post_edges),
|
|
|
+ "帖子镜像边数(二阶)": len(post_edges_via_expanded),
|
|
|
"总节点数": len(all_nodes),
|
|
|
"总边数": len(all_edges)
|
|
|
}
|
|
|
},
|
|
|
"帖子节点列表": post_nodes,
|
|
|
"人设节点列表": persona_nodes,
|
|
|
+ "扩展节点列表": useful_expanded_nodes,
|
|
|
"匹配边列表": match_edges,
|
|
|
"人设节点间边列表": persona_edges,
|
|
|
- "帖子节点间边列表": post_edges,
|
|
|
+ "扩展边列表": useful_expanded_edges,
|
|
|
+ "帖子镜像边列表(直接)": post_edges,
|
|
|
+ "帖子镜像边列表(二阶)": post_edges_via_expanded,
|
|
|
"节点列表": all_nodes,
|
|
|
"边列表": all_edges,
|
|
|
"节点边索引": edges_by_node
|
|
|
@@ -406,9 +650,12 @@ def process_filtered_result(
|
|
|
"帖子ID": post_id,
|
|
|
"帖子节点数": len(post_nodes),
|
|
|
"人设节点数": len(persona_nodes),
|
|
|
+ "扩展节点数": len(useful_expanded_nodes),
|
|
|
"匹配边数": len(match_edges),
|
|
|
- "人设节点间边数": len(persona_edges),
|
|
|
- "帖子节点间边数": len(post_edges),
|
|
|
+ "人设边数": len(persona_edges),
|
|
|
+ "扩展边数": len(useful_expanded_edges),
|
|
|
+ "帖子边数(直接)": len(post_edges),
|
|
|
+ "帖子边数(二阶)": len(post_edges_via_expanded),
|
|
|
"总节点数": len(all_nodes),
|
|
|
"总边数": len(all_edges),
|
|
|
"输出文件": str(output_file)
|
|
|
@@ -463,8 +710,9 @@ def main():
|
|
|
print(f"\n[{i}/{len(filtered_files)}] 处理: {filtered_file.name}")
|
|
|
result = process_filtered_result(filtered_file, nodes_data, edges_data, output_dir)
|
|
|
results.append(result)
|
|
|
- print(f" 帖子节点: {result['帖子节点数']}, 人设节点: {result['人设节点数']}")
|
|
|
- print(f" 匹配边: {result['匹配边数']}, 人设边: {result['人设节点间边数']}, 帖子边: {result['帖子节点间边数']}")
|
|
|
+ print(f" 帖子节点: {result['帖子节点数']}, 人设节点: {result['人设节点数']}, 扩展节点: {result['扩展节点数']}")
|
|
|
+ print(f" 匹配边: {result['匹配边数']}, 人设边: {result['人设边数']}, 扩展边: {result['扩展边数']}")
|
|
|
+ print(f" 帖子边(直接): {result['帖子边数(直接)']}, 帖子边(二阶): {result['帖子边数(二阶)']}")
|
|
|
|
|
|
# 汇总统计
|
|
|
print("\n" + "="*60)
|
|
|
@@ -473,14 +721,20 @@ def main():
|
|
|
print(f" 处理文件数: {len(results)}")
|
|
|
total_post = sum(r['帖子节点数'] for r in results)
|
|
|
total_persona = sum(r['人设节点数'] for r in results)
|
|
|
+ total_expanded = sum(r['扩展节点数'] for r in results)
|
|
|
total_match = sum(r['匹配边数'] for r in results)
|
|
|
- total_persona_edges = sum(r['人设节点间边数'] for r in results)
|
|
|
- total_post_edges = sum(r['帖子节点间边数'] for r in results)
|
|
|
+ total_persona_edges = sum(r['人设边数'] for r in results)
|
|
|
+ total_expanded_edges = sum(r['扩展边数'] for r in results)
|
|
|
+ total_post_edges_direct = sum(r['帖子边数(直接)'] for r in results)
|
|
|
+ total_post_edges_2hop = sum(r['帖子边数(二阶)'] for r in results)
|
|
|
print(f" 总帖子节点: {total_post}")
|
|
|
print(f" 总人设节点: {total_persona}")
|
|
|
+ print(f" 总扩展节点: {total_expanded}")
|
|
|
print(f" 总匹配边: {total_match}")
|
|
|
print(f" 总人设边: {total_persona_edges}")
|
|
|
- print(f" 总帖子边: {total_post_edges}")
|
|
|
+ print(f" 总扩展边: {total_expanded_edges}")
|
|
|
+ print(f" 总帖子边(直接): {total_post_edges_direct}")
|
|
|
+ print(f" 总帖子边(二阶): {total_post_edges_2hop}")
|
|
|
print(f"\n输出目录: {output_dir}")
|
|
|
|
|
|
|