|
|
@@ -120,8 +120,11 @@ def calc_node_conditional_ratio(
|
|
|
return min(1.0, num / den)
|
|
|
|
|
|
|
|
|
-def _pattern_nodes_and_post_count(pattern: dict[str, Any]) -> tuple[list[str], int]:
|
|
|
- """从 pattern 中解析出节点列表和 post_count。支持 nodes + post_count 或 i + post_count。"""
|
|
|
+def _pattern_nodes_and_post_count(pattern: dict[str, Any]) -> tuple[list[str], int, float]:
|
|
|
+ """
|
|
|
+ 从 pattern 中解析出节点列表和 post_count。支持 nodes + post_count 或 i + post_count。
|
|
|
+ 返回的 post_count 表示该 pattern 本身的帖子数,在条件概率计算中作为分子(即 pattern 本身的概率/占比的分子)。
|
|
|
+ """
|
|
|
nodes = pattern.get("nodes")
|
|
|
if nodes is not None and isinstance(nodes, list):
|
|
|
nodes = [str(x).strip() for x in nodes if x]
|
|
|
@@ -129,7 +132,8 @@ def _pattern_nodes_and_post_count(pattern: dict[str, Any]) -> tuple[list[str], i
|
|
|
raw = pattern.get("i") or pattern.get("pattern_str") or ""
|
|
|
nodes = [x.strip() for x in str(raw).replace("+", " ").split() if x.strip()]
|
|
|
post_count = int(pattern.get("post_count", 0))
|
|
|
- return nodes, post_count
|
|
|
+ support = pattern.get("s", 0.0)
|
|
|
+ return nodes, post_count, support
|
|
|
|
|
|
|
|
|
def calc_pattern_conditional_ratio(
|
|
|
@@ -152,24 +156,24 @@ def calc_pattern_conditional_ratio(
|
|
|
计算规则:
|
|
|
取 pattern 中「已被推导」的节点(其名称出现在 derived 的推导来源中),
|
|
|
在人设树中取这些节点的 post_ids 的交集作为分母;
|
|
|
- 分子 = pattern.post_count。
|
|
|
+ 分子 = pattern.post_count(由 _pattern_nodes_and_post_count 解析得到,表示 pattern 本身的帖子数)。
|
|
|
条件概率 = 分子/分母,且 ≤1;分母为 0 时返回 1。
|
|
|
"""
|
|
|
- pattern_nodes, post_count = _pattern_nodes_and_post_count(pattern)
|
|
|
+ pattern_nodes, post_count, pattern_s = _pattern_nodes_and_post_count(pattern)
|
|
|
if not pattern_nodes or post_count <= 0:
|
|
|
- return 1.0
|
|
|
+ return pattern_s
|
|
|
|
|
|
derived_sources = set(source for _post, source in derived_list)
|
|
|
# pattern 中已被推导的节点
|
|
|
derived_pattern_nodes = [n for n in pattern_nodes if n in derived_sources]
|
|
|
if not derived_pattern_nodes:
|
|
|
- return 1.0
|
|
|
+ return pattern_s
|
|
|
|
|
|
index = _build_node_index(account_name, base_dir)
|
|
|
# 仅使用在人设树中存在的「已被推导」节点,取它们在树中的 post_ids 的交集
|
|
|
derived_in_tree = [n for n in derived_pattern_nodes if n in index]
|
|
|
if not derived_in_tree:
|
|
|
- return 1.0
|
|
|
+ return pattern_s
|
|
|
common: set[str] | None = None
|
|
|
for name in derived_in_tree:
|
|
|
pids = set(index[name][0])
|
|
|
@@ -178,8 +182,9 @@ def calc_pattern_conditional_ratio(
|
|
|
else:
|
|
|
common &= pids
|
|
|
if common is None or len(common) == 0:
|
|
|
- return 1.0
|
|
|
+ return pattern_s
|
|
|
den = len(common)
|
|
|
+ # 分子为 pattern 本身的帖子数(post_count),分母为条件集合大小
|
|
|
return min(1.0, post_count / den)
|
|
|
|
|
|
|
|
|
@@ -192,7 +197,7 @@ def _test_with_user_example() -> None:
|
|
|
account_name = "家有大志"
|
|
|
# 已推导列表:(已推导的选题点, 推导来源人设树节点)
|
|
|
derived_list: list[DerivedItem] = [
|
|
|
- ("分享", "分享"),
|
|
|
+ # ("分享", "分享"),
|
|
|
# ("柴犬", "动物角色"),
|
|
|
]
|
|
|
|
|
|
@@ -201,7 +206,7 @@ def _test_with_user_example() -> None:
|
|
|
print(f"1) 人设树节点「恶作剧」条件概率: {r_node}")
|
|
|
|
|
|
# 2)pattern 分享+动物角色+创意表达 post_count=2 的条件概率
|
|
|
- pattern = {"i": "分享+动物角色+创意表达", "post_count": 2}
|
|
|
+ pattern = {"i": "分享+动物角色+创意表达", "post_count": 2, "s": 0.3}
|
|
|
r_pattern = calc_pattern_conditional_ratio(account_name, derived_list, pattern)
|
|
|
print(f"2) pattern 分享+动物角色+创意表达 (post_count=2) 条件概率: {r_pattern}")
|
|
|
|