extract_feature_categories.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 从过去帖子_pattern聚合结果.json中提取特征名称及其对应的分类层级
  5. """
  6. import json
  7. from pathlib import Path
  8. from typing import Dict, List, Any, Optional, Set
  9. import sys
  10. import re
  11. # 添加项目根目录到路径
  12. project_root = Path(__file__).parent.parent.parent
  13. sys.path.insert(0, str(project_root))
  14. from script.detail import get_xiaohongshu_detail
  15. def extract_post_id_from_filename(filename: str) -> str:
  16. """从文件名中提取帖子ID"""
  17. match = re.match(r'^([^_]+)_', filename)
  18. if match:
  19. return match.group(1)
  20. return ""
  21. def get_post_detail(post_id: str) -> Optional[Dict]:
  22. """获取帖子详情"""
  23. try:
  24. detail = get_xiaohongshu_detail(post_id)
  25. return detail
  26. except Exception as e:
  27. print(f" 警告: 获取帖子 {post_id} 详情失败: {e}")
  28. return None
  29. def get_earliest_publish_time(current_posts_dir: Path) -> Optional[str]:
  30. """
  31. 获取当前帖子目录中最早的发布时间
  32. Args:
  33. current_posts_dir: 当前帖子目录路径
  34. Returns:
  35. 最早的发布时间字符串,格式为 "YYYY-MM-DD HH:MM:SS"
  36. """
  37. if not current_posts_dir.exists():
  38. print(f"警告: 当前帖子目录不存在: {current_posts_dir}")
  39. return None
  40. json_files = list(current_posts_dir.glob("*.json"))
  41. if not json_files:
  42. print(f"警告: 当前帖子目录为空: {current_posts_dir}")
  43. return None
  44. print(f"\n正在获取当前帖子的发布时间...")
  45. print(f"找到 {len(json_files)} 个当前帖子")
  46. earliest_time = None
  47. for file_path in json_files:
  48. post_id = extract_post_id_from_filename(file_path.name)
  49. if not post_id:
  50. continue
  51. try:
  52. detail = get_post_detail(post_id)
  53. if detail and 'publish_time' in detail:
  54. publish_time = detail['publish_time']
  55. if earliest_time is None or publish_time < earliest_time:
  56. earliest_time = publish_time
  57. print(f" 更新最早时间: {publish_time} (帖子: {post_id})")
  58. except Exception as e:
  59. print(f" 警告: 获取帖子 {post_id} 发布时间失败: {e}")
  60. if earliest_time:
  61. print(f"\n当前帖子最早发布时间: {earliest_time}")
  62. else:
  63. print("\n警告: 未能获取到任何当前帖子的发布时间")
  64. return earliest_time
  65. def collect_all_post_ids(data: Dict) -> Set[str]:
  66. """
  67. 收集数据中的所有帖子ID
  68. Args:
  69. data: 聚合结果数据
  70. Returns:
  71. 帖子ID集合
  72. """
  73. post_ids = set()
  74. def traverse_node(node):
  75. if isinstance(node, dict):
  76. # 检查是否有帖子列表
  77. if "帖子列表" in node and isinstance(node["帖子列表"], list):
  78. post_ids.update(node["帖子列表"])
  79. # 检查是否有特征列表
  80. if "特征列表" in node and isinstance(node["特征列表"], list):
  81. for feature in node["特征列表"]:
  82. if "帖子id" in feature:
  83. post_ids.add(feature["帖子id"])
  84. # 递归遍历
  85. for key, value in node.items():
  86. if key not in ["_meta", "帖子数", "特征数", "帖子列表"]:
  87. traverse_node(value)
  88. elif isinstance(node, list):
  89. for item in node:
  90. traverse_node(item)
  91. for category in ["灵感点列表", "目的点", "关键点列表"]:
  92. if category in data:
  93. traverse_node(data[category])
  94. return post_ids
  95. def filter_data_by_time(data: Dict, time_filter: str) -> tuple[Dict, Set[str]]:
  96. """
  97. 根据发布时间过滤数据
  98. Args:
  99. data: 原始聚合结果数据
  100. time_filter: 时间过滤阈值
  101. Returns:
  102. (过滤后的数据, 被过滤掉的帖子ID集合)
  103. """
  104. # 收集所有帖子ID
  105. all_post_ids = collect_all_post_ids(data)
  106. print(f"\n数据中包含 {len(all_post_ids)} 个不同的帖子")
  107. # 获取所有帖子的详情
  108. print("正在获取帖子详情...")
  109. post_details = {}
  110. for i, post_id in enumerate(all_post_ids, 1):
  111. print(f"[{i}/{len(all_post_ids)}] 获取帖子 {post_id} 的详情...")
  112. detail = get_post_detail(post_id)
  113. if detail:
  114. post_details[post_id] = detail
  115. # 根据时间过滤(过滤掉发布时间晚于等于阈值的帖子,避免穿越)
  116. print(f"\n正在应用时间过滤 (< {time_filter}),避免使用晚于当前帖子的数据...")
  117. filtered_post_ids = set()
  118. valid_post_ids = set()
  119. for post_id, detail in post_details.items():
  120. publish_time = detail.get('publish_time', '')
  121. if publish_time < time_filter:
  122. valid_post_ids.add(post_id)
  123. else:
  124. filtered_post_ids.add(post_id)
  125. print(f" ⚠️ 过滤掉帖子 {post_id} (发布时间: {publish_time},晚于阈值)")
  126. print(f"\n过滤统计: 过滤掉 {len(filtered_post_ids)} 个帖子(穿越),保留 {len(valid_post_ids)} 个帖子")
  127. # 过滤数据
  128. filtered_data = filter_node_by_post_ids(data, valid_post_ids)
  129. return filtered_data, filtered_post_ids
  130. def filter_node_by_post_ids(node: Any, valid_post_ids: Set[str]) -> Any:
  131. """
  132. 递归过滤节点,只保留有效帖子的数据
  133. Args:
  134. node: 当前节点
  135. valid_post_ids: 有效的帖子ID集合
  136. Returns:
  137. 过滤后的节点
  138. """
  139. if isinstance(node, dict):
  140. filtered_node = {}
  141. # 处理特征列表
  142. if "特征列表" in node:
  143. filtered_features = []
  144. for feature in node["特征列表"]:
  145. if "帖子id" in feature and feature["帖子id"] in valid_post_ids:
  146. filtered_features.append(feature)
  147. if filtered_features:
  148. filtered_node["特征列表"] = filtered_features
  149. # 更新元数据
  150. if "_meta" in node:
  151. filtered_node["_meta"] = node["_meta"].copy()
  152. filtered_node["帖子数"] = len(set(f["帖子id"] for f in filtered_features if "帖子id" in f))
  153. filtered_node["特征数"] = len(filtered_features)
  154. # 更新帖子列表
  155. filtered_node["帖子列表"] = list(set(f["帖子id"] for f in filtered_features if "帖子id" in f))
  156. # 递归处理子节点
  157. for key, value in node.items():
  158. if key in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
  159. continue
  160. filtered_child = filter_node_by_post_ids(value, valid_post_ids)
  161. if filtered_child: # 只添加非空的子节点
  162. filtered_node[key] = filtered_child
  163. return filtered_node if filtered_node else None
  164. elif isinstance(node, list):
  165. return [filter_node_by_post_ids(item, valid_post_ids) for item in node]
  166. else:
  167. return node
  168. def extract_categories_from_node(node: Dict, current_path: List[str], result: Dict[str, Dict]):
  169. """
  170. 递归遍历树形结构,提取特征名称及其分类路径
  171. Args:
  172. node: 当前节点
  173. current_path: 当前分类路径(从下到上)
  174. result: 结果字典,用于存储特征名称到分类的映射
  175. """
  176. # 如果当前节点包含"特征列表"
  177. if "特征列表" in node:
  178. for feature in node["特征列表"]:
  179. feature_name = feature.get("特征名称")
  180. if feature_name:
  181. # 将分类路径存储到结果中
  182. result[feature_name] = {
  183. "所属分类": current_path.copy()
  184. }
  185. # 递归处理子节点
  186. for key, value in node.items():
  187. # 跳过特殊字段
  188. if key in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
  189. continue
  190. # 如果值是字典,继续递归
  191. if isinstance(value, dict):
  192. # 将当前key添加到路径中
  193. new_path = [key] + current_path
  194. extract_categories_from_node(value, new_path, result)
  195. def process_category(category_data: Dict, category_key: str) -> Dict[str, Dict]:
  196. """
  197. 处理单个分类(灵感点列表/目的点/关键点列表)
  198. Args:
  199. category_data: 分类数据
  200. category_key: 分类键名
  201. Returns:
  202. 特征名称到分类的映射字典
  203. """
  204. result = {}
  205. if isinstance(category_data, dict):
  206. extract_categories_from_node(category_data, [], result)
  207. return result
  208. def build_category_hierarchy_from_node(
  209. node: Dict,
  210. category_hierarchy: Dict[str, Dict],
  211. current_level: int = 1,
  212. parent_categories: List[str] = None
  213. ):
  214. """
  215. 递归构建分类层级结构
  216. Args:
  217. node: 当前节点
  218. category_hierarchy: 分类层级字典
  219. current_level: 当前层级(从1开始)
  220. parent_categories: 父级分类列表(从顶到下)
  221. """
  222. if parent_categories is None:
  223. parent_categories = []
  224. # 遍历当前节点的所有键
  225. for key, value in node.items():
  226. # 跳过特殊字段
  227. if key in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
  228. continue
  229. if isinstance(value, dict):
  230. # 初始化当前分类的信息
  231. if key not in category_hierarchy:
  232. category_hierarchy[key] = {
  233. "几级分类": current_level,
  234. "是否是叶子分类": False,
  235. "下一级": []
  236. }
  237. # 收集下一级的分类名称和特征名称
  238. next_level_items = []
  239. # 检查是否有子分类
  240. has_sub_categories = False
  241. for sub_key, sub_value in value.items():
  242. if sub_key not in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
  243. if isinstance(sub_value, dict):
  244. has_sub_categories = True
  245. next_level_items.append({
  246. "节点类型": "分类",
  247. "节点名称": sub_key
  248. })
  249. # 如果有特征列表,添加特征名称
  250. if "特征列表" in value:
  251. for feature in value["特征列表"]:
  252. feature_name = feature.get("特征名称")
  253. if feature_name:
  254. next_level_items.append({
  255. "节点类型": "特征",
  256. "节点名称": feature_name
  257. })
  258. # 更新下一级列表
  259. category_hierarchy[key]["下一级"] = next_level_items
  260. # 如果没有子分类,标记为叶子分类
  261. if not has_sub_categories:
  262. category_hierarchy[key]["是否是叶子分类"] = True
  263. # 递归处理子节点
  264. new_parent_categories = parent_categories + [key]
  265. build_category_hierarchy_from_node(
  266. value,
  267. category_hierarchy,
  268. current_level + 1,
  269. new_parent_categories
  270. )
  271. def build_category_hierarchy(category_data: Dict) -> Dict[str, Dict]:
  272. """
  273. 构建分类名称到下一级的映射关系
  274. Args:
  275. category_data: 分类数据
  276. Returns:
  277. 分类层级映射字典
  278. """
  279. category_hierarchy = {}
  280. if isinstance(category_data, dict):
  281. build_category_hierarchy_from_node(category_data, category_hierarchy)
  282. return category_hierarchy
  283. def main():
  284. # 输入输出文件路径(默认使用项目根目录下的 data/data_1117 目录)
  285. script_dir = Path(__file__).parent
  286. project_root = script_dir.parent.parent
  287. data_dir = project_root / "data" / "data_1118"
  288. input_file = data_dir / "过去帖子_pattern聚合结果.json"
  289. current_posts_dir = data_dir / "当前帖子_what解构结果"
  290. output_file_1 = data_dir / "特征名称_分类映射.json"
  291. output_file_2 = data_dir / "分类层级映射.json"
  292. # 获取当前帖子的最早发布时间
  293. earliest_time = get_earliest_publish_time(current_posts_dir)
  294. # 读取输入文件
  295. print(f"\n正在读取文件: {input_file}")
  296. with open(input_file, "r", encoding="utf-8") as f:
  297. data = json.load(f)
  298. # 如果有时间过滤,应用过滤
  299. filtered_post_ids = set()
  300. if earliest_time:
  301. print("\n" + "="*60)
  302. print("开始应用时间过滤...")
  303. data, filtered_post_ids = filter_data_by_time(data, earliest_time)
  304. if filtered_post_ids:
  305. print(f"\n⚠️ 警告: 以下 {len(filtered_post_ids)} 个帖子因发布时间晚于阈值被过滤:")
  306. for post_id in sorted(filtered_post_ids):
  307. print(f" - {post_id}")
  308. else:
  309. print("\n未启用时间过滤")
  310. # 处理结果1: 特征名称到分类的映射
  311. output_1 = {}
  312. # 处理灵感点列表
  313. if "灵感点列表" in data:
  314. print("正在处理: 灵感点列表 (特征名称映射)")
  315. output_1["灵感点"] = process_category(data["灵感点列表"], "灵感点列表")
  316. print(f" 提取了 {len(output_1['灵感点'])} 个特征")
  317. # 处理目的点
  318. if "目的点" in data:
  319. print("正在处理: 目的点 (特征名称映射)")
  320. output_1["目的点"] = process_category(data["目的点"], "目的点")
  321. print(f" 提取了 {len(output_1['目的点'])} 个特征")
  322. # 处理关键点列表
  323. if "关键点列表" in data:
  324. print("正在处理: 关键点列表 (特征名称映射)")
  325. output_1["关键点"] = process_category(data["关键点列表"], "关键点列表")
  326. print(f" 提取了 {len(output_1['关键点'])} 个特征")
  327. # 保存结果1
  328. print(f"\n正在保存结果到: {output_file_1}")
  329. with open(output_file_1, "w", encoding="utf-8") as f:
  330. json.dump(output_1, f, ensure_ascii=False, indent=4)
  331. print("完成!")
  332. if earliest_time:
  333. print(f"\n总计 (特征名称映射,已过滤掉发布时间 >= {earliest_time} 的帖子):")
  334. else:
  335. print(f"\n总计 (特征名称映射):")
  336. for category, features in output_1.items():
  337. print(f" {category}: {len(features)} 个特征")
  338. # 处理结果2: 分类层级映射
  339. print("\n" + "="*60)
  340. print("开始生成分类层级映射...")
  341. output_2 = {}
  342. # 处理灵感点列表
  343. if "灵感点列表" in data:
  344. print("正在处理: 灵感点列表 (分类层级)")
  345. output_2["灵感点"] = build_category_hierarchy(data["灵感点列表"])
  346. print(f" 提取了 {len(output_2['灵感点'])} 个分类")
  347. # 处理目的点
  348. if "目的点" in data:
  349. print("正在处理: 目的点 (分类层级)")
  350. output_2["目的点"] = build_category_hierarchy(data["目的点"])
  351. print(f" 提取了 {len(output_2['目的点'])} 个分类")
  352. # 处理关键点列表
  353. if "关键点列表" in data:
  354. print("正在处理: 关键点列表 (分类层级)")
  355. output_2["关键点"] = build_category_hierarchy(data["关键点列表"])
  356. print(f" 提取了 {len(output_2['关键点'])} 个分类")
  357. # 保存结果2
  358. print(f"\n正在保存结果到: {output_file_2}")
  359. with open(output_file_2, "w", encoding="utf-8") as f:
  360. json.dump(output_2, f, ensure_ascii=False, indent=4)
  361. print("完成!")
  362. if earliest_time:
  363. print(f"\n总计 (分类层级映射,已过滤掉发布时间 >= {earliest_time} 的帖子):")
  364. else:
  365. print(f"\n总计 (分类层级映射):")
  366. for category, hierarchies in output_2.items():
  367. print(f" {category}: {len(hierarchies)} 个分类")
  368. if __name__ == "__main__":
  369. main()