derive_pattern_relations.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 创作模式推导 - 第二步:基于共现关系的迭代推导
  5. 输入:起点分析结果 + 待分析节点数据
  6. 输出:推导结果(已知点集合 + 推导关系)
  7. 算法:
  8. 1. 初始化:起点分析中 score >= 0.8 的点 → 已知点集合
  9. 2. 迭代:
  10. - 从新加入的已知点中,筛选人设匹配分数 >= 0.8 的
  11. - 获取它们的所属分类的历史共现分类ID列表
  12. - 遍历未知点(人设匹配 >= 0.8),检查其所属分类ID是否在共现列表中
  13. - 如果在,加入已知点,建立关系
  14. 3. 直到没有新点加入
  15. """
  16. import json
  17. from dataclasses import dataclass, field, asdict
  18. from pathlib import Path
  19. from typing import Dict, List, Set, Optional
  20. import sys
  21. # 添加项目根目录到路径
  22. project_root = Path(__file__).parent.parent.parent
  23. sys.path.insert(0, str(project_root))
  24. from script.data_processing.path_config import PathConfig
  25. # ===== 配置 =====
  26. ORIGIN_SCORE_THRESHOLD = 0.8 # 起点分数阈值
  27. MATCH_SCORE_THRESHOLD = 0.8 # 人设匹配分数阈值
  28. # ===== 数据结构 =====
  29. @dataclass
  30. class AnalysisNode:
  31. """待分析节点"""
  32. 节点ID: str
  33. 节点名称: str
  34. 节点分类: str
  35. 节点维度: str
  36. 人设匹配分数: float
  37. 所属分类ID: Optional[str]
  38. 历史共现分类: Dict[str, float] = field(default_factory=dict) # {分类ID: 共现度}
  39. @classmethod
  40. def from_raw(cls, raw: Dict) -> "AnalysisNode":
  41. """从原始数据构造"""
  42. match_info = raw.get("人设匹配") or {}
  43. match_score = match_info.get("匹配分数", 0)
  44. category_info = match_info.get("所属分类") or {}
  45. category_id = category_info.get("节点ID")
  46. co_occur_list = category_info.get("历史共现分类", [])
  47. co_occur_map = {
  48. c.get("节点ID"): c.get("共现度", 0)
  49. for c in co_occur_list
  50. if c.get("节点ID")
  51. }
  52. return cls(
  53. 节点ID=raw.get("节点ID", ""),
  54. 节点名称=raw.get("节点名称", ""),
  55. 节点分类=raw.get("节点分类", ""),
  56. 节点维度=raw.get("节点维度", ""),
  57. 人设匹配分数=match_score,
  58. 所属分类ID=category_id,
  59. 历史共现分类=co_occur_map,
  60. )
  61. @dataclass
  62. class DerivedRelation:
  63. """推导出的关系"""
  64. 来源节点ID: str
  65. 来源节点名称: str
  66. 目标节点ID: str
  67. 目标节点名称: str
  68. 关系类型: str # "共现推导"
  69. 推导轮次: int
  70. 共现分类ID: str # 通过哪个共现分类建立的关系
  71. 共现度: float # 共现度分数
  72. @dataclass
  73. class DerivationResult:
  74. """推导结果"""
  75. 帖子ID: str
  76. 起点列表: List[Dict] # {节点ID, 节点名称, 起点分数}
  77. 已知点列表: List[Dict] # {节点ID, 节点名称, 加入轮次, 加入原因}
  78. 推导关系列表: List[Dict] # DerivedRelation 的 dict 形式
  79. 推导轮次: int
  80. 未知点列表: List[Dict] # 未被推导的点
  81. # ===== 数据加载 =====
  82. def load_json(file_path: Path) -> Dict:
  83. """加载JSON文件"""
  84. with open(file_path, "r", encoding="utf-8") as f:
  85. return json.load(f)
  86. def get_origin_result_files(config: PathConfig) -> List[Path]:
  87. """获取所有起点分析结果文件"""
  88. result_dir = config.intermediate_dir / "origin_analysis_result"
  89. return sorted(result_dir.glob("*_起点分析.json"))
  90. def get_prepared_file(config: PathConfig, post_id: str) -> Optional[Path]:
  91. """获取待分析数据文件"""
  92. prepared_dir = config.intermediate_dir / "origin_analysis_prepared"
  93. files = list(prepared_dir.glob(f"{post_id}_待分析数据.json"))
  94. return files[0] if files else None
  95. # ===== 核心算法 =====
  96. def derive_patterns(
  97. nodes: List[AnalysisNode],
  98. origin_scores: Dict[str, float], # {节点名称: 起点分数}
  99. ) -> DerivationResult:
  100. """
  101. 基于共现关系的迭代推导
  102. Args:
  103. nodes: 所有待分析节点
  104. origin_scores: 起点分析的分数 {节点名称: score}
  105. Returns:
  106. DerivationResult
  107. """
  108. # 构建索引
  109. node_by_name: Dict[str, AnalysisNode] = {n.节点名称: n for n in nodes}
  110. node_by_id: Dict[str, AnalysisNode] = {n.节点ID: n for n in nodes}
  111. # 1. 初始化已知点集合(起点分数 >= 0.8)
  112. known_names: Set[str] = set()
  113. known_info: List[Dict] = [] # {节点ID, 节点名称, 加入轮次, 加入原因}
  114. origins: List[Dict] = []
  115. for name, score in origin_scores.items():
  116. if score >= ORIGIN_SCORE_THRESHOLD:
  117. known_names.add(name)
  118. node = node_by_name.get(name)
  119. if node:
  120. origins.append({
  121. "节点ID": node.节点ID,
  122. "节点名称": name,
  123. "起点分数": score,
  124. })
  125. known_info.append({
  126. "节点ID": node.节点ID,
  127. "节点名称": name,
  128. "加入轮次": 0,
  129. "加入原因": f"起点(score={score:.2f})",
  130. })
  131. # 未知点集合
  132. unknown_names: Set[str] = set(node_by_name.keys()) - known_names
  133. # 推导关系
  134. relations: List[DerivedRelation] = []
  135. # 2. 迭代推导
  136. round_num = 0
  137. new_known_this_round = known_names.copy() # 第0轮新加入的就是起点
  138. while new_known_this_round:
  139. round_num += 1
  140. print(f"\n 第 {round_num} 轮推导...")
  141. # 本轮新加入的点
  142. new_known_next_round: Set[str] = set()
  143. # 遍历上一轮新加入的已知点
  144. for known_name in new_known_this_round:
  145. known_node = node_by_name.get(known_name)
  146. if not known_node:
  147. continue
  148. # 过滤:人设匹配分数 >= 0.8
  149. if known_node.人设匹配分数 < MATCH_SCORE_THRESHOLD:
  150. continue
  151. # 获取历史共现分类 {ID: 共现度}
  152. co_occur_map = known_node.历史共现分类
  153. if not co_occur_map:
  154. continue
  155. # 遍历未知点
  156. for unknown_name in list(unknown_names):
  157. unknown_node = node_by_name.get(unknown_name)
  158. if not unknown_node:
  159. continue
  160. # 过滤:人设匹配分数 >= 0.8
  161. if unknown_node.人设匹配分数 < MATCH_SCORE_THRESHOLD:
  162. continue
  163. # 检查:未知点的所属分类ID 是否在已知点的共现列表中
  164. if unknown_node.所属分类ID and unknown_node.所属分类ID in co_occur_map:
  165. # 找到关联!
  166. co_occur_score = co_occur_map[unknown_node.所属分类ID]
  167. new_known_next_round.add(unknown_name)
  168. # 建立关系
  169. relations.append(DerivedRelation(
  170. 来源节点ID=known_node.节点ID,
  171. 来源节点名称=known_name,
  172. 目标节点ID=unknown_node.节点ID,
  173. 目标节点名称=unknown_name,
  174. 关系类型="共现推导",
  175. 推导轮次=round_num,
  176. 共现分类ID=unknown_node.所属分类ID,
  177. 共现度=co_occur_score,
  178. ))
  179. print(f" {known_name} → {unknown_name} (共现度: {co_occur_score:.2f})")
  180. # 更新集合
  181. for name in new_known_next_round:
  182. node = node_by_name.get(name)
  183. if node:
  184. known_info.append({
  185. "节点ID": node.节点ID,
  186. "节点名称": name,
  187. "加入轮次": round_num,
  188. "加入原因": "共现推导",
  189. })
  190. known_names.update(new_known_next_round)
  191. unknown_names -= new_known_next_round
  192. new_known_this_round = new_known_next_round
  193. if not new_known_next_round:
  194. print(f" 无新点加入,推导结束")
  195. break
  196. # 3. 构建未知点列表
  197. unknown_list = []
  198. for name in unknown_names:
  199. node = node_by_name.get(name)
  200. if node:
  201. unknown_list.append({
  202. "节点ID": node.节点ID,
  203. "节点名称": name,
  204. "节点维度": node.节点维度,
  205. "人设匹配分数": node.人设匹配分数,
  206. "未加入原因": "人设匹配分数不足" if node.人设匹配分数 < MATCH_SCORE_THRESHOLD else "无共现关联",
  207. })
  208. return DerivationResult(
  209. 帖子ID="", # 由调用方设置
  210. 起点列表=origins,
  211. 已知点列表=known_info,
  212. 推导关系列表=[asdict(r) for r in relations],
  213. 推导轮次=round_num,
  214. 未知点列表=unknown_list,
  215. )
  216. # ===== 处理函数 =====
  217. def process_single_post(
  218. origin_file: Path,
  219. config: PathConfig,
  220. ) -> Optional[Dict]:
  221. """处理单个帖子"""
  222. # 加载起点分析结果
  223. origin_data = load_json(origin_file)
  224. post_id = origin_data.get("帖子id", "unknown")
  225. print(f"\n{'=' * 60}")
  226. print(f"处理帖子: {post_id}")
  227. print("-" * 60)
  228. # 获取起点分数
  229. origin_output = origin_data.get("输出", {})
  230. if not origin_output:
  231. print(" 错误: 起点分析结果为空")
  232. return None
  233. origin_scores = {name: info.get("score", 0) for name, info in origin_output.items()}
  234. # 加载待分析数据(获取完整节点信息)
  235. prepared_file = get_prepared_file(config, post_id)
  236. if not prepared_file:
  237. print(f" 错误: 未找到待分析数据文件")
  238. return None
  239. prepared_data = load_json(prepared_file)
  240. raw_nodes = prepared_data.get("待分析节点列表", [])
  241. # 转换为 AnalysisNode
  242. nodes = [AnalysisNode.from_raw(raw) for raw in raw_nodes]
  243. print(f" 节点数: {len(nodes)}")
  244. # 显示起点
  245. origins = [(name, score) for name, score in origin_scores.items() if score >= ORIGIN_SCORE_THRESHOLD]
  246. print(f" 起点 (score >= {ORIGIN_SCORE_THRESHOLD}): {len(origins)} 个")
  247. for name, score in sorted(origins, key=lambda x: -x[1]):
  248. print(f" ★ {name}: {score:.2f}")
  249. # 执行推导
  250. result = derive_patterns(nodes, origin_scores)
  251. result.帖子ID = post_id
  252. # 显示结果
  253. print(f"\n 推导轮次: {result.推导轮次}")
  254. print(f" 已知点: {len(result.已知点列表)} 个")
  255. print(f" 推导关系: {len(result.推导关系列表)} 条")
  256. print(f" 未知点: {len(result.未知点列表)} 个")
  257. # 保存结果
  258. output_dir = config.intermediate_dir / "pattern_derivation"
  259. output_dir.mkdir(parents=True, exist_ok=True)
  260. output_file = output_dir / f"{post_id}_模式推导.json"
  261. with open(output_file, "w", encoding="utf-8") as f:
  262. json.dump(asdict(result), f, ensure_ascii=False, indent=2)
  263. print(f"\n 已保存: {output_file.name}")
  264. return asdict(result)
  265. # ===== 主函数 =====
  266. def main(
  267. post_id: str = None,
  268. all_posts: bool = False,
  269. ):
  270. """
  271. 主函数
  272. Args:
  273. post_id: 帖子ID,可选
  274. all_posts: 是否处理所有帖子
  275. """
  276. config = PathConfig()
  277. print(f"账号: {config.account_name}")
  278. print(f"起点分数阈值: {ORIGIN_SCORE_THRESHOLD}")
  279. print(f"匹配分数阈值: {MATCH_SCORE_THRESHOLD}")
  280. # 获取起点分析结果文件
  281. origin_files = get_origin_result_files(config)
  282. if not origin_files:
  283. print("错误: 没有找到起点分析结果,请先运行 analyze_creation_origin.py")
  284. return
  285. # 确定要处理的帖子
  286. if post_id:
  287. target_file = next(
  288. (f for f in origin_files if post_id in f.name),
  289. None
  290. )
  291. if not target_file:
  292. print(f"错误: 未找到帖子 {post_id} 的起点分析结果")
  293. return
  294. files_to_process = [target_file]
  295. elif all_posts:
  296. files_to_process = origin_files
  297. else:
  298. files_to_process = [origin_files[0]]
  299. print(f"待处理帖子数: {len(files_to_process)}")
  300. # 处理
  301. results = []
  302. for i, origin_file in enumerate(files_to_process, 1):
  303. print(f"\n{'#' * 60}")
  304. print(f"# 处理帖子 {i}/{len(files_to_process)}")
  305. print(f"{'#' * 60}")
  306. result = process_single_post(origin_file, config)
  307. if result:
  308. results.append(result)
  309. # 汇总
  310. print(f"\n{'#' * 60}")
  311. print(f"# 完成! 共处理 {len(results)} 个帖子")
  312. print(f"{'#' * 60}")
  313. print("\n汇总:")
  314. for result in results:
  315. post_id = result.get("帖子ID")
  316. known_count = len(result.get("已知点列表", []))
  317. relation_count = len(result.get("推导关系列表", []))
  318. unknown_count = len(result.get("未知点列表", []))
  319. print(f" {post_id}: 已知={known_count}, 关系={relation_count}, 未知={unknown_count}")
  320. if __name__ == "__main__":
  321. import argparse
  322. parser = argparse.ArgumentParser(description="创作模式推导")
  323. parser.add_argument("--post-id", type=str, help="帖子ID")
  324. parser.add_argument("--all-posts", action="store_true", help="处理所有帖子")
  325. args = parser.parse_args()
  326. main(
  327. post_id=args.post_id,
  328. all_posts=args.all_posts,
  329. )