how_tree_data_process.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. How 人设树处理:从「原始数据/tree」与「原始数据/point_tree_weight」生成与
  5. 「实质_point_tree_how.json」同结构的输出。
  6. 可选:`原始数据/exclude_note_ids.json` 中的帖子 ID 不参与建树(权重聚合与带 post_ids 的扁平 data 行会过滤)。
  7. 用法:
  8. python how_tree_data_process.py --account_name 阿里多多酱
  9. 依赖: 环境变量 GEMINI_API_KEY(用于子分类关系判断;缺失时采用保守默认「有交集」)
  10. """
  11. from __future__ import annotations
  12. import argparse
  13. import hashlib
  14. import json
  15. import os
  16. import re
  17. import sys
  18. from pathlib import Path
  19. from typing import Any, Dict, List, Optional, Tuple
  20. try:
  21. import httpx
  22. except ImportError:
  23. httpx = None # type: ignore
  24. # -----------------------------------------------------------------------------
  25. # 路径
  26. # -----------------------------------------------------------------------------
  27. _SCRIPT_DIR = Path(__file__).resolve().parent
  28. _OVERALL_DIR = _SCRIPT_DIR.parent
  29. _DEFAULT_PROMPT = _OVERALL_DIR / "prompt" / "judge_category_relation.md"
  30. _DEFAULT_MODEL = "gemini-3-flash-preview"
  31. def _input_base(account_name: str) -> Path:
  32. return _OVERALL_DIR / "input" / account_name / "原始数据"
  33. def _output_tree_dir(account_name: str) -> Path:
  34. return _OVERALL_DIR / "input" / account_name / "处理后数据" / "tree"
  35. def _cache_dir(account_name: str) -> Path:
  36. d = _output_tree_dir(account_name) / ".cache_relation" / account_name
  37. d.mkdir(parents=True, exist_ok=True)
  38. return d
  39. def load_exclude_note_ids(base: Path) -> set[str]:
  40. """
  41. 读取「原始数据/exclude_note_ids.json」,格式为帖子 ID 字符串数组。
  42. 文件不存在或解析失败时返回空集合。
  43. """
  44. path = base / "exclude_note_ids.json"
  45. if not path.exists():
  46. return set()
  47. try:
  48. with open(path, "r", encoding="utf-8") as f:
  49. raw = json.load(f)
  50. if isinstance(raw, list):
  51. return {str(x).strip() for x in raw if str(x).strip()}
  52. if isinstance(raw, dict) and "ids" in raw:
  53. v = raw["ids"]
  54. if isinstance(v, list):
  55. return {str(x).strip() for x in v if str(x).strip()}
  56. except Exception as e:
  57. print(f"警告: 读取排除帖子列表失败 {path}: {e}")
  58. return set()
  59. def _row_has_posts_after_exclude(row: Dict[str, Any], exclude_ids: set[str]) -> bool:
  60. """
  61. 若行带 post_ids 且去掉排除 ID 后仍非空,则参与建树;
  62. 若无 post_ids 或为空列表,则仍参与(仅结构、由权重侧决定帖子)。
  63. """
  64. if not exclude_ids:
  65. return True
  66. pids = row.get("post_ids")
  67. if not isinstance(pids, list) or not pids:
  68. return True
  69. return any(pid not in exclude_ids for pid in pids if pid)
  70. def _filter_data_rows_for_exclude(
  71. rows: List[Dict[str, Any]], exclude_ids: set[str]
  72. ) -> List[Dict[str, Any]]:
  73. if not exclude_ids:
  74. return rows
  75. return [r for r in rows if _row_has_posts_after_exclude(r, exclude_ids)]
  76. # -----------------------------------------------------------------------------
  77. # 扁平 tree JSON -> 分类树(直接元素 / 子分类)
  78. # -----------------------------------------------------------------------------
  79. def _find_or_create_classification(
  80. level: List[Dict[str, Any]], name: str
  81. ) -> Dict[str, Any]:
  82. for node in level:
  83. if node.get("分类名称") == name:
  84. return node
  85. node = {"分类名称": name, "直接元素": [], "子分类": []}
  86. level.append(node)
  87. return node
  88. def _merge_path_element(
  89. classification_tree: List[Dict[str, Any]],
  90. segments: List[str],
  91. element: str,
  92. ) -> None:
  93. if not segments:
  94. return
  95. name = segments[0]
  96. node = _find_or_create_classification(classification_tree, name)
  97. if len(segments) == 1:
  98. if element and element not in node["直接元素"]:
  99. node["直接元素"].append(element)
  100. else:
  101. _merge_path_element(node["子分类"], segments[1:], element)
  102. def _rows_to_classification_tree(
  103. rows: List[Dict[str, Any]],
  104. ) -> Tuple[List[Dict[str, Any]], List[str]]:
  105. """
  106. 返回 (分类树, root 下直接挂接的元素名列表)。
  107. category_path 为空时(常见于「意图」维度),元素不进入子分类,直接挂在 root 下。
  108. """
  109. tree: List[Dict[str, Any]] = []
  110. root_direct: List[str] = []
  111. for row in rows:
  112. raw_path = row.get("category_path") or ""
  113. element = (row.get("element_name") or "").strip()
  114. segments = [s for s in raw_path.strip("/").split("/") if s]
  115. if not segments:
  116. if element and element not in root_direct:
  117. root_direct.append(element)
  118. continue
  119. _merge_path_element(tree, segments, element)
  120. return tree, root_direct
  121. def _is_classification_style(nodes: Any) -> bool:
  122. if not isinstance(nodes, list) or not nodes:
  123. return False
  124. n0 = nodes[0]
  125. return isinstance(n0, dict) and (
  126. "分类名称" in n0 or "直接元素" in n0 or "子分类" in n0
  127. )
  128. def load_classification_tree_from_file(
  129. tree_json_path: Path,
  130. exclude_note_ids: Optional[set[str]] = None,
  131. ) -> Optional[Tuple[List[Dict[str, Any]], List[str]]]:
  132. """
  133. 支持:
  134. 1) 扁平 data[](category_path + element_name;path 空则元素进 root 直挂列表)
  135. 2) 与 generate_how_point_tree 相同的 classification_tree JSON
  136. 返回 (最终分类树, root 直挂元素名);解析失败返回 None。
  137. """
  138. with open(tree_json_path, "r", encoding="utf-8") as f:
  139. tree_data = json.load(f)
  140. if isinstance(tree_data, dict):
  141. if "追加分类结果" in tree_data:
  142. ft = tree_data["追加分类结果"].get("最终分类树")
  143. if ft is not None:
  144. return ft, []
  145. if "分类结果" in tree_data:
  146. ft = tree_data["分类结果"].get("最终分类树")
  147. if ft is not None:
  148. return ft, []
  149. ft = tree_data.get("最终分类树")
  150. if ft is not None and _is_classification_style(ft):
  151. return ft, []
  152. rows = tree_data.get("data")
  153. if isinstance(rows, list):
  154. if not rows:
  155. return [], []
  156. if isinstance(rows[0], dict) and (
  157. "category_path" in rows[0] or "element_name" in rows[0]
  158. ):
  159. rows = _filter_data_rows_for_exclude(rows, exclude_note_ids or set())
  160. return _rows_to_classification_tree(rows)
  161. if isinstance(tree_data, list):
  162. if not tree_data:
  163. return [], []
  164. if isinstance(tree_data[0], dict) and (
  165. "category_path" in tree_data[0] or "element_name" in tree_data[0]
  166. ):
  167. rows = _filter_data_rows_for_exclude(tree_data, exclude_note_ids or set())
  168. return _rows_to_classification_tree(rows)
  169. if _is_classification_style(tree_data):
  170. return tree_data, []
  171. return None
  172. # -----------------------------------------------------------------------------
  173. # 权重(与 generate_how_point_tree.load_weight_scores 一致)
  174. # -----------------------------------------------------------------------------
  175. def load_weight_scores(
  176. weight_file: Path,
  177. exclude_note_ids: Optional[set[str]] = None,
  178. ) -> Tuple[Dict[str, float], Dict[str, set], int]:
  179. if not weight_file.exists():
  180. print(f"警告: 权重分文件不存在: {weight_file}")
  181. return {}, {}, 0
  182. ex = exclude_note_ids or set()
  183. try:
  184. with open(weight_file, "r", encoding="utf-8") as f:
  185. weight_data = json.load(f)
  186. weight_map: Dict[str, float] = {}
  187. post_ids_map: Dict[str, set] = {}
  188. all_post_ids: set = set()
  189. for item in weight_data:
  190. word = item.get("词", "")
  191. score = item.get("归一权重分", 0.0)
  192. if word:
  193. post_ids: set = set()
  194. for detail in item.get("权重分详情", []):
  195. post_id = detail.get("帖子ID", "")
  196. if post_id and post_id not in ex:
  197. post_ids.add(post_id)
  198. all_post_ids.add(post_id)
  199. post_ids_map[word] = post_ids
  200. if not post_ids:
  201. weight_map[word] = 0.0
  202. else:
  203. weight_map[word] = score
  204. total_post_count = len(all_post_ids)
  205. return weight_map, post_ids_map, total_post_count
  206. except Exception as e:
  207. print(f"加载权重分数据失败: {e}")
  208. return {}, {}, 0
  209. # -----------------------------------------------------------------------------
  210. # 建树、概率、常量标记(与 generate_how_point_tree 一致)
  211. # -----------------------------------------------------------------------------
  212. def build_id_node(
  213. element: str,
  214. weight_map: Dict[str, float],
  215. post_ids_map: Dict[str, set],
  216. ) -> Dict[str, Any]:
  217. post_ids = post_ids_map.get(element, set())
  218. post_ids_list = list(post_ids)
  219. return {
  220. "_type": "ID",
  221. "_persona_weight_score": round(weight_map.get(element, 0.0), 4),
  222. "_post_count": len(post_ids_list),
  223. "_post_ids": post_ids_list,
  224. }
  225. def build_tree_node_from_classification(
  226. classification_node: Dict,
  227. weight_map: Dict[str, float],
  228. post_ids_map: Dict[str, set],
  229. dimension: str,
  230. ) -> Dict[str, Any]:
  231. direct_elements = classification_node.get("直接元素", [])
  232. sub_classifications = classification_node.get("子分类", [])
  233. children: Dict[str, Any] = {}
  234. for element in direct_elements:
  235. if element:
  236. children[element] = build_id_node(element, weight_map, post_ids_map)
  237. for sub_class in sub_classifications:
  238. sub_node = build_tree_node_from_classification(
  239. sub_class, weight_map, post_ids_map, dimension
  240. )
  241. sub_node_name = sub_class.get("分类名称", "")
  242. if sub_node_name:
  243. children[sub_node_name] = sub_node
  244. def count_leaves_and_sum_scores(node: Dict) -> Tuple[int, float]:
  245. if node.get("_type") == "ID":
  246. return (1, node.get("_persona_weight_score", 0.0))
  247. if node.get("_type") == "class":
  248. leaf_count = 0
  249. total_score = 0.0
  250. for child in node.get("children", {}).values():
  251. c, s = count_leaves_and_sum_scores(child)
  252. leaf_count += c
  253. total_score += s
  254. return (leaf_count, total_score)
  255. return (0, 0.0)
  256. def collect_post_ids(node: Dict) -> set:
  257. if node.get("_type") == "ID":
  258. ids = node.get("_post_ids", [])
  259. return set(ids) if isinstance(ids, list) else ids
  260. if node.get("_type") == "class":
  261. all_post_ids = set()
  262. for child in node.get("children", {}).values():
  263. all_post_ids |= collect_post_ids(child)
  264. return all_post_ids
  265. return set()
  266. total_score = 0.0
  267. for child in children.values():
  268. _, score = count_leaves_and_sum_scores(child)
  269. total_score += score
  270. all_post_ids = set()
  271. for child in children.values():
  272. all_post_ids |= collect_post_ids(child)
  273. total_post_count = len(all_post_ids)
  274. result: Dict[str, Any] = {
  275. "_type": "class",
  276. "_persona_weight_score": round(total_score, 4),
  277. "_post_count": total_post_count,
  278. "_post_ids": list(all_post_ids),
  279. }
  280. if children:
  281. result["children"] = children
  282. return result
  283. def build_tree_from_classification(
  284. classification_tree: List[Dict],
  285. weight_map: Dict[str, float],
  286. post_ids_map: Dict[str, set],
  287. total_post_count: int,
  288. dimension: str,
  289. root_direct_elements: Optional[List[str]] = None,
  290. ) -> Dict[str, Any]:
  291. root_children: Dict[str, Any] = {}
  292. for classification_node in classification_tree:
  293. node = build_tree_node_from_classification(
  294. classification_node, weight_map, post_ids_map, dimension
  295. )
  296. node_name = classification_node.get("分类名称", "")
  297. if node_name:
  298. root_children[node_name] = node
  299. # 意图等:category_path 为空时元素直接作为 root 的子节点(_type=ID)
  300. if root_direct_elements:
  301. for element in root_direct_elements:
  302. if not element or element in root_children:
  303. continue
  304. root_children[element] = build_id_node(element, weight_map, post_ids_map)
  305. def collect_post_ids(node: Dict) -> set:
  306. if node.get("_type") == "ID":
  307. ids = node.get("_post_ids", [])
  308. return set(ids) if isinstance(ids, list) else ids
  309. if node.get("_type") == "class":
  310. all_post_ids = set()
  311. for child in node.get("children", {}).values():
  312. all_post_ids |= collect_post_ids(child)
  313. return all_post_ids
  314. return set()
  315. root_post_ids: set = set()
  316. for child in root_children.values():
  317. root_post_ids |= collect_post_ids(child)
  318. root_post_count = len(root_post_ids)
  319. root: Dict[str, Any] = {
  320. "_type": "root",
  321. "_post_count": root_post_count,
  322. "_post_ids": list(root_post_ids),
  323. "children": root_children,
  324. }
  325. def calculate_ratio(node: Dict[str, Any], parent_post_count: int | None = None) -> None:
  326. node_type = node.get("_type")
  327. post_count = node.get("_post_count", 0)
  328. if node_type == "root":
  329. pass
  330. elif node_type in ("class", "ID"):
  331. if total_post_count > 0:
  332. node["_ratio"] = round(post_count / total_post_count, 4)
  333. else:
  334. node["_ratio"] = 0.0
  335. children = node.get("children", {})
  336. for child in children.values():
  337. calculate_ratio(child, post_count)
  338. calculate_ratio(root)
  339. return root
  340. def collect_id_nodes(
  341. node: Dict[str, Any],
  342. node_path: List[str],
  343. id_nodes: List[Tuple[Dict[str, Any], List[str]]],
  344. ) -> None:
  345. node_type = node.get("_type")
  346. children = node.get("children", {})
  347. if node_type == "ID":
  348. id_nodes.append((node, node_path.copy()))
  349. elif node_type in ("class", "root"):
  350. for child_name, child_node in children.items():
  351. collect_id_nodes(child_node, node_path + [child_name], id_nodes)
  352. def collect_class_nodes(
  353. node: Dict[str, Any],
  354. node_path: List[str],
  355. class_nodes: List[Tuple[Dict[str, Any], List[str]]],
  356. is_root: bool = False,
  357. ) -> None:
  358. node_type = node.get("_type")
  359. children = node.get("children", {})
  360. is_first_level = len(node_path) == 1
  361. if node_type == "class":
  362. if not is_root and not is_first_level:
  363. class_nodes.append((node, node_path.copy()))
  364. for child_name, child_node in children.items():
  365. collect_class_nodes(child_node, node_path + [child_name], class_nodes, False)
  366. elif node_type == "root":
  367. for child_name, child_node in children.items():
  368. collect_class_nodes(child_node, node_path + [child_name], class_nodes, False)
  369. def select_constant_nodes(
  370. candidates: List[Tuple[Dict[str, Any], List[str]]],
  371. ) -> set:
  372. if not candidates:
  373. return set()
  374. max_score = 0.0
  375. for node, _ in candidates:
  376. score = node.get("_persona_weight_score", 0.0)
  377. if score > max_score:
  378. max_score = score
  379. candidate_scores = []
  380. for node, path in candidates:
  381. score = node.get("_persona_weight_score", 0.0)
  382. relative_score = (
  383. score / max_score if max_score > 0 else (1.0 if len(candidates) == 1 else 0.0)
  384. )
  385. candidate_scores.append((node, path, relative_score, score))
  386. qualified_candidates = [
  387. (node, path, rel_score, score)
  388. for node, path, rel_score, score in candidate_scores
  389. if rel_score >= 0.5
  390. ]
  391. if len(qualified_candidates) > 8:
  392. qualified_candidates.sort(key=lambda x: x[2], reverse=True)
  393. constant_nodes = qualified_candidates[:8]
  394. else:
  395. constant_nodes = qualified_candidates.copy()
  396. if len(constant_nodes) < 3:
  397. filtered_candidates = [
  398. (node, path, rel_score, score)
  399. for node, path, rel_score, score in candidate_scores
  400. if rel_score >= 0.2
  401. ]
  402. filtered_candidates.sort(key=lambda x: x[2], reverse=True)
  403. constant_nodes = filtered_candidates[: min(3, len(filtered_candidates))]
  404. return {tuple(path) for _, path, _, _ in constant_nodes}
  405. def mark_constant_nodes(tree: Dict[str, Any]) -> None:
  406. id_nodes: List[Tuple[Dict[str, Any], List[str]]] = []
  407. collect_id_nodes(tree, [], id_nodes)
  408. class_nodes: List[Tuple[Dict[str, Any], List[str]]] = []
  409. collect_class_nodes(tree, [], class_nodes, is_root=True)
  410. id_constant_paths = select_constant_nodes(id_nodes)
  411. class_constant_paths = select_constant_nodes(class_nodes)
  412. constant_paths = id_constant_paths | class_constant_paths
  413. def mark_node(node: Dict[str, Any], path: List[str], is_root: bool = False) -> None:
  414. node_type = node.get("_type")
  415. children = node.get("children", {})
  416. is_first_level = len(path) == 1
  417. if node_type == "ID":
  418. node["_is_constant"] = tuple(path) in constant_paths
  419. elif node_type == "class":
  420. if not is_root and not is_first_level:
  421. node["_is_constant"] = tuple(path) in constant_paths
  422. for child_name, child_node in children.items():
  423. mark_node(child_node, path + [child_name], False)
  424. elif node_type == "root":
  425. for child_name, child_node in children.items():
  426. mark_node(child_node, path + [child_name], False)
  427. mark_node(tree, [], True)
  428. def get_cache_key(parent_category: str, child_categories: List[str]) -> str:
  429. sorted_categories = sorted(child_categories)
  430. content = f"{parent_category}|||{','.join(sorted_categories)}"
  431. return hashlib.md5(content.encode("utf-8")).hexdigest()
  432. def _try_parse_json_text(text: str) -> Dict[str, Any]:
  433. text = text.strip()
  434. text = re.sub(r"^```(?:json)?\s*", "", text)
  435. text = re.sub(r"\s*```\s*$", "", text)
  436. return json.loads(text)
  437. def _gemini_json_call(
  438. system_prompt: str,
  439. user_prompt: str,
  440. model: str,
  441. ) -> str:
  442. if httpx is None:
  443. raise RuntimeError("需要安装 httpx: pip install httpx")
  444. api_key = os.getenv("GEMINI_API_KEY")
  445. if not api_key:
  446. raise ValueError("GEMINI_API_KEY 未设置")
  447. base_url = os.getenv("GEMINI_API_BASE", "https://generativelanguage.googleapis.com/v1beta")
  448. url = f"{base_url}/models/{model}:generateContent"
  449. payload: Dict[str, Any] = {
  450. "contents": [{"role": "user", "parts": [{"text": user_prompt}]}],
  451. "systemInstruction": {"parts": [{"text": system_prompt}]},
  452. "generationConfig": {
  453. "temperature": 0,
  454. "maxOutputTokens": 4096,
  455. "responseMimeType": "application/json",
  456. },
  457. }
  458. with httpx.Client(timeout=120.0) as client:
  459. r = client.post(url, params={"key": api_key}, json=payload)
  460. r.raise_for_status()
  461. data = r.json()
  462. candidates = data.get("candidates") or []
  463. if not candidates:
  464. raise RuntimeError(f"Gemini 无候选输出: {data}")
  465. parts = (candidates[0].get("content") or {}).get("parts") or []
  466. text = "".join(p.get("text", "") for p in parts)
  467. if not text.strip():
  468. raise RuntimeError("Gemini 返回空文本")
  469. return text
  470. def load_cached_relation(account_name: str, cache_key: str) -> Optional[Dict]:
  471. cache_file = _cache_dir(account_name) / f"{cache_key}.json"
  472. if cache_file.exists():
  473. try:
  474. with open(cache_file, "r", encoding="utf-8") as f:
  475. return json.load(f)
  476. except Exception as e:
  477. print(f"读取缓存失败: {e}")
  478. return None
  479. def save_cached_relation(account_name: str, cache_key: str, relation_data: Dict) -> None:
  480. cache_file = _cache_dir(account_name) / f"{cache_key}.json"
  481. try:
  482. with open(cache_file, "w", encoding="utf-8") as f:
  483. json.dump(relation_data, f, ensure_ascii=False, indent=2)
  484. except Exception as e:
  485. print(f"保存缓存失败: {e}")
  486. def judge_category_relation(
  487. parent_category: str,
  488. child_categories: List[str],
  489. account_name: str,
  490. prompt_path: Path,
  491. model: str,
  492. ) -> Dict[str, Any]:
  493. cache_key = get_cache_key(parent_category, child_categories)
  494. cached = load_cached_relation(account_name, cache_key)
  495. if cached:
  496. return cached
  497. if not prompt_path.exists():
  498. raise FileNotFoundError(f"Prompt 文件不存在: {prompt_path}")
  499. prompt_template = prompt_path.read_text(encoding="utf-8")
  500. system_prompt = (
  501. prompt_template.replace("{parent_category}", parent_category).replace(
  502. "{child_categories}", json.dumps(child_categories, ensure_ascii=False)
  503. )
  504. )
  505. user_prompt = "请分析父分类和子分类列表的关系,判断它们是互斥还是有交集,并以JSON格式输出结果。"
  506. try:
  507. raw = _gemini_json_call(system_prompt, user_prompt, model=model)
  508. result = _try_parse_json_text(raw)
  509. save_cached_relation(account_name, cache_key, result)
  510. return result
  511. except Exception as e:
  512. print(f"调用LLM判断分类关系失败: {e}")
  513. result = {
  514. "relation": "有交集",
  515. "confidence": 0.5,
  516. "reasoning": f"LLM调用失败,默认判断为有交集: {str(e)}",
  517. }
  518. save_cached_relation(account_name, cache_key, result)
  519. return result
  520. def mark_local_constant_nodes(
  521. tree: Dict[str, Any],
  522. account_name: str,
  523. prompt_path: Path,
  524. model: str,
  525. ) -> None:
  526. def process_node(node: Dict[str, Any], path: List[str], is_root: bool = False) -> None:
  527. node_type = node.get("_type")
  528. children = node.get("children", {})
  529. is_first_level = len(path) == 1
  530. if node_type == "root":
  531. for child_name, child_node in children.items():
  532. process_node(child_node, path + [child_name], False)
  533. elif node_type == "class":
  534. ratio = node.get("_ratio", 0.0)
  535. if (is_first_level or len(path) > 1) and ratio >= 0.5:
  536. sub_class_nodes = [
  537. (name, cn)
  538. for name, cn in children.items()
  539. if cn.get("_type") == "class"
  540. ]
  541. if len(sub_class_nodes) >= 2:
  542. child_categories = [name for name, _ in sub_class_nodes]
  543. parent_category = path[-1] if path else "根分类"
  544. relation_result = judge_category_relation(
  545. parent_category, child_categories, account_name, prompt_path, model
  546. )
  547. relation = relation_result.get("relation", "有交集")
  548. node["_child_categories_relation"] = relation
  549. node["_child_categories_relation_detail"] = relation_result
  550. if relation == "互斥":
  551. for child_name, child_node in sub_class_nodes:
  552. child_node["_is_local_constant"] = True
  553. else:
  554. parent_post_count = node.get("_post_count", 0)
  555. if parent_post_count > 0:
  556. for child_name, child_node in sub_class_nodes:
  557. child_post_count = child_node.get("_post_count", 0)
  558. child_node["_is_local_constant"] = (
  559. child_post_count / parent_post_count > 0.5
  560. )
  561. else:
  562. for child_name, child_node in sub_class_nodes:
  563. child_node["_is_local_constant"] = False
  564. for child_name, child_node in children.items():
  565. process_node(child_node, path + [child_name], False)
  566. process_node(tree, [], True)
  567. def discover_dimensions(tree_dir: Path) -> List[str]:
  568. dims: List[str] = []
  569. if not tree_dir.is_dir():
  570. return dims
  571. for p in sorted(tree_dir.glob("*_tree.json")):
  572. name = p.name
  573. if name.endswith("_tree.json"):
  574. dim = name[: -len("_tree.json")]
  575. if dim:
  576. dims.append(dim)
  577. return dims
  578. def process_account(
  579. account_name: str,
  580. prompt_path: Optional[Path] = None,
  581. model: Optional[str] = None,
  582. dimensions: Optional[List[str]] = None,
  583. ) -> None:
  584. prompt_path = prompt_path or _DEFAULT_PROMPT
  585. model = model or _DEFAULT_MODEL
  586. base = _input_base(account_name)
  587. tree_dir = base / "tree"
  588. weight_dir = base / "point_tree_weight"
  589. out_dir = _output_tree_dir(account_name)
  590. out_dir.mkdir(parents=True, exist_ok=True)
  591. dims = dimensions if dimensions is not None else discover_dimensions(tree_dir)
  592. if not dims:
  593. print(f"未在 {tree_dir} 找到 *_tree.json")
  594. sys.exit(1)
  595. exclude_note_ids = load_exclude_note_ids(base)
  596. print(f"账号: {account_name}")
  597. print(f"输出目录: {out_dir}")
  598. print(f"维度: {dims}")
  599. print(f"Gemini 模型: {model}")
  600. print(f"排除帖子 ID 数: {len(exclude_note_ids)}")
  601. for dimension in dims:
  602. tree_file = tree_dir / f"{dimension}_tree.json"
  603. weight_file = weight_dir / f"{dimension}_tree_weight_score.json"
  604. if not tree_file.exists():
  605. print(f"跳过维度 {dimension}:缺少 {tree_file}")
  606. continue
  607. weight_map, post_ids_map, total_post_count = load_weight_scores(
  608. weight_file, exclude_note_ids=exclude_note_ids
  609. )
  610. if not weight_map:
  611. print(f"跳过维度 {dimension}:无法加载权重分 {weight_file}")
  612. continue
  613. loaded = load_classification_tree_from_file(
  614. tree_file, exclude_note_ids=exclude_note_ids
  615. )
  616. if loaded is None:
  617. print(f"跳过维度 {dimension}:无法解析分类树 {tree_file}")
  618. continue
  619. classification_tree, root_direct_elements = loaded
  620. print(
  621. f"处理 {dimension}: 分类顶层 {len(classification_tree)} 类, "
  622. f"root 直挂 {len(root_direct_elements)} 词, 权重词 {len(weight_map)}"
  623. )
  624. tree = build_tree_from_classification(
  625. classification_tree,
  626. weight_map,
  627. post_ids_map,
  628. total_post_count,
  629. dimension,
  630. root_direct_elements=root_direct_elements or None,
  631. )
  632. mark_constant_nodes(tree)
  633. mark_local_constant_nodes(tree, account_name, prompt_path, model)
  634. result = {dimension: tree}
  635. out_file = out_dir / f"{dimension}_point_tree_how.json"
  636. with open(out_file, "w", encoding="utf-8") as f:
  637. json.dump(result, f, ensure_ascii=False, indent=2)
  638. print(f"已写入 {out_file}")
  639. def main(account_name) -> None:
  640. process_account(
  641. account_name
  642. )
  643. if __name__ == "__main__":
  644. main(account_name="空间点阵设计研究室")