how_tree_data_process.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. How 人设树处理:从「原始数据/tree」与「原始数据/point_tree_weight」生成与
  5. 「实质_point_tree_how.json」同结构的输出。
  6. 用法:
  7. python how_tree_data_process.py --account_name 阿里多多酱
  8. 依赖: 环境变量 GEMINI_API_KEY(用于子分类关系判断;缺失时采用保守默认「有交集」)
  9. """
  10. from __future__ import annotations
  11. import hashlib
  12. import json
  13. import os
  14. import re
  15. import sys
  16. from pathlib import Path
  17. from typing import Any, Dict, List, Optional, Tuple
  18. try:
  19. import httpx
  20. except ImportError:
  21. httpx = None # type: ignore
  22. # -----------------------------------------------------------------------------
  23. # 路径
  24. # -----------------------------------------------------------------------------
  25. _SCRIPT_DIR = Path(__file__).resolve().parent
  26. _OVERALL_DIR = _SCRIPT_DIR.parent
  27. _DEFAULT_PROMPT = _OVERALL_DIR / "prompt" / "judge_category_relation.md"
  28. _DEFAULT_MODEL = "gemini-3-flash-preview"
  29. _GLOBAL_CONSTANT_START_LEVEL = 3
  30. _LOCAL_CONSTANT_START_LEVEL = 4
  31. def _input_base(account_name: str) -> Path:
  32. return _OVERALL_DIR / "input" / account_name / "原始数据"
  33. def _output_tree_dir(account_name: str) -> Path:
  34. return _OVERALL_DIR / "input" / account_name / "处理后数据" / "tree"
  35. def _cache_dir(account_name: str) -> Path:
  36. d = _output_tree_dir(account_name) / ".cache_relation" / account_name
  37. d.mkdir(parents=True, exist_ok=True)
  38. return d
  39. # -----------------------------------------------------------------------------
  40. # 扁平 tree JSON -> 分类树(直接元素 / 子分类)
  41. # -----------------------------------------------------------------------------
  42. def _find_or_create_classification(
  43. level: List[Dict[str, Any]], name: str
  44. ) -> Dict[str, Any]:
  45. for node in level:
  46. if node.get("分类名称") == name:
  47. return node
  48. node = {"分类名称": name, "直接元素": [], "子分类": []}
  49. level.append(node)
  50. return node
  51. def _merge_path_element(
  52. classification_tree: List[Dict[str, Any]],
  53. segments: List[str],
  54. element: str,
  55. ) -> None:
  56. if not segments:
  57. return
  58. name = segments[0]
  59. node = _find_or_create_classification(classification_tree, name)
  60. if len(segments) == 1:
  61. if element and element not in node["直接元素"]:
  62. node["直接元素"].append(element)
  63. else:
  64. _merge_path_element(node["子分类"], segments[1:], element)
  65. def _rows_to_classification_tree(
  66. rows: List[Dict[str, Any]],
  67. ) -> Tuple[List[Dict[str, Any]], List[str]]:
  68. """
  69. 返回 (分类树, root 下直接挂接的元素名列表)。
  70. category_path 为空时(常见于「意图」维度),元素不进入子分类,直接挂在 root 下。
  71. """
  72. tree: List[Dict[str, Any]] = []
  73. root_direct: List[str] = []
  74. for row in rows:
  75. raw_path = row.get("category_path") or ""
  76. element = (row.get("element_name") or "").strip()
  77. segments = [s for s in raw_path.strip("/").split("/") if s]
  78. if not segments:
  79. if element and element not in root_direct:
  80. root_direct.append(element)
  81. continue
  82. _merge_path_element(tree, segments, element)
  83. return tree, root_direct
  84. def _is_classification_style(nodes: Any) -> bool:
  85. if not isinstance(nodes, list) or not nodes:
  86. return False
  87. n0 = nodes[0]
  88. return isinstance(n0, dict) and (
  89. "分类名称" in n0 or "直接元素" in n0 or "子分类" in n0
  90. )
  91. def load_classification_tree_from_file(
  92. tree_json_path: Path,
  93. ) -> Optional[Tuple[List[Dict[str, Any]], List[str]]]:
  94. """
  95. 支持:
  96. 1) 扁平 data[](category_path + element_name;path 空则元素进 root 直挂列表)
  97. 2) 与 generate_how_point_tree 相同的 classification_tree JSON
  98. 返回 (最终分类树, root 直挂元素名);解析失败返回 None。
  99. """
  100. with open(tree_json_path, "r", encoding="utf-8") as f:
  101. tree_data = json.load(f)
  102. if isinstance(tree_data, dict):
  103. if "追加分类结果" in tree_data:
  104. ft = tree_data["追加分类结果"].get("最终分类树")
  105. if ft is not None:
  106. return ft, []
  107. if "分类结果" in tree_data:
  108. ft = tree_data["分类结果"].get("最终分类树")
  109. if ft is not None:
  110. return ft, []
  111. ft = tree_data.get("最终分类树")
  112. if ft is not None and _is_classification_style(ft):
  113. return ft, []
  114. rows = tree_data.get("data")
  115. if isinstance(rows, list):
  116. if not rows:
  117. return [], []
  118. if isinstance(rows[0], dict) and (
  119. "category_path" in rows[0] or "element_name" in rows[0]
  120. ):
  121. return _rows_to_classification_tree(rows)
  122. if isinstance(tree_data, list):
  123. if not tree_data:
  124. return [], []
  125. if isinstance(tree_data[0], dict) and (
  126. "category_path" in tree_data[0] or "element_name" in tree_data[0]
  127. ):
  128. return _rows_to_classification_tree(tree_data)
  129. if _is_classification_style(tree_data):
  130. return tree_data, []
  131. return None
  132. # -----------------------------------------------------------------------------
  133. # 权重(与 generate_how_point_tree.load_weight_scores 一致)
  134. # -----------------------------------------------------------------------------
  135. def load_weight_scores(
  136. weight_file: Path,
  137. ) -> Tuple[Dict[str, float], Dict[str, set], int]:
  138. if not weight_file.exists():
  139. print(f"警告: 权重分文件不存在: {weight_file}")
  140. return {}, {}, 0
  141. try:
  142. with open(weight_file, "r", encoding="utf-8") as f:
  143. weight_data = json.load(f)
  144. weight_map: Dict[str, float] = {}
  145. post_ids_map: Dict[str, set] = {}
  146. all_post_ids: set = set()
  147. for item in weight_data:
  148. word = item.get("词", "")
  149. score = item.get("归一权重分", 0.0)
  150. if word:
  151. post_ids: set = set()
  152. for detail in item.get("权重分详情", []):
  153. post_id = detail.get("帖子ID", "")
  154. if post_id:
  155. post_ids.add(post_id)
  156. all_post_ids.add(post_id)
  157. post_ids_map[word] = post_ids
  158. if not post_ids:
  159. weight_map[word] = 0.0
  160. else:
  161. weight_map[word] = score
  162. total_post_count = len(all_post_ids)
  163. return weight_map, post_ids_map, total_post_count
  164. except Exception as e:
  165. print(f"加载权重分数据失败: {e}")
  166. return {}, {}, 0
  167. # -----------------------------------------------------------------------------
  168. # 建树、概率、常量标记(与 generate_how_point_tree 一致)
  169. # -----------------------------------------------------------------------------
  170. def build_id_node(
  171. element: str,
  172. weight_map: Dict[str, float],
  173. post_ids_map: Dict[str, set],
  174. ) -> Dict[str, Any]:
  175. post_ids = post_ids_map.get(element, set())
  176. post_ids_list = list(post_ids)
  177. return {
  178. "_type": "ID",
  179. "_persona_weight_score": round(weight_map.get(element, 0.0), 4),
  180. "_post_count": len(post_ids_list),
  181. "_post_ids": post_ids_list,
  182. }
  183. def build_tree_node_from_classification(
  184. classification_node: Dict,
  185. weight_map: Dict[str, float],
  186. post_ids_map: Dict[str, set],
  187. dimension: str,
  188. ) -> Dict[str, Any]:
  189. direct_elements = classification_node.get("直接元素", [])
  190. sub_classifications = classification_node.get("子分类", [])
  191. children: Dict[str, Any] = {}
  192. for element in direct_elements:
  193. if element:
  194. children[element] = build_id_node(element, weight_map, post_ids_map)
  195. for sub_class in sub_classifications:
  196. sub_node = build_tree_node_from_classification(
  197. sub_class, weight_map, post_ids_map, dimension
  198. )
  199. sub_node_name = sub_class.get("分类名称", "")
  200. if sub_node_name:
  201. children[sub_node_name] = sub_node
  202. def count_leaves_and_sum_scores(node: Dict) -> Tuple[int, float]:
  203. if node.get("_type") == "ID":
  204. return (1, node.get("_persona_weight_score", 0.0))
  205. if node.get("_type") == "class":
  206. leaf_count = 0
  207. total_score = 0.0
  208. for child in node.get("children", {}).values():
  209. c, s = count_leaves_and_sum_scores(child)
  210. leaf_count += c
  211. total_score += s
  212. return (leaf_count, total_score)
  213. return (0, 0.0)
  214. def collect_post_ids(node: Dict) -> set:
  215. if node.get("_type") == "ID":
  216. ids = node.get("_post_ids", [])
  217. return set(ids) if isinstance(ids, list) else ids
  218. if node.get("_type") == "class":
  219. all_post_ids = set()
  220. for child in node.get("children", {}).values():
  221. all_post_ids |= collect_post_ids(child)
  222. return all_post_ids
  223. return set()
  224. total_score = 0.0
  225. for child in children.values():
  226. _, score = count_leaves_and_sum_scores(child)
  227. total_score += score
  228. all_post_ids = set()
  229. for child in children.values():
  230. all_post_ids |= collect_post_ids(child)
  231. total_post_count = len(all_post_ids)
  232. result: Dict[str, Any] = {
  233. "_type": "class",
  234. "_persona_weight_score": round(total_score, 4),
  235. "_post_count": total_post_count,
  236. "_post_ids": list(all_post_ids),
  237. }
  238. if children:
  239. result["children"] = children
  240. return result
  241. def build_tree_from_classification(
  242. classification_tree: List[Dict],
  243. weight_map: Dict[str, float],
  244. post_ids_map: Dict[str, set],
  245. total_post_count: int,
  246. dimension: str,
  247. root_direct_elements: Optional[List[str]] = None,
  248. ) -> Dict[str, Any]:
  249. root_children: Dict[str, Any] = {}
  250. for classification_node in classification_tree:
  251. node = build_tree_node_from_classification(
  252. classification_node, weight_map, post_ids_map, dimension
  253. )
  254. node_name = classification_node.get("分类名称", "")
  255. if node_name:
  256. root_children[node_name] = node
  257. # 意图等:category_path 为空时元素直接作为 root 的子节点(_type=ID)
  258. if root_direct_elements:
  259. for element in root_direct_elements:
  260. if not element or element in root_children:
  261. continue
  262. root_children[element] = build_id_node(element, weight_map, post_ids_map)
  263. def collect_post_ids(node: Dict) -> set:
  264. if node.get("_type") == "ID":
  265. ids = node.get("_post_ids", [])
  266. return set(ids) if isinstance(ids, list) else ids
  267. if node.get("_type") == "class":
  268. all_post_ids = set()
  269. for child in node.get("children", {}).values():
  270. all_post_ids |= collect_post_ids(child)
  271. return all_post_ids
  272. return set()
  273. root_post_ids: set = set()
  274. for child in root_children.values():
  275. root_post_ids |= collect_post_ids(child)
  276. root_post_count = len(root_post_ids)
  277. root: Dict[str, Any] = {
  278. "_type": "root",
  279. "_post_count": root_post_count,
  280. "_post_ids": list(root_post_ids),
  281. "children": root_children,
  282. }
  283. def calculate_ratio(node: Dict[str, Any], parent_post_count: int | None = None) -> None:
  284. node_type = node.get("_type")
  285. post_count = node.get("_post_count", 0)
  286. if node_type == "root":
  287. pass
  288. elif node_type in ("class", "ID"):
  289. if total_post_count > 0:
  290. node["_ratio"] = round(post_count / total_post_count, 4)
  291. else:
  292. node["_ratio"] = 0.0
  293. children = node.get("children", {})
  294. for child in children.values():
  295. calculate_ratio(child, post_count)
  296. calculate_ratio(root)
  297. return root
  298. def collect_id_nodes(
  299. node: Dict[str, Any],
  300. node_path: List[str],
  301. id_nodes: List[Tuple[Dict[str, Any], List[str]]],
  302. min_level: int,
  303. ) -> None:
  304. node_type = node.get("_type")
  305. children = node.get("children", {})
  306. if node_type == "ID":
  307. if len(node_path) >= min_level:
  308. id_nodes.append((node, node_path.copy()))
  309. elif node_type in ("class", "root"):
  310. for child_name, child_node in children.items():
  311. collect_id_nodes(child_node, node_path + [child_name], id_nodes, min_level)
  312. def collect_class_nodes(
  313. node: Dict[str, Any],
  314. node_path: List[str],
  315. class_nodes: List[Tuple[Dict[str, Any], List[str]]],
  316. min_level: int,
  317. is_root: bool = False,
  318. ) -> None:
  319. node_type = node.get("_type")
  320. children = node.get("children", {})
  321. if node_type == "class":
  322. if not is_root and len(node_path) >= min_level:
  323. class_nodes.append((node, node_path.copy()))
  324. for child_name, child_node in children.items():
  325. collect_class_nodes(
  326. child_node, node_path + [child_name], class_nodes, min_level, False
  327. )
  328. elif node_type == "root":
  329. for child_name, child_node in children.items():
  330. collect_class_nodes(
  331. child_node, node_path + [child_name], class_nodes, min_level, False
  332. )
  333. def select_constant_nodes(
  334. candidates: List[Tuple[Dict[str, Any], List[str]]],
  335. ) -> set:
  336. if not candidates:
  337. return set()
  338. max_score = 0.0
  339. for node, _ in candidates:
  340. score = node.get("_persona_weight_score", 0.0)
  341. if score > max_score:
  342. max_score = score
  343. candidate_scores = []
  344. for node, path in candidates:
  345. score = node.get("_persona_weight_score", 0.0)
  346. relative_score = (
  347. score / max_score if max_score > 0 else (1.0 if len(candidates) == 1 else 0.0)
  348. )
  349. candidate_scores.append((node, path, relative_score, score))
  350. qualified_candidates = [
  351. (node, path, rel_score, score)
  352. for node, path, rel_score, score in candidate_scores
  353. if rel_score >= 0.5
  354. ]
  355. if len(qualified_candidates) > 8:
  356. qualified_candidates.sort(key=lambda x: x[2], reverse=True)
  357. constant_nodes = qualified_candidates[:8]
  358. else:
  359. constant_nodes = qualified_candidates.copy()
  360. if len(constant_nodes) < 3:
  361. filtered_candidates = [
  362. (node, path, rel_score, score)
  363. for node, path, rel_score, score in candidate_scores
  364. if rel_score >= 0.2
  365. ]
  366. filtered_candidates.sort(key=lambda x: x[2], reverse=True)
  367. constant_nodes = filtered_candidates[: min(3, len(filtered_candidates))]
  368. return {tuple(path) for _, path, _, _ in constant_nodes}
  369. def mark_constant_nodes(tree: Dict[str, Any], dimension: str) -> None:
  370. # 意图维度常见为 root 直挂节点,不受 level3 起算限制。
  371. min_level = 1 if dimension == "意图" else _GLOBAL_CONSTANT_START_LEVEL
  372. id_nodes: List[Tuple[Dict[str, Any], List[str]]] = []
  373. collect_id_nodes(tree, [], id_nodes, min_level)
  374. class_nodes: List[Tuple[Dict[str, Any], List[str]]] = []
  375. collect_class_nodes(tree, [], class_nodes, min_level, is_root=True)
  376. id_constant_paths = select_constant_nodes(id_nodes)
  377. class_constant_paths = select_constant_nodes(class_nodes)
  378. constant_paths = id_constant_paths | class_constant_paths
  379. def mark_node(node: Dict[str, Any], path: List[str], is_root: bool = False) -> None:
  380. node_type = node.get("_type")
  381. children = node.get("children", {})
  382. if node_type == "ID":
  383. if len(path) >= min_level:
  384. node["_is_constant"] = tuple(path) in constant_paths
  385. elif node_type == "class":
  386. if not is_root and len(path) >= min_level:
  387. node["_is_constant"] = tuple(path) in constant_paths
  388. for child_name, child_node in children.items():
  389. mark_node(child_node, path + [child_name], False)
  390. elif node_type == "root":
  391. for child_name, child_node in children.items():
  392. mark_node(child_node, path + [child_name], False)
  393. mark_node(tree, [], True)
  394. def get_cache_key(parent_category: str, child_categories: List[str]) -> str:
  395. sorted_categories = sorted(child_categories)
  396. content = f"{parent_category}|||{','.join(sorted_categories)}"
  397. return hashlib.md5(content.encode("utf-8")).hexdigest()
  398. def _try_parse_json_text(text: str) -> Dict[str, Any]:
  399. text = text.strip()
  400. text = re.sub(r"^```(?:json)?\s*", "", text)
  401. text = re.sub(r"\s*```\s*$", "", text)
  402. return json.loads(text)
  403. def _gemini_json_call(
  404. system_prompt: str,
  405. user_prompt: str,
  406. model: str,
  407. ) -> str:
  408. if httpx is None:
  409. raise RuntimeError("需要安装 httpx: pip install httpx")
  410. api_key = os.getenv("GEMINI_API_KEY")
  411. if not api_key:
  412. raise ValueError("GEMINI_API_KEY 未设置")
  413. base_url = os.getenv("GEMINI_API_BASE", "https://generativelanguage.googleapis.com/v1beta")
  414. url = f"{base_url}/models/{model}:generateContent"
  415. payload: Dict[str, Any] = {
  416. "contents": [{"role": "user", "parts": [{"text": user_prompt}]}],
  417. "systemInstruction": {"parts": [{"text": system_prompt}]},
  418. "generationConfig": {
  419. "temperature": 0,
  420. "maxOutputTokens": 4096,
  421. "responseMimeType": "application/json",
  422. },
  423. }
  424. with httpx.Client(timeout=120.0) as client:
  425. r = client.post(url, params={"key": api_key}, json=payload)
  426. r.raise_for_status()
  427. data = r.json()
  428. candidates = data.get("candidates") or []
  429. if not candidates:
  430. raise RuntimeError(f"Gemini 无候选输出: {data}")
  431. parts = (candidates[0].get("content") or {}).get("parts") or []
  432. text = "".join(p.get("text", "") for p in parts)
  433. if not text.strip():
  434. raise RuntimeError("Gemini 返回空文本")
  435. return text
  436. def load_cached_relation(account_name: str, cache_key: str) -> Optional[Dict]:
  437. cache_file = _cache_dir(account_name) / f"{cache_key}.json"
  438. if cache_file.exists():
  439. try:
  440. with open(cache_file, "r", encoding="utf-8") as f:
  441. return json.load(f)
  442. except Exception as e:
  443. print(f"读取缓存失败: {e}")
  444. return None
  445. def save_cached_relation(account_name: str, cache_key: str, relation_data: Dict) -> None:
  446. cache_file = _cache_dir(account_name) / f"{cache_key}.json"
  447. try:
  448. with open(cache_file, "w", encoding="utf-8") as f:
  449. json.dump(relation_data, f, ensure_ascii=False, indent=2)
  450. except Exception as e:
  451. print(f"保存缓存失败: {e}")
  452. def judge_category_relation(
  453. parent_category: str,
  454. child_categories: List[str],
  455. account_name: str,
  456. prompt_path: Path,
  457. model: str,
  458. ) -> Dict[str, Any]:
  459. cache_key = get_cache_key(parent_category, child_categories)
  460. cached = load_cached_relation(account_name, cache_key)
  461. if cached:
  462. return cached
  463. if not prompt_path.exists():
  464. raise FileNotFoundError(f"Prompt 文件不存在: {prompt_path}")
  465. prompt_template = prompt_path.read_text(encoding="utf-8")
  466. system_prompt = (
  467. prompt_template.replace("{parent_category}", parent_category).replace(
  468. "{child_categories}", json.dumps(child_categories, ensure_ascii=False)
  469. )
  470. )
  471. user_prompt = "请分析父分类和子分类列表的关系,判断它们是互斥还是有交集,并以JSON格式输出结果。"
  472. try:
  473. raw = _gemini_json_call(system_prompt, user_prompt, model=model)
  474. result = _try_parse_json_text(raw)
  475. save_cached_relation(account_name, cache_key, result)
  476. return result
  477. except Exception as e:
  478. print(f"调用LLM判断分类关系失败: {e}")
  479. result = {
  480. "relation": "有交集",
  481. "confidence": 0.5,
  482. "reasoning": f"LLM调用失败,默认判断为有交集: {str(e)}",
  483. }
  484. save_cached_relation(account_name, cache_key, result)
  485. return result
  486. def mark_local_constant_nodes(
  487. tree: Dict[str, Any],
  488. account_name: str,
  489. prompt_path: Path,
  490. model: str,
  491. ) -> None:
  492. def process_node(node: Dict[str, Any], path: List[str], is_root: bool = False) -> None:
  493. node_type = node.get("_type")
  494. children = node.get("children", {})
  495. if node_type == "root":
  496. for child_name, child_node in children.items():
  497. process_node(child_node, path + [child_name], False)
  498. elif node_type == "class":
  499. ratio = node.get("_ratio", 0.0)
  500. # 局部常量标记从 level4 开始,因此只有 level3 及更深父分类参与判定。
  501. if len(path) >= (_LOCAL_CONSTANT_START_LEVEL - 1) and ratio >= 0.5:
  502. sub_class_nodes = [
  503. (name, cn)
  504. for name, cn in children.items()
  505. if cn.get("_type") == "class"
  506. ]
  507. if len(sub_class_nodes) >= 2:
  508. child_categories = [name for name, _ in sub_class_nodes]
  509. parent_category = path[-1] if path else "根分类"
  510. relation_result = judge_category_relation(
  511. parent_category, child_categories, account_name, prompt_path, model
  512. )
  513. relation = relation_result.get("relation", "有交集")
  514. node["_child_categories_relation"] = relation
  515. node["_child_categories_relation_detail"] = relation_result
  516. if relation == "互斥":
  517. for child_name, child_node in sub_class_nodes:
  518. child_node["_is_local_constant"] = True
  519. else:
  520. parent_post_count = node.get("_post_count", 0)
  521. if parent_post_count > 0:
  522. for child_name, child_node in sub_class_nodes:
  523. child_post_count = child_node.get("_post_count", 0)
  524. child_node["_is_local_constant"] = (
  525. child_post_count / parent_post_count > 0.5
  526. )
  527. else:
  528. for child_name, child_node in sub_class_nodes:
  529. child_node["_is_local_constant"] = False
  530. for child_name, child_node in children.items():
  531. process_node(child_node, path + [child_name], False)
  532. process_node(tree, [], True)
  533. def discover_dimensions(tree_dir: Path) -> List[str]:
  534. dims: List[str] = []
  535. if not tree_dir.is_dir():
  536. return dims
  537. for p in sorted(tree_dir.glob("*_tree.json")):
  538. name = p.name
  539. if name.endswith("_tree.json"):
  540. dim = name[: -len("_tree.json")]
  541. if dim:
  542. dims.append(dim)
  543. return dims
  544. def process_account(
  545. account_name: str,
  546. prompt_path: Optional[Path] = None,
  547. model: Optional[str] = None,
  548. dimensions: Optional[List[str]] = None,
  549. ) -> None:
  550. prompt_path = prompt_path or _DEFAULT_PROMPT
  551. model = model or _DEFAULT_MODEL
  552. base = _input_base(account_name)
  553. tree_dir = base / "tree"
  554. weight_dir = base / "point_tree_weight"
  555. out_dir = _output_tree_dir(account_name)
  556. out_dir.mkdir(parents=True, exist_ok=True)
  557. dims = dimensions if dimensions is not None else discover_dimensions(tree_dir)
  558. if not dims:
  559. print(f"未在 {tree_dir} 找到 *_tree.json")
  560. sys.exit(1)
  561. print(f"账号: {account_name}")
  562. print(f"输出目录: {out_dir}")
  563. print(f"维度: {dims}")
  564. print(f"Gemini 模型: {model}")
  565. for dimension in dims:
  566. tree_file = tree_dir / f"{dimension}_tree.json"
  567. weight_file = weight_dir / f"{dimension}_tree_weight_score.json"
  568. if not tree_file.exists():
  569. print(f"跳过维度 {dimension}:缺少 {tree_file}")
  570. continue
  571. weight_map, post_ids_map, total_post_count = load_weight_scores(weight_file)
  572. if not weight_map:
  573. print(f"跳过维度 {dimension}:无法加载权重分 {weight_file}")
  574. continue
  575. loaded = load_classification_tree_from_file(tree_file)
  576. if loaded is None:
  577. print(f"跳过维度 {dimension}:无法解析分类树 {tree_file}")
  578. continue
  579. classification_tree, root_direct_elements = loaded
  580. print(
  581. f"处理 {dimension}: 分类顶层 {len(classification_tree)} 类, "
  582. f"root 直挂 {len(root_direct_elements)} 词, 权重词 {len(weight_map)}"
  583. )
  584. tree = build_tree_from_classification(
  585. classification_tree,
  586. weight_map,
  587. post_ids_map,
  588. total_post_count,
  589. dimension,
  590. root_direct_elements=root_direct_elements or None,
  591. )
  592. mark_constant_nodes(tree, dimension)
  593. mark_local_constant_nodes(tree, account_name, prompt_path, model)
  594. result = {dimension: tree}
  595. out_file = out_dir / f"{dimension}_point_tree_how.json"
  596. with open(out_file, "w", encoding="utf-8") as f:
  597. json.dump(result, f, ensure_ascii=False, indent=2)
  598. print(f"已写入 {out_file}")
  599. def main(account_name) -> None:
  600. process_account(
  601. account_name
  602. )
  603. if __name__ == "__main__":
  604. main(account_name="空间点阵设计研究室")