tree_lib_data_process.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """
  2. 人设树库数据处理脚本
  3. 将 xiaohongshu 原始数据树库(扁平列表格式)转换为嵌套树结构格式,
  4. 与 家有大志/原始数据/tree/*.json 格式对齐,方便后续统一处理。
  5. 输入:examples_how/overall_derivation/input/xiaohongshu/原始数据/tree/
  6. 输出:examples_how/overall_derivation/input/xiaohongshu/tree/
  7. """
  8. import json
  9. import os
  10. from pathlib import Path
  11. def parse_category_path(path: str) -> list[str]:
  12. """将 '/表象/实体/生物/动物/宠物' 解析为路径段列表。"""
  13. return [seg for seg in path.split("/") if seg]
  14. def build_tree_from_flat(elements: list, tree_name: str) -> dict:
  15. """
  16. 从扁平元素列表构建嵌套树结构。
  17. 每个元素按 category_path 插入对应层级,元素本身作为叶节点(_type: "ID")。
  18. 分类节点(_type: "class")的 _post_ids 为其下所有叶节点 post_ids 的去重并集。
  19. """
  20. # 中间构建用的树,使用嵌套 dict,children 内均为同结构
  21. root_children: dict = {}
  22. for elem in elements:
  23. path_segments = parse_category_path(elem["category_path"])
  24. elem_name = elem["element_name"]
  25. post_ids = elem.get("post_ids", [])
  26. # 逐层创建或取已有的分类节点
  27. current_children = root_children
  28. for seg in path_segments:
  29. if seg not in current_children:
  30. current_children[seg] = {
  31. "_type": "class",
  32. "_post_ids_set": set(),
  33. "children": {},
  34. }
  35. current_children = current_children[seg]["children"]
  36. # 叶节点:同名元素合并 post_ids
  37. if elem_name not in current_children:
  38. current_children[elem_name] = {
  39. "_type": "ID",
  40. "_post_ids_set": set(),
  41. }
  42. current_children[elem_name]["_post_ids_set"].update(post_ids)
  43. # 自底向上传播 post_ids 并计算 _post_count
  44. def propagate(children: dict) -> set:
  45. """递归传播,返回当前层所有 post_ids 的并集。"""
  46. all_ids: set = set()
  47. for node in children.values():
  48. if node["_type"] == "ID":
  49. all_ids.update(node["_post_ids_set"])
  50. else:
  51. child_ids = propagate(node["children"])
  52. node["_post_ids_set"].update(child_ids)
  53. all_ids.update(node["_post_ids_set"])
  54. return all_ids
  55. root_ids = propagate(root_children)
  56. root_post_count = len(root_ids)
  57. # 递归序列化为目标格式,同时计算 _ratio
  58. def serialize_class(name: str, node: dict) -> dict:
  59. post_ids_list = sorted(node["_post_ids_set"])
  60. post_count = len(post_ids_list)
  61. result = {
  62. "_type": node["_type"],
  63. "_post_count": post_count,
  64. "_post_ids": post_ids_list,
  65. }
  66. if node["_type"] == "class" and node.get("children"):
  67. serialized_children = {}
  68. for child_name, child_node in node["children"].items():
  69. serialized_children[child_name] = serialize_class(child_name, child_node)
  70. result["children"] = serialized_children
  71. result["_ratio"] = (
  72. round(post_count / root_post_count, 4) if root_post_count > 0 else 0.0
  73. )
  74. return result
  75. # 构建根节点
  76. root_ids_list = sorted(root_ids)
  77. serialized_children = {}
  78. for child_name, child_node in root_children.items():
  79. serialized_children[child_name] = serialize_class(child_name, child_node)
  80. root_node = {
  81. "_type": "root",
  82. "_post_count": root_post_count,
  83. "_post_ids": root_ids_list,
  84. "children": serialized_children,
  85. }
  86. return {tree_name: root_node}
  87. def process_file(input_path: str, output_path: str, tree_name: str) -> None:
  88. with open(input_path, "r", encoding="utf-8") as f:
  89. data = json.load(f)
  90. elements = data["data"]
  91. tree = build_tree_from_flat(elements, tree_name)
  92. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  93. with open(output_path, "w", encoding="utf-8") as f:
  94. json.dump(tree, f, ensure_ascii=False, indent=2)
  95. root_post_count = tree[tree_name]["_post_count"]
  96. print(f"[完成] {os.path.basename(input_path)} -> {output_path}")
  97. print(f" 根节点帖子数: {root_post_count},元素总数: {len(elements)}")
  98. def main():
  99. base_dir = Path(__file__).parent
  100. input_dir = base_dir / "input/xiaohongshu/原始数据/tree"
  101. output_dir = base_dir / "input/xiaohongshu/tree"
  102. file_map = {
  103. "实质_tree.json": "实质",
  104. "形式_tree.json": "形式",
  105. }
  106. for filename, tree_name in file_map.items():
  107. input_path = input_dir / filename
  108. output_path = output_dir / f"{tree_name}_point_tree_how.json"
  109. if not input_path.exists():
  110. print(f"[跳过] 文件不存在: {input_path}")
  111. continue
  112. process_file(str(input_path), str(output_path), tree_name)
  113. if __name__ == "__main__":
  114. main()