conditional_ratio_calc.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. """
  2. 条件概率计算工具:
  3. 1)计算某个人设树节点在父节点下的条件概率;
  4. 2)计算某个 pattern 的条件概率。
  5. """
  6. from __future__ import annotations
  7. import itertools
  8. import json
  9. from pathlib import Path
  10. from typing import Any
  11. # 已推导列表:每项为 (已推导的选题点, 推导来源人设树节点),如 ("分享","分享")、("柴犬","动物角色")
  12. # 推导来源人设树节点的 post_ids 在计算条件概率时从人设树中读取
  13. DerivedItem = tuple[str, str]
  14. def _tree_dir(account_name: str, base_dir: Path | None = None) -> Path:
  15. """人设树目录:../input/{account_name}/原始数据/tree/(相对本文件所在目录)。"""
  16. if base_dir is not None:
  17. return base_dir / account_name / "原始数据" / "tree"
  18. return Path(__file__).resolve().parent.parent / "input" / account_name / "原始数据" / "tree"
  19. def _load_trees(account_name: str, base_dir: Path | None = None) -> list[tuple[str, dict]]:
  20. """加载该账号下所有维度的人设树。返回 [(维度名, 根节点 dict), ...]。"""
  21. td = _tree_dir(account_name, base_dir)
  22. if not td.is_dir():
  23. return []
  24. result = []
  25. for p in td.glob("*.json"):
  26. try:
  27. with open(p, "r", encoding="utf-8") as f:
  28. data = json.load(f)
  29. # 文件格式为 { "实质": { ... } } 或 { "形式": { ... } }
  30. for dim_name, root in data.items():
  31. if isinstance(root, dict):
  32. result.append((dim_name, root))
  33. break
  34. except Exception:
  35. continue
  36. return result
  37. def _post_ids_of(node: dict) -> list[str]:
  38. """从树节点中取出 _post_ids,无则返回空列表。"""
  39. return list(node.get("_post_ids") or [])
  40. def _build_node_index(account_name: str, base_dir: Path | None = None) -> dict[str, tuple[list[str], list[str]]]:
  41. """
  42. 遍历所有维度的人设树,建立 节点名 -> (该节点 post_ids, 父节点 post_ids)。
  43. 同一节点名在多个分支出现时,保留第一次遇到的(保证父子一致)。
  44. """
  45. index: dict[str, tuple[list[str], list[str]]] = {}
  46. for _dim, root in _load_trees(account_name, base_dir):
  47. parent_pids = _post_ids_of(root)
  48. def walk(parent_ids: list[str], node_dict: dict) -> None:
  49. for name, child in (node_dict.get("children") or {}).items():
  50. if not isinstance(child, dict):
  51. continue
  52. if name not in index:
  53. index[name] = (_post_ids_of(child), list(parent_ids))
  54. walk(_post_ids_of(child), child)
  55. walk(parent_pids, root)
  56. return index
  57. def _derived_post_ids_from_sources(
  58. derived_list: list[DerivedItem],
  59. index: dict[str, tuple[list[str], list[str]]],
  60. ) -> set[str]:
  61. """根据 derived_list 中的「推导来源人设树节点」在人设树中的 post_ids 取交集,得到已推导的帖子集合。"""
  62. common: set[str] | None = None
  63. for _topic_point, source_node in derived_list:
  64. if source_node not in index:
  65. continue
  66. pids = set(index[source_node][0])
  67. if common is None:
  68. common = pids
  69. else:
  70. common &= pids
  71. return common if common is not None else set()
  72. def calc_node_conditional_ratio(
  73. account_name: str,
  74. derived_list: list[DerivedItem],
  75. tree_node_name: str,
  76. base_dir: Path | None = None,
  77. ) -> float:
  78. """
  79. 计算人设树节点 N 在父节点 P 下的条件概率。
  80. 参数:
  81. account_name: 账号名称
  82. derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点)
  83. tree_node_name: 人设树节点 N 的名称(字符串匹配)
  84. base_dir: 可选,input 根目录;不传则使用相对本文件的 ../input
  85. 计算规则:
  86. 已推导的帖子集合:从 derived_list 中先取「最多选题点」的交集,再逐步减少到 1 个选题点,
  87. 对每种选题点子集分别计算条件概率,最后取最大值。
  88. 对每种情况:已推导的帖子集合 = 该子集中各「推导来源人设树节点」在人设树中的 post_ids 的交集;
  89. 分子 = |已推导的帖子集合 ∩ N 的 post_ids|,分母 = |已推导的帖子集合 ∩ P 的 post_ids|;
  90. 条件概率 = 分子/分母,且 ≤1;分母为 0 时该情况跳过。
  91. """
  92. index = _build_node_index(account_name, base_dir)
  93. if tree_node_name not in index:
  94. return 0.0
  95. n_pids, p_pids = index[tree_node_name]
  96. set_n = set(n_pids)
  97. set_p = set(p_pids)
  98. max_ratio = 0.0
  99. # 从「最多选题点」到 1 个选题点:对每种子集大小,取所有组合,分别算条件概率后取最大
  100. for k in range(len(derived_list), 0, -1):
  101. for combo in itertools.combinations(derived_list, k):
  102. derived_post_ids = _derived_post_ids_from_sources(list(combo), index)
  103. den = len(derived_post_ids & set_p)
  104. if den == 0:
  105. continue
  106. num = len(derived_post_ids & set_n)
  107. ratio = min(1.0, num / den)
  108. max_ratio = max(max_ratio, ratio)
  109. return max_ratio
  110. def _pattern_nodes_and_post_count(pattern: dict[str, Any]) -> tuple[list[str], int, float]:
  111. """
  112. 从 pattern 中解析出节点列表和 post_count。支持 nodes + post_count 或 i + post_count。
  113. 返回的 post_count 表示该 pattern 本身的帖子数,在条件概率计算中作为分子(即 pattern 本身的概率/占比的分子)。
  114. """
  115. nodes = pattern.get("nodes")
  116. if nodes is not None and isinstance(nodes, list):
  117. nodes = [str(x).strip() for x in nodes if x]
  118. else:
  119. raw = pattern.get("i") or pattern.get("pattern_str") or ""
  120. nodes = [x.strip() for x in str(raw).replace("+", " ").split() if x.strip()]
  121. post_count = int(pattern.get("post_count", 0))
  122. support = pattern.get("s", 0.0)
  123. return nodes, post_count, support
  124. def calc_pattern_conditional_ratio(
  125. account_name: str,
  126. derived_list: list[DerivedItem],
  127. pattern: dict[str, Any],
  128. base_dir: Path | None = None,
  129. ) -> float:
  130. """
  131. 计算某个 pattern 的条件概率。
  132. 参数:
  133. account_name: 账号名称
  134. derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点)
  135. pattern: 至少包含节点列表与 post_count。
  136. - 节点列表: key 为 "nodes"(list)或 "i"(字符串,用 + 连接)
  137. - post_count: 该 pattern 的帖子数量,作为分子
  138. base_dir: 可选,input 根目录
  139. 计算规则:
  140. 取 pattern 中「已被推导」的节点(其名称出现在 derived 的推导来源中),
  141. 在人设树中取这些节点的 post_ids 的交集作为分母;
  142. 分子 = pattern.post_count(由 _pattern_nodes_and_post_count 解析得到,表示 pattern 本身的帖子数)。
  143. 条件概率 = 分子/分母,且 ≤1;分母为 0 时返回 1。
  144. """
  145. pattern_nodes, post_count, pattern_s = _pattern_nodes_and_post_count(pattern)
  146. if not pattern_nodes or post_count <= 0:
  147. return pattern_s
  148. derived_sources = set(source for _post, source in derived_list)
  149. # pattern 中已被推导的节点
  150. derived_pattern_nodes = [n for n in pattern_nodes if n in derived_sources]
  151. if not derived_pattern_nodes:
  152. return pattern_s
  153. index = _build_node_index(account_name, base_dir)
  154. # 仅使用在人设树中存在的「已被推导」节点,取它们在树中的 post_ids 的交集
  155. derived_in_tree = [n for n in derived_pattern_nodes if n in index]
  156. if not derived_in_tree:
  157. return pattern_s
  158. common: set[str] | None = None
  159. for name in derived_in_tree:
  160. pids = set(index[name][0])
  161. if common is None:
  162. common = pids
  163. else:
  164. common &= pids
  165. if common is None or len(common) == 0:
  166. return pattern_s
  167. den = len(common)
  168. # 分子为 pattern 本身的帖子数(post_count),分母为条件集合大小
  169. return min(1.0, post_count / den)
  170. def _test_with_user_example() -> None:
  171. """
  172. 使用你提供的测试数据:已推导 (分享|分享)、(柴犬|动物角色);
  173. 人设树节点:恶作剧;pattern:分享+动物角色+创意表达 post_count=2。
  174. 推导来源的 post_ids 在方法内部从人设树读取。
  175. """
  176. account_name = "家有大志"
  177. # 已推导列表:(已推导的选题点, 推导来源人设树节点)
  178. derived_list: list[DerivedItem] = [
  179. ("分享", "分享"),
  180. ("叙事结构", "叙事结构"),
  181. ("图片文字", "图片文字"),
  182. ("补充说明式", "补充说明式"),
  183. ("幽默化标题", "幽默化标题"),
  184. ("标题", "标题"),
  185. ]
  186. # 1)人设树节点「恶作剧」的条件概率
  187. r_node = calc_node_conditional_ratio(account_name, derived_list, "柴犬主角")
  188. print(f"1) 人设树节点条件概率: {r_node}")
  189. # 2)pattern 分享+动物角色+创意表达 post_count=2 的条件概率
  190. pattern = {"i": "分享+动物角色+创意表达", "post_count": 2, "s": 0.3}
  191. r_pattern = calc_pattern_conditional_ratio(account_name, derived_list, pattern)
  192. print(f"2) pattern 分享+动物角色+创意表达 (post_count=2) 条件概率: {r_pattern}")
  193. if __name__ == "__main__":
  194. _test_with_user_example()