| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- How 人设树处理:从「原始数据/tree」与「原始数据/point_tree_weight」生成与
- 「实质_point_tree_how.json」同结构的输出。
- 可选:`原始数据/exclude_note_ids.json` 中的帖子 ID 不参与建树(权重聚合与带 post_ids 的扁平 data 行会过滤)。
- 用法:
- python how_tree_data_process.py --account_name 阿里多多酱
- 依赖: 环境变量 GEMINI_API_KEY(用于子分类关系判断;缺失时采用保守默认「有交集」)
- """
- from __future__ import annotations
- import argparse
- import hashlib
- import json
- import os
- import re
- import sys
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Tuple
- try:
- import httpx
- except ImportError:
- httpx = None # type: ignore
- # -----------------------------------------------------------------------------
- # 路径
- # -----------------------------------------------------------------------------
- _SCRIPT_DIR = Path(__file__).resolve().parent
- _OVERALL_DIR = _SCRIPT_DIR.parent
- _DEFAULT_PROMPT = _OVERALL_DIR / "prompt" / "judge_category_relation.md"
- _DEFAULT_MODEL = "gemini-3-flash-preview"
- def _input_base(account_name: str) -> Path:
- return _OVERALL_DIR / "input" / account_name / "原始数据"
- def _output_tree_dir(account_name: str) -> Path:
- return _OVERALL_DIR / "input" / account_name / "处理后数据" / "tree"
- def _cache_dir(account_name: str) -> Path:
- d = _output_tree_dir(account_name) / ".cache_relation" / account_name
- d.mkdir(parents=True, exist_ok=True)
- return d
- def load_exclude_note_ids(base: Path) -> set[str]:
- """
- 读取「原始数据/exclude_note_ids.json」,格式为帖子 ID 字符串数组。
- 文件不存在或解析失败时返回空集合。
- """
- path = base / "exclude_note_ids.json"
- if not path.exists():
- return set()
- try:
- with open(path, "r", encoding="utf-8") as f:
- raw = json.load(f)
- if isinstance(raw, list):
- return {str(x).strip() for x in raw if str(x).strip()}
- if isinstance(raw, dict) and "ids" in raw:
- v = raw["ids"]
- if isinstance(v, list):
- return {str(x).strip() for x in v if str(x).strip()}
- except Exception as e:
- print(f"警告: 读取排除帖子列表失败 {path}: {e}")
- return set()
- def _row_has_posts_after_exclude(row: Dict[str, Any], exclude_ids: set[str]) -> bool:
- """
- 若行带 post_ids 且去掉排除 ID 后仍非空,则参与建树;
- 若无 post_ids 或为空列表,则仍参与(仅结构、由权重侧决定帖子)。
- """
- if not exclude_ids:
- return True
- pids = row.get("post_ids")
- if not isinstance(pids, list) or not pids:
- return True
- return any(pid not in exclude_ids for pid in pids if pid)
- def _filter_data_rows_for_exclude(
- rows: List[Dict[str, Any]], exclude_ids: set[str]
- ) -> List[Dict[str, Any]]:
- if not exclude_ids:
- return rows
- return [r for r in rows if _row_has_posts_after_exclude(r, exclude_ids)]
- # -----------------------------------------------------------------------------
- # 扁平 tree JSON -> 分类树(直接元素 / 子分类)
- # -----------------------------------------------------------------------------
- def _find_or_create_classification(
- level: List[Dict[str, Any]], name: str
- ) -> Dict[str, Any]:
- for node in level:
- if node.get("分类名称") == name:
- return node
- node = {"分类名称": name, "直接元素": [], "子分类": []}
- level.append(node)
- return node
- def _merge_path_element(
- classification_tree: List[Dict[str, Any]],
- segments: List[str],
- element: str,
- ) -> None:
- if not segments:
- return
- name = segments[0]
- node = _find_or_create_classification(classification_tree, name)
- if len(segments) == 1:
- if element and element not in node["直接元素"]:
- node["直接元素"].append(element)
- else:
- _merge_path_element(node["子分类"], segments[1:], element)
- def _rows_to_classification_tree(
- rows: List[Dict[str, Any]],
- ) -> Tuple[List[Dict[str, Any]], List[str]]:
- """
- 返回 (分类树, root 下直接挂接的元素名列表)。
- category_path 为空时(常见于「意图」维度),元素不进入子分类,直接挂在 root 下。
- """
- tree: List[Dict[str, Any]] = []
- root_direct: List[str] = []
- for row in rows:
- raw_path = row.get("category_path") or ""
- element = (row.get("element_name") or "").strip()
- segments = [s for s in raw_path.strip("/").split("/") if s]
- if not segments:
- if element and element not in root_direct:
- root_direct.append(element)
- continue
- _merge_path_element(tree, segments, element)
- return tree, root_direct
- def _is_classification_style(nodes: Any) -> bool:
- if not isinstance(nodes, list) or not nodes:
- return False
- n0 = nodes[0]
- return isinstance(n0, dict) and (
- "分类名称" in n0 or "直接元素" in n0 or "子分类" in n0
- )
- def load_classification_tree_from_file(
- tree_json_path: Path,
- exclude_note_ids: Optional[set[str]] = None,
- ) -> Optional[Tuple[List[Dict[str, Any]], List[str]]]:
- """
- 支持:
- 1) 扁平 data[](category_path + element_name;path 空则元素进 root 直挂列表)
- 2) 与 generate_how_point_tree 相同的 classification_tree JSON
- 返回 (最终分类树, root 直挂元素名);解析失败返回 None。
- """
- with open(tree_json_path, "r", encoding="utf-8") as f:
- tree_data = json.load(f)
- if isinstance(tree_data, dict):
- if "追加分类结果" in tree_data:
- ft = tree_data["追加分类结果"].get("最终分类树")
- if ft is not None:
- return ft, []
- if "分类结果" in tree_data:
- ft = tree_data["分类结果"].get("最终分类树")
- if ft is not None:
- return ft, []
- ft = tree_data.get("最终分类树")
- if ft is not None and _is_classification_style(ft):
- return ft, []
- rows = tree_data.get("data")
- if isinstance(rows, list):
- if not rows:
- return [], []
- if isinstance(rows[0], dict) and (
- "category_path" in rows[0] or "element_name" in rows[0]
- ):
- rows = _filter_data_rows_for_exclude(rows, exclude_note_ids or set())
- return _rows_to_classification_tree(rows)
- if isinstance(tree_data, list):
- if not tree_data:
- return [], []
- if isinstance(tree_data[0], dict) and (
- "category_path" in tree_data[0] or "element_name" in tree_data[0]
- ):
- rows = _filter_data_rows_for_exclude(tree_data, exclude_note_ids or set())
- return _rows_to_classification_tree(rows)
- if _is_classification_style(tree_data):
- return tree_data, []
- return None
- # -----------------------------------------------------------------------------
- # 权重(与 generate_how_point_tree.load_weight_scores 一致)
- # -----------------------------------------------------------------------------
- def load_weight_scores(
- weight_file: Path,
- exclude_note_ids: Optional[set[str]] = None,
- ) -> Tuple[Dict[str, float], Dict[str, set], int]:
- if not weight_file.exists():
- print(f"警告: 权重分文件不存在: {weight_file}")
- return {}, {}, 0
- ex = exclude_note_ids or set()
- try:
- with open(weight_file, "r", encoding="utf-8") as f:
- weight_data = json.load(f)
- weight_map: Dict[str, float] = {}
- post_ids_map: Dict[str, set] = {}
- all_post_ids: set = set()
- for item in weight_data:
- word = item.get("词", "")
- score = item.get("归一权重分", 0.0)
- if word:
- post_ids: set = set()
- for detail in item.get("权重分详情", []):
- post_id = detail.get("帖子ID", "")
- if post_id and post_id not in ex:
- post_ids.add(post_id)
- all_post_ids.add(post_id)
- post_ids_map[word] = post_ids
- if not post_ids:
- weight_map[word] = 0.0
- else:
- weight_map[word] = score
- total_post_count = len(all_post_ids)
- return weight_map, post_ids_map, total_post_count
- except Exception as e:
- print(f"加载权重分数据失败: {e}")
- return {}, {}, 0
- # -----------------------------------------------------------------------------
- # 建树、概率、常量标记(与 generate_how_point_tree 一致)
- # -----------------------------------------------------------------------------
- def build_id_node(
- element: str,
- weight_map: Dict[str, float],
- post_ids_map: Dict[str, set],
- ) -> Dict[str, Any]:
- post_ids = post_ids_map.get(element, set())
- post_ids_list = list(post_ids)
- return {
- "_type": "ID",
- "_persona_weight_score": round(weight_map.get(element, 0.0), 4),
- "_post_count": len(post_ids_list),
- "_post_ids": post_ids_list,
- }
- def build_tree_node_from_classification(
- classification_node: Dict,
- weight_map: Dict[str, float],
- post_ids_map: Dict[str, set],
- dimension: str,
- ) -> Dict[str, Any]:
- direct_elements = classification_node.get("直接元素", [])
- sub_classifications = classification_node.get("子分类", [])
- children: Dict[str, Any] = {}
- for element in direct_elements:
- if element:
- children[element] = build_id_node(element, weight_map, post_ids_map)
- for sub_class in sub_classifications:
- sub_node = build_tree_node_from_classification(
- sub_class, weight_map, post_ids_map, dimension
- )
- sub_node_name = sub_class.get("分类名称", "")
- if sub_node_name:
- children[sub_node_name] = sub_node
- def count_leaves_and_sum_scores(node: Dict) -> Tuple[int, float]:
- if node.get("_type") == "ID":
- return (1, node.get("_persona_weight_score", 0.0))
- if node.get("_type") == "class":
- leaf_count = 0
- total_score = 0.0
- for child in node.get("children", {}).values():
- c, s = count_leaves_and_sum_scores(child)
- leaf_count += c
- total_score += s
- return (leaf_count, total_score)
- return (0, 0.0)
- def collect_post_ids(node: Dict) -> set:
- if node.get("_type") == "ID":
- ids = node.get("_post_ids", [])
- return set(ids) if isinstance(ids, list) else ids
- if node.get("_type") == "class":
- all_post_ids = set()
- for child in node.get("children", {}).values():
- all_post_ids |= collect_post_ids(child)
- return all_post_ids
- return set()
- total_score = 0.0
- for child in children.values():
- _, score = count_leaves_and_sum_scores(child)
- total_score += score
- all_post_ids = set()
- for child in children.values():
- all_post_ids |= collect_post_ids(child)
- total_post_count = len(all_post_ids)
- result: Dict[str, Any] = {
- "_type": "class",
- "_persona_weight_score": round(total_score, 4),
- "_post_count": total_post_count,
- "_post_ids": list(all_post_ids),
- }
- if children:
- result["children"] = children
- return result
- def build_tree_from_classification(
- classification_tree: List[Dict],
- weight_map: Dict[str, float],
- post_ids_map: Dict[str, set],
- total_post_count: int,
- dimension: str,
- root_direct_elements: Optional[List[str]] = None,
- ) -> Dict[str, Any]:
- root_children: Dict[str, Any] = {}
- for classification_node in classification_tree:
- node = build_tree_node_from_classification(
- classification_node, weight_map, post_ids_map, dimension
- )
- node_name = classification_node.get("分类名称", "")
- if node_name:
- root_children[node_name] = node
- # 意图等:category_path 为空时元素直接作为 root 的子节点(_type=ID)
- if root_direct_elements:
- for element in root_direct_elements:
- if not element or element in root_children:
- continue
- root_children[element] = build_id_node(element, weight_map, post_ids_map)
- def collect_post_ids(node: Dict) -> set:
- if node.get("_type") == "ID":
- ids = node.get("_post_ids", [])
- return set(ids) if isinstance(ids, list) else ids
- if node.get("_type") == "class":
- all_post_ids = set()
- for child in node.get("children", {}).values():
- all_post_ids |= collect_post_ids(child)
- return all_post_ids
- return set()
- root_post_ids: set = set()
- for child in root_children.values():
- root_post_ids |= collect_post_ids(child)
- root_post_count = len(root_post_ids)
- root: Dict[str, Any] = {
- "_type": "root",
- "_post_count": root_post_count,
- "_post_ids": list(root_post_ids),
- "children": root_children,
- }
- def calculate_ratio(node: Dict[str, Any], parent_post_count: int | None = None) -> None:
- node_type = node.get("_type")
- post_count = node.get("_post_count", 0)
- if node_type == "root":
- pass
- elif node_type in ("class", "ID"):
- if total_post_count > 0:
- node["_ratio"] = round(post_count / total_post_count, 4)
- else:
- node["_ratio"] = 0.0
- children = node.get("children", {})
- for child in children.values():
- calculate_ratio(child, post_count)
- calculate_ratio(root)
- return root
- def collect_id_nodes(
- node: Dict[str, Any],
- node_path: List[str],
- id_nodes: List[Tuple[Dict[str, Any], List[str]]],
- ) -> None:
- node_type = node.get("_type")
- children = node.get("children", {})
- if node_type == "ID":
- id_nodes.append((node, node_path.copy()))
- elif node_type in ("class", "root"):
- for child_name, child_node in children.items():
- collect_id_nodes(child_node, node_path + [child_name], id_nodes)
- def collect_class_nodes(
- node: Dict[str, Any],
- node_path: List[str],
- class_nodes: List[Tuple[Dict[str, Any], List[str]]],
- is_root: bool = False,
- ) -> None:
- node_type = node.get("_type")
- children = node.get("children", {})
- is_first_level = len(node_path) == 1
- if node_type == "class":
- if not is_root and not is_first_level:
- class_nodes.append((node, node_path.copy()))
- for child_name, child_node in children.items():
- collect_class_nodes(child_node, node_path + [child_name], class_nodes, False)
- elif node_type == "root":
- for child_name, child_node in children.items():
- collect_class_nodes(child_node, node_path + [child_name], class_nodes, False)
- def select_constant_nodes(
- candidates: List[Tuple[Dict[str, Any], List[str]]],
- ) -> set:
- if not candidates:
- return set()
- max_score = 0.0
- for node, _ in candidates:
- score = node.get("_persona_weight_score", 0.0)
- if score > max_score:
- max_score = score
- candidate_scores = []
- for node, path in candidates:
- score = node.get("_persona_weight_score", 0.0)
- relative_score = (
- score / max_score if max_score > 0 else (1.0 if len(candidates) == 1 else 0.0)
- )
- candidate_scores.append((node, path, relative_score, score))
- qualified_candidates = [
- (node, path, rel_score, score)
- for node, path, rel_score, score in candidate_scores
- if rel_score >= 0.5
- ]
- if len(qualified_candidates) > 8:
- qualified_candidates.sort(key=lambda x: x[2], reverse=True)
- constant_nodes = qualified_candidates[:8]
- else:
- constant_nodes = qualified_candidates.copy()
- if len(constant_nodes) < 3:
- filtered_candidates = [
- (node, path, rel_score, score)
- for node, path, rel_score, score in candidate_scores
- if rel_score >= 0.2
- ]
- filtered_candidates.sort(key=lambda x: x[2], reverse=True)
- constant_nodes = filtered_candidates[: min(3, len(filtered_candidates))]
- return {tuple(path) for _, path, _, _ in constant_nodes}
- def mark_constant_nodes(tree: Dict[str, Any]) -> None:
- id_nodes: List[Tuple[Dict[str, Any], List[str]]] = []
- collect_id_nodes(tree, [], id_nodes)
- class_nodes: List[Tuple[Dict[str, Any], List[str]]] = []
- collect_class_nodes(tree, [], class_nodes, is_root=True)
- id_constant_paths = select_constant_nodes(id_nodes)
- class_constant_paths = select_constant_nodes(class_nodes)
- constant_paths = id_constant_paths | class_constant_paths
- def mark_node(node: Dict[str, Any], path: List[str], is_root: bool = False) -> None:
- node_type = node.get("_type")
- children = node.get("children", {})
- is_first_level = len(path) == 1
- if node_type == "ID":
- node["_is_constant"] = tuple(path) in constant_paths
- elif node_type == "class":
- if not is_root and not is_first_level:
- node["_is_constant"] = tuple(path) in constant_paths
- for child_name, child_node in children.items():
- mark_node(child_node, path + [child_name], False)
- elif node_type == "root":
- for child_name, child_node in children.items():
- mark_node(child_node, path + [child_name], False)
- mark_node(tree, [], True)
- def get_cache_key(parent_category: str, child_categories: List[str]) -> str:
- sorted_categories = sorted(child_categories)
- content = f"{parent_category}|||{','.join(sorted_categories)}"
- return hashlib.md5(content.encode("utf-8")).hexdigest()
- def _try_parse_json_text(text: str) -> Dict[str, Any]:
- text = text.strip()
- text = re.sub(r"^```(?:json)?\s*", "", text)
- text = re.sub(r"\s*```\s*$", "", text)
- return json.loads(text)
- def _gemini_json_call(
- system_prompt: str,
- user_prompt: str,
- model: str,
- ) -> str:
- if httpx is None:
- raise RuntimeError("需要安装 httpx: pip install httpx")
- api_key = os.getenv("GEMINI_API_KEY")
- if not api_key:
- raise ValueError("GEMINI_API_KEY 未设置")
- base_url = os.getenv("GEMINI_API_BASE", "https://generativelanguage.googleapis.com/v1beta")
- url = f"{base_url}/models/{model}:generateContent"
- payload: Dict[str, Any] = {
- "contents": [{"role": "user", "parts": [{"text": user_prompt}]}],
- "systemInstruction": {"parts": [{"text": system_prompt}]},
- "generationConfig": {
- "temperature": 0,
- "maxOutputTokens": 4096,
- "responseMimeType": "application/json",
- },
- }
- with httpx.Client(timeout=120.0) as client:
- r = client.post(url, params={"key": api_key}, json=payload)
- r.raise_for_status()
- data = r.json()
- candidates = data.get("candidates") or []
- if not candidates:
- raise RuntimeError(f"Gemini 无候选输出: {data}")
- parts = (candidates[0].get("content") or {}).get("parts") or []
- text = "".join(p.get("text", "") for p in parts)
- if not text.strip():
- raise RuntimeError("Gemini 返回空文本")
- return text
- def load_cached_relation(account_name: str, cache_key: str) -> Optional[Dict]:
- cache_file = _cache_dir(account_name) / f"{cache_key}.json"
- if cache_file.exists():
- try:
- with open(cache_file, "r", encoding="utf-8") as f:
- return json.load(f)
- except Exception as e:
- print(f"读取缓存失败: {e}")
- return None
- def save_cached_relation(account_name: str, cache_key: str, relation_data: Dict) -> None:
- cache_file = _cache_dir(account_name) / f"{cache_key}.json"
- try:
- with open(cache_file, "w", encoding="utf-8") as f:
- json.dump(relation_data, f, ensure_ascii=False, indent=2)
- except Exception as e:
- print(f"保存缓存失败: {e}")
- def judge_category_relation(
- parent_category: str,
- child_categories: List[str],
- account_name: str,
- prompt_path: Path,
- model: str,
- ) -> Dict[str, Any]:
- cache_key = get_cache_key(parent_category, child_categories)
- cached = load_cached_relation(account_name, cache_key)
- if cached:
- return cached
- if not prompt_path.exists():
- raise FileNotFoundError(f"Prompt 文件不存在: {prompt_path}")
- prompt_template = prompt_path.read_text(encoding="utf-8")
- system_prompt = (
- prompt_template.replace("{parent_category}", parent_category).replace(
- "{child_categories}", json.dumps(child_categories, ensure_ascii=False)
- )
- )
- user_prompt = "请分析父分类和子分类列表的关系,判断它们是互斥还是有交集,并以JSON格式输出结果。"
- try:
- raw = _gemini_json_call(system_prompt, user_prompt, model=model)
- result = _try_parse_json_text(raw)
- save_cached_relation(account_name, cache_key, result)
- return result
- except Exception as e:
- print(f"调用LLM判断分类关系失败: {e}")
- result = {
- "relation": "有交集",
- "confidence": 0.5,
- "reasoning": f"LLM调用失败,默认判断为有交集: {str(e)}",
- }
- save_cached_relation(account_name, cache_key, result)
- return result
- def mark_local_constant_nodes(
- tree: Dict[str, Any],
- account_name: str,
- prompt_path: Path,
- model: str,
- ) -> None:
- def process_node(node: Dict[str, Any], path: List[str], is_root: bool = False) -> None:
- node_type = node.get("_type")
- children = node.get("children", {})
- is_first_level = len(path) == 1
- if node_type == "root":
- for child_name, child_node in children.items():
- process_node(child_node, path + [child_name], False)
- elif node_type == "class":
- ratio = node.get("_ratio", 0.0)
- if (is_first_level or len(path) > 1) and ratio >= 0.5:
- sub_class_nodes = [
- (name, cn)
- for name, cn in children.items()
- if cn.get("_type") == "class"
- ]
- if len(sub_class_nodes) >= 2:
- child_categories = [name for name, _ in sub_class_nodes]
- parent_category = path[-1] if path else "根分类"
- relation_result = judge_category_relation(
- parent_category, child_categories, account_name, prompt_path, model
- )
- relation = relation_result.get("relation", "有交集")
- node["_child_categories_relation"] = relation
- node["_child_categories_relation_detail"] = relation_result
- if relation == "互斥":
- for child_name, child_node in sub_class_nodes:
- child_node["_is_local_constant"] = True
- else:
- parent_post_count = node.get("_post_count", 0)
- if parent_post_count > 0:
- for child_name, child_node in sub_class_nodes:
- child_post_count = child_node.get("_post_count", 0)
- child_node["_is_local_constant"] = (
- child_post_count / parent_post_count > 0.5
- )
- else:
- for child_name, child_node in sub_class_nodes:
- child_node["_is_local_constant"] = False
- for child_name, child_node in children.items():
- process_node(child_node, path + [child_name], False)
- process_node(tree, [], True)
- def discover_dimensions(tree_dir: Path) -> List[str]:
- dims: List[str] = []
- if not tree_dir.is_dir():
- return dims
- for p in sorted(tree_dir.glob("*_tree.json")):
- name = p.name
- if name.endswith("_tree.json"):
- dim = name[: -len("_tree.json")]
- if dim:
- dims.append(dim)
- return dims
- def process_account(
- account_name: str,
- prompt_path: Optional[Path] = None,
- model: Optional[str] = None,
- dimensions: Optional[List[str]] = None,
- ) -> None:
- prompt_path = prompt_path or _DEFAULT_PROMPT
- model = model or _DEFAULT_MODEL
- base = _input_base(account_name)
- tree_dir = base / "tree"
- weight_dir = base / "point_tree_weight"
- out_dir = _output_tree_dir(account_name)
- out_dir.mkdir(parents=True, exist_ok=True)
- dims = dimensions if dimensions is not None else discover_dimensions(tree_dir)
- if not dims:
- print(f"未在 {tree_dir} 找到 *_tree.json")
- sys.exit(1)
- exclude_note_ids = load_exclude_note_ids(base)
- print(f"账号: {account_name}")
- print(f"输出目录: {out_dir}")
- print(f"维度: {dims}")
- print(f"Gemini 模型: {model}")
- print(f"排除帖子 ID 数: {len(exclude_note_ids)}")
- for dimension in dims:
- tree_file = tree_dir / f"{dimension}_tree.json"
- weight_file = weight_dir / f"{dimension}_tree_weight_score.json"
- if not tree_file.exists():
- print(f"跳过维度 {dimension}:缺少 {tree_file}")
- continue
- weight_map, post_ids_map, total_post_count = load_weight_scores(
- weight_file, exclude_note_ids=exclude_note_ids
- )
- if not weight_map:
- print(f"跳过维度 {dimension}:无法加载权重分 {weight_file}")
- continue
- loaded = load_classification_tree_from_file(
- tree_file, exclude_note_ids=exclude_note_ids
- )
- if loaded is None:
- print(f"跳过维度 {dimension}:无法解析分类树 {tree_file}")
- continue
- classification_tree, root_direct_elements = loaded
- print(
- f"处理 {dimension}: 分类顶层 {len(classification_tree)} 类, "
- f"root 直挂 {len(root_direct_elements)} 词, 权重词 {len(weight_map)}"
- )
- tree = build_tree_from_classification(
- classification_tree,
- weight_map,
- post_ids_map,
- total_post_count,
- dimension,
- root_direct_elements=root_direct_elements or None,
- )
- mark_constant_nodes(tree)
- mark_local_constant_nodes(tree, account_name, prompt_path, model)
- result = {dimension: tree}
- out_file = out_dir / f"{dimension}_point_tree_how.json"
- with open(out_file, "w", encoding="utf-8") as f:
- json.dump(result, f, ensure_ascii=False, indent=2)
- print(f"已写入 {out_file}")
- def main(account_name) -> None:
- process_account(
- account_name
- )
- if __name__ == "__main__":
- main(account_name="空间点阵设计研究室")
|