""" 条件概率计算工具: 1)计算某个人设树节点在父节点下的条件概率; 2)计算某个 pattern 的条件概率。 """ from __future__ import annotations import itertools import json from pathlib import Path from typing import Any # 已推导列表:每项为 (已推导的选题点, 推导来源人设树节点),如 ("分享","分享")、("柴犬","动物角色") # 推导来源人设树节点的 post_ids 在计算条件概率时从人设树中读取 DerivedItem = tuple[str, str] def _tree_dir(account_name: str, base_dir: Path | None = None) -> Path: """人设树目录:../input/{account_name}/原始数据/tree/(相对本文件所在目录)。""" if base_dir is not None: return base_dir / account_name / "原始数据" / "tree" return Path(__file__).resolve().parent.parent / "input" / account_name / "原始数据" / "tree" def _load_trees(account_name: str, base_dir: Path | None = None) -> list[tuple[str, dict]]: """加载该账号下所有维度的人设树。返回 [(维度名, 根节点 dict), ...]。""" td = _tree_dir(account_name, base_dir) if not td.is_dir(): return [] result = [] for p in td.glob("*.json"): try: with open(p, "r", encoding="utf-8") as f: data = json.load(f) # 文件格式为 { "实质": { ... } } 或 { "形式": { ... } } for dim_name, root in data.items(): if isinstance(root, dict): result.append((dim_name, root)) break except Exception: continue return result def _post_ids_of(node: dict) -> list[str]: """从树节点中取出 _post_ids,无则返回空列表。""" return list(node.get("_post_ids") or []) def _build_node_index(account_name: str, base_dir: Path | None = None) -> dict[str, tuple[list[str], list[str]]]: """ 遍历所有维度的人设树,建立 节点名 -> (该节点 post_ids, 父节点 post_ids)。 同一节点名在多个分支出现时,保留第一次遇到的(保证父子一致)。 """ index: dict[str, tuple[list[str], list[str]]] = {} for _dim, root in _load_trees(account_name, base_dir): parent_pids = _post_ids_of(root) def walk(parent_ids: list[str], node_dict: dict) -> None: for name, child in (node_dict.get("children") or {}).items(): if not isinstance(child, dict): continue if name not in index: index[name] = (_post_ids_of(child), list(parent_ids)) walk(_post_ids_of(child), child) walk(parent_pids, root) return index def _derived_post_ids_from_sources( derived_list: list[DerivedItem], index: dict[str, tuple[list[str], list[str]]], ) -> set[str]: """根据 derived_list 中的「推导来源人设树节点」在人设树中的 post_ids 取交集,得到已推导的帖子集合。""" common: set[str] | None = None for _topic_point, source_node in derived_list: if source_node not in index: continue pids = set(index[source_node][0]) if common is None: common = pids else: common &= pids return common if common is not None else set() def calc_node_conditional_ratio( account_name: str, derived_list: list[DerivedItem], tree_node_name: str, base_dir: Path | None = None, ) -> float: """ 计算人设树节点 N 在父节点 P 下的条件概率。 参数: account_name: 账号名称 derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点) tree_node_name: 人设树节点 N 的名称(字符串匹配) base_dir: 可选,input 根目录;不传则使用相对本文件的 ../input 计算规则: 已推导的帖子集合:从 derived_list 中先取「最多选题点」的交集,再逐步减少到 1 个选题点, 对每种选题点子集分别计算条件概率,最后取最大值。 对每种情况:已推导的帖子集合 = 该子集中各「推导来源人设树节点」在人设树中的 post_ids 的交集; 分子 = |已推导的帖子集合 ∩ N 的 post_ids|,分母 = |已推导的帖子集合 ∩ P 的 post_ids|; 条件概率 = 分子/分母,且 ≤1;分母为 0 时该情况跳过。 """ index = _build_node_index(account_name, base_dir) if tree_node_name not in index: return 0.0 n_pids, p_pids = index[tree_node_name] set_n = set(n_pids) set_p = set(p_pids) max_ratio = 0.0 # 从「最多选题点」到 1 个选题点:对每种子集大小,取所有组合,分别算条件概率后取最大 for k in range(len(derived_list), 0, -1): for combo in itertools.combinations(derived_list, k): derived_post_ids = _derived_post_ids_from_sources(list(combo), index) den = len(derived_post_ids & set_p) if den == 0: continue num = len(derived_post_ids & set_n) ratio = min(1.0, num / den) max_ratio = max(max_ratio, ratio) return max_ratio def _pattern_nodes_and_post_count(pattern: dict[str, Any]) -> tuple[list[str], int, float]: """ 从 pattern 中解析出节点列表和 post_count。支持 nodes + post_count 或 i + post_count。 返回的 post_count 表示该 pattern 本身的帖子数,在条件概率计算中作为分子(即 pattern 本身的概率/占比的分子)。 """ nodes = pattern.get("nodes") if nodes is not None and isinstance(nodes, list): nodes = [str(x).strip() for x in nodes if x] else: raw = pattern.get("i") or pattern.get("pattern_str") or "" nodes = [x.strip() for x in str(raw).replace("+", " ").split() if x.strip()] post_count = int(pattern.get("post_count", 0)) support = pattern.get("s", 0.0) return nodes, post_count, support def calc_pattern_conditional_ratio( account_name: str, derived_list: list[DerivedItem], pattern: dict[str, Any], base_dir: Path | None = None, ) -> float: """ 计算某个 pattern 的条件概率。 参数: account_name: 账号名称 derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点) pattern: 至少包含节点列表与 post_count。 - 节点列表: key 为 "nodes"(list)或 "i"(字符串,用 + 连接) - post_count: 该 pattern 的帖子数量,作为分子 base_dir: 可选,input 根目录 计算规则: 取 pattern 中「已被推导」的节点(其名称出现在 derived 的推导来源中), 在人设树中取这些节点的 post_ids 的交集作为分母; 分子 = pattern.post_count(由 _pattern_nodes_and_post_count 解析得到,表示 pattern 本身的帖子数)。 条件概率 = 分子/分母,且 ≤1;分母为 0 时返回 1。 """ pattern_nodes, post_count, pattern_s = _pattern_nodes_and_post_count(pattern) if not pattern_nodes or post_count <= 0: return pattern_s derived_sources = set(source for _post, source in derived_list) # pattern 中已被推导的节点 derived_pattern_nodes = [n for n in pattern_nodes if n in derived_sources] if not derived_pattern_nodes: return pattern_s index = _build_node_index(account_name, base_dir) # 仅使用在人设树中存在的「已被推导」节点,取它们在树中的 post_ids 的交集 derived_in_tree = [n for n in derived_pattern_nodes if n in index] if not derived_in_tree: return pattern_s common: set[str] | None = None for name in derived_in_tree: pids = set(index[name][0]) if common is None: common = pids else: common &= pids if common is None or len(common) == 0: return pattern_s den = len(common) # 分子为 pattern 本身的帖子数(post_count),分母为条件集合大小 return min(1.0, post_count / den) def _test_with_user_example() -> None: """ 使用你提供的测试数据:已推导 (分享|分享)、(柴犬|动物角色); 人设树节点:恶作剧;pattern:分享+动物角色+创意表达 post_count=2。 推导来源的 post_ids 在方法内部从人设树读取。 """ account_name = "家有大志" # 已推导列表:(已推导的选题点, 推导来源人设树节点) derived_list: list[DerivedItem] = [ ("分享", "分享"), ("叙事结构", "叙事结构"), ("图片文字", "图片文字"), ("补充说明式", "补充说明式"), ("幽默化标题", "幽默化标题"), ("标题", "标题"), ] # 1)人设树节点「恶作剧」的条件概率 r_node = calc_node_conditional_ratio(account_name, derived_list, "柴犬主角") print(f"1) 人设树节点条件概率: {r_node}") # 2)pattern 分享+动物角色+创意表达 post_count=2 的条件概率 pattern = {"i": "分享+动物角色+创意表达", "post_count": 2, "s": 0.3} r_pattern = calc_pattern_conditional_ratio(account_name, derived_list, pattern) print(f"2) pattern 分享+动物角色+创意表达 (post_count=2) 条件概率: {r_pattern}") if __name__ == "__main__": _test_with_user_example()