Pārlūkot izejas kodu

how agent 条件概率计算性能优化

liuzhiheng 1 mēnesi atpakaļ
vecāks
revīzija
f274c437fc

+ 11 - 2
examples_how/overall_derivation/tools/find_tree_node.py

@@ -15,7 +15,10 @@ from typing import Any, Optional
 _root = Path(__file__).resolve().parent.parent
 if str(_root) not in sys.path:
     sys.path.insert(0, str(_root))
-from utils.conditional_ratio_calc import calc_node_conditional_ratio  # noqa: E402
+from utils.conditional_ratio_calc import (  # noqa: E402
+    build_node_post_index,
+    calc_node_conditional_ratio,
+)
 from tools.point_match import match_derivation_to_post_points  # noqa: E402
 
 try:
@@ -216,11 +219,17 @@ def get_nodes_by_conditional_ratio(
             if ratio >= threshold:
                 scored.append((node_name, ratio, parent_name, dim_for(node_name)))
     else:
+        node_post_index = build_node_post_index(account_name, base_dir)
         for node_name, parent_name in node_to_parent.items():
             if allowed_node_names is not None and node_name not in allowed_node_names:
                 continue
             ratio = calc_node_conditional_ratio(
-                account_name, derived_list, node_name, base_dir=base_dir
+                account_name,
+                derived_list,
+                node_name,
+                base_dir=base_dir,
+                node_post_index=node_post_index,
+                target_ratio=threshold,
             )
             if ratio >= threshold:
                 scored.append((node_name, ratio, parent_name, dim_for(node_name)))

+ 61 - 8
examples_how/overall_derivation/utils/conditional_ratio_calc.py

@@ -11,6 +11,9 @@ import json
 from pathlib import Path
 from typing import Any
 
+# 节点名 -> (该节点 post_ids, 父节点 post_ids),用 frozenset 便于批量计算时复用、避免重复转换
+NodePostIndex = dict[str, tuple[frozenset[str], frozenset[str]]]
+
 # 已推导列表:每项为 (已推导的选题点, 推导来源人设树节点),如 ("分享","分享")、("柴犬","动物角色")
 # 推导来源人设树节点的 post_ids 在计算条件概率时从人设树中读取
 DerivedItem = tuple[str, str]
@@ -86,11 +89,36 @@ def _derived_post_ids_from_sources(
     return common if common is not None else set()
 
 
+def _derived_post_ids_from_frozen_index(
+    derived_list: list[DerivedItem],
+    index: NodePostIndex,
+) -> frozenset[str]:
+    """与 _derived_post_ids_from_sources 相同语义,索引为 frozenset 版(批量场景复用)。"""
+    common: frozenset[str] | None = None
+    for _topic_point, source_node in derived_list:
+        if source_node not in index:
+            continue
+        pids = index[source_node][0]
+        common = pids if common is None else common & pids
+    return common if common is not None else frozenset()
+
+
+def build_node_post_index(account_name: str, base_dir: Path | None = None) -> NodePostIndex:
+    """
+    构建账号人设树的节点索引(每个节点只建一次,供批量 calc_node_conditional_ratio 复用)。
+    值为 (节点 post_ids, 父节点 post_ids) 的 frozenset,减少重复 list->set 与拷贝。
+    """
+    raw = _build_node_index(account_name, base_dir)
+    return {k: (frozenset(a), frozenset(b)) for k, (a, b) in raw.items()}
+
+
 def calc_node_conditional_ratio(
     account_name: str,
     derived_list: list[DerivedItem],
     tree_node_name: str,
     base_dir: Path | None = None,
+    node_post_index: NodePostIndex | None = None,
+    target_ratio: float | None = None,
 ) -> float:
     """
     计算人设树节点 N 在父节点 P 下的条件概率。
@@ -100,6 +128,8 @@ def calc_node_conditional_ratio(
         derived_list: 已推导列表,每项 (已推导的选题点, 推导来源人设树节点)
         tree_node_name: 人设树节点 N 的名称(字符串匹配)
         base_dir: 可选,input 根目录;不传则使用相对本文件的 ../input
+        node_post_index: 可选,由 build_node_post_index 预构建;批量对多节点计算时传入可避免重复读盘与遍历整棵树
+        target_ratio: 可选,目标条件概率。若某个组合的条件概率已达到该值,则直接返回(用于缩小组合搜索)
 
     计算规则:
         已推导的帖子集合:从 derived_list 中先取「最多选题点」的交集,再逐步减少到 1 个选题点,
@@ -108,24 +138,47 @@ def calc_node_conditional_ratio(
         分子 = |已推导的帖子集合 ∩ N 的 post_ids|,分母 = |已推导的帖子集合 ∩ P 的 post_ids|;
         条件概率 = 分子/分母,且 ≤1;分母为 0 时该情况跳过。
     """
-    index = _build_node_index(account_name, base_dir)
+    index = node_post_index if node_post_index is not None else build_node_post_index(account_name, base_dir)
     if tree_node_name not in index:
         return 0.0
-    n_pids, p_pids = index[tree_node_name]
-    set_n = set(n_pids)
-    set_p = set(p_pids)
+    set_n, set_p = index[tree_node_name]
+
+    # 关键优化(不改变搜索空间/结果):
+    # - derived_list 里重复的 source_node 对“交集”没有任何影响,但会把 L 变大导致 2^L 爆炸
+    # - 不在 index 里的 source_node 原本也会被跳过,提前过滤可减少组合规模
+    # - 组合内交集直接对 frozenset 逐步 &,避免 list(combo)/函数调用开销
+    seen_sources: set[str] = set()
+    source_sets: list[frozenset[str]] = []
+    for _topic, source_node in derived_list:
+        if source_node in seen_sources:
+            continue
+        seen_sources.add(source_node)
+        tup = index.get(source_node)
+        if tup is None:
+            continue
+        source_sets.append(tup[0])
+
+    if not source_sets:
+        return 0.0
+
+    # 将更小的集合放前面:交集会更快“变小”,每次 & 的成本更低(仍然枚举全部子集)
+    source_sets.sort(key=len)
 
     max_ratio = 0.0
-    # 从「最多选题点」到 1 个选题点:对每种子集大小,取所有组合,分别算条件概率后取最大
-    for k in range(len(derived_list), 0, -1):
-        for combo in itertools.combinations(derived_list, k):
-            derived_post_ids = _derived_post_ids_from_sources(list(combo), index)
+    # 从 1 个选题点到「最多选题点」:对每种子集大小,取所有组合,分别算条件概率后取最大
+    for k in range(1, len(source_sets) + 1):
+        for combo_sets in itertools.combinations(source_sets, k):
+            derived_post_ids = combo_sets[0]
+            for s in combo_sets[1:]:
+                derived_post_ids = derived_post_ids & s
             den = len(derived_post_ids & set_p)
             if den == 0:
                 continue
             num = len(derived_post_ids & set_n)
             ratio = min(1.0, num / den)
             max_ratio = max(max_ratio, ratio)
+            if target_ratio is not None and max_ratio >= target_ratio:
+                return round(max_ratio, 4)
     return round(max_ratio, 4)