|
@@ -17,8 +17,6 @@ import sys
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Set
|
|
from typing import Any, Dict, List, Optional, Tuple, Set
|
|
|
|
|
|
|
|
-from Cython.Includes.libc.stdio import printf
|
|
|
|
|
-
|
|
|
|
|
# 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
|
|
# 保证直接运行或作为包加载时都能解析 utils / tools(IDE 可跳转)
|
|
|
_root = Path(__file__).resolve().parent.parent
|
|
_root = Path(__file__).resolve().parent.parent
|
|
|
if str(_root) not in sys.path:
|
|
if str(_root) not in sys.path:
|
|
@@ -105,7 +103,15 @@ def _load_round_matched_points(
|
|
|
continue
|
|
continue
|
|
|
if not item.get("is_matched"):
|
|
if not item.get("is_matched"):
|
|
|
continue
|
|
continue
|
|
|
- mp = item.get("matched_post_point")
|
|
|
|
|
|
|
+
|
|
|
|
|
+ # 根据是否已完全推导,选择不同的帖子选题点字段:
|
|
|
|
|
+ # - is_fully_derived 为 False 时,使用 derivation_output_point
|
|
|
|
|
+ # - 其他情况(True 或缺失)使用 matched_post_point(兼容旧数据)
|
|
|
|
|
+ if item.get("is_fully_derived") is False:
|
|
|
|
|
+ mp = item.get("derivation_output_point")
|
|
|
|
|
+ else:
|
|
|
|
|
+ mp = item.get("matched_post_point")
|
|
|
|
|
+
|
|
|
if mp is None:
|
|
if mp is None:
|
|
|
continue
|
|
continue
|
|
|
mp = str(mp).strip()
|
|
mp = str(mp).strip()
|
|
@@ -261,13 +267,17 @@ def _score_patterns_by_matched_points(
|
|
|
best_post_point: Optional[str] = None
|
|
best_post_point: Optional[str] = None
|
|
|
if name:
|
|
if name:
|
|
|
for post_point in matched_post_points:
|
|
for post_point in matched_post_points:
|
|
|
- score = match_lookup.get((post_point, name))
|
|
|
|
|
- if score is None:
|
|
|
|
|
- continue
|
|
|
|
|
- try:
|
|
|
|
|
- s = float(score)
|
|
|
|
|
- except (TypeError, ValueError):
|
|
|
|
|
- continue
|
|
|
|
|
|
|
+ # 如果帖子选题点与节点名称完全一致,直接视为满分匹配
|
|
|
|
|
+ if post_point == name:
|
|
|
|
|
+ s = 1.0
|
|
|
|
|
+ else:
|
|
|
|
|
+ score = match_lookup.get((post_point, name))
|
|
|
|
|
+ if score is None:
|
|
|
|
|
+ continue
|
|
|
|
|
+ try:
|
|
|
|
|
+ s = float(score)
|
|
|
|
|
+ except (TypeError, ValueError):
|
|
|
|
|
+ continue
|
|
|
if s > best_score:
|
|
if s > best_score:
|
|
|
best_score = s
|
|
best_score = s
|
|
|
best_post_point = post_point
|
|
best_post_point = post_point
|
|
@@ -397,15 +407,17 @@ class TreeIndex:
|
|
|
def find_clusters(
|
|
def find_clusters(
|
|
|
self,
|
|
self,
|
|
|
elements: List[str],
|
|
elements: List[str],
|
|
|
- min_level: int,
|
|
|
|
|
|
|
+ cluster_level: int,
|
|
|
) -> List[Dict[str, Any]]:
|
|
) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
"""
|
|
|
在所有人设树中,为给定元素列表寻找聚类节点(不再要求 dimension 一致)。
|
|
在所有人设树中,为给定元素列表寻找聚类节点(不再要求 dimension 一致)。
|
|
|
|
|
|
|
|
- 规则:
|
|
|
|
|
- - 从 min_level 开始,自上而下扫描:
|
|
|
|
|
|
|
+ 规则(固定聚类层级 cluster_level):
|
|
|
|
|
+ - 仅在 depth == cluster_level 的节点上做聚类判断:
|
|
|
* 若某节点子树中包含的元素数量 >= 2,
|
|
* 若某节点子树中包含的元素数量 >= 2,
|
|
|
且在该路径上尚未存在更高层(深度更小)的聚类节点,则将其视为一个聚类节点。
|
|
且在该路径上尚未存在更高层(深度更小)的聚类节点,则将其视为一个聚类节点。
|
|
|
|
|
+ - 对无法向上形成聚类的元素,为其寻找 depth == cluster_level 的祖先节点,
|
|
|
|
|
+ 若存在则作为该元素的「单元素聚类」节点。
|
|
|
- 返回:
|
|
- 返回:
|
|
|
[
|
|
[
|
|
|
{
|
|
{
|
|
@@ -447,7 +459,7 @@ class TreeIndex:
|
|
|
for root_name in self.roots.values():
|
|
for root_name in self.roots.values():
|
|
|
dfs_count(root_name, set())
|
|
dfs_count(root_name, set())
|
|
|
|
|
|
|
|
- # 再自上而下优先选择「更上层」聚类节点(但不低于 min_level):
|
|
|
|
|
|
|
+ # 再自上而下优先选择「更上层」聚类节点(但仅在 cluster_level 层):
|
|
|
# - 若当前节点已作为聚类节点,则其子孙不再作为聚类节点(保证尽量向上聚类);
|
|
# - 若当前节点已作为聚类节点,则其子孙不再作为聚类节点(保证尽量向上聚类);
|
|
|
# 同样需要防止意外的环导致递归过深,这里使用 visited 集合。
|
|
# 同样需要防止意外的环导致递归过深,这里使用 visited 集合。
|
|
|
clusters: Set[str] = set()
|
|
clusters: Set[str] = set()
|
|
@@ -461,8 +473,8 @@ class TreeIndex:
|
|
|
cnt = subtree_count.get(node, 0)
|
|
cnt = subtree_count.get(node, 0)
|
|
|
|
|
|
|
|
selected_here = False
|
|
selected_here = False
|
|
|
- # 仅当祖先尚未被选中、且当前节点满足条件时,选当前节点为聚类节点
|
|
|
|
|
- if (not ancestor_selected) and depth >= min_level and cnt >= 2:
|
|
|
|
|
|
|
+ # 仅当祖先尚未被选中、当前节点位于 cluster_level 层且满足条件时,选当前节点为聚类节点
|
|
|
|
|
+ if (not ancestor_selected) and depth == cluster_level and cnt >= 2:
|
|
|
clusters.add(node)
|
|
clusters.add(node)
|
|
|
selected_here = True
|
|
selected_here = True
|
|
|
|
|
|
|
@@ -492,6 +504,9 @@ class TreeIndex:
|
|
|
cur = parent
|
|
cur = parent
|
|
|
|
|
|
|
|
out: List[Dict[str, Any]] = []
|
|
out: List[Dict[str, Any]] = []
|
|
|
|
|
+ # 1)多元素聚类:仅统计真正输出的聚类节点所覆盖的元素,
|
|
|
|
|
+ # 避免把「元素数不足 2 的节点」也算作已覆盖,从而导致元素丢失。
|
|
|
|
|
+ covered_elems: Set[str] = set()
|
|
|
for node in clusters:
|
|
for node in clusters:
|
|
|
elems = sorted(cluster_to_elements.get(node) or [])
|
|
elems = sorted(cluster_to_elements.get(node) or [])
|
|
|
if len(elems) < 2:
|
|
if len(elems) < 2:
|
|
@@ -503,15 +518,19 @@ class TreeIndex:
|
|
|
"from_elements": elems,
|
|
"from_elements": elems,
|
|
|
}
|
|
}
|
|
|
)
|
|
)
|
|
|
|
|
+ for e in elems:
|
|
|
|
|
+ covered_elems.add(e)
|
|
|
|
|
|
|
|
# 2)对无法向上形成聚类的元素,给一个「单元素聚类」
|
|
# 2)对无法向上形成聚类的元素,给一个「单元素聚类」
|
|
|
- covered_elems: Set[str] = set()
|
|
|
|
|
- for elems in cluster_to_elements.values():
|
|
|
|
|
- covered_elems.update(elems)
|
|
|
|
|
uncovered = elem_set - covered_elems
|
|
uncovered = elem_set - covered_elems
|
|
|
|
|
|
|
|
|
|
+ # 将未覆盖元素按「cluster_level 层级的祖先节点」分组,确保同一个祖先节点下的
|
|
|
|
|
+ # 多个元素合并为一个聚类,而不是多个单元素聚类。
|
|
|
|
|
+ single_clusters: Dict[str, Set[str]] = {}
|
|
|
|
|
+
|
|
|
for e in uncovered:
|
|
for e in uncovered:
|
|
|
- # 单元素聚类时,cluster_node 应为「祖先节点」,不直接使用元素自身
|
|
|
|
|
|
|
+ # 单元素聚类时,cluster_node 应为「祖先节点」,不直接使用元素自身。
|
|
|
|
|
+ # 这里固定选择 depth == cluster_level 的祖先节点。
|
|
|
info_e = self.node_info.get(e) or {}
|
|
info_e = self.node_info.get(e) or {}
|
|
|
parent = info_e.get("parent")
|
|
parent = info_e.get("parent")
|
|
|
cur = parent
|
|
cur = parent
|
|
@@ -521,22 +540,23 @@ class TreeIndex:
|
|
|
visited_chain.add(cur)
|
|
visited_chain.add(cur)
|
|
|
info = self.node_info.get(cur) or {}
|
|
info = self.node_info.get(cur) or {}
|
|
|
depth = info.get("depth", 0) or 0
|
|
depth = info.get("depth", 0) or 0
|
|
|
- if depth >= min_level:
|
|
|
|
|
- # 遍历方向为从元素向上(depth 单调递减),
|
|
|
|
|
- # 每次满足 >= min_level 就更新,循环结束后得到的是
|
|
|
|
|
- # depth 最小(离元素最远)且不低于 min_level 的祖先节点。
|
|
|
|
|
|
|
+ if depth == cluster_level:
|
|
|
best_ancestor = cur
|
|
best_ancestor = cur
|
|
|
|
|
+ break
|
|
|
parent = info.get("parent")
|
|
parent = info.get("parent")
|
|
|
if parent is None:
|
|
if parent is None:
|
|
|
break
|
|
break
|
|
|
cur = parent
|
|
cur = parent
|
|
|
if best_ancestor:
|
|
if best_ancestor:
|
|
|
- out.append(
|
|
|
|
|
- {
|
|
|
|
|
- "cluster_node": best_ancestor,
|
|
|
|
|
- "from_elements": [e],
|
|
|
|
|
- }
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ single_clusters.setdefault(best_ancestor, set()).add(e)
|
|
|
|
|
+
|
|
|
|
|
+ for anc, elems in single_clusters.items():
|
|
|
|
|
+ out.append(
|
|
|
|
|
+ {
|
|
|
|
|
+ "cluster_node": anc,
|
|
|
|
|
+ "from_elements": sorted(elems),
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
# 为了输出更稳定,按 from_elements 的元素数量从大到小排序,数量相同再按节点名排序
|
|
# 为了输出更稳定,按 from_elements 的元素数量从大到小排序,数量相同再按节点名排序
|
|
|
out.sort(key=lambda x: (-len(x["from_elements"]), x["cluster_node"]))
|
|
out.sort(key=lambda x: (-len(x["from_elements"]), x["cluster_node"]))
|
|
@@ -554,7 +574,7 @@ def _analyze_single_round(
|
|
|
tree_index: TreeIndex,
|
|
tree_index: TreeIndex,
|
|
|
cumulative_points: List[str],
|
|
cumulative_points: List[str],
|
|
|
match_threshold: float,
|
|
match_threshold: float,
|
|
|
- min_level: int,
|
|
|
|
|
|
|
+ cluster_level: int,
|
|
|
) -> Dict[str, Any]:
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
"""
|
|
|
对某一轮(给定累计 matched_post_point 列表)执行分析:
|
|
对某一轮(给定累计 matched_post_point 列表)执行分析:
|
|
@@ -569,7 +589,7 @@ def _analyze_single_round(
|
|
|
matched_post_points=cumulative_points,
|
|
matched_post_points=cumulative_points,
|
|
|
match_threshold=match_threshold,
|
|
match_threshold=match_threshold,
|
|
|
)
|
|
)
|
|
|
- printf(f"_score_patterns_by_matched_points len: {len(patterns)}")
|
|
|
|
|
|
|
+ print(f"_score_patterns_by_matched_points len: {len(patterns)}")
|
|
|
|
|
|
|
|
# 已推导 / 未推导 元素列表(不再按维度拆分)
|
|
# 已推导 / 未推导 元素列表(不再按维度拆分)
|
|
|
derived_elems: List[str] = []
|
|
derived_elems: List[str] = []
|
|
@@ -599,14 +619,33 @@ def _analyze_single_round(
|
|
|
|
|
|
|
|
# 已推导元素聚类
|
|
# 已推导元素聚类
|
|
|
if derived_set:
|
|
if derived_set:
|
|
|
- c = tree_index.find_clusters(derived_set, min_level=min_level)
|
|
|
|
|
|
|
+ c = tree_index.find_clusters(derived_set, cluster_level=cluster_level)
|
|
|
clusters["derived"] = c or []
|
|
clusters["derived"] = c or []
|
|
|
|
|
|
|
|
# 未推导元素聚类
|
|
# 未推导元素聚类
|
|
|
if underived_set:
|
|
if underived_set:
|
|
|
- c = tree_index.find_clusters(underived_set, min_level=min_level)
|
|
|
|
|
|
|
+ c = tree_index.find_clusters(underived_set, cluster_level=cluster_level)
|
|
|
clusters["underived"] = c or []
|
|
clusters["underived"] = c or []
|
|
|
|
|
|
|
|
|
|
+ # 在同一轮中,如果某个 cluster_node 已经在 derived 聚类里出现过,
|
|
|
|
|
+ # 则从 underived 聚类中剔除该 cluster_node,避免重复展示。
|
|
|
|
|
+ if isinstance(clusters.get("derived"), list) and isinstance(clusters.get("underived"), list):
|
|
|
|
|
+ derived_nodes = {
|
|
|
|
|
+ str(item.get("cluster_node"))
|
|
|
|
|
+ for item in clusters["derived"]
|
|
|
|
|
+ if isinstance(item, dict) and item.get("cluster_node") is not None
|
|
|
|
|
+ }
|
|
|
|
|
+ if derived_nodes:
|
|
|
|
|
+ filtered_underived = []
|
|
|
|
|
+ for item in clusters["underived"]:
|
|
|
|
|
+ if not isinstance(item, dict):
|
|
|
|
|
+ continue
|
|
|
|
|
+ node = str(item.get("cluster_node"))
|
|
|
|
|
+ if node in derived_nodes:
|
|
|
|
|
+ continue
|
|
|
|
|
+ filtered_underived.append(item)
|
|
|
|
|
+ clusters["underived"] = filtered_underived
|
|
|
|
|
+
|
|
|
return {
|
|
return {
|
|
|
"matched_post_points": list(cumulative_points),
|
|
"matched_post_points": list(cumulative_points),
|
|
|
"patterns": patterns,
|
|
"patterns": patterns,
|
|
@@ -627,7 +666,7 @@ def pattern_dimension_analyze(
|
|
|
post_id: str,
|
|
post_id: str,
|
|
|
log_id: str,
|
|
log_id: str,
|
|
|
match_threshold: float = 0.6,
|
|
match_threshold: float = 0.6,
|
|
|
- min_level: int = 2,
|
|
|
|
|
|
|
+ cluster_level: int = 2,
|
|
|
) -> Dict[str, Any]:
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
"""
|
|
|
Pattern 维度分析主入口。
|
|
Pattern 维度分析主入口。
|
|
@@ -638,7 +677,7 @@ def pattern_dimension_analyze(
|
|
|
post_id : 帖子 ID(用于定位推导日志与帖子匹配数据)
|
|
post_id : 帖子 ID(用于定位推导日志与帖子匹配数据)
|
|
|
log_id : 推导日志目录名(../output/{account_name}/推导日志/{post_id}/{log_id}/)
|
|
log_id : 推导日志目录名(../output/{account_name}/推导日志/{post_id}/{log_id}/)
|
|
|
match_threshold : pattern 元素与 matched_post_point 的最小匹配分,默认 0.6
|
|
match_threshold : pattern 元素与 matched_post_point 的最小匹配分,默认 0.6
|
|
|
- min_level : 在人设树中搜索聚类节点时的最小层级(root 为 0 层),默认 2
|
|
|
|
|
|
|
+ cluster_level : 在人设树中搜索聚类节点的聚类层级(root 为 0 层),默认 2
|
|
|
|
|
|
|
|
"""
|
|
"""
|
|
|
eval_dir = _round_eval_dir(account_name, post_id, log_id)
|
|
eval_dir = _round_eval_dir(account_name, post_id, log_id)
|
|
@@ -652,7 +691,7 @@ def pattern_dimension_analyze(
|
|
|
"post_id": post_id,
|
|
"post_id": post_id,
|
|
|
"log_id": log_id,
|
|
"log_id": log_id,
|
|
|
"match_threshold": match_threshold,
|
|
"match_threshold": match_threshold,
|
|
|
- "min_level": min_level,
|
|
|
|
|
|
|
+ "cluster_level": cluster_level,
|
|
|
"rounds": [],
|
|
"rounds": [],
|
|
|
"message": "未在指定日志目录下找到任何评估结果文件(*_评估.json)",
|
|
"message": "未在指定日志目录下找到任何评估结果文件(*_评估.json)",
|
|
|
}
|
|
}
|
|
@@ -661,7 +700,7 @@ def pattern_dimension_analyze(
|
|
|
# pattern 库只在整体分析时读取 & 去重一次,避免每一轮重复 IO 与解析
|
|
# pattern 库只在整体分析时读取 & 去重一次,避免每一轮重复 IO 与解析
|
|
|
raw_patterns = _load_raw_patterns(account_name)
|
|
raw_patterns = _load_raw_patterns(account_name)
|
|
|
deduped_patterns = _dedupe_patterns(raw_patterns)
|
|
deduped_patterns = _dedupe_patterns(raw_patterns)
|
|
|
- printf(f"deduped_patterns len: {len(deduped_patterns)}")
|
|
|
|
|
|
|
+ print(f"deduped_patterns len: {len(deduped_patterns)}")
|
|
|
|
|
|
|
|
rounds_output: List[Dict[str, Any]] = []
|
|
rounds_output: List[Dict[str, Any]] = []
|
|
|
for info in round_infos:
|
|
for info in round_infos:
|
|
@@ -674,7 +713,7 @@ def pattern_dimension_analyze(
|
|
|
tree_index=tree_index,
|
|
tree_index=tree_index,
|
|
|
cumulative_points=cumulative_points,
|
|
cumulative_points=cumulative_points,
|
|
|
match_threshold=match_threshold,
|
|
match_threshold=match_threshold,
|
|
|
- min_level=min_level,
|
|
|
|
|
|
|
+ cluster_level=cluster_level,
|
|
|
)
|
|
)
|
|
|
analyzed["round"] = r
|
|
analyzed["round"] = r
|
|
|
rounds_output.append(analyzed)
|
|
rounds_output.append(analyzed)
|
|
@@ -684,7 +723,7 @@ def pattern_dimension_analyze(
|
|
|
"post_id": post_id,
|
|
"post_id": post_id,
|
|
|
"log_id": log_id,
|
|
"log_id": log_id,
|
|
|
"match_threshold": match_threshold,
|
|
"match_threshold": match_threshold,
|
|
|
- "min_level": min_level,
|
|
|
|
|
|
|
+ "cluster_level": cluster_level,
|
|
|
"rounds": rounds_output,
|
|
"rounds": rounds_output,
|
|
|
}
|
|
}
|
|
|
return result
|
|
return result
|
|
@@ -702,7 +741,7 @@ def main() -> None:
|
|
|
post_id=post_id,
|
|
post_id=post_id,
|
|
|
log_id=log_id,
|
|
log_id=log_id,
|
|
|
match_threshold=0.5,
|
|
match_threshold=0.5,
|
|
|
- min_level=2,
|
|
|
|
|
|
|
+ cluster_level=2,
|
|
|
)
|
|
)
|
|
|
# 控制台打印前 4000 字符,便于快速查看
|
|
# 控制台打印前 4000 字符,便于快速查看
|
|
|
# print(json.dumps(result, ensure_ascii=False, indent=2)[:4000] + "...")
|
|
# print(json.dumps(result, ensure_ascii=False, indent=2)[:4000] + "...")
|
|
@@ -710,7 +749,8 @@ def main() -> None:
|
|
|
# 写入输出文件:../output/{account_name}/推导日志/{post_id}/{log_id}/pattern_dimension_analyze.json
|
|
# 写入输出文件:../output/{account_name}/推导日志/{post_id}/{log_id}/pattern_dimension_analyze.json
|
|
|
out_dir = _round_eval_dir(account_name, post_id, log_id)
|
|
out_dir = _round_eval_dir(account_name, post_id, log_id)
|
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
- out_path = out_dir / "pattern_dimension_analyze.json"
|
|
|
|
|
|
|
+ output_file_name = f"{post_id}_pattern_dimension_analyze.json"
|
|
|
|
|
+ out_path = out_dir / output_file_name
|
|
|
with open(out_path, "w", encoding="utf-8") as f:
|
|
with open(out_path, "w", encoding="utf-8") as f:
|
|
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
|
|
print(f"\n分析结果已写入: {out_path}")
|
|
print(f"\n分析结果已写入: {out_path}")
|