build_match_graph.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 从匹配结果中构建帖子与人设的节点边关系图
  5. 输入:
  6. 1. filtered_results目录下的匹配结果文件
  7. 2. 节点列表.json
  8. 3. 边关系.json
  9. 输出:
  10. 1. match_graph目录下的节点边关系文件
  11. """
  12. import json
  13. from pathlib import Path
  14. from typing import Dict, List, Set, Any, Optional
  15. import sys
  16. # 添加项目根目录到路径
  17. project_root = Path(__file__).parent.parent.parent
  18. sys.path.insert(0, str(project_root))
  19. from script.data_processing.path_config import PathConfig
  20. def build_post_node_id(dimension: str, node_type: str, name: str) -> str:
  21. """构建帖子节点ID"""
  22. return f"帖子_{dimension}_{node_type}_{name}"
  23. def build_persona_node_id(dimension: str, node_type: str, name: str) -> str:
  24. """构建人设节点ID"""
  25. return f"{dimension}_{node_type}_{name}"
  26. def extract_matched_nodes_and_edges(filtered_data: Dict) -> tuple:
  27. """
  28. 从匹配结果中提取帖子节点、人设节点和匹配边
  29. Args:
  30. filtered_data: 匹配结果数据
  31. Returns:
  32. (帖子节点列表, 人设节点ID集合, 匹配边列表)
  33. """
  34. post_nodes = []
  35. persona_node_ids = set()
  36. match_edges = []
  37. how_result = filtered_data.get("how解构结果", {})
  38. # 维度映射
  39. dimension_mapping = {
  40. "灵感点列表": "灵感点",
  41. "目的点列表": "目的点",
  42. "关键点列表": "关键点"
  43. }
  44. for list_key, dimension in dimension_mapping.items():
  45. points = how_result.get(list_key, [])
  46. for point in points:
  47. # 遍历how步骤列表
  48. how_steps = point.get("how步骤列表", [])
  49. for step in how_steps:
  50. features = step.get("特征列表", [])
  51. for feature in features:
  52. feature_name = feature.get("特征名称", "")
  53. weight = feature.get("权重", 0)
  54. match_results = feature.get("匹配结果", [])
  55. if not feature_name:
  56. continue
  57. # 如果有匹配结果,创建帖子节点和匹配边
  58. if match_results:
  59. # 创建帖子节点(标签类型)
  60. post_node_id = build_post_node_id(dimension, "标签", feature_name)
  61. post_node = {
  62. "节点ID": post_node_id,
  63. "节点名称": feature_name,
  64. "节点类型": "标签",
  65. "节点层级": dimension,
  66. "权重": weight
  67. }
  68. # 避免重复添加
  69. if not any(n["节点ID"] == post_node_id for n in post_nodes):
  70. post_nodes.append(post_node)
  71. # 处理每个匹配结果
  72. for match in match_results:
  73. persona_name = match.get("人设特征名称", "")
  74. persona_dimension = match.get("人设特征层级", "")
  75. persona_type = match.get("特征类型", "标签")
  76. match_detail = match.get("匹配结果", {})
  77. if not persona_name or not persona_dimension:
  78. continue
  79. # 构建人设节点ID
  80. persona_node_id = build_persona_node_id(
  81. persona_dimension, persona_type, persona_name
  82. )
  83. persona_node_ids.add(persona_node_id)
  84. # 创建匹配边
  85. match_edge = {
  86. "源节点ID": post_node_id,
  87. "目标节点ID": persona_node_id,
  88. "边类型": "匹配",
  89. "边详情": {
  90. "相似度": match_detail.get("相似度", 0),
  91. "说明": match_detail.get("说明", "")
  92. }
  93. }
  94. match_edges.append(match_edge)
  95. return post_nodes, persona_node_ids, match_edges
  96. def get_persona_nodes_details(
  97. persona_node_ids: Set[str],
  98. nodes_data: Dict
  99. ) -> List[Dict]:
  100. """
  101. 从节点列表中获取人设节点的详细信息
  102. Args:
  103. persona_node_ids: 人设节点ID集合
  104. nodes_data: 节点列表数据
  105. Returns:
  106. 人设节点详情列表
  107. """
  108. persona_nodes = []
  109. all_nodes = nodes_data.get("节点列表", [])
  110. for node in all_nodes:
  111. if node["节点ID"] in persona_node_ids:
  112. persona_nodes.append(node)
  113. return persona_nodes
  114. def get_edges_between_nodes(
  115. node_ids: Set[str],
  116. edges_data: Dict
  117. ) -> List[Dict]:
  118. """
  119. 获取指定节点之间的边关系
  120. Args:
  121. node_ids: 节点ID集合
  122. edges_data: 边关系数据
  123. Returns:
  124. 节点之间的边列表
  125. """
  126. edges_between = []
  127. all_edges = edges_data.get("边列表", [])
  128. for edge in all_edges:
  129. source_id = edge["源节点ID"]
  130. target_id = edge["目标节点ID"]
  131. # 两个节点都在集合中
  132. if source_id in node_ids and target_id in node_ids:
  133. edges_between.append(edge)
  134. return edges_between
  135. def create_mirrored_post_edges(
  136. match_edges: List[Dict],
  137. persona_edges: List[Dict]
  138. ) -> List[Dict]:
  139. """
  140. 根据人设节点之间的边,创建帖子节点之间的镜像边
  141. 逻辑:如果人设节点A和B之间有边,且帖子节点X匹配A,帖子节点Y匹配B,
  142. 则创建帖子节点X和Y之间的镜像边
  143. Args:
  144. match_edges: 匹配边列表(帖子节点 -> 人设节点)
  145. persona_edges: 人设节点之间的边列表
  146. Returns:
  147. 帖子节点之间的镜像边列表
  148. """
  149. # 构建人设节点到帖子节点的反向映射
  150. # persona_id -> [post_id1, post_id2, ...]
  151. persona_to_posts = {}
  152. for edge in match_edges:
  153. post_id = edge["源节点ID"]
  154. persona_id = edge["目标节点ID"]
  155. if persona_id not in persona_to_posts:
  156. persona_to_posts[persona_id] = []
  157. if post_id not in persona_to_posts[persona_id]:
  158. persona_to_posts[persona_id].append(post_id)
  159. # 根据人设边创建帖子镜像边
  160. post_edges = []
  161. seen_edges = set()
  162. for persona_edge in persona_edges:
  163. source_persona = persona_edge["源节点ID"]
  164. target_persona = persona_edge["目标节点ID"]
  165. edge_type = persona_edge["边类型"]
  166. # 获取匹配到这两个人设节点的帖子节点
  167. source_posts = persona_to_posts.get(source_persona, [])
  168. target_posts = persona_to_posts.get(target_persona, [])
  169. # 为每对帖子节点创建镜像边
  170. for src_post in source_posts:
  171. for tgt_post in target_posts:
  172. if src_post == tgt_post:
  173. continue
  174. # 使用排序后的key避免重复(A-B 和 B-A 视为同一条边)
  175. edge_key = tuple(sorted([src_post, tgt_post])) + (edge_type,)
  176. if edge_key in seen_edges:
  177. continue
  178. seen_edges.add(edge_key)
  179. post_edge = {
  180. "源节点ID": src_post,
  181. "目标节点ID": tgt_post,
  182. "边类型": f"镜像_{edge_type}", # 标记为镜像边
  183. "边详情": {
  184. "原始边类型": edge_type,
  185. "源人设节点": source_persona,
  186. "目标人设节点": target_persona
  187. }
  188. }
  189. post_edges.append(post_edge)
  190. return post_edges
  191. def expand_one_layer(
  192. node_ids: Set[str],
  193. edges_data: Dict,
  194. nodes_data: Dict,
  195. edge_types: List[str] = None,
  196. direction: str = "both"
  197. ) -> tuple:
  198. """
  199. 从指定节点扩展一层,获取相邻节点和连接边
  200. Args:
  201. node_ids: 起始节点ID集合
  202. edges_data: 边关系数据
  203. nodes_data: 节点列表数据
  204. edge_types: 要扩展的边类型列表,None表示所有类型
  205. direction: 扩展方向
  206. - "outgoing": 只沿出边扩展(源节点在集合中,扩展到目标节点)
  207. - "incoming": 只沿入边扩展(目标节点在集合中,扩展到源节点)
  208. - "both": 双向扩展
  209. Returns:
  210. (扩展的节点列表, 扩展的边列表, 扩展的节点ID集合)
  211. """
  212. expanded_node_ids = set()
  213. expanded_edges = []
  214. all_edges = edges_data.get("边列表", [])
  215. # 找出所有与起始节点相连的边和节点
  216. for edge in all_edges:
  217. # 过滤边类型
  218. if edge_types and edge["边类型"] not in edge_types:
  219. continue
  220. source_id = edge["源节点ID"]
  221. target_id = edge["目标节点ID"]
  222. # 沿出边扩展:源节点在集合中,扩展到目标节点
  223. if direction in ["outgoing", "both"]:
  224. if source_id in node_ids and target_id not in node_ids:
  225. expanded_node_ids.add(target_id)
  226. expanded_edges.append(edge)
  227. # 沿入边扩展:目标节点在集合中,扩展到源节点
  228. if direction in ["incoming", "both"]:
  229. if target_id in node_ids and source_id not in node_ids:
  230. expanded_node_ids.add(source_id)
  231. expanded_edges.append(edge)
  232. # 获取扩展节点的详情
  233. expanded_nodes = []
  234. all_nodes = nodes_data.get("节点列表", [])
  235. for node in all_nodes:
  236. if node["节点ID"] in expanded_node_ids:
  237. # 标记为扩展节点
  238. node_copy = node.copy()
  239. node_copy["是否扩展"] = True
  240. expanded_nodes.append(node_copy)
  241. return expanded_nodes, expanded_edges, expanded_node_ids
  242. def process_filtered_result(
  243. filtered_file: Path,
  244. nodes_data: Dict,
  245. edges_data: Dict,
  246. output_dir: Path
  247. ) -> Dict:
  248. """
  249. 处理单个匹配结果文件
  250. Args:
  251. filtered_file: 匹配结果文件路径
  252. nodes_data: 节点列表数据
  253. edges_data: 边关系数据
  254. output_dir: 输出目录
  255. Returns:
  256. 处理结果统计
  257. """
  258. # 读取匹配结果
  259. with open(filtered_file, "r", encoding="utf-8") as f:
  260. filtered_data = json.load(f)
  261. post_id = filtered_data.get("帖子id", "")
  262. post_detail = filtered_data.get("帖子详情", {})
  263. post_title = post_detail.get("title", "")
  264. # 提取节点和边
  265. post_nodes, persona_node_ids, match_edges = extract_matched_nodes_and_edges(filtered_data)
  266. # 获取人设节点详情(直接匹配的,标记为非扩展)
  267. persona_nodes = get_persona_nodes_details(persona_node_ids, nodes_data)
  268. for node in persona_nodes:
  269. node["是否扩展"] = False
  270. # 获取人设节点之间的边
  271. persona_edges = get_edges_between_nodes(persona_node_ids, edges_data)
  272. # 创建帖子节点之间的镜像边(基于人设边的投影)
  273. post_edges = create_mirrored_post_edges(match_edges, persona_edges)
  274. # 合并节点列表(不扩展,只保留直接匹配的节点)
  275. all_nodes = post_nodes + persona_nodes
  276. # 合并边列表
  277. all_edges = match_edges + persona_edges + post_edges
  278. # 去重边
  279. seen_edges = set()
  280. unique_edges = []
  281. for edge in all_edges:
  282. edge_key = (edge["源节点ID"], edge["目标节点ID"], edge["边类型"])
  283. if edge_key not in seen_edges:
  284. seen_edges.add(edge_key)
  285. unique_edges.append(edge)
  286. all_edges = unique_edges
  287. # 构建节点边索引
  288. edges_by_node = {}
  289. for edge in all_edges:
  290. source_id = edge["源节点ID"]
  291. target_id = edge["目标节点ID"]
  292. edge_type = edge["边类型"]
  293. if source_id not in edges_by_node:
  294. edges_by_node[source_id] = {}
  295. if edge_type not in edges_by_node[source_id]:
  296. edges_by_node[source_id][edge_type] = {}
  297. edges_by_node[source_id][edge_type][target_id] = edge
  298. # 构建输出数据
  299. output_data = {
  300. "说明": {
  301. "帖子ID": post_id,
  302. "帖子标题": post_title,
  303. "描述": "帖子与人设的节点匹配关系",
  304. "统计": {
  305. "帖子节点数": len(post_nodes),
  306. "人设节点数": len(persona_nodes),
  307. "匹配边数": len(match_edges),
  308. "人设节点间边数": len(persona_edges),
  309. "帖子节点间边数": len(post_edges),
  310. "总节点数": len(all_nodes),
  311. "总边数": len(all_edges)
  312. }
  313. },
  314. "帖子节点列表": post_nodes,
  315. "人设节点列表": persona_nodes,
  316. "匹配边列表": match_edges,
  317. "人设节点间边列表": persona_edges,
  318. "帖子节点间边列表": post_edges,
  319. "节点列表": all_nodes,
  320. "边列表": all_edges,
  321. "节点边索引": edges_by_node
  322. }
  323. # 保存输出文件
  324. output_file = output_dir / f"{post_id}_match_graph.json"
  325. with open(output_file, "w", encoding="utf-8") as f:
  326. json.dump(output_data, f, ensure_ascii=False, indent=2)
  327. return {
  328. "帖子ID": post_id,
  329. "帖子节点数": len(post_nodes),
  330. "人设节点数": len(persona_nodes),
  331. "匹配边数": len(match_edges),
  332. "人设节点间边数": len(persona_edges),
  333. "帖子节点间边数": len(post_edges),
  334. "总节点数": len(all_nodes),
  335. "总边数": len(all_edges),
  336. "输出文件": str(output_file)
  337. }
  338. def main():
  339. # 使用路径配置
  340. config = PathConfig()
  341. config.ensure_dirs()
  342. print(f"账号: {config.account_name}")
  343. print(f"输出版本: {config.output_version}")
  344. print()
  345. # 输入文件/目录
  346. filtered_results_dir = config.intermediate_dir / "filtered_results"
  347. nodes_file = config.intermediate_dir / "节点列表.json"
  348. edges_file = config.intermediate_dir / "边关系.json"
  349. # 输出目录
  350. output_dir = config.intermediate_dir / "match_graph"
  351. output_dir.mkdir(parents=True, exist_ok=True)
  352. print(f"输入:")
  353. print(f" 匹配结果目录: {filtered_results_dir}")
  354. print(f" 节点列表: {nodes_file}")
  355. print(f" 边关系: {edges_file}")
  356. print(f"\n输出目录: {output_dir}")
  357. print()
  358. # 读取节点和边数据
  359. print("正在读取节点列表...")
  360. with open(nodes_file, "r", encoding="utf-8") as f:
  361. nodes_data = json.load(f)
  362. print(f" 共 {len(nodes_data.get('节点列表', []))} 个节点")
  363. print("正在读取边关系...")
  364. with open(edges_file, "r", encoding="utf-8") as f:
  365. edges_data = json.load(f)
  366. print(f" 共 {len(edges_data.get('边列表', []))} 条边")
  367. # 处理所有匹配结果文件
  368. print("\n" + "="*60)
  369. print("处理匹配结果文件...")
  370. filtered_files = list(filtered_results_dir.glob("*_filtered.json"))
  371. print(f"找到 {len(filtered_files)} 个匹配结果文件")
  372. results = []
  373. for i, filtered_file in enumerate(filtered_files, 1):
  374. print(f"\n[{i}/{len(filtered_files)}] 处理: {filtered_file.name}")
  375. result = process_filtered_result(filtered_file, nodes_data, edges_data, output_dir)
  376. results.append(result)
  377. print(f" 帖子节点: {result['帖子节点数']}, 人设节点: {result['人设节点数']}")
  378. print(f" 匹配边: {result['匹配边数']}, 人设边: {result['人设节点间边数']}, 帖子边: {result['帖子节点间边数']}")
  379. # 汇总统计
  380. print("\n" + "="*60)
  381. print("处理完成!")
  382. print(f"\n汇总:")
  383. print(f" 处理文件数: {len(results)}")
  384. total_post = sum(r['帖子节点数'] for r in results)
  385. total_persona = sum(r['人设节点数'] for r in results)
  386. total_match = sum(r['匹配边数'] for r in results)
  387. total_persona_edges = sum(r['人设节点间边数'] for r in results)
  388. total_post_edges = sum(r['帖子节点间边数'] for r in results)
  389. print(f" 总帖子节点: {total_post}")
  390. print(f" 总人设节点: {total_persona}")
  391. print(f" 总匹配边: {total_match}")
  392. print(f" 总人设边: {total_persona_edges}")
  393. print(f" 总帖子边: {total_post_edges}")
  394. print(f"\n输出目录: {output_dir}")
  395. if __name__ == "__main__":
  396. main()