#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ How 人设树处理:从「原始数据/tree」与「原始数据/point_tree_weight」生成与 「实质_point_tree_how.json」同结构的输出。 可选:`原始数据/exclude_note_ids.json` 中的帖子 ID 不参与建树(权重聚合与带 post_ids 的扁平 data 行会过滤)。 用法: python how_tree_data_process.py --account_name 阿里多多酱 依赖: 环境变量 GEMINI_API_KEY(用于子分类关系判断;缺失时采用保守默认「有交集」) """ from __future__ import annotations import argparse import hashlib import json import os import re import sys from pathlib import Path from typing import Any, Dict, List, Optional, Tuple try: import httpx except ImportError: httpx = None # type: ignore # ----------------------------------------------------------------------------- # 路径 # ----------------------------------------------------------------------------- _SCRIPT_DIR = Path(__file__).resolve().parent _OVERALL_DIR = _SCRIPT_DIR.parent _DEFAULT_PROMPT = _OVERALL_DIR / "prompt" / "judge_category_relation.md" _DEFAULT_MODEL = "gemini-3-flash-preview" def _input_base(account_name: str) -> Path: return _OVERALL_DIR / "input" / account_name / "原始数据" def _output_tree_dir(account_name: str) -> Path: return _OVERALL_DIR / "input" / account_name / "处理后数据" / "tree" def _cache_dir(account_name: str) -> Path: d = _output_tree_dir(account_name) / ".cache_relation" / account_name d.mkdir(parents=True, exist_ok=True) return d def load_exclude_note_ids(base: Path) -> set[str]: """ 读取「原始数据/exclude_note_ids.json」,格式为帖子 ID 字符串数组。 文件不存在或解析失败时返回空集合。 """ path = base / "exclude_note_ids.json" if not path.exists(): return set() try: with open(path, "r", encoding="utf-8") as f: raw = json.load(f) if isinstance(raw, list): return {str(x).strip() for x in raw if str(x).strip()} if isinstance(raw, dict) and "ids" in raw: v = raw["ids"] if isinstance(v, list): return {str(x).strip() for x in v if str(x).strip()} except Exception as e: print(f"警告: 读取排除帖子列表失败 {path}: {e}") return set() def _row_has_posts_after_exclude(row: Dict[str, Any], exclude_ids: set[str]) -> bool: """ 若行带 post_ids 且去掉排除 ID 后仍非空,则参与建树; 若无 post_ids 或为空列表,则仍参与(仅结构、由权重侧决定帖子)。 """ if not exclude_ids: return True pids = row.get("post_ids") if not isinstance(pids, list) or not pids: return True return any(pid not in exclude_ids for pid in pids if pid) def _filter_data_rows_for_exclude( rows: List[Dict[str, Any]], exclude_ids: set[str] ) -> List[Dict[str, Any]]: if not exclude_ids: return rows return [r for r in rows if _row_has_posts_after_exclude(r, exclude_ids)] # ----------------------------------------------------------------------------- # 扁平 tree JSON -> 分类树(直接元素 / 子分类) # ----------------------------------------------------------------------------- def _find_or_create_classification( level: List[Dict[str, Any]], name: str ) -> Dict[str, Any]: for node in level: if node.get("分类名称") == name: return node node = {"分类名称": name, "直接元素": [], "子分类": []} level.append(node) return node def _merge_path_element( classification_tree: List[Dict[str, Any]], segments: List[str], element: str, ) -> None: if not segments: return name = segments[0] node = _find_or_create_classification(classification_tree, name) if len(segments) == 1: if element and element not in node["直接元素"]: node["直接元素"].append(element) else: _merge_path_element(node["子分类"], segments[1:], element) def _rows_to_classification_tree( rows: List[Dict[str, Any]], ) -> Tuple[List[Dict[str, Any]], List[str]]: """ 返回 (分类树, root 下直接挂接的元素名列表)。 category_path 为空时(常见于「意图」维度),元素不进入子分类,直接挂在 root 下。 """ tree: List[Dict[str, Any]] = [] root_direct: List[str] = [] for row in rows: raw_path = row.get("category_path") or "" element = (row.get("element_name") or "").strip() segments = [s for s in raw_path.strip("/").split("/") if s] if not segments: if element and element not in root_direct: root_direct.append(element) continue _merge_path_element(tree, segments, element) return tree, root_direct def _is_classification_style(nodes: Any) -> bool: if not isinstance(nodes, list) or not nodes: return False n0 = nodes[0] return isinstance(n0, dict) and ( "分类名称" in n0 or "直接元素" in n0 or "子分类" in n0 ) def load_classification_tree_from_file( tree_json_path: Path, exclude_note_ids: Optional[set[str]] = None, ) -> Optional[Tuple[List[Dict[str, Any]], List[str]]]: """ 支持: 1) 扁平 data[](category_path + element_name;path 空则元素进 root 直挂列表) 2) 与 generate_how_point_tree 相同的 classification_tree JSON 返回 (最终分类树, root 直挂元素名);解析失败返回 None。 """ with open(tree_json_path, "r", encoding="utf-8") as f: tree_data = json.load(f) if isinstance(tree_data, dict): if "追加分类结果" in tree_data: ft = tree_data["追加分类结果"].get("最终分类树") if ft is not None: return ft, [] if "分类结果" in tree_data: ft = tree_data["分类结果"].get("最终分类树") if ft is not None: return ft, [] ft = tree_data.get("最终分类树") if ft is not None and _is_classification_style(ft): return ft, [] rows = tree_data.get("data") if isinstance(rows, list): if not rows: return [], [] if isinstance(rows[0], dict) and ( "category_path" in rows[0] or "element_name" in rows[0] ): rows = _filter_data_rows_for_exclude(rows, exclude_note_ids or set()) return _rows_to_classification_tree(rows) if isinstance(tree_data, list): if not tree_data: return [], [] if isinstance(tree_data[0], dict) and ( "category_path" in tree_data[0] or "element_name" in tree_data[0] ): rows = _filter_data_rows_for_exclude(tree_data, exclude_note_ids or set()) return _rows_to_classification_tree(rows) if _is_classification_style(tree_data): return tree_data, [] return None # ----------------------------------------------------------------------------- # 权重(与 generate_how_point_tree.load_weight_scores 一致) # ----------------------------------------------------------------------------- def load_weight_scores( weight_file: Path, exclude_note_ids: Optional[set[str]] = None, ) -> Tuple[Dict[str, float], Dict[str, set], int]: if not weight_file.exists(): print(f"警告: 权重分文件不存在: {weight_file}") return {}, {}, 0 ex = exclude_note_ids or set() try: with open(weight_file, "r", encoding="utf-8") as f: weight_data = json.load(f) weight_map: Dict[str, float] = {} post_ids_map: Dict[str, set] = {} all_post_ids: set = set() for item in weight_data: word = item.get("词", "") score = item.get("归一权重分", 0.0) if word: post_ids: set = set() for detail in item.get("权重分详情", []): post_id = detail.get("帖子ID", "") if post_id and post_id not in ex: post_ids.add(post_id) all_post_ids.add(post_id) post_ids_map[word] = post_ids if not post_ids: weight_map[word] = 0.0 else: weight_map[word] = score total_post_count = len(all_post_ids) return weight_map, post_ids_map, total_post_count except Exception as e: print(f"加载权重分数据失败: {e}") return {}, {}, 0 # ----------------------------------------------------------------------------- # 建树、概率、常量标记(与 generate_how_point_tree 一致) # ----------------------------------------------------------------------------- def build_id_node( element: str, weight_map: Dict[str, float], post_ids_map: Dict[str, set], ) -> Dict[str, Any]: post_ids = post_ids_map.get(element, set()) post_ids_list = list(post_ids) return { "_type": "ID", "_persona_weight_score": round(weight_map.get(element, 0.0), 4), "_post_count": len(post_ids_list), "_post_ids": post_ids_list, } def build_tree_node_from_classification( classification_node: Dict, weight_map: Dict[str, float], post_ids_map: Dict[str, set], dimension: str, ) -> Dict[str, Any]: direct_elements = classification_node.get("直接元素", []) sub_classifications = classification_node.get("子分类", []) children: Dict[str, Any] = {} for element in direct_elements: if element: children[element] = build_id_node(element, weight_map, post_ids_map) for sub_class in sub_classifications: sub_node = build_tree_node_from_classification( sub_class, weight_map, post_ids_map, dimension ) sub_node_name = sub_class.get("分类名称", "") if sub_node_name: children[sub_node_name] = sub_node def count_leaves_and_sum_scores(node: Dict) -> Tuple[int, float]: if node.get("_type") == "ID": return (1, node.get("_persona_weight_score", 0.0)) if node.get("_type") == "class": leaf_count = 0 total_score = 0.0 for child in node.get("children", {}).values(): c, s = count_leaves_and_sum_scores(child) leaf_count += c total_score += s return (leaf_count, total_score) return (0, 0.0) def collect_post_ids(node: Dict) -> set: if node.get("_type") == "ID": ids = node.get("_post_ids", []) return set(ids) if isinstance(ids, list) else ids if node.get("_type") == "class": all_post_ids = set() for child in node.get("children", {}).values(): all_post_ids |= collect_post_ids(child) return all_post_ids return set() total_score = 0.0 for child in children.values(): _, score = count_leaves_and_sum_scores(child) total_score += score all_post_ids = set() for child in children.values(): all_post_ids |= collect_post_ids(child) total_post_count = len(all_post_ids) result: Dict[str, Any] = { "_type": "class", "_persona_weight_score": round(total_score, 4), "_post_count": total_post_count, "_post_ids": list(all_post_ids), } if children: result["children"] = children return result def build_tree_from_classification( classification_tree: List[Dict], weight_map: Dict[str, float], post_ids_map: Dict[str, set], total_post_count: int, dimension: str, root_direct_elements: Optional[List[str]] = None, ) -> Dict[str, Any]: root_children: Dict[str, Any] = {} for classification_node in classification_tree: node = build_tree_node_from_classification( classification_node, weight_map, post_ids_map, dimension ) node_name = classification_node.get("分类名称", "") if node_name: root_children[node_name] = node # 意图等:category_path 为空时元素直接作为 root 的子节点(_type=ID) if root_direct_elements: for element in root_direct_elements: if not element or element in root_children: continue root_children[element] = build_id_node(element, weight_map, post_ids_map) def collect_post_ids(node: Dict) -> set: if node.get("_type") == "ID": ids = node.get("_post_ids", []) return set(ids) if isinstance(ids, list) else ids if node.get("_type") == "class": all_post_ids = set() for child in node.get("children", {}).values(): all_post_ids |= collect_post_ids(child) return all_post_ids return set() root_post_ids: set = set() for child in root_children.values(): root_post_ids |= collect_post_ids(child) root_post_count = len(root_post_ids) root: Dict[str, Any] = { "_type": "root", "_post_count": root_post_count, "_post_ids": list(root_post_ids), "children": root_children, } def calculate_ratio(node: Dict[str, Any], parent_post_count: int | None = None) -> None: node_type = node.get("_type") post_count = node.get("_post_count", 0) if node_type == "root": pass elif node_type in ("class", "ID"): if total_post_count > 0: node["_ratio"] = round(post_count / total_post_count, 4) else: node["_ratio"] = 0.0 children = node.get("children", {}) for child in children.values(): calculate_ratio(child, post_count) calculate_ratio(root) return root def collect_id_nodes( node: Dict[str, Any], node_path: List[str], id_nodes: List[Tuple[Dict[str, Any], List[str]]], ) -> None: node_type = node.get("_type") children = node.get("children", {}) if node_type == "ID": id_nodes.append((node, node_path.copy())) elif node_type in ("class", "root"): for child_name, child_node in children.items(): collect_id_nodes(child_node, node_path + [child_name], id_nodes) def collect_class_nodes( node: Dict[str, Any], node_path: List[str], class_nodes: List[Tuple[Dict[str, Any], List[str]]], is_root: bool = False, ) -> None: node_type = node.get("_type") children = node.get("children", {}) is_first_level = len(node_path) == 1 if node_type == "class": if not is_root and not is_first_level: class_nodes.append((node, node_path.copy())) for child_name, child_node in children.items(): collect_class_nodes(child_node, node_path + [child_name], class_nodes, False) elif node_type == "root": for child_name, child_node in children.items(): collect_class_nodes(child_node, node_path + [child_name], class_nodes, False) def select_constant_nodes( candidates: List[Tuple[Dict[str, Any], List[str]]], ) -> set: if not candidates: return set() max_score = 0.0 for node, _ in candidates: score = node.get("_persona_weight_score", 0.0) if score > max_score: max_score = score candidate_scores = [] for node, path in candidates: score = node.get("_persona_weight_score", 0.0) relative_score = ( score / max_score if max_score > 0 else (1.0 if len(candidates) == 1 else 0.0) ) candidate_scores.append((node, path, relative_score, score)) qualified_candidates = [ (node, path, rel_score, score) for node, path, rel_score, score in candidate_scores if rel_score >= 0.5 ] if len(qualified_candidates) > 8: qualified_candidates.sort(key=lambda x: x[2], reverse=True) constant_nodes = qualified_candidates[:8] else: constant_nodes = qualified_candidates.copy() if len(constant_nodes) < 3: filtered_candidates = [ (node, path, rel_score, score) for node, path, rel_score, score in candidate_scores if rel_score >= 0.2 ] filtered_candidates.sort(key=lambda x: x[2], reverse=True) constant_nodes = filtered_candidates[: min(3, len(filtered_candidates))] return {tuple(path) for _, path, _, _ in constant_nodes} def mark_constant_nodes(tree: Dict[str, Any]) -> None: id_nodes: List[Tuple[Dict[str, Any], List[str]]] = [] collect_id_nodes(tree, [], id_nodes) class_nodes: List[Tuple[Dict[str, Any], List[str]]] = [] collect_class_nodes(tree, [], class_nodes, is_root=True) id_constant_paths = select_constant_nodes(id_nodes) class_constant_paths = select_constant_nodes(class_nodes) constant_paths = id_constant_paths | class_constant_paths def mark_node(node: Dict[str, Any], path: List[str], is_root: bool = False) -> None: node_type = node.get("_type") children = node.get("children", {}) is_first_level = len(path) == 1 if node_type == "ID": node["_is_constant"] = tuple(path) in constant_paths elif node_type == "class": if not is_root and not is_first_level: node["_is_constant"] = tuple(path) in constant_paths for child_name, child_node in children.items(): mark_node(child_node, path + [child_name], False) elif node_type == "root": for child_name, child_node in children.items(): mark_node(child_node, path + [child_name], False) mark_node(tree, [], True) def get_cache_key(parent_category: str, child_categories: List[str]) -> str: sorted_categories = sorted(child_categories) content = f"{parent_category}|||{','.join(sorted_categories)}" return hashlib.md5(content.encode("utf-8")).hexdigest() def _try_parse_json_text(text: str) -> Dict[str, Any]: text = text.strip() text = re.sub(r"^```(?:json)?\s*", "", text) text = re.sub(r"\s*```\s*$", "", text) return json.loads(text) def _gemini_json_call( system_prompt: str, user_prompt: str, model: str, ) -> str: if httpx is None: raise RuntimeError("需要安装 httpx: pip install httpx") api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise ValueError("GEMINI_API_KEY 未设置") base_url = os.getenv("GEMINI_API_BASE", "https://generativelanguage.googleapis.com/v1beta") url = f"{base_url}/models/{model}:generateContent" payload: Dict[str, Any] = { "contents": [{"role": "user", "parts": [{"text": user_prompt}]}], "systemInstruction": {"parts": [{"text": system_prompt}]}, "generationConfig": { "temperature": 0, "maxOutputTokens": 4096, "responseMimeType": "application/json", }, } with httpx.Client(timeout=120.0) as client: r = client.post(url, params={"key": api_key}, json=payload) r.raise_for_status() data = r.json() candidates = data.get("candidates") or [] if not candidates: raise RuntimeError(f"Gemini 无候选输出: {data}") parts = (candidates[0].get("content") or {}).get("parts") or [] text = "".join(p.get("text", "") for p in parts) if not text.strip(): raise RuntimeError("Gemini 返回空文本") return text def load_cached_relation(account_name: str, cache_key: str) -> Optional[Dict]: cache_file = _cache_dir(account_name) / f"{cache_key}.json" if cache_file.exists(): try: with open(cache_file, "r", encoding="utf-8") as f: return json.load(f) except Exception as e: print(f"读取缓存失败: {e}") return None def save_cached_relation(account_name: str, cache_key: str, relation_data: Dict) -> None: cache_file = _cache_dir(account_name) / f"{cache_key}.json" try: with open(cache_file, "w", encoding="utf-8") as f: json.dump(relation_data, f, ensure_ascii=False, indent=2) except Exception as e: print(f"保存缓存失败: {e}") def judge_category_relation( parent_category: str, child_categories: List[str], account_name: str, prompt_path: Path, model: str, ) -> Dict[str, Any]: cache_key = get_cache_key(parent_category, child_categories) cached = load_cached_relation(account_name, cache_key) if cached: return cached if not prompt_path.exists(): raise FileNotFoundError(f"Prompt 文件不存在: {prompt_path}") prompt_template = prompt_path.read_text(encoding="utf-8") system_prompt = ( prompt_template.replace("{parent_category}", parent_category).replace( "{child_categories}", json.dumps(child_categories, ensure_ascii=False) ) ) user_prompt = "请分析父分类和子分类列表的关系,判断它们是互斥还是有交集,并以JSON格式输出结果。" try: raw = _gemini_json_call(system_prompt, user_prompt, model=model) result = _try_parse_json_text(raw) save_cached_relation(account_name, cache_key, result) return result except Exception as e: print(f"调用LLM判断分类关系失败: {e}") result = { "relation": "有交集", "confidence": 0.5, "reasoning": f"LLM调用失败,默认判断为有交集: {str(e)}", } save_cached_relation(account_name, cache_key, result) return result def mark_local_constant_nodes( tree: Dict[str, Any], account_name: str, prompt_path: Path, model: str, ) -> None: def process_node(node: Dict[str, Any], path: List[str], is_root: bool = False) -> None: node_type = node.get("_type") children = node.get("children", {}) is_first_level = len(path) == 1 if node_type == "root": for child_name, child_node in children.items(): process_node(child_node, path + [child_name], False) elif node_type == "class": ratio = node.get("_ratio", 0.0) if (is_first_level or len(path) > 1) and ratio >= 0.5: sub_class_nodes = [ (name, cn) for name, cn in children.items() if cn.get("_type") == "class" ] if len(sub_class_nodes) >= 2: child_categories = [name for name, _ in sub_class_nodes] parent_category = path[-1] if path else "根分类" relation_result = judge_category_relation( parent_category, child_categories, account_name, prompt_path, model ) relation = relation_result.get("relation", "有交集") node["_child_categories_relation"] = relation node["_child_categories_relation_detail"] = relation_result if relation == "互斥": for child_name, child_node in sub_class_nodes: child_node["_is_local_constant"] = True else: parent_post_count = node.get("_post_count", 0) if parent_post_count > 0: for child_name, child_node in sub_class_nodes: child_post_count = child_node.get("_post_count", 0) child_node["_is_local_constant"] = ( child_post_count / parent_post_count > 0.5 ) else: for child_name, child_node in sub_class_nodes: child_node["_is_local_constant"] = False for child_name, child_node in children.items(): process_node(child_node, path + [child_name], False) process_node(tree, [], True) def discover_dimensions(tree_dir: Path) -> List[str]: dims: List[str] = [] if not tree_dir.is_dir(): return dims for p in sorted(tree_dir.glob("*_tree.json")): name = p.name if name.endswith("_tree.json"): dim = name[: -len("_tree.json")] if dim: dims.append(dim) return dims def process_account( account_name: str, prompt_path: Optional[Path] = None, model: Optional[str] = None, dimensions: Optional[List[str]] = None, ) -> None: prompt_path = prompt_path or _DEFAULT_PROMPT model = model or _DEFAULT_MODEL base = _input_base(account_name) tree_dir = base / "tree" weight_dir = base / "point_tree_weight" out_dir = _output_tree_dir(account_name) out_dir.mkdir(parents=True, exist_ok=True) dims = dimensions if dimensions is not None else discover_dimensions(tree_dir) if not dims: print(f"未在 {tree_dir} 找到 *_tree.json") sys.exit(1) exclude_note_ids = load_exclude_note_ids(base) print(f"账号: {account_name}") print(f"输出目录: {out_dir}") print(f"维度: {dims}") print(f"Gemini 模型: {model}") print(f"排除帖子 ID 数: {len(exclude_note_ids)}") for dimension in dims: tree_file = tree_dir / f"{dimension}_tree.json" weight_file = weight_dir / f"{dimension}_tree_weight_score.json" if not tree_file.exists(): print(f"跳过维度 {dimension}:缺少 {tree_file}") continue weight_map, post_ids_map, total_post_count = load_weight_scores( weight_file, exclude_note_ids=exclude_note_ids ) if not weight_map: print(f"跳过维度 {dimension}:无法加载权重分 {weight_file}") continue loaded = load_classification_tree_from_file( tree_file, exclude_note_ids=exclude_note_ids ) if loaded is None: print(f"跳过维度 {dimension}:无法解析分类树 {tree_file}") continue classification_tree, root_direct_elements = loaded print( f"处理 {dimension}: 分类顶层 {len(classification_tree)} 类, " f"root 直挂 {len(root_direct_elements)} 词, 权重词 {len(weight_map)}" ) tree = build_tree_from_classification( classification_tree, weight_map, post_ids_map, total_post_count, dimension, root_direct_elements=root_direct_elements or None, ) mark_constant_nodes(tree) mark_local_constant_nodes(tree, account_name, prompt_path, model) result = {dimension: tree} out_file = out_dir / f"{dimension}_point_tree_how.json" with open(out_file, "w", encoding="utf-8") as f: json.dump(result, f, ensure_ascii=False, indent=2) print(f"已写入 {out_file}") def main(account_name) -> None: process_account( account_name ) if __name__ == "__main__": main(account_name="空间点阵设计研究室")