#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 创作模式推导 - 第二步:基于共现关系的迭代推导 输入:起点分析结果 + 待分析节点数据 输出:推导结果(已知点集合 + 推导关系) 算法: 1. 初始化:起点分析中 score >= 0.8 的点 → 已知点集合 2. 迭代: - 从新加入的已知点中,筛选人设匹配分数 >= 0.8 的 - 获取它们的所属分类的历史共现分类ID列表 - 遍历未知点(人设匹配 >= 0.8),检查其所属分类ID是否在共现列表中 - 如果在,加入已知点,建立关系 3. 直到没有新点加入 """ import json from dataclasses import dataclass, field, asdict from pathlib import Path from typing import Dict, List, Set, Optional import sys # 添加项目根目录到路径 project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from script.data_processing.path_config import PathConfig # ===== 配置 ===== ORIGIN_SCORE_THRESHOLD = 0.8 # 起点分数阈值 MATCH_SCORE_THRESHOLD = 0.8 # 人设匹配分数阈值 # ===== 数据结构 ===== @dataclass class AnalysisNode: """待分析节点""" 节点ID: str 节点名称: str 节点分类: str 节点维度: str 人设匹配分数: float 所属分类ID: Optional[str] 历史共现分类: Dict[str, float] = field(default_factory=dict) # {分类ID: 共现度} @classmethod def from_raw(cls, raw: Dict) -> "AnalysisNode": """从原始数据构造""" match_info = raw.get("人设匹配") or {} match_score = match_info.get("匹配分数", 0) category_info = match_info.get("所属分类") or {} category_id = category_info.get("节点ID") co_occur_list = category_info.get("历史共现分类", []) co_occur_map = { c.get("节点ID"): c.get("共现度", 0) for c in co_occur_list if c.get("节点ID") } return cls( 节点ID=raw.get("节点ID", ""), 节点名称=raw.get("节点名称", ""), 节点分类=raw.get("节点分类", ""), 节点维度=raw.get("节点维度", ""), 人设匹配分数=match_score, 所属分类ID=category_id, 历史共现分类=co_occur_map, ) @dataclass class DerivedRelation: """推导出的关系""" 来源节点ID: str 来源节点名称: str 目标节点ID: str 目标节点名称: str 关系类型: str # "共现推导" 推导轮次: int 共现分类ID: str # 通过哪个共现分类建立的关系 共现度: float # 共现度分数 @dataclass class DerivationResult: """推导结果""" 帖子ID: str 起点列表: List[Dict] # {节点ID, 节点名称, 起点分数} 已知点列表: List[Dict] # {节点ID, 节点名称, 加入轮次, 加入原因} 推导关系列表: List[Dict] # DerivedRelation 的 dict 形式 推导轮次: int 未知点列表: List[Dict] # 未被推导的点 # ===== 数据加载 ===== def load_json(file_path: Path) -> Dict: """加载JSON文件""" with open(file_path, "r", encoding="utf-8") as f: return json.load(f) def get_origin_result_files(config: PathConfig) -> List[Path]: """获取所有起点分析结果文件""" result_dir = config.intermediate_dir / "origin_analysis_result" return sorted(result_dir.glob("*_起点分析.json")) def get_prepared_file(config: PathConfig, post_id: str) -> Optional[Path]: """获取待分析数据文件""" prepared_dir = config.intermediate_dir / "origin_analysis_prepared" files = list(prepared_dir.glob(f"{post_id}_待分析数据.json")) return files[0] if files else None # ===== 核心算法 ===== def derive_patterns( nodes: List[AnalysisNode], origin_scores: Dict[str, float], # {节点名称: 起点分数} ) -> DerivationResult: """ 基于共现关系的迭代推导 Args: nodes: 所有待分析节点 origin_scores: 起点分析的分数 {节点名称: score} Returns: DerivationResult """ # 构建索引 node_by_name: Dict[str, AnalysisNode] = {n.节点名称: n for n in nodes} node_by_id: Dict[str, AnalysisNode] = {n.节点ID: n for n in nodes} # 1. 初始化已知点集合(起点分数 >= 0.8) known_names: Set[str] = set() known_info: List[Dict] = [] # {节点ID, 节点名称, 加入轮次, 加入原因} origins: List[Dict] = [] for name, score in origin_scores.items(): if score >= ORIGIN_SCORE_THRESHOLD: known_names.add(name) node = node_by_name.get(name) if node: origins.append({ "节点ID": node.节点ID, "节点名称": name, "起点分数": score, }) known_info.append({ "节点ID": node.节点ID, "节点名称": name, "加入轮次": 0, "加入原因": f"起点(score={score:.2f})", }) # 未知点集合 unknown_names: Set[str] = set(node_by_name.keys()) - known_names # 推导关系 relations: List[DerivedRelation] = [] # 2. 迭代推导 round_num = 0 new_known_this_round = known_names.copy() # 第0轮新加入的就是起点 while new_known_this_round: round_num += 1 print(f"\n 第 {round_num} 轮推导...") # 本轮新加入的点 new_known_next_round: Set[str] = set() # 遍历上一轮新加入的已知点 for known_name in new_known_this_round: known_node = node_by_name.get(known_name) if not known_node: continue # 过滤:人设匹配分数 >= 0.8 if known_node.人设匹配分数 < MATCH_SCORE_THRESHOLD: continue # 获取历史共现分类 {ID: 共现度} co_occur_map = known_node.历史共现分类 if not co_occur_map: continue # 遍历未知点 for unknown_name in list(unknown_names): unknown_node = node_by_name.get(unknown_name) if not unknown_node: continue # 过滤:人设匹配分数 >= 0.8 if unknown_node.人设匹配分数 < MATCH_SCORE_THRESHOLD: continue # 检查:未知点的所属分类ID 是否在已知点的共现列表中 if unknown_node.所属分类ID and unknown_node.所属分类ID in co_occur_map: # 找到关联! co_occur_score = co_occur_map[unknown_node.所属分类ID] new_known_next_round.add(unknown_name) # 建立关系 relations.append(DerivedRelation( 来源节点ID=known_node.节点ID, 来源节点名称=known_name, 目标节点ID=unknown_node.节点ID, 目标节点名称=unknown_name, 关系类型="共现推导", 推导轮次=round_num, 共现分类ID=unknown_node.所属分类ID, 共现度=co_occur_score, )) print(f" {known_name} → {unknown_name} (共现度: {co_occur_score:.2f})") # 更新集合 for name in new_known_next_round: node = node_by_name.get(name) if node: known_info.append({ "节点ID": node.节点ID, "节点名称": name, "加入轮次": round_num, "加入原因": "共现推导", }) known_names.update(new_known_next_round) unknown_names -= new_known_next_round new_known_this_round = new_known_next_round if not new_known_next_round: print(f" 无新点加入,推导结束") break # 3. 构建未知点列表 unknown_list = [] for name in unknown_names: node = node_by_name.get(name) if node: unknown_list.append({ "节点ID": node.节点ID, "节点名称": name, "节点维度": node.节点维度, "人设匹配分数": node.人设匹配分数, "未加入原因": "人设匹配分数不足" if node.人设匹配分数 < MATCH_SCORE_THRESHOLD else "无共现关联", }) return DerivationResult( 帖子ID="", # 由调用方设置 起点列表=origins, 已知点列表=known_info, 推导关系列表=[asdict(r) for r in relations], 推导轮次=round_num, 未知点列表=unknown_list, ) # ===== 处理函数 ===== def process_single_post( origin_file: Path, config: PathConfig, ) -> Optional[Dict]: """处理单个帖子""" # 加载起点分析结果 origin_data = load_json(origin_file) post_id = origin_data.get("帖子id", "unknown") print(f"\n{'=' * 60}") print(f"处理帖子: {post_id}") print("-" * 60) # 获取起点分数 origin_output = origin_data.get("输出", {}) if not origin_output: print(" 错误: 起点分析结果为空") return None origin_scores = {name: info.get("score", 0) for name, info in origin_output.items()} # 加载待分析数据(获取完整节点信息) prepared_file = get_prepared_file(config, post_id) if not prepared_file: print(f" 错误: 未找到待分析数据文件") return None prepared_data = load_json(prepared_file) raw_nodes = prepared_data.get("待分析节点列表", []) # 转换为 AnalysisNode nodes = [AnalysisNode.from_raw(raw) for raw in raw_nodes] print(f" 节点数: {len(nodes)}") # 显示起点 origins = [(name, score) for name, score in origin_scores.items() if score >= ORIGIN_SCORE_THRESHOLD] print(f" 起点 (score >= {ORIGIN_SCORE_THRESHOLD}): {len(origins)} 个") for name, score in sorted(origins, key=lambda x: -x[1]): print(f" ★ {name}: {score:.2f}") # 执行推导 result = derive_patterns(nodes, origin_scores) result.帖子ID = post_id # 显示结果 print(f"\n 推导轮次: {result.推导轮次}") print(f" 已知点: {len(result.已知点列表)} 个") print(f" 推导关系: {len(result.推导关系列表)} 条") print(f" 未知点: {len(result.未知点列表)} 个") # 保存结果 output_dir = config.intermediate_dir / "pattern_derivation" output_dir.mkdir(parents=True, exist_ok=True) output_file = output_dir / f"{post_id}_模式推导.json" with open(output_file, "w", encoding="utf-8") as f: json.dump(asdict(result), f, ensure_ascii=False, indent=2) print(f"\n 已保存: {output_file.name}") return asdict(result) # ===== 主函数 ===== def main( post_id: str = None, all_posts: bool = False, ): """ 主函数 Args: post_id: 帖子ID,可选 all_posts: 是否处理所有帖子 """ config = PathConfig() print(f"账号: {config.account_name}") print(f"起点分数阈值: {ORIGIN_SCORE_THRESHOLD}") print(f"匹配分数阈值: {MATCH_SCORE_THRESHOLD}") # 获取起点分析结果文件 origin_files = get_origin_result_files(config) if not origin_files: print("错误: 没有找到起点分析结果,请先运行 analyze_creation_origin.py") return # 确定要处理的帖子 if post_id: target_file = next( (f for f in origin_files if post_id in f.name), None ) if not target_file: print(f"错误: 未找到帖子 {post_id} 的起点分析结果") return files_to_process = [target_file] elif all_posts: files_to_process = origin_files else: files_to_process = [origin_files[0]] print(f"待处理帖子数: {len(files_to_process)}") # 处理 results = [] for i, origin_file in enumerate(files_to_process, 1): print(f"\n{'#' * 60}") print(f"# 处理帖子 {i}/{len(files_to_process)}") print(f"{'#' * 60}") result = process_single_post(origin_file, config) if result: results.append(result) # 汇总 print(f"\n{'#' * 60}") print(f"# 完成! 共处理 {len(results)} 个帖子") print(f"{'#' * 60}") print("\n汇总:") for result in results: post_id = result.get("帖子ID") known_count = len(result.get("已知点列表", [])) relation_count = len(result.get("推导关系列表", [])) unknown_count = len(result.get("未知点列表", [])) print(f" {post_id}: 已知={known_count}, 关系={relation_count}, 未知={unknown_count}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="创作模式推导") parser.add_argument("--post-id", type=str, help="帖子ID") parser.add_argument("--all-posts", action="store_true", help="处理所有帖子") args = parser.parse_args() main( post_id=args.post_id, all_posts=args.all_posts, )