prepare_origin_analysis.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 创作起点分析 - 数据准备脚本
  5. 第一步:根据帖子图谱 + 人设图谱,把信息压缩到待分析节点中
  6. 输入:帖子图谱 + 人设图谱
  7. 输出:待分析数据结构
  8. """
  9. import json
  10. from pathlib import Path
  11. from typing import Dict, List, Optional
  12. import sys
  13. # 添加项目根目录到路径
  14. project_root = Path(__file__).parent.parent.parent
  15. sys.path.insert(0, str(project_root))
  16. from script.data_processing.path_config import PathConfig
  17. # ===== 数据加载函数 =====
  18. def load_json(file_path: Path) -> Dict:
  19. """加载JSON文件"""
  20. with open(file_path, "r", encoding="utf-8") as f:
  21. return json.load(f)
  22. def get_post_graph_files(config: PathConfig) -> List[Path]:
  23. """获取所有帖子图谱文件"""
  24. post_graph_dir = config.intermediate_dir / "post_graph"
  25. return sorted(post_graph_dir.glob("*_帖子图谱.json"))
  26. # ===== 数据提取函数 =====
  27. def extract_post_detail(post_graph: Dict) -> Dict:
  28. """
  29. 提取帖子详情(保留原始字段名)
  30. """
  31. meta = post_graph.get("meta", {})
  32. post_detail = meta.get("postDetail", {})
  33. return {
  34. "postId": meta.get("postId", ""),
  35. "postTitle": meta.get("postTitle", ""),
  36. "body_text": post_detail.get("body_text", ""),
  37. "images": post_detail.get("images", []),
  38. "video": post_detail.get("video"),
  39. "publish_time": post_detail.get("publish_time", ""),
  40. "like_count": post_detail.get("like_count", 0),
  41. "collect_count": post_detail.get("collect_count", 0),
  42. }
  43. def extract_analysis_nodes(post_graph: Dict, persona_graph: Dict) -> List[Dict]:
  44. """
  45. 提取待分析节点列表
  46. 待分析节点 = 灵感点 + 目的点(不包括关键点,关键点是支撑信息)
  47. """
  48. nodes = post_graph.get("nodes", {})
  49. edges = post_graph.get("edges", {})
  50. persona_nodes = persona_graph.get("nodes", {})
  51. persona_index = persona_graph.get("index", {})
  52. # 1. 收集关键点信息(用于支撑信息)
  53. keypoints = {}
  54. for node_id, node in nodes.items():
  55. if node.get("type") == "标签" and node.get("dimension") == "关键点":
  56. keypoints[node_id] = {
  57. "名称": node.get("name", ""),
  58. "描述": node.get("detail", {}).get("description", ""),
  59. }
  60. # 2. 分析支撑关系:关键点 → 灵感点/目的点
  61. support_map = {} # {target_node_id: [支撑的关键点信息]}
  62. for edge_id, edge in edges.items():
  63. if edge.get("type") == "支撑":
  64. source_id = edge.get("source", "")
  65. target_id = edge.get("target", "")
  66. if source_id in keypoints:
  67. if target_id not in support_map:
  68. support_map[target_id] = []
  69. support_map[target_id].append(keypoints[source_id])
  70. # 3. 分析关联关系
  71. relation_map = {} # {node_id: [关联的节点名称]}
  72. for edge_id, edge in edges.items():
  73. if edge.get("type") == "关联":
  74. source_id = edge.get("source", "")
  75. target_id = edge.get("target", "")
  76. source_name = nodes.get(source_id, {}).get("name", "")
  77. target_name = nodes.get(target_id, {}).get("name", "")
  78. # 双向记录
  79. if source_id not in relation_map:
  80. relation_map[source_id] = []
  81. relation_map[source_id].append(target_name)
  82. if target_id not in relation_map:
  83. relation_map[target_id] = []
  84. relation_map[target_id].append(source_name)
  85. # 4. 分析人设匹配
  86. match_map = {} # {node_id: 匹配信息}
  87. persona_out_edges = persona_index.get("outEdges", {})
  88. def get_node_info(node_id: str) -> Optional[Dict]:
  89. """获取人设节点的标准信息"""
  90. node = persona_nodes.get(node_id, {})
  91. if not node:
  92. return None
  93. detail = node.get("detail", {})
  94. parent_path = detail.get("parentPath", [])
  95. return {
  96. "节点ID": node_id,
  97. "节点名称": node.get("name", ""),
  98. "节点分类": "/".join(parent_path) if parent_path else "",
  99. "节点维度": node.get("dimension", ""),
  100. "节点类型": node.get("type", ""),
  101. "人设全局占比": detail.get("probGlobal", 0),
  102. "父类下占比": detail.get("probToParent", 0),
  103. }
  104. def get_parent_category_id(node_id: str) -> Optional[str]:
  105. """通过属于边获取父分类节点ID"""
  106. belong_edges = persona_out_edges.get(node_id, {}).get("属于", [])
  107. for edge in belong_edges:
  108. target_id = edge.get("target", "")
  109. target_node = persona_nodes.get(target_id, {})
  110. if target_node.get("type") == "分类":
  111. return target_id
  112. return None
  113. for edge_id, edge in edges.items():
  114. if edge.get("type") == "匹配":
  115. source_id = edge.get("source", "")
  116. target_id = edge.get("target", "")
  117. # 只处理 帖子节点 → 人设节点 的匹配
  118. if source_id.startswith("帖子:") and target_id.startswith("人设:"):
  119. match_score = edge.get("score", 0)
  120. persona_node = persona_nodes.get(target_id, {})
  121. if persona_node:
  122. node_type = persona_node.get("type", "")
  123. # 获取匹配节点信息
  124. match_node_info = get_node_info(target_id)
  125. if not match_node_info:
  126. continue
  127. # 确定所属分类节点
  128. if node_type == "标签":
  129. # 标签:找父分类
  130. category_id = get_parent_category_id(target_id)
  131. else:
  132. # 分类:就是自己
  133. category_id = target_id
  134. # 获取所属分类信息和常见搭配
  135. category_info = None
  136. if category_id:
  137. category_node = persona_nodes.get(category_id, {})
  138. if category_node:
  139. category_detail = category_node.get("detail", {})
  140. category_path = category_detail.get("parentPath", [])
  141. category_info = {
  142. "节点ID": category_id,
  143. "节点名称": category_node.get("name", ""),
  144. "节点分类": "/".join(category_path) if category_path else "",
  145. "节点维度": category_node.get("dimension", ""),
  146. "节点类型": "分类",
  147. "人设全局占比": category_detail.get("probGlobal", 0),
  148. "父类下占比": category_detail.get("probToParent", 0),
  149. "历史共现分类": [],
  150. }
  151. # 获取分类共现节点(按共现度降序排列)
  152. co_occur_edges = persona_out_edges.get(category_id, {}).get("分类共现", [])
  153. co_occur_edges_sorted = sorted(co_occur_edges, key=lambda x: x.get("score", 0), reverse=True)
  154. for co_edge in co_occur_edges_sorted[:5]: # 取前5个
  155. co_target_id = co_edge.get("target", "")
  156. co_score = co_edge.get("score", 0)
  157. co_node = persona_nodes.get(co_target_id, {})
  158. if co_node:
  159. co_detail = co_node.get("detail", {})
  160. co_path = co_detail.get("parentPath", [])
  161. category_info["历史共现分类"].append({
  162. "节点ID": co_target_id,
  163. "节点名称": co_node.get("name", ""),
  164. "节点分类": "/".join(co_path) if co_path else "",
  165. "节点维度": co_node.get("dimension", ""),
  166. "节点类型": "分类",
  167. "人设全局占比": co_detail.get("probGlobal", 0),
  168. "父类下占比": co_detail.get("probToParent", 0),
  169. "共现度": round(co_score, 4),
  170. })
  171. match_map[source_id] = {
  172. "匹配节点": match_node_info,
  173. "匹配分数": round(match_score, 4),
  174. "所属分类": category_info,
  175. }
  176. # 5. 构建待分析节点列表(灵感点、目的点、关键点)
  177. analysis_nodes = []
  178. for node_id, node in nodes.items():
  179. if node.get("type") == "标签" and node.get("domain") == "帖子":
  180. dimension = node.get("dimension", "")
  181. if dimension in ["灵感点", "目的点", "关键点"]:
  182. # 人设匹配信息
  183. match_info = match_map.get(node_id)
  184. analysis_nodes.append({
  185. "节点ID": node_id,
  186. "节点名称": node.get("name", ""),
  187. "节点分类": node.get("category", ""), # 根分类:意图/实质/形式
  188. "节点维度": dimension,
  189. "节点类型": node.get("type", ""),
  190. "节点描述": node.get("detail", {}).get("description", ""),
  191. "人设匹配": match_info,
  192. })
  193. # 6. 构建可能的关系列表
  194. relation_list = []
  195. # 支撑关系:关键点 → 灵感点/目的点
  196. for edge_id, edge in edges.items():
  197. if edge.get("type") == "支撑":
  198. source_id = edge.get("source", "")
  199. target_id = edge.get("target", "")
  200. if source_id in keypoints:
  201. relation_list.append({
  202. "来源节点": source_id,
  203. "目标节点": target_id,
  204. "关系类型": "支撑",
  205. })
  206. # 关联关系:节点之间的关联(去重,只记录一次)
  207. seen_relations = set()
  208. for edge_id, edge in edges.items():
  209. if edge.get("type") == "关联":
  210. source_id = edge.get("source", "")
  211. target_id = edge.get("target", "")
  212. # 用排序后的元组作为key去重
  213. key = tuple(sorted([source_id, target_id]))
  214. if key not in seen_relations:
  215. seen_relations.add(key)
  216. relation_list.append({
  217. "来源节点": source_id,
  218. "目标节点": target_id,
  219. "关系类型": "关联",
  220. })
  221. return analysis_nodes, relation_list
  222. def prepare_analysis_data(post_graph: Dict, persona_graph: Dict) -> Dict:
  223. """
  224. 准备完整的分析数据
  225. Returns:
  226. {
  227. "帖子详情": {...},
  228. "待分析节点列表": [...],
  229. "可能的关系列表": [...]
  230. }
  231. """
  232. analysis_nodes, relation_list = extract_analysis_nodes(post_graph, persona_graph)
  233. return {
  234. "帖子详情": extract_post_detail(post_graph),
  235. "待分析节点列表": analysis_nodes,
  236. "可能的关系列表": relation_list,
  237. }
  238. # ===== 显示函数 =====
  239. def display_prepared_data(data: Dict):
  240. """显示准备好的数据"""
  241. post = data["帖子详情"]
  242. nodes = data["待分析节点列表"]
  243. relations = data["可能的关系列表"]
  244. print(f"\n帖子: {post['postId']}")
  245. print(f"标题: {post['postTitle']}")
  246. print(f"正文: {post['body_text'][:100]}...")
  247. print(f"\n待分析节点 ({len(nodes)} 个):")
  248. for node in nodes:
  249. match = node.get("人设匹配")
  250. category = node.get('节点分类', '')
  251. print(f" - [{node['节点ID']}] {node['节点名称']} ({node['节点维度']}/{category})")
  252. if match:
  253. match_node = match.get("匹配节点", {})
  254. category_node = match.get("所属分类", {})
  255. print(f" 匹配: {match_node.get('节点名称', '')} ({match_node.get('节点类型', '')}, 全局占比={match_node.get('人设全局占比', 0):.2%})")
  256. if category_node:
  257. co_count = len(category_node.get("历史共现分类", []))
  258. print(f" 所属分类: {category_node.get('节点名称', '')} (全局占比={category_node.get('人设全局占比', 0):.2%}, {co_count}个历史共现分类)")
  259. else:
  260. print(f" 人设: 无匹配")
  261. print(f"\n可能的关系 ({len(relations)} 条):")
  262. for rel in relations:
  263. rel_type = rel["关系类型"]
  264. if rel_type == "支撑":
  265. print(f" - {rel['来源节点']} → {rel['目标节点']} [支撑]")
  266. else:
  267. print(f" - {rel['来源节点']} ↔ {rel['目标节点']} [关联]")
  268. # ===== 处理函数 =====
  269. def process_single_post(
  270. post_file: Path,
  271. persona_graph: Dict,
  272. config: PathConfig,
  273. save: bool = True,
  274. ) -> Dict:
  275. """
  276. 处理单个帖子
  277. Args:
  278. post_file: 帖子图谱文件路径
  279. persona_graph: 人设图谱数据
  280. config: 路径配置
  281. save: 是否保存结果
  282. Returns:
  283. 准备好的分析数据
  284. """
  285. # 加载帖子图谱
  286. post_graph = load_json(post_file)
  287. post_id = post_graph.get("meta", {}).get("postId", "unknown")
  288. print(f"\n{'=' * 60}")
  289. print(f"处理帖子: {post_id}")
  290. print("-" * 60)
  291. # 准备数据
  292. data = prepare_analysis_data(post_graph, persona_graph)
  293. # 显示
  294. display_prepared_data(data)
  295. # 保存
  296. if save:
  297. output_dir = config.intermediate_dir / "origin_analysis_prepared"
  298. output_dir.mkdir(parents=True, exist_ok=True)
  299. output_file = output_dir / f"{post_id}_待分析数据.json"
  300. with open(output_file, "w", encoding="utf-8") as f:
  301. json.dump(data, f, ensure_ascii=False, indent=2)
  302. print(f"\n已保存: {output_file.name}")
  303. return data
  304. # ===== 主函数 =====
  305. def main(
  306. post_id: str = None,
  307. all_posts: bool = False,
  308. save: bool = True,
  309. ):
  310. """
  311. 主函数
  312. Args:
  313. post_id: 帖子ID,可选
  314. all_posts: 是否处理所有帖子
  315. save: 是否保存结果
  316. """
  317. config = PathConfig()
  318. print(f"账号: {config.account_name}")
  319. # 加载人设图谱
  320. persona_graph_file = config.intermediate_dir / "人设图谱.json"
  321. if not persona_graph_file.exists():
  322. print(f"错误: 人设图谱文件不存在: {persona_graph_file}")
  323. return
  324. persona_graph = load_json(persona_graph_file)
  325. print(f"人设图谱节点数: {len(persona_graph.get('nodes', {}))}")
  326. # 获取帖子图谱文件
  327. post_graph_files = get_post_graph_files(config)
  328. if not post_graph_files:
  329. print("错误: 没有找到帖子图谱文件")
  330. return
  331. # 确定要处理的帖子
  332. if post_id:
  333. target_file = next(
  334. (f for f in post_graph_files if post_id in f.name),
  335. None
  336. )
  337. if not target_file:
  338. print(f"错误: 未找到帖子 {post_id}")
  339. return
  340. files_to_process = [target_file]
  341. elif all_posts:
  342. files_to_process = post_graph_files
  343. else:
  344. files_to_process = [post_graph_files[0]]
  345. print(f"待处理帖子数: {len(files_to_process)}")
  346. # 处理
  347. results = []
  348. for i, post_file in enumerate(files_to_process, 1):
  349. print(f"\n{'#' * 60}")
  350. print(f"# 处理帖子 {i}/{len(files_to_process)}")
  351. print(f"{'#' * 60}")
  352. data = process_single_post(
  353. post_file=post_file,
  354. persona_graph=persona_graph,
  355. config=config,
  356. save=save,
  357. )
  358. results.append(data)
  359. print(f"\n{'#' * 60}")
  360. print(f"# 完成! 共处理 {len(results)} 个帖子")
  361. print(f"{'#' * 60}")
  362. return results
  363. if __name__ == "__main__":
  364. import argparse
  365. parser = argparse.ArgumentParser(description="创作起点分析 - 数据准备")
  366. parser.add_argument("--post-id", type=str, help="帖子ID")
  367. parser.add_argument("--all-posts", action="store_true", help="处理所有帖子")
  368. parser.add_argument("--no-save", action="store_true", help="不保存结果")
  369. args = parser.parse_args()
  370. main(
  371. post_id=args.post_id,
  372. all_posts=args.all_posts,
  373. save=not args.no_save,
  374. )