| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 创作模式推导 - 第二步:基于共现关系的迭代推导
- 输入:起点分析结果 + 待分析节点数据
- 输出:推导结果(已知点集合 + 推导关系)
- 算法:
- 1. 初始化:起点分析中 score >= 0.8 的点 → 已知点集合
- 2. 迭代:
- - 从新加入的已知点中,筛选人设匹配分数 >= 0.8 的
- - 获取它们的所属分类的历史共现分类ID列表
- - 遍历未知点(人设匹配 >= 0.8),检查其所属分类ID是否在共现列表中
- - 如果在,加入已知点,建立关系
- 3. 直到没有新点加入
- """
- import json
- from dataclasses import dataclass, field, asdict
- from pathlib import Path
- from typing import Dict, List, Set, Optional
- import sys
- # 添加项目根目录到路径
- project_root = Path(__file__).parent.parent.parent
- sys.path.insert(0, str(project_root))
- from script.data_processing.path_config import PathConfig
- # ===== 配置 =====
- ORIGIN_SCORE_THRESHOLD = 0.8 # 起点分数阈值
- MATCH_SCORE_THRESHOLD = 0.8 # 人设匹配分数阈值
- # ===== 数据结构 =====
- @dataclass
- class AnalysisNode:
- """待分析节点"""
- 节点ID: str
- 节点名称: str
- 节点分类: str
- 节点维度: str
- 人设匹配分数: float
- 所属分类ID: Optional[str]
- 历史共现分类: Dict[str, float] = field(default_factory=dict) # {分类ID: 共现度}
- @classmethod
- def from_raw(cls, raw: Dict) -> "AnalysisNode":
- """从原始数据构造"""
- match_info = raw.get("人设匹配") or {}
- match_score = match_info.get("匹配分数", 0)
- category_info = match_info.get("所属分类") or {}
- category_id = category_info.get("节点ID")
- co_occur_list = category_info.get("历史共现分类", [])
- co_occur_map = {
- c.get("节点ID"): c.get("共现度", 0)
- for c in co_occur_list
- if c.get("节点ID")
- }
- return cls(
- 节点ID=raw.get("节点ID", ""),
- 节点名称=raw.get("节点名称", ""),
- 节点分类=raw.get("节点分类", ""),
- 节点维度=raw.get("节点维度", ""),
- 人设匹配分数=match_score,
- 所属分类ID=category_id,
- 历史共现分类=co_occur_map,
- )
- @dataclass
- class DerivedRelation:
- """推导出的关系"""
- 来源节点ID: str
- 来源节点名称: str
- 目标节点ID: str
- 目标节点名称: str
- 关系类型: str # "共现推导"
- 推导轮次: int
- 共现分类ID: str # 通过哪个共现分类建立的关系
- 共现度: float # 共现度分数
- @dataclass
- class DerivationResult:
- """推导结果"""
- 帖子ID: str
- 起点列表: List[Dict] # {节点ID, 节点名称, 起点分数}
- 已知点列表: List[Dict] # {节点ID, 节点名称, 加入轮次, 加入原因}
- 推导关系列表: List[Dict] # DerivedRelation 的 dict 形式
- 推导轮次: int
- 未知点列表: List[Dict] # 未被推导的点
- # ===== 数据加载 =====
- def load_json(file_path: Path) -> Dict:
- """加载JSON文件"""
- with open(file_path, "r", encoding="utf-8") as f:
- return json.load(f)
- def get_origin_result_files(config: PathConfig) -> List[Path]:
- """获取所有起点分析结果文件"""
- result_dir = config.intermediate_dir / "origin_analysis_result"
- return sorted(result_dir.glob("*_起点分析.json"))
- def get_prepared_file(config: PathConfig, post_id: str) -> Optional[Path]:
- """获取待分析数据文件"""
- prepared_dir = config.intermediate_dir / "origin_analysis_prepared"
- files = list(prepared_dir.glob(f"{post_id}_待分析数据.json"))
- return files[0] if files else None
- # ===== 核心算法 =====
- def derive_patterns(
- nodes: List[AnalysisNode],
- origin_scores: Dict[str, float], # {节点名称: 起点分数}
- ) -> DerivationResult:
- """
- 基于共现关系的迭代推导
- Args:
- nodes: 所有待分析节点
- origin_scores: 起点分析的分数 {节点名称: score}
- Returns:
- DerivationResult
- """
- # 构建索引
- node_by_name: Dict[str, AnalysisNode] = {n.节点名称: n for n in nodes}
- node_by_id: Dict[str, AnalysisNode] = {n.节点ID: n for n in nodes}
- # 1. 初始化已知点集合(起点分数 >= 0.8)
- known_names: Set[str] = set()
- known_info: List[Dict] = [] # {节点ID, 节点名称, 加入轮次, 加入原因}
- origins: List[Dict] = []
- for name, score in origin_scores.items():
- if score >= ORIGIN_SCORE_THRESHOLD:
- known_names.add(name)
- node = node_by_name.get(name)
- if node:
- origins.append({
- "节点ID": node.节点ID,
- "节点名称": name,
- "起点分数": score,
- })
- known_info.append({
- "节点ID": node.节点ID,
- "节点名称": name,
- "加入轮次": 0,
- "加入原因": f"起点(score={score:.2f})",
- })
- # 未知点集合
- unknown_names: Set[str] = set(node_by_name.keys()) - known_names
- # 推导关系
- relations: List[DerivedRelation] = []
- # 2. 迭代推导
- round_num = 0
- new_known_this_round = known_names.copy() # 第0轮新加入的就是起点
- while new_known_this_round:
- round_num += 1
- print(f"\n 第 {round_num} 轮推导...")
- # 本轮新加入的点
- new_known_next_round: Set[str] = set()
- # 遍历上一轮新加入的已知点
- for known_name in new_known_this_round:
- known_node = node_by_name.get(known_name)
- if not known_node:
- continue
- # 过滤:人设匹配分数 >= 0.8
- if known_node.人设匹配分数 < MATCH_SCORE_THRESHOLD:
- continue
- # 获取历史共现分类 {ID: 共现度}
- co_occur_map = known_node.历史共现分类
- if not co_occur_map:
- continue
- # 遍历未知点
- for unknown_name in list(unknown_names):
- unknown_node = node_by_name.get(unknown_name)
- if not unknown_node:
- continue
- # 过滤:人设匹配分数 >= 0.8
- if unknown_node.人设匹配分数 < MATCH_SCORE_THRESHOLD:
- continue
- # 检查:未知点的所属分类ID 是否在已知点的共现列表中
- if unknown_node.所属分类ID and unknown_node.所属分类ID in co_occur_map:
- # 找到关联!
- co_occur_score = co_occur_map[unknown_node.所属分类ID]
- new_known_next_round.add(unknown_name)
- # 建立关系
- relations.append(DerivedRelation(
- 来源节点ID=known_node.节点ID,
- 来源节点名称=known_name,
- 目标节点ID=unknown_node.节点ID,
- 目标节点名称=unknown_name,
- 关系类型="共现推导",
- 推导轮次=round_num,
- 共现分类ID=unknown_node.所属分类ID,
- 共现度=co_occur_score,
- ))
- print(f" {known_name} → {unknown_name} (共现度: {co_occur_score:.2f})")
- # 更新集合
- for name in new_known_next_round:
- node = node_by_name.get(name)
- if node:
- known_info.append({
- "节点ID": node.节点ID,
- "节点名称": name,
- "加入轮次": round_num,
- "加入原因": "共现推导",
- })
- known_names.update(new_known_next_round)
- unknown_names -= new_known_next_round
- new_known_this_round = new_known_next_round
- if not new_known_next_round:
- print(f" 无新点加入,推导结束")
- break
- # 3. 构建未知点列表
- unknown_list = []
- for name in unknown_names:
- node = node_by_name.get(name)
- if node:
- unknown_list.append({
- "节点ID": node.节点ID,
- "节点名称": name,
- "节点维度": node.节点维度,
- "人设匹配分数": node.人设匹配分数,
- "未加入原因": "人设匹配分数不足" if node.人设匹配分数 < MATCH_SCORE_THRESHOLD else "无共现关联",
- })
- return DerivationResult(
- 帖子ID="", # 由调用方设置
- 起点列表=origins,
- 已知点列表=known_info,
- 推导关系列表=[asdict(r) for r in relations],
- 推导轮次=round_num,
- 未知点列表=unknown_list,
- )
- # ===== 处理函数 =====
- def process_single_post(
- origin_file: Path,
- config: PathConfig,
- ) -> Optional[Dict]:
- """处理单个帖子"""
- # 加载起点分析结果
- origin_data = load_json(origin_file)
- post_id = origin_data.get("帖子id", "unknown")
- print(f"\n{'=' * 60}")
- print(f"处理帖子: {post_id}")
- print("-" * 60)
- # 获取起点分数
- origin_output = origin_data.get("输出", {})
- if not origin_output:
- print(" 错误: 起点分析结果为空")
- return None
- origin_scores = {name: info.get("score", 0) for name, info in origin_output.items()}
- # 加载待分析数据(获取完整节点信息)
- prepared_file = get_prepared_file(config, post_id)
- if not prepared_file:
- print(f" 错误: 未找到待分析数据文件")
- return None
- prepared_data = load_json(prepared_file)
- raw_nodes = prepared_data.get("待分析节点列表", [])
- # 转换为 AnalysisNode
- nodes = [AnalysisNode.from_raw(raw) for raw in raw_nodes]
- print(f" 节点数: {len(nodes)}")
- # 显示起点
- origins = [(name, score) for name, score in origin_scores.items() if score >= ORIGIN_SCORE_THRESHOLD]
- print(f" 起点 (score >= {ORIGIN_SCORE_THRESHOLD}): {len(origins)} 个")
- for name, score in sorted(origins, key=lambda x: -x[1]):
- print(f" ★ {name}: {score:.2f}")
- # 执行推导
- result = derive_patterns(nodes, origin_scores)
- result.帖子ID = post_id
- # 显示结果
- print(f"\n 推导轮次: {result.推导轮次}")
- print(f" 已知点: {len(result.已知点列表)} 个")
- print(f" 推导关系: {len(result.推导关系列表)} 条")
- print(f" 未知点: {len(result.未知点列表)} 个")
- # 保存结果
- output_dir = config.intermediate_dir / "pattern_derivation"
- output_dir.mkdir(parents=True, exist_ok=True)
- output_file = output_dir / f"{post_id}_模式推导.json"
- with open(output_file, "w", encoding="utf-8") as f:
- json.dump(asdict(result), f, ensure_ascii=False, indent=2)
- print(f"\n 已保存: {output_file.name}")
- return asdict(result)
- # ===== 主函数 =====
- def main(
- post_id: str = None,
- all_posts: bool = False,
- ):
- """
- 主函数
- Args:
- post_id: 帖子ID,可选
- all_posts: 是否处理所有帖子
- """
- config = PathConfig()
- print(f"账号: {config.account_name}")
- print(f"起点分数阈值: {ORIGIN_SCORE_THRESHOLD}")
- print(f"匹配分数阈值: {MATCH_SCORE_THRESHOLD}")
- # 获取起点分析结果文件
- origin_files = get_origin_result_files(config)
- if not origin_files:
- print("错误: 没有找到起点分析结果,请先运行 analyze_creation_origin.py")
- return
- # 确定要处理的帖子
- if post_id:
- target_file = next(
- (f for f in origin_files if post_id in f.name),
- None
- )
- if not target_file:
- print(f"错误: 未找到帖子 {post_id} 的起点分析结果")
- return
- files_to_process = [target_file]
- elif all_posts:
- files_to_process = origin_files
- else:
- files_to_process = [origin_files[0]]
- print(f"待处理帖子数: {len(files_to_process)}")
- # 处理
- results = []
- for i, origin_file in enumerate(files_to_process, 1):
- print(f"\n{'#' * 60}")
- print(f"# 处理帖子 {i}/{len(files_to_process)}")
- print(f"{'#' * 60}")
- result = process_single_post(origin_file, config)
- if result:
- results.append(result)
- # 汇总
- print(f"\n{'#' * 60}")
- print(f"# 完成! 共处理 {len(results)} 个帖子")
- print(f"{'#' * 60}")
- print("\n汇总:")
- for result in results:
- post_id = result.get("帖子ID")
- known_count = len(result.get("已知点列表", []))
- relation_count = len(result.get("推导关系列表", []))
- unknown_count = len(result.get("未知点列表", []))
- print(f" {post_id}: 已知={known_count}, 关系={relation_count}, 未知={unknown_count}")
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser(description="创作模式推导")
- parser.add_argument("--post-id", type=str, help="帖子ID")
- parser.add_argument("--all-posts", action="store_true", help="处理所有帖子")
- args = parser.parse_args()
- main(
- post_id=args.post_id,
- all_posts=args.all_posts,
- )
|