| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- """
- 人设树库数据处理脚本
- 将 xiaohongshu 原始数据树库(扁平列表格式)转换为嵌套树结构格式,
- 与 家有大志/原始数据/tree/*.json 格式对齐,方便后续统一处理。
- 输入:examples_how/overall_derivation/input/xiaohongshu/原始数据/tree/
- 输出:examples_how/overall_derivation/input/xiaohongshu/tree/
- """
- import json
- import os
- from pathlib import Path
- def parse_category_path(path: str) -> list[str]:
- """将 '/表象/实体/生物/动物/宠物' 解析为路径段列表。"""
- return [seg for seg in path.split("/") if seg]
- def build_tree_from_flat(elements: list, tree_name: str) -> dict:
- """
- 从扁平元素列表构建嵌套树结构。
- 每个元素按 category_path 插入对应层级,元素本身作为叶节点(_type: "ID")。
- 分类节点(_type: "class")的 _post_ids 为其下所有叶节点 post_ids 的去重并集。
- """
- # 中间构建用的树,使用嵌套 dict,children 内均为同结构
- root_children: dict = {}
- for elem in elements:
- path_segments = parse_category_path(elem["category_path"])
- elem_name = elem["element_name"]
- post_ids = elem.get("post_ids", [])
- # 逐层创建或取已有的分类节点
- current_children = root_children
- for seg in path_segments:
- if seg not in current_children:
- current_children[seg] = {
- "_type": "class",
- "_post_ids_set": set(),
- "children": {},
- }
- current_children = current_children[seg]["children"]
- # 叶节点:同名元素合并 post_ids
- if elem_name not in current_children:
- current_children[elem_name] = {
- "_type": "ID",
- "_post_ids_set": set(),
- }
- current_children[elem_name]["_post_ids_set"].update(post_ids)
- # 自底向上传播 post_ids 并计算 _post_count
- def propagate(children: dict) -> set:
- """递归传播,返回当前层所有 post_ids 的并集。"""
- all_ids: set = set()
- for node in children.values():
- if node["_type"] == "ID":
- all_ids.update(node["_post_ids_set"])
- else:
- child_ids = propagate(node["children"])
- node["_post_ids_set"].update(child_ids)
- all_ids.update(node["_post_ids_set"])
- return all_ids
- root_ids = propagate(root_children)
- root_post_count = len(root_ids)
- # 递归序列化为目标格式,同时计算 _ratio
- def serialize_class(name: str, node: dict) -> dict:
- post_ids_list = sorted(node["_post_ids_set"])
- post_count = len(post_ids_list)
- result = {
- "_type": node["_type"],
- "_post_count": post_count,
- "_post_ids": post_ids_list,
- }
- if node["_type"] == "class" and node.get("children"):
- serialized_children = {}
- for child_name, child_node in node["children"].items():
- serialized_children[child_name] = serialize_class(child_name, child_node)
- result["children"] = serialized_children
- result["_ratio"] = (
- round(post_count / root_post_count, 4) if root_post_count > 0 else 0.0
- )
- return result
- # 构建根节点
- root_ids_list = sorted(root_ids)
- serialized_children = {}
- for child_name, child_node in root_children.items():
- serialized_children[child_name] = serialize_class(child_name, child_node)
- root_node = {
- "_type": "root",
- "_post_count": root_post_count,
- "_post_ids": root_ids_list,
- "children": serialized_children,
- }
- return {tree_name: root_node}
- def process_file(input_path: str, output_path: str, tree_name: str) -> None:
- with open(input_path, "r", encoding="utf-8") as f:
- data = json.load(f)
- elements = data["data"]
- tree = build_tree_from_flat(elements, tree_name)
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
- with open(output_path, "w", encoding="utf-8") as f:
- json.dump(tree, f, ensure_ascii=False, indent=2)
- root_post_count = tree[tree_name]["_post_count"]
- print(f"[完成] {os.path.basename(input_path)} -> {output_path}")
- print(f" 根节点帖子数: {root_post_count},元素总数: {len(elements)}")
- def main():
- base_dir = Path(__file__).parent
- input_dir = base_dir / "input/xiaohongshu/原始数据/tree"
- output_dir = base_dir / "input/xiaohongshu/tree"
- file_map = {
- "实质_tree.json": "实质",
- "形式_tree.json": "形式",
- }
- for filename, tree_name in file_map.items():
- input_path = input_dir / filename
- output_path = output_dir / f"{tree_name}_point_tree_how.json"
- if not input_path.exists():
- print(f"[跳过] 文件不存在: {input_path}")
- continue
- process_file(str(input_path), str(output_path), tree_name)
- if __name__ == "__main__":
- main()
|