generate_visualize_data.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852
  1. #!/usr/bin/env python3
  2. """
  3. 生成推导可视化数据。
  4. 输入参数:account_name, post_id, log_id
  5. - 从 input/{account_name}/解构内容/{post_id}.json 解析选题点列表
  6. - 从 output/{account_name}/推导日志/{post_id}/{log_id}/ 读取推导与评估 JSON,生成:
  7. 1. output/{account_name}/整体推导结果/{post_id}.json
  8. 2. output/{account_name}/整体推导路径可视化/{post_id}.json
  9. """
  10. import argparse
  11. import json
  12. import re
  13. from pathlib import Path
  14. from typing import Any, Dict, List, Optional
  15. def _walk_tree_children_for_persona(
  16. children: Any, persona_by_name: Dict[str, Dict[str, Any]]
  17. ) -> None:
  18. """递归遍历人设树 children,按节点名(与 input_tree_nodes 短名一致)登记 type / 常量标记。"""
  19. if not isinstance(children, dict):
  20. return
  21. for name, node in children.items():
  22. if not isinstance(node, dict):
  23. continue
  24. if name not in persona_by_name:
  25. persona_by_name[name] = {
  26. "name": name,
  27. "type": node.get("_type"),
  28. "is_constant": bool(node.get("_is_constant", False)),
  29. "is_local_constant": bool(node.get("_is_local_constant", False)),
  30. }
  31. sub = node.get("children")
  32. if isinstance(sub, dict):
  33. _walk_tree_children_for_persona(sub, persona_by_name)
  34. def build_persona_by_name_from_tree_dir(tree_dir: Path) -> Dict[str, Dict[str, Any]]:
  35. """
  36. 从 input/{account}/处理后数据/tree 下所有人设树 JSON(如 *_point_tree_how.json)构建 name -> 人设节点信息。
  37. 同名节点以首次出现为准,与 process_pipeline_tree_data.build_persona_by_name 用法一致。
  38. """
  39. persona_by_name: Dict[str, Dict[str, Any]] = {}
  40. if not tree_dir.is_dir():
  41. return persona_by_name
  42. for path in sorted(tree_dir.glob("*_point_tree_how.json")):
  43. with open(path, "r", encoding="utf-8") as f:
  44. data = json.load(f)
  45. if not isinstance(data, dict):
  46. continue
  47. for _dim, root in data.items():
  48. if not isinstance(root, dict):
  49. continue
  50. ch = root.get("children")
  51. _walk_tree_children_for_persona(ch, persona_by_name)
  52. return persona_by_name
  53. def _node_obj_for_used_tree(
  54. name: str,
  55. node: Optional[Dict[str, Any]],
  56. persona: Optional[Dict[str, Any]],
  57. ) -> Dict[str, Any]:
  58. """与 process_pipeline_tree_data._node_obj 一致:合并人设与 edge 上节点字段。"""
  59. type_val = None
  60. is_constant = False
  61. is_local_constant = False
  62. if persona is not None:
  63. type_val = persona.get("type")
  64. if "is_constant" in persona:
  65. is_constant = bool(persona["is_constant"])
  66. if "is_local_constant" in persona:
  67. is_local_constant = bool(persona["is_local_constant"])
  68. if node is not None:
  69. t = node.get("type")
  70. if t is not None and len(t) > 0:
  71. type_val = t
  72. if "is_constant" in node:
  73. is_constant = bool(node["is_constant"])
  74. if "is_local_constant" in node:
  75. is_local_constant = bool(node["is_local_constant"])
  76. return {
  77. "name": name,
  78. "type": type_val,
  79. "is_constant": is_constant,
  80. "is_local_constant": is_local_constant,
  81. }
  82. def _dedup_node_objs(nodes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
  83. seen = set()
  84. out = []
  85. for n in nodes:
  86. key = (n["name"], n.get("type"), n["is_constant"], n["is_local_constant"])
  87. if key not in seen:
  88. seen.add(key)
  89. out.append(n)
  90. return out
  91. def extract_used_tree_nodes_from_edge(
  92. edge: Dict[str, Any],
  93. persona_by_name: Dict[str, Dict[str, Any]],
  94. ) -> List[Dict[str, Any]]:
  95. """与 process_pipeline_tree_data.extract_used_tree_nodes_from_edge 一致。"""
  96. used: List[Dict[str, Any]] = []
  97. for node in edge.get("input_tree_nodes") or []:
  98. name = node.get("name")
  99. if name is None or name == "":
  100. continue
  101. persona = persona_by_name.get(name)
  102. used.append(_node_obj_for_used_tree(name, node, persona))
  103. for pn in edge.get("input_pattern_nodes") or []:
  104. for item in pn.get("match_items") or []:
  105. if item is None or item == "":
  106. continue
  107. persona = persona_by_name.get(item)
  108. used.append(_node_obj_for_used_tree(item, None, persona))
  109. return _dedup_node_objs(used)
  110. def enrich_visualize_with_used_tree_nodes(
  111. data: Dict[str, Any],
  112. persona_by_name: Dict[str, Dict[str, Any]],
  113. ) -> Dict[str, Any]:
  114. """
  115. 为 edge_list / fail_edge_list 每条 edge 增加 used_tree_nodes;
  116. 顶层 all_used_tree_nodes 聚合 edge_list 的 tree nodes,
  117. fail_all_used_tree_nodes 单独聚合 fail_edge_list 的 tree nodes。
  118. """
  119. for list_key, agg_key in (
  120. ("edge_list", "all_used_tree_nodes"),
  121. ("fail_edge_list", "fail_all_used_tree_nodes"),
  122. ):
  123. edge_list = data.get(list_key)
  124. if not edge_list:
  125. data[agg_key] = []
  126. continue
  127. all_used: List[Dict[str, Any]] = []
  128. for edge in edge_list:
  129. used = extract_used_tree_nodes_from_edge(edge, persona_by_name)
  130. edge["used_tree_nodes"] = used
  131. all_used.extend(used)
  132. data[agg_key] = _dedup_node_objs(all_used)
  133. return data
  134. def _collect_dimension_names(point_data: dict) -> dict[str, str]:
  135. """从点的 实质/形式/意图 中收集 名称 -> dimension。"""
  136. name_to_dim = {}
  137. if "实质" in point_data and point_data["实质"]:
  138. for key in ("具体元素", "具象概念", "抽象概念"):
  139. for item in (point_data["实质"].get(key) or []):
  140. n = item.get("名称")
  141. if n:
  142. name_to_dim[n] = "实质"
  143. if "形式" in point_data and point_data["形式"]:
  144. for key in ("具体元素形式", "具象概念形式", "整体形式"):
  145. for item in (point_data["形式"].get(key) or []):
  146. n = item.get("名称")
  147. if n:
  148. name_to_dim[n] = "形式"
  149. if point_data.get("意图"):
  150. for item in point_data["意图"]:
  151. n = item.get("名称")
  152. if n:
  153. name_to_dim[n] = "意图"
  154. return name_to_dim
  155. def parse_topic_points_from_deconstruct(deconstruct_path: Path) -> list[dict[str, Any]]:
  156. """
  157. 从 input/{account_name}/解构内容/{post_id}.json 解析选题点列表。
  158. - 新格式(Agent):灵感点/目的点/关键点 下为「选题点」「选题点元素」(元素名称、元素类型)。
  159. - 旧格式:「点」「分词结果」中的「词」等。
  160. 输出字段:name, point, dimension, root_source, root_sources_desc。
  161. """
  162. if not deconstruct_path.exists():
  163. raise FileNotFoundError(f"解构内容文件不存在: {deconstruct_path}")
  164. with open(deconstruct_path, "r", encoding="utf-8") as f:
  165. data = json.load(f)
  166. result_agent: list[dict[str, Any]] = []
  167. for point_type in ("灵感点", "目的点", "关键点"):
  168. for point in data.get(point_type) or []:
  169. if not isinstance(point, dict):
  170. continue
  171. root_source = (point.get("选题点") or point.get("点") or "").strip()
  172. root_sources_desc = point.get("选题点描述") or point.get("点描述") or ""
  173. for el in point.get("选题点元素") or []:
  174. if not isinstance(el, dict):
  175. continue
  176. name = (el.get("元素名称") or "").strip()
  177. if not name:
  178. continue
  179. et = el.get("元素类型") or "实质"
  180. if et not in ("实质", "形式", "意图"):
  181. et = "实质"
  182. result_agent.append(
  183. {
  184. "name": name,
  185. "point": point_type,
  186. "dimension": et,
  187. "root_source": root_source,
  188. "root_sources_desc": root_sources_desc,
  189. }
  190. )
  191. if result_agent:
  192. return result_agent
  193. result = []
  194. for point_type in ("灵感点", "目的点", "关键点"):
  195. for point in data.get(point_type) or []:
  196. root_source = point.get("点", "")
  197. root_sources_desc = point.get("点描述", "")
  198. name_to_dim = _collect_dimension_names(point)
  199. for word_item in point.get("分词结果") or []:
  200. name = word_item.get("词", "").strip()
  201. if not name:
  202. continue
  203. dimension = name_to_dim.get(name, "实质")
  204. result.append({
  205. "name": name,
  206. "point": point_type,
  207. "dimension": dimension,
  208. "root_source": root_source,
  209. "root_sources_desc": root_sources_desc,
  210. })
  211. return result
  212. def _topic_point_key(t: dict) -> tuple:
  213. return (t["name"], t["point"], t["dimension"])
  214. def load_derivation_logs(log_dir: Path) -> tuple[list[dict], list[dict]]:
  215. """
  216. 从 output/{account_name}/推导日志/{post_id}/{log_id}/ 读取所有 {轮次}_推导.json 与 {轮次}_评估.json。
  217. 返回 (推导列表按轮次序, 评估列表按轮次序)。
  218. """
  219. if not log_dir.is_dir():
  220. raise FileNotFoundError(f"推导日志目录不存在: {log_dir}")
  221. derivation_by_round = {}
  222. eval_by_round = {}
  223. for p in log_dir.glob("*.json"):
  224. base = p.stem
  225. m = re.match(r"^(\d+)_(推导|评估)$", base)
  226. if not m:
  227. continue
  228. round_num = int(m.group(1))
  229. with open(p, "r", encoding="utf-8") as f:
  230. content = json.load(f)
  231. if m.group(2) == "推导":
  232. derivation_by_round[round_num] = content
  233. else:
  234. eval_by_round[round_num] = content
  235. rounds = sorted(set(derivation_by_round) | set(eval_by_round))
  236. derivations = [derivation_by_round[r] for r in rounds if r in derivation_by_round]
  237. evals = [eval_by_round[r] for r in rounds if r in eval_by_round]
  238. return derivations, evals
  239. def build_derivation_result(
  240. topic_points: list[dict],
  241. derivations: list[dict],
  242. evals: list[dict],
  243. ) -> list[dict]:
  244. """
  245. 生成整体推导结果:每轮 轮次、推导成功的选题点、未推导成功的选题点、本次新推导成功的选题点。
  246. 选题点用 topic_points 中的完整信息;按 name 判定是否被推导(评估中的 match_post_point)。
  247. 若之前推导成功的选题点 is_fully_derived=false,本轮变为 is_fully_derived=true,则算本次新推导成功的选题点,
  248. 且 matched_score、is_fully_derived 在本轮后更新为该轮评估值。
  249. 推导成功的选题点:使用当前已更新的 best (matched_score, is_fully_derived)。
  250. 本次新推导成功的选题点:用当轮评估的 matched_score、is_fully_derived。
  251. 未推导成功的选题点:不包含 matched_score、is_fully_derived。
  252. """
  253. all_keys = {_topic_point_key(t) for t in topic_points}
  254. topic_by_key = {_topic_point_key(t): t for t in topic_points}
  255. # 分轮次收集 (round_num, name) -> (matched_score, is_fully_derived),同一轮同名保留 matched_score 最高的
  256. score_by_round_name: dict[tuple[int, str], tuple[float, bool]] = {}
  257. for round_idx, eval_data in enumerate(evals):
  258. round_num = eval_data.get("round", round_idx + 1)
  259. for er in eval_data.get("eval_results") or []:
  260. if not (er.get("is_matched") is True or er.get("match_result") == "匹配"):
  261. continue
  262. mp = (er.get("matched_post_point") or er.get("matched_post_topic") or er.get("match_post_point") or "").strip()
  263. if not mp:
  264. continue
  265. score = er.get("matched_score")
  266. if score is None:
  267. score = 1.0
  268. else:
  269. try:
  270. score = float(score)
  271. except (TypeError, ValueError):
  272. score = 1.0
  273. is_fully = er.get("is_fully_derived", True)
  274. key = (round_num, mp)
  275. if key not in score_by_round_name or score > score_by_round_name[key][0]:
  276. score_by_round_name[key] = (score, bool(is_fully))
  277. result = []
  278. derived_names_so_far: set[str] = set()
  279. fully_derived_names_so_far: set[str] = set() # 已出现过 is_fully_derived=true 的选题点
  280. # name -> (matched_score, is_fully_derived),一旦 is_fully_derived=True,后续轮次不再更新 matched_score
  281. best_score_by_name: dict[str, tuple[float, bool]] = {}
  282. for i, (derivation, eval_data) in enumerate(zip(derivations, evals)):
  283. round_num = derivation.get("round", i + 1)
  284. eval_results = eval_data.get("eval_results") or []
  285. matched_post_points = set()
  286. for er in eval_results:
  287. if not (er.get("is_matched") is True or er.get("match_result") == "匹配"):
  288. continue
  289. mp = er.get("matched_post_point") or er.get("matched_post_topic") or er.get("match_post_point") or ""
  290. if mp and str(mp).strip():
  291. matched_post_points.add(str(mp).strip())
  292. # 本轮每个匹配名的 (score, is_fully)
  293. this_round_scores: dict[str, tuple[float, bool]] = {}
  294. for name in matched_post_points:
  295. val = score_by_round_name.get((round_num, name))
  296. if val is not None:
  297. this_round_scores[name] = val
  298. # 本次新推导成功:首次匹配 或 之前 is_fully=false 且本轮 is_fully=true
  299. new_derived_names = set()
  300. for name in matched_post_points:
  301. score, is_fully = this_round_scores.get(name, (None, False))
  302. if name not in derived_names_so_far:
  303. new_derived_names.add(name)
  304. elif name not in fully_derived_names_so_far and is_fully:
  305. new_derived_names.add(name)
  306. # 更新推导集合与 best:
  307. # - 首次出现时写入
  308. # - 若尚未 fully 且本轮 fully,则更新为 fully,并锁定,不再被后续轮次覆盖
  309. # - 若尚未 fully 且本轮仍为部分推导,可用更高分数更新
  310. derived_names_so_far |= matched_post_points
  311. for name in matched_post_points:
  312. val = this_round_scores.get(name)
  313. if val is None:
  314. continue
  315. score, is_fully = val
  316. if name not in best_score_by_name:
  317. best_score_by_name[name] = (score, is_fully)
  318. else:
  319. prev_score, prev_fully = best_score_by_name[name]
  320. # 已经 fully 的节点,后续轮次不再更新 matched_score
  321. if prev_fully:
  322. pass
  323. else:
  324. if is_fully:
  325. best_score_by_name[name] = (score, True)
  326. else:
  327. # 都是部分推导时,可以用更高分覆盖
  328. if score > prev_score:
  329. best_score_by_name[name] = (score, False)
  330. if is_fully:
  331. fully_derived_names_so_far.add(name)
  332. derived_keys = {k for k in all_keys if topic_by_key[k]["name"] in derived_names_so_far}
  333. new_derived_keys = {k for k in all_keys if topic_by_key[k]["name"] in new_derived_names}
  334. not_derived_keys = all_keys - derived_keys
  335. sort_derived = sorted(derived_keys, key=lambda k: (topic_by_key[k]["name"], k[1], k[2]))
  336. sort_new = sorted(new_derived_keys, key=lambda k: (topic_by_key[k]["name"], k[1], k[2]))
  337. sort_not = sorted(not_derived_keys, key=lambda k: (topic_by_key[k]["name"], k[1], k[2]))
  338. def add_score_fields(keys: set, sort_keys: list, round_for_score: int | None) -> list[dict]:
  339. """round_for_score: 用该轮评估的分数;若为 None 则不添加 score 字段。"""
  340. out = []
  341. for k in sort_keys:
  342. if k not in keys:
  343. continue
  344. obj = dict(topic_by_key[k])
  345. if round_for_score is not None:
  346. name = obj.get("name", "")
  347. val = score_by_round_name.get((round_for_score, name))
  348. if val is not None:
  349. obj["matched_score"] = val[0]
  350. obj["is_fully_derived"] = val[1]
  351. else:
  352. obj["matched_score"] = None
  353. obj["is_fully_derived"] = False
  354. out.append(obj)
  355. return out
  356. # 推导成功的选题点:用当前已更新的 best (matched_score, is_fully_derived)
  357. derived_list = []
  358. for k in sort_derived:
  359. if k not in derived_keys:
  360. continue
  361. obj = dict(topic_by_key[k])
  362. name = obj.get("name", "")
  363. val = best_score_by_name.get(name)
  364. if val is not None:
  365. obj["matched_score"] = val[0]
  366. obj["is_fully_derived"] = val[1]
  367. else:
  368. obj["matched_score"] = None
  369. obj["is_fully_derived"] = False
  370. derived_list.append(obj)
  371. new_list = add_score_fields(new_derived_keys, sort_new, round_for_score=round_num)
  372. not_derived_list = [dict(topic_by_key[k]) for k in sort_not] # 不带 matched_score、is_fully_derived
  373. result.append({
  374. "轮次": round_num,
  375. "推导成功的选题点": derived_list,
  376. "未推导成功的选题点": not_derived_list,
  377. "本次新推导成功的选题点": new_list,
  378. })
  379. return result
  380. def _tree_node_display_name(raw: str) -> str:
  381. """人设节点可能是 a.b.c 路径形式,实际需要的是最后一段节点名 c。"""
  382. s = (raw or "").strip()
  383. if "." in s:
  384. return s.rsplit(".", 1)[-1].strip() or s
  385. return s
  386. def _to_tree_node(name: str, extra: dict | None = None) -> dict:
  387. d = {"name": name}
  388. if extra:
  389. d.update(extra)
  390. return d
  391. def _to_pattern_node(pattern_name: str) -> dict:
  392. """将 pattern 字符串转为 input_pattern_nodes 的一项(简化版)。"""
  393. items = [x.strip() for x in pattern_name.replace("+", " ").split() if x.strip()]
  394. return {
  395. "items": [{"name": x, "point": "关键点", "dimension": "形式", "type": "标签"} for x in items],
  396. "match_items": items,
  397. }
  398. def build_visualize_edges(
  399. derivations: list[dict],
  400. evals: list[dict],
  401. topic_points: list[dict],
  402. ) -> tuple[list[dict], list[dict], list[dict], list[dict]]:
  403. """
  404. 生成 node_list(评估通过的帖子选题点)、edge_list(评估通过的推导路径)、
  405. fail_node_list(评估不通过的帖子选题点)、fail_edge_list(评估不通过的推导路径)。
  406. - node_list / edge_list:同一轮内节点不重复,重复时保留 matched_score 更高的;节点带 matched_score、is_fully_derived。
  407. - fail_node_list / fail_edge_list:数据结构与 node_list / edge_list 一致,保存未通过评估的推导输出节点及对应路径。
  408. 评估数据支持 path_id(对应推导 derivation_results[].id)、derivation_output_point(与推导 output 中字符串对齐)、matched_score、is_fully_derived;不按 item_id 对齐。
  409. """
  410. derivations = sorted(derivations, key=lambda d: d.get("round", 0))
  411. evals = sorted(evals, key=lambda e: e.get("round", 0))
  412. topic_by_name = {t["name"]: t for t in topic_points}
  413. # 评估匹配:(round_num, path_id, derivation_output_point) -> (matched_post_point, matched_reason, matched_score, is_fully_derived)
  414. match_by_path_out: dict[tuple[int, int, str], tuple[str, str, float, bool]] = {}
  415. match_by_round_output: dict[tuple[int, str], tuple[str, str, float, bool]] = {} # 兼容无 path_id 的旧数据
  416. for round_idx, eval_data in enumerate(evals):
  417. round_num = eval_data.get("round", round_idx + 1)
  418. for er in eval_data.get("eval_results") or []:
  419. if not (er.get("is_matched") is True or er.get("match_result") == "匹配"):
  420. continue
  421. mp = (er.get("matched_post_point") or er.get("matched_post_topic") or er.get("match_post_point") or "").strip()
  422. if not mp:
  423. continue
  424. out_point = (er.get("derivation_output_point") or "").strip()
  425. reason = (er.get("matched_reason") or er.get("match_reason") or "").strip()
  426. score = er.get("matched_score")
  427. if score is None:
  428. score = 1.0
  429. else:
  430. try:
  431. score = float(score)
  432. except (TypeError, ValueError):
  433. score = 1.0
  434. is_fully = er.get("is_fully_derived", True)
  435. val = (mp, reason, score, bool(is_fully))
  436. path_id = er.get("path_id")
  437. if path_id is not None and out_point:
  438. try:
  439. match_by_path_out[(round_num, int(path_id), out_point)] = val
  440. except (TypeError, ValueError):
  441. pass
  442. if out_point:
  443. k = (round_num, out_point)
  444. if k not in match_by_round_output:
  445. match_by_round_output[k] = val
  446. def get_match(round_num: int, path_id: int | None, out_item: str) -> tuple[str, str, float, bool] | None:
  447. out_item = (out_item or "").strip()
  448. if not out_item:
  449. return None
  450. if path_id is not None:
  451. v = match_by_path_out.get((round_num, path_id, out_item))
  452. if v is not None:
  453. return v
  454. return match_by_round_output.get((round_num, out_item))
  455. # 第一遍:按 (round_num, mp) 聚合节点最佳信息(不考虑边是否最终保留)
  456. # (round_num, mp) -> (score, is_fully_derived, derivation_output_point, method)
  457. best_node_info_by_round_mp: dict[tuple[int, str], tuple[float, bool, str, str]] = {}
  458. for round_idx, derivation in enumerate(derivations):
  459. round_num = derivation.get("round", round_idx + 1)
  460. for dr in derivation.get("derivation_results") or []:
  461. output_list = dr.get("output") or []
  462. path_id = dr.get("id")
  463. for out_item in output_list:
  464. v = get_match(round_num, path_id, out_item)
  465. if not v:
  466. continue
  467. mp, _reason, score, is_fully = v
  468. key = (round_num, mp)
  469. prev = best_node_info_by_round_mp.get(key)
  470. if prev is None or score > prev[0]:
  471. best_node_info_by_round_mp[key] = (score, bool(is_fully), out_item, dr.get("method", ""))
  472. edge_list = []
  473. fail_edge_list: list[dict] = []
  474. fail_node_list: list[dict] = []
  475. # 跨轮次全局去重:derivation_output_point 在任意轮次出现后,后续轮次不再重复记录
  476. fail_node_seen: set[str] = set()
  477. round_output_seen: set[tuple[int, str]] = set() # (round_num, node_name) 本轮已作为某边的 output
  478. prev_best_by_node: dict[str, tuple[float, bool]] = {} # node_name -> (score, is_fully) of last included round
  479. def _add_fail_node(out_item: str, round_num: int, method: str) -> None:
  480. """将推导输出点加入 fail_node_list(跨轮次去重)。name/original_word 均设为 out_item。"""
  481. if out_item in fail_node_seen:
  482. return
  483. fail_node_seen.add(out_item)
  484. base = dict(topic_by_name.get(
  485. out_item,
  486. {"name": out_item, "point": "", "dimension": "", "root_source": "", "root_sources_desc": ""},
  487. ))
  488. # 未匹配到帖子选题点,name/original_word 统一使用推导输出点本身
  489. base["name"] = out_item
  490. base["level"] = round_num
  491. base["original_word"] = out_item
  492. base["derivation_type"] = method
  493. base["matched_score"] = 0
  494. base["is_fully_derived"] = False
  495. base["derivation_output_point"] = out_item
  496. fail_node_list.append(base)
  497. for round_idx, derivation in enumerate(derivations):
  498. round_num = derivation.get("round", round_idx + 1)
  499. for dr in derivation.get("derivation_results") or []:
  500. output_list = dr.get("output") or []
  501. path_id = dr.get("id")
  502. method = dr.get("method", "")
  503. # 提前构建输入节点(成功边与失败边均需要)
  504. input_data = dr.get("input") or {}
  505. derived_nodes = input_data.get("derived_nodes") or []
  506. tree_nodes = input_data.get("tree_nodes") or []
  507. patterns = input_data.get("patterns") or []
  508. input_post_nodes = [{"name": x} for x in derived_nodes]
  509. input_tree_nodes = [_to_tree_node(_tree_node_display_name(x)) for x in tree_nodes]
  510. if patterns and isinstance(patterns[0], str):
  511. input_pattern_nodes = [_to_pattern_node(p) for p in patterns]
  512. elif patterns and isinstance(patterns[0], dict):
  513. input_pattern_nodes = patterns
  514. else:
  515. input_pattern_nodes = []
  516. matched: list[tuple[str, str, float, bool, str]] = [] # (mp, reason, score, is_fully, derivation_out)
  517. unmatched_out_items: list[str] = []
  518. for out_item in output_list:
  519. out_item = (out_item or "").strip()
  520. if not out_item:
  521. continue
  522. v = get_match(round_num, path_id, out_item)
  523. if not v:
  524. unmatched_out_items.append(out_item)
  525. continue
  526. mp, reason, score, is_fully = v
  527. matched.append((mp, reason, score, is_fully, out_item))
  528. if not matched:
  529. # 全部未匹配 → 记录 fail_edge;仅有新 fail_node 时才写入 fail_edge
  530. new_fail_items = [o for o in unmatched_out_items if o not in fail_node_seen]
  531. if new_fail_items:
  532. for out_item in new_fail_items:
  533. _add_fail_node(out_item, round_num, method)
  534. fail_output_nodes = [
  535. {"name": o, "matched_score": 0, "is_fully_derived": False}
  536. for o in unmatched_out_items
  537. ]
  538. fail_detail: dict = {
  539. "reason": dr.get("reason", ""),
  540. "评估结果": "匹配失败",
  541. }
  542. if dr.get("tools"):
  543. fail_detail["tools"] = dr["tools"]
  544. fail_edge_list.append({
  545. "name": method or f"推导-{round_num}",
  546. "level": round_num,
  547. "input_post_nodes": input_post_nodes,
  548. "input_tree_nodes": input_tree_nodes,
  549. "input_pattern_nodes": input_pattern_nodes,
  550. "output_nodes": fail_output_nodes,
  551. "detail": fail_detail,
  552. })
  553. continue
  554. # 部分匹配 → 收集未匹配的 out_item 到 fail_node,并为新出现的未匹配输出建 fail_edge
  555. new_unmatched = [o for o in unmatched_out_items if o not in fail_node_seen]
  556. for out_item in unmatched_out_items:
  557. _add_fail_node(out_item, round_num, method)
  558. if new_unmatched:
  559. fail_output_nodes = [
  560. {"name": o, "matched_score": 0, "is_fully_derived": False}
  561. for o in unmatched_out_items
  562. ]
  563. partial_fail_detail: dict = {
  564. "reason": dr.get("reason", ""),
  565. "评估结果": "部分匹配失败",
  566. }
  567. if dr.get("tools"):
  568. partial_fail_detail["tools"] = dr["tools"]
  569. fail_edge_list.append({
  570. "name": method or f"推导-{round_num}",
  571. "level": round_num,
  572. "input_post_nodes": input_post_nodes,
  573. "input_tree_nodes": input_tree_nodes,
  574. "input_pattern_nodes": input_pattern_nodes,
  575. "output_nodes": fail_output_nodes,
  576. "detail": partial_fail_detail,
  577. })
  578. # 同一轮内 output 节点不重复;若前面轮次该节点已完全推导,或分数未提升且未从 false 变 true,则本轮跳过;
  579. # 并且只保留与 node_list 中该轮该节点的最高分记录一致的边
  580. output_names_this_edge = []
  581. for mp, reason, score, is_fully, out_item in matched:
  582. if (round_num, mp) in round_output_seen:
  583. continue
  584. prev = prev_best_by_node.get(mp)
  585. if prev is not None:
  586. prev_score, prev_fully = prev
  587. if prev_fully:
  588. continue
  589. if not is_fully and score <= prev_score:
  590. continue
  591. best_info = best_node_info_by_round_mp.get((round_num, mp))
  592. if not best_info or score < best_info[0]:
  593. continue
  594. output_names_this_edge.append((mp, reason, score, is_fully, out_item))
  595. if not output_names_this_edge:
  596. continue
  597. for mp, _r, score, is_fully, _o in output_names_this_edge:
  598. round_output_seen.add((round_num, mp))
  599. prev = prev_best_by_node.get(mp)
  600. if prev is None or (not prev[1] and (is_fully or score > prev[0])):
  601. prev_best_by_node[mp] = (score, is_fully)
  602. output_nodes = []
  603. reasons_list = []
  604. compare_detail_list = []
  605. for mp, reason, score, is_fully, out_item in output_names_this_edge:
  606. output_nodes.append({"name": mp, "matched_score": score, "is_fully_derived": is_fully})
  607. reasons_list.append(reason)
  608. compare_detail_list.append(
  609. f"待比对推导选题点:{out_item} -> 帖子选题点:{mp} ({score})"
  610. )
  611. detail = {
  612. "reason": dr.get("reason", ""),
  613. "评估结果": "匹配成功",
  614. }
  615. if any(reasons_list):
  616. detail["匹配理由"] = reasons_list
  617. detail["比对详情"] = compare_detail_list
  618. if dr.get("tools"):
  619. detail["tools"] = dr["tools"]
  620. edge_list.append({
  621. "name": method or f"推导-{round_num}",
  622. "level": round_num,
  623. "input_post_nodes": input_post_nodes,
  624. "input_tree_nodes": input_tree_nodes,
  625. "input_pattern_nodes": input_pattern_nodes,
  626. "output_nodes": output_nodes,
  627. "detail": detail,
  628. })
  629. # 根据按 (round, mp) 聚合后的最佳信息生成 node_list
  630. # 规则:节点首次出现保留;is_fully_derived 从 false 变 true 时保留;
  631. # is_fully_derived=false 且分数高于之前已保留轮次时保留;其余情况跳过
  632. prev_node_best: dict[str, tuple[float, bool]] = {} # mp -> (score, is_fully) of last included round
  633. node_list: list[dict] = []
  634. for (round_num, mp), (score, is_fully, out_item, method) in sorted(
  635. best_node_info_by_round_mp.items(), key=lambda x: (x[0][0], x[0][1])
  636. ):
  637. prev = prev_node_best.get(mp)
  638. if prev is None:
  639. should_include = True
  640. else:
  641. prev_score, prev_fully = prev
  642. if prev_fully:
  643. should_include = False
  644. elif is_fully:
  645. should_include = True
  646. elif score > prev_score:
  647. should_include = True
  648. else:
  649. should_include = False
  650. if not should_include:
  651. continue
  652. prev_node_best[mp] = (score, is_fully)
  653. base = dict(topic_by_name.get(mp, {"name": mp, "point": "", "dimension": "", "root_source": "", "root_sources_desc": ""}))
  654. base["level"] = round_num
  655. base.setdefault("original_word", base.get("name", mp))
  656. base["derivation_type"] = method
  657. base["matched_score"] = score
  658. base["is_fully_derived"] = is_fully
  659. base["derivation_output_point"] = out_item
  660. node_list.append(base)
  661. node_list.sort(key=lambda n: (n.get("level", 0), str(n.get("name", ""))))
  662. fail_node_list.sort(key=lambda n: (n.get("level", 0), str(n.get("name", ""))))
  663. return node_list, edge_list, fail_node_list, fail_edge_list
  664. def _find_project_root() -> Path:
  665. """从脚本所在目录向上查找包含 .git 的项目根目录。"""
  666. p = Path(__file__).resolve().parent
  667. while p != p.parent:
  668. if (p / ".git").is_dir():
  669. return p
  670. p = p.parent
  671. return Path(__file__).resolve().parent
  672. def generate_visualize_data(account_name: str, post_id: str, log_id: str, base_dir: Path | None = None) -> None:
  673. """
  674. 主流程:读取解构内容与推导日志,生成整体推导结果与整体推导路径可视化两个 JSON。
  675. base_dir 默认为脚本所在目录;若其下 output/.../推导日志 不存在,则尝试项目根目录下的 output/...(兼容从项目根运行)。
  676. """
  677. if base_dir is None:
  678. base_dir = Path(__file__).resolve().parent
  679. input_dir = base_dir / "input" / account_name / "原始数据" / "解构内容"
  680. log_dir = base_dir / "output" / account_name / "推导日志" / post_id / log_id
  681. result_dir = base_dir / "output" / account_name / "整体推导结果"
  682. visualize_dir = base_dir / "output" / account_name / "整体推导路径可视化"
  683. # 兼容:若推导日志不在 base_dir 下,尝试项目根目录下的 output/
  684. if not log_dir.is_dir():
  685. project_root = _find_project_root()
  686. if project_root != base_dir:
  687. alt_log = project_root / "output" / account_name / "推导日志" / post_id / log_id
  688. if alt_log.is_dir():
  689. log_dir = alt_log
  690. result_dir = project_root / "output" / account_name / "整体推导结果"
  691. visualize_dir = project_root / "output" / account_name / "整体推导路径可视化"
  692. deconstruct_path = input_dir / f"{post_id}.json"
  693. topic_points = parse_topic_points_from_deconstruct(deconstruct_path)
  694. derivations, evals = load_derivation_logs(log_dir)
  695. if not derivations or not evals:
  696. raise ValueError(f"推导或评估数据为空: {log_dir}")
  697. # 2.1 整体推导结果
  698. derivation_result = build_derivation_result(topic_points, derivations, evals)
  699. result_dir.mkdir(parents=True, exist_ok=True)
  700. result_path = result_dir / f"{post_id}.json"
  701. with open(result_path, "w", encoding="utf-8") as f:
  702. json.dump(derivation_result, f, ensure_ascii=False, indent=4)
  703. print(f"已写入整体推导结果: {result_path}")
  704. # 2.2 整体推导路径可视化(人设节点补全:used_tree_nodes / all_used_tree_nodes,数据来自处理后数据/tree 人设树)
  705. node_list, edge_list, fail_node_list, fail_edge_list = build_visualize_edges(derivations, evals, topic_points)
  706. tree_dir = base_dir / "input" / account_name / "处理后数据" / "tree"
  707. persona_by_name = build_persona_by_name_from_tree_dir(tree_dir)
  708. if persona_by_name:
  709. print(
  710. f"已加载人设树节点: {len(persona_by_name)} 个(目录: {tree_dir.name})"
  711. )
  712. else:
  713. print(
  714. f"警告: 未从人设树目录加载到节点(请确认存在 *_point_tree_how.json): {tree_dir}"
  715. )
  716. visualize_payload: Dict[str, Any] = {
  717. "node_list": node_list,
  718. "edge_list": edge_list,
  719. "fail_node_list": fail_node_list,
  720. "fail_edge_list": fail_edge_list,
  721. }
  722. enrich_visualize_with_used_tree_nodes(visualize_payload, persona_by_name)
  723. visualize_path = visualize_dir / f"{post_id}.json"
  724. visualize_dir.mkdir(parents=True, exist_ok=True)
  725. with open(visualize_path, "w", encoding="utf-8") as f:
  726. json.dump(visualize_payload, f, ensure_ascii=False, indent=4)
  727. print(f"已写入整体推导路径可视化: {visualize_path}")
  728. def main(account_name, post_id, log_id):
  729. # parser = argparse.ArgumentParser(description="生成推导可视化数据")
  730. # parser.add_argument("account_name", help="账号名,如 家有大志")
  731. # parser.add_argument("post_id", help="帖子 ID")
  732. # parser.add_argument("log_id", help="推导日志 ID,如 20260303204232")
  733. # parser.add_argument("--base-dir", type=Path, default=None, help="项目根目录,默认为本脚本所在目录")
  734. # args = parser.parse_args()
  735. generate_visualize_data(account_name=account_name, post_id=post_id, log_id=log_id)
  736. if __name__ == "__main__":
  737. from tools.pattern_dimension_analyze import main as pattern_dimension_analyze_main
  738. account_name = "家有大志"
  739. items = [
  740. {"post_id": "68fb6a5c000000000302e5de", "log_id": "20260324172323"},
  741. ]
  742. # account_name="阿里多多酱"
  743. # items = [
  744. # {"post_id": "6915dfc400000000070224d9", "log_id": "20260322135142"},
  745. # {"post_id":"69002ba70000000007008bcc","log_id":"20260322213934"},
  746. # ]
  747. # account_name="摸鱼阿希"
  748. # items = [
  749. # {"post_id": "68ae91ce000000001d016b8b", "log_id": "20260322202416"},
  750. # {"post_id":"689c63ac000000001d015119","log_id":"20260322203119"},
  751. # ]
  752. # account_name = "每天心理学"
  753. # items = [
  754. # {"post_id": "6949df27000000001d03e0e9", "log_id": "20260322205512"},
  755. # {"post_id": "6951c718000000001e0105b7", "log_id": "20260322211126"},
  756. # ]
  757. # account_name = "空间点阵设计研究室"
  758. # items = [
  759. # {"post_id": "687ee6fc000000001c032bb1", "log_id": "20260322211748"},
  760. # {"post_id": "68843a4d000000001c037591", "log_id": "20260322213024"},
  761. # ]
  762. for item in items:
  763. post_id = item["post_id"]
  764. log_id = item["log_id"]
  765. main(account_name, post_id, log_id)
  766. pattern_dimension_analyze_main(account_name, post_id, log_id)