|
|
@@ -6,6 +6,7 @@
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
+import itertools
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
from typing import Any
|
|
|
@@ -101,23 +102,31 @@ def calc_node_conditional_ratio(
|
|
|
base_dir: 可选,input 根目录;不传则使用相对本文件的 ../input
|
|
|
|
|
|
计算规则:
|
|
|
- 已推导的帖子集合 = 各「推导来源人设树节点」在人设树中的 post_ids 的交集(方法内从树读取)
|
|
|
- 分子 = |已推导的帖子集合 ∩ N 的 post_ids|
|
|
|
- 分母 = |已推导的帖子集合 ∩ P 的 post_ids|
|
|
|
- 条件概率 = 分子/分母,且 ≤1;分母为 0 时返回 1。
|
|
|
+ 已推导的帖子集合:从 derived_list 中先取「最多选题点」的交集,再逐步减少到 1 个选题点,
|
|
|
+ 对每种选题点子集分别计算条件概率,最后取最大值。
|
|
|
+ 对每种情况:已推导的帖子集合 = 该子集中各「推导来源人设树节点」在人设树中的 post_ids 的交集;
|
|
|
+ 分子 = |已推导的帖子集合 ∩ N 的 post_ids|,分母 = |已推导的帖子集合 ∩ P 的 post_ids|;
|
|
|
+ 条件概率 = 分子/分母,且 ≤1;分母为 0 时该情况跳过。
|
|
|
"""
|
|
|
index = _build_node_index(account_name, base_dir)
|
|
|
- derived_post_ids = _derived_post_ids_from_sources(derived_list, index)
|
|
|
if tree_node_name not in index:
|
|
|
- return 1.0
|
|
|
+ return 0.0
|
|
|
n_pids, p_pids = index[tree_node_name]
|
|
|
set_n = set(n_pids)
|
|
|
set_p = set(p_pids)
|
|
|
- den = len(derived_post_ids & set_p)
|
|
|
- if den == 0:
|
|
|
- return 0.0
|
|
|
- num = len(derived_post_ids & set_n)
|
|
|
- return min(1.0, num / den)
|
|
|
+
|
|
|
+ max_ratio = 0.0
|
|
|
+ # 从「最多选题点」到 1 个选题点:对每种子集大小,取所有组合,分别算条件概率后取最大
|
|
|
+ for k in range(len(derived_list), 0, -1):
|
|
|
+ for combo in itertools.combinations(derived_list, k):
|
|
|
+ derived_post_ids = _derived_post_ids_from_sources(list(combo), index)
|
|
|
+ den = len(derived_post_ids & set_p)
|
|
|
+ if den == 0:
|
|
|
+ continue
|
|
|
+ num = len(derived_post_ids & set_n)
|
|
|
+ ratio = min(1.0, num / den)
|
|
|
+ max_ratio = max(max_ratio, ratio)
|
|
|
+ return max_ratio
|
|
|
|
|
|
|
|
|
def _pattern_nodes_and_post_count(pattern: dict[str, Any]) -> tuple[list[str], int, float]:
|
|
|
@@ -197,13 +206,17 @@ def _test_with_user_example() -> None:
|
|
|
account_name = "家有大志"
|
|
|
# 已推导列表:(已推导的选题点, 推导来源人设树节点)
|
|
|
derived_list: list[DerivedItem] = [
|
|
|
- # ("分享", "分享"),
|
|
|
- # ("柴犬", "动物角色"),
|
|
|
+ ("分享", "分享"),
|
|
|
+ ("叙事结构", "叙事结构"),
|
|
|
+ ("图片文字", "图片文字"),
|
|
|
+ ("补充说明式", "补充说明式"),
|
|
|
+ ("幽默化标题", "幽默化标题"),
|
|
|
+ ("标题", "标题"),
|
|
|
]
|
|
|
|
|
|
# 1)人设树节点「恶作剧」的条件概率
|
|
|
- r_node = calc_node_conditional_ratio(account_name, derived_list, "恶作剧")
|
|
|
- print(f"1) 人设树节点「恶作剧」条件概率: {r_node}")
|
|
|
+ r_node = calc_node_conditional_ratio(account_name, derived_list, "柴犬主角")
|
|
|
+ print(f"1) 人设树节点条件概率: {r_node}")
|
|
|
|
|
|
# 2)pattern 分享+动物角色+创意表达 post_count=2 的条件概率
|
|
|
pattern = {"i": "分享+动物角色+创意表达", "post_count": 2, "s": 0.3}
|