|
@@ -11,6 +11,9 @@ import json
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from typing import Any
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
+# 节点名 -> (该节点 post_ids, 父节点 post_ids),用 frozenset 便于批量计算时复用、避免重复转换
|
|
|
|
|
+NodePostIndex = dict[str, tuple[frozenset[str], frozenset[str]]]
|
|
|
|
|
+
|
|
|
# 已推导列表:每项为 (已推导的选题点, 推导来源人设树节点),如 ("分享","分享")、("柴犬","动物角色")
|
|
# 已推导列表:每项为 (已推导的选题点, 推导来源人设树节点),如 ("分享","分享")、("柴犬","动物角色")
|
|
|
# 推导来源人设树节点的 post_ids 在计算条件概率时从人设树中读取
|
|
# 推导来源人设树节点的 post_ids 在计算条件概率时从人设树中读取
|
|
|
DerivedItem = tuple[str, str]
|
|
DerivedItem = tuple[str, str]
|
|
@@ -86,11 +89,36 @@ def _derived_post_ids_from_sources(
|
|
|
return common if common is not None else set()
|
|
return common if common is not None else set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _derived_post_ids_from_frozen_index(
|
|
|
|
|
+ derived_list: list[DerivedItem],
|
|
|
|
|
+ index: NodePostIndex,
|
|
|
|
|
+) -> frozenset[str]:
|
|
|
|
|
+ """与 _derived_post_ids_from_sources 相同语义,索引为 frozenset 版(批量场景复用)。"""
|
|
|
|
|
+ common: frozenset[str] | None = None
|
|
|
|
|
+ for _topic_point, source_node in derived_list:
|
|
|
|
|
+ if source_node not in index:
|
|
|
|
|
+ continue
|
|
|
|
|
+ pids = index[source_node][0]
|
|
|
|
|
+ common = pids if common is None else common & pids
|
|
|
|
|
+ return common if common is not None else frozenset()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def build_node_post_index(account_name: str, base_dir: Path | None = None) -> NodePostIndex:
|
|
|
|
|
+ """
|
|
|
|
|
+ 构建账号人设树的节点索引(每个节点只建一次,供批量 calc_node_conditional_ratio 复用)。
|
|
|
|
|
+ 值为 (节点 post_ids, 父节点 post_ids) 的 frozenset,减少重复 list->set 与拷贝。
|
|
|
|
|
+ """
|
|
|
|
|
+ raw = _build_node_index(account_name, base_dir)
|
|
|
|
|
+ return {k: (frozenset(a), frozenset(b)) for k, (a, b) in raw.items()}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def calc_node_conditional_ratio(
|
|
def calc_node_conditional_ratio(
|
|
|
account_name: str,
|
|
account_name: str,
|
|
|
derived_list: list[DerivedItem],
|
|
derived_list: list[DerivedItem],
|
|
|
tree_node_name: str,
|
|
tree_node_name: str,
|
|
|
base_dir: Path | None = None,
|
|
base_dir: Path | None = None,
|
|
|
|
|
+ node_post_index: NodePostIndex | None = None,
|
|
|
|
|
+ target_ratio: float | None = None,
|
|
|
) -> float:
|
|
) -> float:
|
|
|
"""
|
|
"""
|
|
|
计算人设树节点 N 在父节点 P 下的条件概率。
|
|
计算人设树节点 N 在父节点 P 下的条件概率。
|
|
@@ -100,6 +128,8 @@ def calc_node_conditional_ratio(
|
|
|
derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点)
|
|
derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点)
|
|
|
tree_node_name: 人设树节点 N 的名称(字符串匹配)
|
|
tree_node_name: 人设树节点 N 的名称(字符串匹配)
|
|
|
base_dir: 可选,input 根目录;不传则使用相对本文件的 ../input
|
|
base_dir: 可选,input 根目录;不传则使用相对本文件的 ../input
|
|
|
|
|
+ node_post_index: 可选,由 build_node_post_index 预构建;批量对多节点计算时传入可避免重复读盘与遍历整棵树
|
|
|
|
|
+ target_ratio: 可选,目标条件概率。若某个组合的条件概率已达到该值,则直接返回(用于缩小组合搜索)
|
|
|
|
|
|
|
|
计算规则:
|
|
计算规则:
|
|
|
已推导的帖子集合:从 derived_list 中先取「最多选题点」的交集,再逐步减少到 1 个选题点,
|
|
已推导的帖子集合:从 derived_list 中先取「最多选题点」的交集,再逐步减少到 1 个选题点,
|
|
@@ -108,24 +138,47 @@ def calc_node_conditional_ratio(
|
|
|
分子 = |已推导的帖子集合 ∩ N 的 post_ids|,分母 = |已推导的帖子集合 ∩ P 的 post_ids|;
|
|
分子 = |已推导的帖子集合 ∩ N 的 post_ids|,分母 = |已推导的帖子集合 ∩ P 的 post_ids|;
|
|
|
条件概率 = 分子/分母,且 ≤1;分母为 0 时该情况跳过。
|
|
条件概率 = 分子/分母,且 ≤1;分母为 0 时该情况跳过。
|
|
|
"""
|
|
"""
|
|
|
- index = _build_node_index(account_name, base_dir)
|
|
|
|
|
|
|
+ index = node_post_index if node_post_index is not None else build_node_post_index(account_name, base_dir)
|
|
|
if tree_node_name not in index:
|
|
if tree_node_name not in index:
|
|
|
return 0.0
|
|
return 0.0
|
|
|
- n_pids, p_pids = index[tree_node_name]
|
|
|
|
|
- set_n = set(n_pids)
|
|
|
|
|
- set_p = set(p_pids)
|
|
|
|
|
|
|
+ set_n, set_p = index[tree_node_name]
|
|
|
|
|
+
|
|
|
|
|
+ # 关键优化(不改变搜索空间/结果):
|
|
|
|
|
+ # - derived_list 里重复的 source_node 对“交集”没有任何影响,但会把 L 变大导致 2^L 爆炸
|
|
|
|
|
+ # - 不在 index 里的 source_node 原本也会被跳过,提前过滤可减少组合规模
|
|
|
|
|
+ # - 组合内交集直接对 frozenset 逐步 &,避免 list(combo)/函数调用开销
|
|
|
|
|
+ seen_sources: set[str] = set()
|
|
|
|
|
+ source_sets: list[frozenset[str]] = []
|
|
|
|
|
+ for _topic, source_node in derived_list:
|
|
|
|
|
+ if source_node in seen_sources:
|
|
|
|
|
+ continue
|
|
|
|
|
+ seen_sources.add(source_node)
|
|
|
|
|
+ tup = index.get(source_node)
|
|
|
|
|
+ if tup is None:
|
|
|
|
|
+ continue
|
|
|
|
|
+ source_sets.append(tup[0])
|
|
|
|
|
+
|
|
|
|
|
+ if not source_sets:
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+
|
|
|
|
|
+ # 将更小的集合放前面:交集会更快“变小”,每次 & 的成本更低(仍然枚举全部子集)
|
|
|
|
|
+ source_sets.sort(key=len)
|
|
|
|
|
|
|
|
max_ratio = 0.0
|
|
max_ratio = 0.0
|
|
|
- # 从「最多选题点」到 1 个选题点:对每种子集大小,取所有组合,分别算条件概率后取最大
|
|
|
|
|
- for k in range(len(derived_list), 0, -1):
|
|
|
|
|
- for combo in itertools.combinations(derived_list, k):
|
|
|
|
|
- derived_post_ids = _derived_post_ids_from_sources(list(combo), index)
|
|
|
|
|
|
|
+ # 从 1 个选题点到「最多选题点」:对每种子集大小,取所有组合,分别算条件概率后取最大
|
|
|
|
|
+ for k in range(1, len(source_sets) + 1):
|
|
|
|
|
+ for combo_sets in itertools.combinations(source_sets, k):
|
|
|
|
|
+ derived_post_ids = combo_sets[0]
|
|
|
|
|
+ for s in combo_sets[1:]:
|
|
|
|
|
+ derived_post_ids = derived_post_ids & s
|
|
|
den = len(derived_post_ids & set_p)
|
|
den = len(derived_post_ids & set_p)
|
|
|
if den == 0:
|
|
if den == 0:
|
|
|
continue
|
|
continue
|
|
|
num = len(derived_post_ids & set_n)
|
|
num = len(derived_post_ids & set_n)
|
|
|
ratio = min(1.0, num / den)
|
|
ratio = min(1.0, num / den)
|
|
|
max_ratio = max(max_ratio, ratio)
|
|
max_ratio = max(max_ratio, ratio)
|
|
|
|
|
+ if target_ratio is not None and max_ratio >= target_ratio:
|
|
|
|
|
+ return round(max_ratio, 4)
|
|
|
return round(max_ratio, 4)
|
|
return round(max_ratio, 4)
|
|
|
|
|
|
|
|
|
|
|