build_post_tree.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 构建帖子树的中间数据
  5. 输入:match_graph/*.json, results/*.json
  6. 输出:match_graph/post_trees.json(包含所有帖子的树结构)
  7. """
  8. import json
  9. from pathlib import Path
  10. import sys
  11. # 添加项目根目录到路径
  12. project_root = Path(__file__).parent.parent.parent
  13. sys.path.insert(0, str(project_root))
  14. from script.data_processing.path_config import PathConfig
  15. def build_post_trees():
  16. """构建所有帖子的树数据"""
  17. config = PathConfig()
  18. print(f"账号: {config.account_name}")
  19. print(f"输出版本: {config.output_version}")
  20. print()
  21. match_graph_dir = config.intermediate_dir / "match_graph"
  22. results_dir = config.intermediate_dir.parent / "results"
  23. output_file = match_graph_dir / "post_trees.json"
  24. # 读取所有匹配图谱文件
  25. graph_files = sorted(match_graph_dir.glob("*_match_graph.json"))
  26. print(f"找到 {len(graph_files)} 个匹配图谱文件")
  27. all_post_trees = []
  28. for i, graph_file in enumerate(graph_files, 1):
  29. print(f"\n[{i}/{len(graph_files)}] 处理: {graph_file.name}")
  30. with open(graph_file, "r", encoding="utf-8") as f:
  31. match_graph_data = json.load(f)
  32. post_id = match_graph_data["说明"]["帖子ID"]
  33. post_title = match_graph_data["说明"].get("帖子标题", "")
  34. # 读取完整帖子详情
  35. post_detail = {
  36. "title": post_title,
  37. "post_id": post_id
  38. }
  39. how_file = results_dir / f"{post_id}_how.json"
  40. if how_file.exists():
  41. with open(how_file, "r", encoding="utf-8") as f:
  42. how_data = json.load(f)
  43. if "帖子详情" in how_data:
  44. post_detail = how_data["帖子详情"]
  45. post_detail["post_id"] = post_id
  46. print(f" 读取帖子详情: {how_file.name}")
  47. # 获取帖子点和帖子标签
  48. post_points = match_graph_data.get("帖子点节点列表", [])
  49. post_tags = match_graph_data.get("帖子标签节点列表", [])
  50. belong_edges = match_graph_data.get("帖子属于边列表", [])
  51. print(f" 帖子点: {len(post_points)}, 帖子标签: {len(post_tags)}, 属于边: {len(belong_edges)}")
  52. # 构建树结构
  53. # 维度颜色
  54. dim_colors = {
  55. "灵感点": "#f39c12",
  56. "目的点": "#3498db",
  57. "关键点": "#9b59b6"
  58. }
  59. # 构建节点映射
  60. point_map = {}
  61. for n in post_points:
  62. point_map[n["节点ID"]] = {
  63. "id": n["节点ID"],
  64. "name": n["节点名称"],
  65. "nodeType": "点",
  66. "level": n.get("节点层级", ""),
  67. "dimColor": dim_colors.get(n.get("节点层级", ""), "#888"),
  68. "description": n.get("描述", ""),
  69. "children": []
  70. }
  71. tag_map = {}
  72. for n in post_tags:
  73. tag_map[n["节点ID"]] = {
  74. "id": n["节点ID"],
  75. "name": n["节点名称"],
  76. "nodeType": "标签",
  77. "level": n.get("节点层级", ""),
  78. "dimColor": dim_colors.get(n.get("节点层级", ""), "#888"),
  79. "weight": n.get("权重", 0),
  80. "children": []
  81. }
  82. # 根据属于边,把标签挂到点下面
  83. for e in belong_edges:
  84. tag_node = tag_map.get(e["源节点ID"])
  85. point_node = point_map.get(e["目标节点ID"])
  86. if tag_node and point_node:
  87. point_node["children"].append(tag_node)
  88. # 按维度分组点节点
  89. dimensions = ["灵感点", "目的点", "关键点"]
  90. dimension_children = []
  91. for dim in dimensions:
  92. dim_points = [
  93. point_map[n["节点ID"]]
  94. for n in post_points
  95. if n.get("节点层级") == dim and n["节点ID"] in point_map
  96. ]
  97. if dim_points:
  98. dim_node = {
  99. "id": f"dim_{dim}",
  100. "name": dim,
  101. "nodeType": "维度",
  102. "isDimension": True,
  103. "dimColor": dim_colors[dim],
  104. "children": dim_points
  105. }
  106. dimension_children.append(dim_node)
  107. # 根节点(帖子)
  108. root_node = {
  109. "id": f"post_{post_id}",
  110. "name": post_title[:20] + "..." if len(post_title) > 20 else post_title,
  111. "nodeType": "帖子",
  112. "isRoot": True,
  113. "postDetail": post_detail,
  114. "children": dimension_children
  115. }
  116. # 统计节点数
  117. total_nodes = 1 + len(dimension_children) # 根节点 + 维度节点
  118. for dim_node in dimension_children:
  119. total_nodes += len(dim_node["children"]) # 点节点
  120. for point_node in dim_node["children"]:
  121. total_nodes += len(point_node["children"]) # 标签节点
  122. post_tree = {
  123. "postId": post_id,
  124. "postTitle": post_title,
  125. "postDetail": post_detail,
  126. "root": root_node,
  127. "stats": {
  128. "totalNodes": total_nodes,
  129. "pointCount": len(post_points),
  130. "tagCount": len(post_tags)
  131. }
  132. }
  133. all_post_trees.append(post_tree)
  134. print(f" 构建完成: {total_nodes} 个节点")
  135. # 输出
  136. output_data = {
  137. "说明": {
  138. "描述": "帖子树结构数据(每个帖子一棵树)",
  139. "帖子数": len(all_post_trees)
  140. },
  141. "postTrees": all_post_trees
  142. }
  143. with open(output_file, "w", encoding="utf-8") as f:
  144. json.dump(output_data, f, ensure_ascii=False, indent=2)
  145. print()
  146. print("=" * 60)
  147. print(f"构建完成!")
  148. print(f" 帖子数: {len(all_post_trees)}")
  149. print(f" 输出文件: {output_file}")
  150. return output_file
  151. if __name__ == "__main__":
  152. build_post_trees()