extract_feature_categories.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 从过去帖子_pattern聚合结果.json中提取特征名称及其对应的分类层级
  5. """
  6. import json
  7. from pathlib import Path
  8. from typing import Dict, List, Any, Optional, Set
  9. import sys
  10. import re
  11. # 添加项目根目录到路径
  12. project_root = Path(__file__).parent.parent.parent
  13. sys.path.insert(0, str(project_root))
  14. from script.detail import get_xiaohongshu_detail
  15. from script.data_processing.path_config import PathConfig
  16. def extract_post_id_from_filename(filename: str) -> str:
  17. """从文件名中提取帖子ID
  18. 格式: 68a6b96f000000001d006058.json
  19. """
  20. return filename.replace('.json', '')
  21. def get_post_detail(post_id: str) -> Optional[Dict]:
  22. """获取帖子详情"""
  23. try:
  24. detail = get_xiaohongshu_detail(post_id)
  25. return detail
  26. except Exception as e:
  27. print(f" 警告: 获取帖子 {post_id} 详情失败: {e}")
  28. return None
  29. def get_current_post_ids(current_posts_dir: Path) -> Set[str]:
  30. """
  31. 获取当前帖子目录中的所有帖子ID
  32. Args:
  33. current_posts_dir: 当前帖子目录路径
  34. Returns:
  35. 当前帖子ID集合
  36. """
  37. if not current_posts_dir.exists():
  38. print(f"警告: 当前帖子目录不存在: {current_posts_dir}")
  39. return set()
  40. json_files = list(current_posts_dir.glob("*.json"))
  41. if not json_files:
  42. print(f"警告: 当前帖子目录为空: {current_posts_dir}")
  43. return set()
  44. print(f"\n正在获取当前帖子ID...")
  45. print(f"找到 {len(json_files)} 个当前帖子")
  46. post_ids = set()
  47. for file_path in json_files:
  48. post_id = extract_post_id_from_filename(file_path.name)
  49. if post_id:
  50. post_ids.add(post_id)
  51. print(f"提取到 {len(post_ids)} 个帖子ID")
  52. return post_ids
  53. def get_earliest_publish_time(current_posts_dir: Path) -> Optional[str]:
  54. """
  55. 获取当前帖子目录中最早的发布时间
  56. Args:
  57. current_posts_dir: 当前帖子目录路径
  58. Returns:
  59. 最早的发布时间字符串,格式为 "YYYY-MM-DD HH:MM:SS"
  60. """
  61. if not current_posts_dir.exists():
  62. print(f"警告: 当前帖子目录不存在: {current_posts_dir}")
  63. return None
  64. json_files = list(current_posts_dir.glob("*.json"))
  65. if not json_files:
  66. print(f"警告: 当前帖子目录为空: {current_posts_dir}")
  67. return None
  68. print(f"\n正在获取当前帖子的发布时间...")
  69. print(f"找到 {len(json_files)} 个当前帖子")
  70. earliest_time = None
  71. for file_path in json_files:
  72. post_id = extract_post_id_from_filename(file_path.name)
  73. if not post_id:
  74. continue
  75. try:
  76. detail = get_post_detail(post_id)
  77. if detail and 'publish_time' in detail:
  78. publish_time = detail['publish_time']
  79. if earliest_time is None or publish_time < earliest_time:
  80. earliest_time = publish_time
  81. print(f" 更新最早时间: {publish_time} (帖子: {post_id})")
  82. except Exception as e:
  83. print(f" 警告: 获取帖子 {post_id} 发布时间失败: {e}")
  84. if earliest_time:
  85. print(f"\n当前帖子最早发布时间: {earliest_time}")
  86. else:
  87. print("\n警告: 未能获取到任何当前帖子的发布时间")
  88. return earliest_time
  89. def collect_all_post_ids(data: Dict) -> Set[str]:
  90. """
  91. 收集数据中的所有帖子ID
  92. Args:
  93. data: 聚合结果数据
  94. Returns:
  95. 帖子ID集合
  96. """
  97. post_ids = set()
  98. def traverse_node(node):
  99. if isinstance(node, dict):
  100. # 检查是否有帖子列表
  101. if "帖子列表" in node and isinstance(node["帖子列表"], list):
  102. post_ids.update(node["帖子列表"])
  103. # 检查是否有特征列表
  104. if "特征列表" in node and isinstance(node["特征列表"], list):
  105. for feature in node["特征列表"]:
  106. if "帖子id" in feature:
  107. post_ids.add(feature["帖子id"])
  108. # 递归遍历
  109. for key, value in node.items():
  110. if key not in ["_meta", "帖子数", "特征数", "帖子列表"]:
  111. traverse_node(value)
  112. elif isinstance(node, list):
  113. for item in node:
  114. traverse_node(item)
  115. for category in ["灵感点列表", "目的点", "关键点列表"]:
  116. if category in data:
  117. traverse_node(data[category])
  118. return post_ids
  119. def filter_data_by_post_ids(data: Dict, exclude_post_ids: Set[str]) -> tuple[Dict, Set[str]]:
  120. """
  121. 根据帖子ID过滤数据(新规则:排除当前帖子ID)
  122. Args:
  123. data: 原始聚合结果数据
  124. exclude_post_ids: 要排除的帖子ID集合
  125. Returns:
  126. (过滤后的数据, 被过滤掉的帖子ID集合)
  127. """
  128. # 收集所有帖子ID
  129. all_post_ids = collect_all_post_ids(data)
  130. print(f"\n数据中包含 {len(all_post_ids)} 个不同的帖子")
  131. # 过滤帖子
  132. print(f"\n正在应用帖子ID过滤,排除当前帖子目录中的 {len(exclude_post_ids)} 个帖子...")
  133. filtered_post_ids = all_post_ids & exclude_post_ids # 交集:需要过滤的
  134. valid_post_ids = all_post_ids - exclude_post_ids # 差集:保留的
  135. if filtered_post_ids:
  136. print(f" ⚠️ 过滤掉 {len(filtered_post_ids)} 个当前帖子:")
  137. for post_id in sorted(list(filtered_post_ids)[:10]): # 最多显示10个
  138. print(f" - {post_id}")
  139. if len(filtered_post_ids) > 10:
  140. print(f" ... 还有 {len(filtered_post_ids) - 10} 个")
  141. print(f"\n过滤统计: 过滤掉 {len(filtered_post_ids)} 个帖子,保留 {len(valid_post_ids)} 个帖子")
  142. # 过滤数据
  143. filtered_data = filter_node_by_post_ids(data, valid_post_ids)
  144. return filtered_data, filtered_post_ids
  145. def filter_data_by_time(data: Dict, time_filter: str) -> tuple[Dict, Set[str]]:
  146. """
  147. 根据发布时间过滤数据(旧规则:基于时间)
  148. Args:
  149. data: 原始聚合结果数据
  150. time_filter: 时间过滤阈值
  151. Returns:
  152. (过滤后的数据, 被过滤掉的帖子ID集合)
  153. """
  154. # 收集所有帖子ID
  155. all_post_ids = collect_all_post_ids(data)
  156. print(f"\n数据中包含 {len(all_post_ids)} 个不同的帖子")
  157. # 获取所有帖子的详情
  158. print("正在获取帖子详情...")
  159. post_details = {}
  160. for i, post_id in enumerate(all_post_ids, 1):
  161. print(f"[{i}/{len(all_post_ids)}] 获取帖子 {post_id} 的详情...")
  162. detail = get_post_detail(post_id)
  163. if detail:
  164. post_details[post_id] = detail
  165. # 根据时间过滤(过滤掉发布时间晚于等于阈值的帖子,避免穿越)
  166. print(f"\n正在应用时间过滤 (< {time_filter}),避免使用晚于当前帖子的数据...")
  167. filtered_post_ids = set()
  168. valid_post_ids = set()
  169. for post_id, detail in post_details.items():
  170. publish_time = detail.get('publish_time', '')
  171. if publish_time < time_filter:
  172. valid_post_ids.add(post_id)
  173. else:
  174. filtered_post_ids.add(post_id)
  175. print(f" ⚠️ 过滤掉帖子 {post_id} (发布时间: {publish_time},晚于阈值)")
  176. print(f"\n过滤统计: 过滤掉 {len(filtered_post_ids)} 个帖子(穿越),保留 {len(valid_post_ids)} 个帖子")
  177. # 过滤数据
  178. filtered_data = filter_node_by_post_ids(data, valid_post_ids)
  179. return filtered_data, filtered_post_ids
  180. def filter_node_by_post_ids(node: Any, valid_post_ids: Set[str]) -> Any:
  181. """
  182. 递归过滤节点,只保留有效帖子的数据
  183. Args:
  184. node: 当前节点
  185. valid_post_ids: 有效的帖子ID集合
  186. Returns:
  187. 过滤后的节点
  188. """
  189. if isinstance(node, dict):
  190. filtered_node = {}
  191. # 处理特征列表
  192. if "特征列表" in node:
  193. filtered_features = []
  194. for feature in node["特征列表"]:
  195. if "帖子id" in feature and feature["帖子id"] in valid_post_ids:
  196. filtered_features.append(feature)
  197. if filtered_features:
  198. filtered_node["特征列表"] = filtered_features
  199. # 更新元数据
  200. if "_meta" in node:
  201. filtered_node["_meta"] = node["_meta"].copy()
  202. filtered_node["帖子数"] = len(set(f["帖子id"] for f in filtered_features if "帖子id" in f))
  203. filtered_node["特征数"] = len(filtered_features)
  204. # 更新帖子列表
  205. filtered_node["帖子列表"] = list(set(f["帖子id"] for f in filtered_features if "帖子id" in f))
  206. # 递归处理子节点
  207. for key, value in node.items():
  208. if key in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
  209. continue
  210. filtered_child = filter_node_by_post_ids(value, valid_post_ids)
  211. if filtered_child: # 只添加非空的子节点
  212. filtered_node[key] = filtered_child
  213. return filtered_node if filtered_node else None
  214. elif isinstance(node, list):
  215. return [filter_node_by_post_ids(item, valid_post_ids) for item in node]
  216. else:
  217. return node
  218. def extract_categories_from_node(node: Dict, current_path: List[str], result: Dict[str, Dict]):
  219. """
  220. 递归遍历树形结构,提取特征名称及其分类路径
  221. Args:
  222. node: 当前节点
  223. current_path: 当前分类路径(从下到上)
  224. result: 结果字典,用于存储特征名称到分类的映射
  225. """
  226. # 如果当前节点包含"特征列表"
  227. if "特征列表" in node:
  228. for feature in node["特征列表"]:
  229. feature_name = feature.get("特征名称")
  230. if feature_name:
  231. # 将分类路径存储到结果中
  232. result[feature_name] = {
  233. "所属分类": current_path.copy()
  234. }
  235. # 递归处理子节点
  236. for key, value in node.items():
  237. # 跳过特殊字段
  238. if key in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
  239. continue
  240. # 如果值是字典,继续递归
  241. if isinstance(value, dict):
  242. # 将当前key添加到路径中
  243. new_path = [key] + current_path
  244. extract_categories_from_node(value, new_path, result)
  245. def process_category(category_data: Dict, category_key: str) -> Dict[str, Dict]:
  246. """
  247. 处理单个分类(灵感点列表/目的点/关键点列表)
  248. Args:
  249. category_data: 分类数据
  250. category_key: 分类键名
  251. Returns:
  252. 特征名称到分类的映射字典
  253. """
  254. result = {}
  255. if isinstance(category_data, dict):
  256. extract_categories_from_node(category_data, [], result)
  257. return result
  258. def build_category_hierarchy_from_node(
  259. node: Dict,
  260. category_hierarchy: Dict[str, Dict],
  261. current_level: int = 1,
  262. parent_categories: List[str] = None
  263. ):
  264. """
  265. 递归构建分类层级结构
  266. Args:
  267. node: 当前节点
  268. category_hierarchy: 分类层级字典
  269. current_level: 当前层级(从1开始)
  270. parent_categories: 父级分类列表(从顶到下)
  271. """
  272. if parent_categories is None:
  273. parent_categories = []
  274. # 遍历当前节点的所有键
  275. for key, value in node.items():
  276. # 跳过特殊字段
  277. if key in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
  278. continue
  279. if isinstance(value, dict):
  280. # 初始化当前分类的信息
  281. if key not in category_hierarchy:
  282. category_hierarchy[key] = {
  283. "几级分类": current_level,
  284. "是否是叶子分类": False,
  285. "下一级": []
  286. }
  287. # 收集下一级的分类名称和特征名称
  288. next_level_items = []
  289. # 检查是否有子分类
  290. has_sub_categories = False
  291. for sub_key, sub_value in value.items():
  292. if sub_key not in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
  293. if isinstance(sub_value, dict):
  294. has_sub_categories = True
  295. next_level_items.append({
  296. "节点类型": "分类",
  297. "节点名称": sub_key
  298. })
  299. # 如果有特征列表,添加特征名称
  300. if "特征列表" in value:
  301. for feature in value["特征列表"]:
  302. feature_name = feature.get("特征名称")
  303. if feature_name:
  304. next_level_items.append({
  305. "节点类型": "特征",
  306. "节点名称": feature_name
  307. })
  308. # 更新下一级列表
  309. category_hierarchy[key]["下一级"] = next_level_items
  310. # 如果没有子分类,标记为叶子分类
  311. if not has_sub_categories:
  312. category_hierarchy[key]["是否是叶子分类"] = True
  313. # 递归处理子节点
  314. new_parent_categories = parent_categories + [key]
  315. build_category_hierarchy_from_node(
  316. value,
  317. category_hierarchy,
  318. current_level + 1,
  319. new_parent_categories
  320. )
  321. def build_category_hierarchy(category_data: Dict) -> Dict[str, Dict]:
  322. """
  323. 构建分类名称到下一级的映射关系
  324. Args:
  325. category_data: 分类数据
  326. Returns:
  327. 分类层级映射字典
  328. """
  329. category_hierarchy = {}
  330. if isinstance(category_data, dict):
  331. build_category_hierarchy_from_node(category_data, category_hierarchy)
  332. return category_hierarchy
  333. def main():
  334. # 使用路径配置
  335. config = PathConfig()
  336. # 确保输出目录存在
  337. config.ensure_dirs()
  338. # 获取路径
  339. input_file = config.pattern_cluster_file
  340. current_posts_dir = config.current_posts_dir
  341. output_file_1 = config.feature_category_mapping_file
  342. output_file_2 = config.category_hierarchy_file
  343. print(f"账号: {config.account_name}")
  344. print(f"过滤模式: {config.filter_mode}")
  345. print(f"输入文件: {input_file}")
  346. print(f"当前帖子目录: {current_posts_dir}")
  347. print(f"输出文件1: {output_file_1}")
  348. print(f"输出文件2: {output_file_2}")
  349. print()
  350. # 读取输入文件
  351. print(f"\n正在读取文件: {input_file}")
  352. with open(input_file, "r", encoding="utf-8") as f:
  353. data = json.load(f)
  354. # 根据配置的过滤模式应用过滤
  355. filtered_post_ids = set()
  356. filter_mode = config.filter_mode
  357. if filter_mode == "exclude_current_posts":
  358. # 新规则:排除当前帖子目录中的帖子ID
  359. print("\n" + "="*60)
  360. print("应用过滤规则: 排除当前帖子ID")
  361. current_post_ids = get_current_post_ids(current_posts_dir)
  362. if current_post_ids:
  363. data, filtered_post_ids = filter_data_by_post_ids(data, current_post_ids)
  364. else:
  365. print("\n未找到当前帖子ID,跳过过滤")
  366. elif filter_mode == "time_based":
  367. # 旧规则:基于发布时间过滤
  368. print("\n" + "="*60)
  369. print("应用过滤规则: 基于发布时间")
  370. earliest_time = get_earliest_publish_time(current_posts_dir)
  371. if earliest_time:
  372. data, filtered_post_ids = filter_data_by_time(data, earliest_time)
  373. else:
  374. print("\n未能获取时间信息,跳过过滤")
  375. elif filter_mode == "none":
  376. print("\n过滤模式: none,不应用任何过滤")
  377. else:
  378. print(f"\n警告: 未知的过滤模式 '{filter_mode}',不应用过滤")
  379. # 处理结果1: 特征名称到分类的映射
  380. output_1 = {}
  381. # 处理灵感点列表
  382. if "灵感点列表" in data:
  383. print("正在处理: 灵感点列表 (特征名称映射)")
  384. output_1["灵感点"] = process_category(data["灵感点列表"], "灵感点列表")
  385. print(f" 提取了 {len(output_1['灵感点'])} 个特征")
  386. # 处理目的点
  387. if "目的点" in data:
  388. print("正在处理: 目的点 (特征名称映射)")
  389. output_1["目的点"] = process_category(data["目的点"], "目的点")
  390. print(f" 提取了 {len(output_1['目的点'])} 个特征")
  391. # 处理关键点列表
  392. if "关键点列表" in data:
  393. print("正在处理: 关键点列表 (特征名称映射)")
  394. output_1["关键点"] = process_category(data["关键点列表"], "关键点列表")
  395. print(f" 提取了 {len(output_1['关键点'])} 个特征")
  396. # 保存结果1
  397. print(f"\n正在保存结果到: {output_file_1}")
  398. with open(output_file_1, "w", encoding="utf-8") as f:
  399. json.dump(output_1, f, ensure_ascii=False, indent=4)
  400. print("完成!")
  401. if filtered_post_ids:
  402. print(f"\n总计 (特征名称映射,已过滤掉 {len(filtered_post_ids)} 个帖子):")
  403. else:
  404. print(f"\n总计 (特征名称映射):")
  405. for category, features in output_1.items():
  406. print(f" {category}: {len(features)} 个特征")
  407. # 处理结果2: 分类层级映射
  408. print("\n" + "="*60)
  409. print("开始生成分类层级映射...")
  410. output_2 = {}
  411. # 处理灵感点列表
  412. if "灵感点列表" in data:
  413. print("正在处理: 灵感点列表 (分类层级)")
  414. output_2["灵感点"] = build_category_hierarchy(data["灵感点列表"])
  415. print(f" 提取了 {len(output_2['灵感点'])} 个分类")
  416. # 处理目的点
  417. if "目的点" in data:
  418. print("正在处理: 目的点 (分类层级)")
  419. output_2["目的点"] = build_category_hierarchy(data["目的点"])
  420. print(f" 提取了 {len(output_2['目的点'])} 个分类")
  421. # 处理关键点列表
  422. if "关键点列表" in data:
  423. print("正在处理: 关键点列表 (分类层级)")
  424. output_2["关键点"] = build_category_hierarchy(data["关键点列表"])
  425. print(f" 提取了 {len(output_2['关键点'])} 个分类")
  426. # 保存结果2
  427. print(f"\n正在保存结果到: {output_file_2}")
  428. with open(output_file_2, "w", encoding="utf-8") as f:
  429. json.dump(output_2, f, ensure_ascii=False, indent=4)
  430. print("完成!")
  431. if filtered_post_ids:
  432. print(f"\n总计 (分类层级映射,已过滤掉 {len(filtered_post_ids)} 个帖子):")
  433. else:
  434. print(f"\n总计 (分类层级映射):")
  435. for category, hierarchies in output_2.items():
  436. print(f" {category}: {len(hierarchies)} 个分类")
  437. if __name__ == "__main__":
  438. main()