conditional_ratio_calc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. """
  2. 条件概率计算工具:
  3. 1)计算某个人设树节点在父节点下的条件概率;
  4. 2)计算某个 pattern 的条件概率。
  5. """
  6. from __future__ import annotations
  7. import itertools
  8. import json
  9. from pathlib import Path
  10. from typing import Any
  11. # 节点名 -> (该节点 post_ids, 父节点 post_ids),用 frozenset 便于批量计算时复用、避免重复转换
  12. NodePostIndex = dict[str, tuple[frozenset[str], frozenset[str]]]
  13. # 已推导列表:每项为 (已推导的选题点, 推导来源人设树节点),如 ("分享","分享")、("柴犬","动物角色")
  14. # 推导来源人设树节点的 post_ids 在计算条件概率时从人设树中读取
  15. DerivedItem = tuple[str, str]
  16. def _tree_dir(account_name: str, base_dir: Path | None = None) -> Path:
  17. """人设树目录:../input/{account_name}/处理后数据/tree/(相对本文件所在目录)。"""
  18. if base_dir is not None:
  19. return base_dir / account_name / "处理后数据" / "tree"
  20. return Path(__file__).resolve().parent.parent / "input" / account_name / "处理后数据" / "tree"
  21. def _load_trees(account_name: str, base_dir: Path | None = None) -> list[tuple[str, dict]]:
  22. """加载该账号下所有维度的人设树。返回 [(维度名, 根节点 dict), ...]。"""
  23. td = _tree_dir(account_name, base_dir)
  24. if not td.is_dir():
  25. return []
  26. return _load_trees_from_directory(td)
  27. def _load_trees_from_directory(tree_dir: Path) -> list[tuple[str, dict]]:
  28. """
  29. 从指定目录加载所有人设树 JSON(每文件取顶层第一个维度根,与按账号目录加载时行为一致)。
  30. 用于平台库等人设树路径非 input/{账号}/处理后数据/tree/ 的场景。
  31. """
  32. if not tree_dir.is_dir():
  33. return []
  34. result: list[tuple[str, dict]] = []
  35. for p in sorted(tree_dir.glob("*.json")):
  36. try:
  37. with open(p, "r", encoding="utf-8") as f:
  38. data = json.load(f)
  39. for dim_name, root in data.items():
  40. if isinstance(root, dict):
  41. result.append((str(dim_name), root))
  42. break
  43. except Exception:
  44. continue
  45. return result
  46. def _post_ids_of(node: dict) -> list[str]:
  47. """从树节点中取出 _post_ids,无则返回空列表。"""
  48. return list(node.get("_post_ids") or [])
  49. def _build_node_index_from_trees(trees: list[tuple[str, dict]]) -> dict[str, tuple[list[str], list[str]]]:
  50. """
  51. 遍历多棵人设树,建立 节点名 -> (该节点 post_ids, 父节点 post_ids)。
  52. 同一节点名在多个分支出现时,保留第一次遇到的(保证父子一致)。
  53. """
  54. index: dict[str, tuple[list[str], list[str]]] = {}
  55. for _dim, root in trees:
  56. parent_pids = _post_ids_of(root)
  57. def walk(parent_ids: list[str], node_dict: dict) -> None:
  58. for name, child in (node_dict.get("children") or {}).items():
  59. if not isinstance(child, dict):
  60. continue
  61. if name not in index:
  62. index[name] = (_post_ids_of(child), list(parent_ids))
  63. walk(_post_ids_of(child), child)
  64. walk(parent_pids, root)
  65. return index
  66. def _build_node_index(account_name: str, base_dir: Path | None = None) -> dict[str, tuple[list[str], list[str]]]:
  67. """遍历账号下所有人设树,建立节点索引。"""
  68. return _build_node_index_from_trees(_load_trees(account_name, base_dir))
  69. def build_node_post_index_from_tree_dir(tree_dir: Path) -> NodePostIndex:
  70. """从任意人设树目录(如 input/xiaohongshu/tree)构建节点 post 索引,算法与账号树一致。"""
  71. raw = _build_node_index_from_trees(_load_trees_from_directory(tree_dir))
  72. return {k: (frozenset(a), frozenset(b)) for k, (a, b) in raw.items()}
  73. def build_node_index_for_tree_dir(tree_dir: Path) -> dict[str, tuple[list[str], list[str]]]:
  74. """从任意人设树目录构建节点名 -> (节点 post_ids, 父 post_ids),供 pattern 条件概率等使用。"""
  75. return _build_node_index_from_trees(_load_trees_from_directory(tree_dir))
  76. def load_persona_trees_from_dir(tree_dir: Path) -> list[tuple[str, dict]]:
  77. """从目录加载人设树列表(每 JSON 文件一个顶层维度),供遍历节点等场景复用。"""
  78. return _load_trees_from_directory(tree_dir)
  79. def _derived_post_ids_from_sources(
  80. derived_list: list[DerivedItem],
  81. index: dict[str, tuple[list[str], list[str]]],
  82. ) -> set[str]:
  83. """根据 derived_list 中的「推导来源人设树节点」在人设树中的 post_ids 取交集,得到已推导的帖子集合。"""
  84. common: set[str] | None = None
  85. for _topic_point, source_node in derived_list:
  86. if source_node not in index:
  87. continue
  88. pids = set(index[source_node][0])
  89. if common is None:
  90. common = pids
  91. else:
  92. common &= pids
  93. return common if common is not None else set()
  94. def _derived_post_ids_from_frozen_index(
  95. derived_list: list[DerivedItem],
  96. index: NodePostIndex,
  97. ) -> frozenset[str]:
  98. """与 _derived_post_ids_from_sources 相同语义,索引为 frozenset 版(批量场景复用)。"""
  99. common: frozenset[str] | None = None
  100. for _topic_point, source_node in derived_list:
  101. if source_node not in index:
  102. continue
  103. pids = index[source_node][0]
  104. common = pids if common is None else common & pids
  105. return common if common is not None else frozenset()
  106. def build_node_post_index(account_name: str, base_dir: Path | None = None) -> NodePostIndex:
  107. """
  108. 构建账号人设树的节点索引(每个节点只建一次,供批量 calc_node_conditional_ratio 复用)。
  109. 值为 (节点 post_ids, 父节点 post_ids) 的 frozenset,减少重复 list->set 与拷贝。
  110. """
  111. raw = _build_node_index(account_name, base_dir)
  112. return {k: (frozenset(a), frozenset(b)) for k, (a, b) in raw.items()}
  113. def calc_node_conditional_ratio(
  114. account_name: str,
  115. derived_list: list[DerivedItem],
  116. tree_node_name: str,
  117. base_dir: Path | None = None,
  118. node_post_index: NodePostIndex | None = None,
  119. target_ratio: float | None = None,
  120. ) -> float:
  121. """
  122. 计算人设树节点 N 在父节点 P 下的条件概率。
  123. 参数:
  124. account_name: 账号名称
  125. derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点)
  126. tree_node_name: 人设树节点 N 的名称(字符串匹配)
  127. base_dir: 可选,input 根目录;不传则使用相对本文件的 ../input
  128. node_post_index: 可选,由 build_node_post_index 预构建;批量对多节点计算时传入可避免重复读盘与遍历整棵树
  129. target_ratio: 可选,目标条件概率。若某个组合的条件概率已达到该值,则直接返回(用于缩小组合搜索)
  130. 计算规则:
  131. 已推导的帖子集合:从 derived_list 中先取「最多选题点」的交集,再逐步减少到 1 个选题点,
  132. 对每种选题点子集分别计算条件概率,最后取最大值。
  133. 对每种情况:已推导的帖子集合 = 该子集中各「推导来源人设树节点」在人设树中的 post_ids 的交集;
  134. 分子 = |已推导的帖子集合 ∩ N 的 post_ids|,分母 = |已推导的帖子集合 ∩ P 的 post_ids|;
  135. 条件概率 = 分子/分母,且 ≤1;分母为 0 时该情况跳过。
  136. """
  137. index = node_post_index if node_post_index is not None else build_node_post_index(account_name, base_dir)
  138. if tree_node_name not in index:
  139. return 0.0
  140. set_n, set_p = index[tree_node_name]
  141. # 关键优化(不改变搜索空间/结果):
  142. # - derived_list 里重复的 source_node 对“交集”没有任何影响,但会把 L 变大导致 2^L 爆炸
  143. # - 不在 index 里的 source_node 原本也会被跳过,提前过滤可减少组合规模
  144. # - 组合内交集直接对 frozenset 逐步 &,避免 list(combo)/函数调用开销
  145. seen_sources: set[str] = set()
  146. source_sets: list[frozenset[str]] = []
  147. for _topic, source_node in derived_list:
  148. if source_node in seen_sources:
  149. continue
  150. seen_sources.add(source_node)
  151. tup = index.get(source_node)
  152. if tup is None:
  153. continue
  154. source_sets.append(tup[0])
  155. if not source_sets:
  156. return 0.0
  157. # 将更小的集合放前面:交集会更快“变小”,每次 & 的成本更低(仍然枚举全部子集)
  158. source_sets.sort(key=len)
  159. max_ratio = 0.0
  160. # 从 1 个选题点到「最多选题点」:对每种子集大小,取所有组合,分别算条件概率后取最大
  161. for k in range(1, len(source_sets) + 1):
  162. for combo_sets in itertools.combinations(source_sets, k):
  163. derived_post_ids = combo_sets[0]
  164. for s in combo_sets[1:]:
  165. derived_post_ids = derived_post_ids & s
  166. den = len(derived_post_ids & set_p)
  167. if den == 0:
  168. continue
  169. num = len(derived_post_ids & set_n)
  170. ratio = min(1.0, num / den)
  171. max_ratio = max(max_ratio, ratio)
  172. if target_ratio is not None and max_ratio >= target_ratio:
  173. return round(max_ratio, 4)
  174. return round(max_ratio, 4)
  175. def _pattern_nodes_and_post_count(pattern: dict[str, Any]) -> tuple[list[str], int, float]:
  176. """
  177. 从 pattern 中解析出节点列表和 post_count。支持 nodes + post_count 或 i + post_count。
  178. 返回的 post_count 表示该 pattern 本身的帖子数,在条件概率计算中作为分子(即 pattern 本身的概率/占比的分子)。
  179. """
  180. nodes = pattern.get("nodes")
  181. if nodes is not None and isinstance(nodes, list):
  182. nodes = [str(x).strip() for x in nodes if x]
  183. else:
  184. raw = pattern.get("i") or pattern.get("pattern_str") or ""
  185. nodes = [x.strip() for x in str(raw).replace("+", " ").split() if x.strip()]
  186. post_count = int(pattern.get("post_count", 0))
  187. support = pattern.get("s", 0.0)
  188. return nodes, post_count, support
  189. def calc_pattern_conditional_ratio_with_index(
  190. derived_list: list[DerivedItem],
  191. pattern: dict[str, Any],
  192. index: dict[str, tuple[list[str], list[str]]],
  193. ) -> float:
  194. """
  195. 与 calc_pattern_conditional_ratio 相同计算规则,但使用已构建的人设树节点索引
  196. (例如平台库 input/xiaohongshu/tree)。
  197. """
  198. pattern_nodes, post_count, pattern_s = _pattern_nodes_and_post_count(pattern)
  199. if not pattern_nodes or post_count <= 0:
  200. return pattern_s
  201. derived_sources = set(source for _post, source in derived_list)
  202. derived_pattern_nodes = [n for n in pattern_nodes if n in derived_sources]
  203. if not derived_pattern_nodes:
  204. return pattern_s
  205. derived_in_tree = [n for n in derived_pattern_nodes if n in index]
  206. if not derived_in_tree:
  207. return pattern_s
  208. common: set[str] | None = None
  209. for name in derived_in_tree:
  210. pids = set(index[name][0])
  211. if common is None:
  212. common = pids
  213. else:
  214. common &= pids
  215. if common is None or len(common) == 0:
  216. return pattern_s
  217. den = len(common)
  218. return round(min(1.0, post_count / den), 4)
  219. def calc_pattern_conditional_ratio(
  220. account_name: str,
  221. derived_list: list[DerivedItem],
  222. pattern: dict[str, Any],
  223. base_dir: Path | None = None,
  224. ) -> float:
  225. """
  226. 计算某个 pattern 的条件概率。
  227. 参数:
  228. account_name: 账号名称
  229. derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点)
  230. pattern: 至少包含节点列表与 post_count。
  231. - 节点列表: key 为 "nodes"(list)或 "i"(字符串,用 + 连接)
  232. - post_count: 该 pattern 的帖子数量,作为分子
  233. base_dir: 可选,input 根目录
  234. 计算规则:
  235. 取 pattern 中「已被推导」的节点(其名称出现在 derived 的推导来源中),
  236. 在人设树中取这些节点的 post_ids 的交集作为分母;
  237. 分子 = pattern.post_count(由 _pattern_nodes_and_post_count 解析得到,表示 pattern 本身的帖子数)。
  238. 条件概率 = 分子/分母,且 ≤1;分母为 0 时返回 1。
  239. """
  240. index = _build_node_index(account_name, base_dir)
  241. return calc_pattern_conditional_ratio_with_index(derived_list, pattern, index)
  242. def _test_with_user_example() -> None:
  243. """
  244. 使用你提供的测试数据:已推导 (分享|分享)、(柴犬|动物角色);
  245. 人设树节点:恶作剧;pattern:分享+动物角色+创意表达 post_count=2。
  246. 推导来源的 post_ids 在方法内部从人设树读取。
  247. """
  248. account_name = "阿里多多酱"
  249. # 已推导列表:(已推导的选题点, 推导来源人设树节点)
  250. derived_list: list[DerivedItem] = [
  251. ("推广", "推广"),
  252. ("视觉调性", "视觉调性"),
  253. # ("图片文字", "图片文字"),
  254. # ("补充说明式", "补充说明式"),
  255. # ("幽默化标题", "幽默化标题"),
  256. # ("标题", "标题"),
  257. ]
  258. # 1)人设树节点「恶作剧」的条件概率
  259. r_node = calc_node_conditional_ratio(account_name, derived_list, "观念")
  260. print(f"1) 人设树节点条件概率: {r_node}")
  261. # 2)pattern 分享+动物角色+创意表达 post_count=2 的条件概率
  262. pattern = {"i": "视觉调性+辞格意象+叙事编排", "post_count": 22, "s": 0.478261}
  263. r_pattern = calc_pattern_conditional_ratio(account_name, derived_list, pattern)
  264. print(f"2) pattern 条件概率: {r_pattern}")
  265. if __name__ == "__main__":
  266. _test_with_user_example()