build_post_tree.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 构建帖子树的中间数据
  5. 输入:match_graph/*.json, results/*.json
  6. 输出:match_graph/post_trees.json(包含所有帖子的树结构)
  7. """
  8. import json
  9. from pathlib import Path
  10. import sys
  11. # 添加项目根目录到路径
  12. project_root = Path(__file__).parent.parent.parent
  13. sys.path.insert(0, str(project_root))
  14. from script.data_processing.path_config import PathConfig
  15. def build_post_trees():
  16. """构建所有帖子的树数据"""
  17. config = PathConfig()
  18. print(f"账号: {config.account_name}")
  19. print(f"输出版本: {config.output_version}")
  20. print()
  21. match_graph_dir = config.intermediate_dir / "match_graph"
  22. results_dir = config.intermediate_dir.parent / "results"
  23. output_file = match_graph_dir / "post_trees.json"
  24. # 读取所有匹配图谱文件
  25. graph_files = sorted(match_graph_dir.glob("*_match_graph.json"))
  26. print(f"找到 {len(graph_files)} 个匹配图谱文件")
  27. all_post_trees = []
  28. for i, graph_file in enumerate(graph_files, 1):
  29. print(f"\n[{i}/{len(graph_files)}] 处理: {graph_file.name}")
  30. with open(graph_file, "r", encoding="utf-8") as f:
  31. match_graph_data = json.load(f)
  32. post_id = match_graph_data["说明"]["帖子ID"]
  33. post_title = match_graph_data["说明"].get("帖子标题", "")
  34. # 读取完整帖子详情
  35. post_detail = {
  36. "title": post_title,
  37. "post_id": post_id
  38. }
  39. how_file = results_dir / f"{post_id}_how.json"
  40. if how_file.exists():
  41. with open(how_file, "r", encoding="utf-8") as f:
  42. how_data = json.load(f)
  43. if "帖子详情" in how_data:
  44. post_detail = how_data["帖子详情"]
  45. post_detail["post_id"] = post_id
  46. print(f" 读取帖子详情: {how_file.name}")
  47. # 获取帖子点和帖子标签
  48. post_points = match_graph_data.get("帖子点节点列表", [])
  49. post_tags = match_graph_data.get("帖子标签节点列表", [])
  50. belong_edges = match_graph_data.get("帖子属于边列表", [])
  51. # 获取匹配边(帖子标签 -> 人设标签)
  52. all_edges = match_graph_data.get("边列表", [])
  53. match_edges = [e for e in all_edges if e["边类型"].startswith("匹配_")]
  54. print(f" 帖子点: {len(post_points)}, 帖子标签: {len(post_tags)}, 属于边: {len(belong_edges)}, 匹配边: {len(match_edges)}")
  55. # 构建树结构
  56. # 维度颜色
  57. dim_colors = {
  58. "灵感点": "#f39c12",
  59. "目的点": "#3498db",
  60. "关键点": "#9b59b6"
  61. }
  62. # 构建节点映射
  63. point_map = {}
  64. for n in post_points:
  65. point_map[n["节点ID"]] = {
  66. "id": n["节点ID"],
  67. "name": n["节点名称"],
  68. "nodeType": "点",
  69. "level": n.get("节点层级", ""),
  70. "dimColor": dim_colors.get(n.get("节点层级", ""), "#888"),
  71. "description": n.get("描述", ""),
  72. "children": []
  73. }
  74. tag_map = {}
  75. for n in post_tags:
  76. tag_map[n["节点ID"]] = {
  77. "id": n["节点ID"],
  78. "name": n["节点名称"],
  79. "nodeType": "标签",
  80. "level": n.get("节点层级", ""),
  81. "dimColor": dim_colors.get(n.get("节点层级", ""), "#888"),
  82. "weight": n.get("权重", 0),
  83. "children": []
  84. }
  85. # 获取所有节点(用于查找扩展节点)
  86. all_nodes = match_graph_data.get("节点列表", [])
  87. expanded_nodes_map = {}
  88. for n in all_nodes:
  89. if n.get("是否扩展"):
  90. expanded_nodes_map[n["节点ID"]] = n
  91. # 构建人设节点之间的边关系(用于找扩展节点)
  92. # 边类型:属于、包含、分类共现等
  93. persona_edges = [e for e in all_edges if not e["边类型"].startswith("匹配_")]
  94. # 构建帖子标签到人设匹配的映射
  95. tag_to_persona_matches = {}
  96. direct_persona_ids = set() # 记录直接匹配的人设ID
  97. for e in match_edges:
  98. src_id = e["源节点ID"] # 帖子标签
  99. tgt_id = e["目标节点ID"] # 人设标签
  100. edge_type = e["边类型"] # 匹配_相同 或 匹配_相似
  101. edge_detail = e.get("边详情", {})
  102. similarity = edge_detail.get("相似度", 0)
  103. if src_id not in tag_to_persona_matches:
  104. tag_to_persona_matches[src_id] = []
  105. direct_persona_ids.add(tgt_id)
  106. # 从人设标签ID提取维度和名称
  107. persona_name = tgt_id
  108. persona_level = ""
  109. if "_标签_" in tgt_id:
  110. parts = tgt_id.split("_标签_")
  111. persona_level = parts[0]
  112. persona_name = parts[1] if len(parts) > 1 else tgt_id
  113. elif "_分类_" in tgt_id:
  114. parts = tgt_id.split("_分类_")
  115. persona_level = parts[0]
  116. persona_name = parts[1] if len(parts) > 1 else tgt_id
  117. # 判断原始节点类型(分类/标签)
  118. original_type = "标签" if "_标签_" in tgt_id else ("分类" if "_分类_" in tgt_id else "标签")
  119. persona_node = {
  120. "id": f"persona_{tgt_id}",
  121. "name": persona_name,
  122. "nodeType": "人设",
  123. "originalType": original_type, # 原始类型:分类或标签
  124. "personaId": tgt_id,
  125. "level": persona_level,
  126. "dimColor": dim_colors.get(persona_level, "#2ecc71"),
  127. "matchType": edge_type.replace("匹配_", ""),
  128. "similarity": similarity,
  129. "children": []
  130. }
  131. tag_to_persona_matches[src_id].append(persona_node)
  132. # 为每个直接匹配的人设节点找扩展节点(第二层)
  133. persona_to_expanded = {}
  134. for e in persona_edges:
  135. src_id = e["源节点ID"]
  136. tgt_id = e["目标节点ID"]
  137. edge_type = e["边类型"]
  138. # 如果源是直接匹配节点,目标是扩展节点
  139. if src_id in direct_persona_ids and tgt_id in expanded_nodes_map:
  140. if src_id not in persona_to_expanded:
  141. persona_to_expanded[src_id] = []
  142. exp_node = expanded_nodes_map[tgt_id]
  143. exp_name = exp_node.get("节点名称", tgt_id)
  144. exp_level = exp_node.get("节点层级", "")
  145. # 扩展节点的原始类型
  146. exp_original_type = exp_node.get("节点类型", "标签")
  147. expanded_node = {
  148. "id": f"expanded_{tgt_id}",
  149. "name": exp_name,
  150. "nodeType": "人设扩展",
  151. "originalType": exp_original_type, # 分类或标签
  152. "personaId": tgt_id,
  153. "level": exp_level,
  154. "dimColor": dim_colors.get(exp_level, "#2ecc71"),
  155. "edgeType": edge_type,
  156. "children": []
  157. }
  158. # 避免重复
  159. if not any(x["personaId"] == tgt_id for x in persona_to_expanded[src_id]):
  160. persona_to_expanded[src_id].append(expanded_node)
  161. # 如果目标是直接匹配节点,源是扩展节点
  162. if tgt_id in direct_persona_ids and src_id in expanded_nodes_map:
  163. if tgt_id not in persona_to_expanded:
  164. persona_to_expanded[tgt_id] = []
  165. exp_node = expanded_nodes_map[src_id]
  166. exp_name = exp_node.get("节点名称", src_id)
  167. exp_level = exp_node.get("节点层级", "")
  168. exp_original_type = exp_node.get("节点类型", "标签")
  169. expanded_node = {
  170. "id": f"expanded_{src_id}",
  171. "name": exp_name,
  172. "nodeType": "人设扩展",
  173. "originalType": exp_original_type,
  174. "personaId": src_id,
  175. "level": exp_level,
  176. "dimColor": dim_colors.get(exp_level, "#2ecc71"),
  177. "edgeType": edge_type,
  178. "children": []
  179. }
  180. if not any(x["personaId"] == src_id for x in persona_to_expanded[tgt_id]):
  181. persona_to_expanded[tgt_id].append(expanded_node)
  182. # 将扩展节点添加到对应的人设节点下
  183. expanded_count = 0
  184. for tag_id, persona_nodes in tag_to_persona_matches.items():
  185. for persona_node in persona_nodes:
  186. persona_id = persona_node["personaId"]
  187. if persona_id in persona_to_expanded:
  188. persona_node["children"] = persona_to_expanded[persona_id]
  189. expanded_count += len(persona_to_expanded[persona_id])
  190. # 将人设匹配节点添加到对应标签下
  191. persona_count = 0
  192. for tag_id, persona_nodes in tag_to_persona_matches.items():
  193. if tag_id in tag_map:
  194. tag_map[tag_id]["children"] = persona_nodes
  195. persona_count += len(persona_nodes)
  196. print(f" 人设匹配节点(1层): {persona_count}, 扩展节点(2层): {expanded_count}")
  197. # 根据属于边,把标签挂到点下面
  198. for e in belong_edges:
  199. tag_node = tag_map.get(e["源节点ID"])
  200. point_node = point_map.get(e["目标节点ID"])
  201. if tag_node and point_node:
  202. point_node["children"].append(tag_node)
  203. # 按维度分组点节点
  204. dimensions = ["灵感点", "目的点", "关键点"]
  205. dimension_children = []
  206. for dim in dimensions:
  207. dim_points = [
  208. point_map[n["节点ID"]]
  209. for n in post_points
  210. if n.get("节点层级") == dim and n["节点ID"] in point_map
  211. ]
  212. if dim_points:
  213. dim_node = {
  214. "id": f"dim_{dim}",
  215. "name": dim,
  216. "nodeType": "维度",
  217. "isDimension": True,
  218. "dimColor": dim_colors[dim],
  219. "children": dim_points
  220. }
  221. dimension_children.append(dim_node)
  222. # 根节点(帖子)
  223. root_node = {
  224. "id": f"post_{post_id}",
  225. "name": post_title[:20] + "..." if len(post_title) > 20 else post_title,
  226. "nodeType": "帖子",
  227. "isRoot": True,
  228. "postDetail": post_detail,
  229. "children": dimension_children
  230. }
  231. # 统计节点数
  232. total_nodes = 1 + len(dimension_children) # 根节点 + 维度节点
  233. for dim_node in dimension_children:
  234. total_nodes += len(dim_node["children"]) # 点节点
  235. for point_node in dim_node["children"]:
  236. total_nodes += len(point_node["children"]) # 标签节点
  237. for tag_node in point_node["children"]:
  238. total_nodes += len(tag_node["children"]) # 人设节点(1层)
  239. for persona_node in tag_node["children"]:
  240. total_nodes += len(persona_node["children"]) # 扩展节点(2层)
  241. post_tree = {
  242. "postId": post_id,
  243. "postTitle": post_title,
  244. "postDetail": post_detail,
  245. "root": root_node,
  246. "stats": {
  247. "totalNodes": total_nodes,
  248. "pointCount": len(post_points),
  249. "tagCount": len(post_tags),
  250. "personaCount": persona_count
  251. }
  252. }
  253. all_post_trees.append(post_tree)
  254. print(f" 构建完成: {total_nodes} 个节点(人设1层: {persona_count}, 扩展2层: {expanded_count})")
  255. # 输出
  256. output_data = {
  257. "说明": {
  258. "描述": "帖子树结构数据(每个帖子一棵树)",
  259. "帖子数": len(all_post_trees)
  260. },
  261. "postTrees": all_post_trees
  262. }
  263. with open(output_file, "w", encoding="utf-8") as f:
  264. json.dump(output_data, f, ensure_ascii=False, indent=2)
  265. print()
  266. print("=" * 60)
  267. print(f"构建完成!")
  268. print(f" 帖子数: {len(all_post_trees)}")
  269. print(f" 输出文件: {output_file}")
  270. return output_file
  271. if __name__ == "__main__":
  272. build_post_trees()