extract_feature_categories.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 从过去帖子_pattern聚合结果.json中提取特征名称及其对应的分类层级
  5. """
  6. import json
  7. from pathlib import Path
  8. from typing import Dict, List, Any, Optional, Set
  9. import sys
  10. import re
  11. # 添加项目根目录到路径
  12. project_root = Path(__file__).parent.parent.parent
  13. sys.path.insert(0, str(project_root))
  14. from script.detail import get_xiaohongshu_detail
  15. from script.data_processing.path_config import PathConfig
  16. def extract_post_id_from_filename(filename: str) -> str:
  17. """从文件名中提取帖子ID"""
  18. match = re.match(r'^([^_]+)_', filename)
  19. if match:
  20. return match.group(1)
  21. return ""
  22. def get_post_detail(post_id: str) -> Optional[Dict]:
  23. """获取帖子详情"""
  24. try:
  25. detail = get_xiaohongshu_detail(post_id)
  26. return detail
  27. except Exception as e:
  28. print(f" 警告: 获取帖子 {post_id} 详情失败: {e}")
  29. return None
  30. def get_current_post_ids(current_posts_dir: Path) -> Set[str]:
  31. """
  32. 获取当前帖子目录中的所有帖子ID
  33. Args:
  34. current_posts_dir: 当前帖子目录路径
  35. Returns:
  36. 当前帖子ID集合
  37. """
  38. if not current_posts_dir.exists():
  39. print(f"警告: 当前帖子目录不存在: {current_posts_dir}")
  40. return set()
  41. json_files = list(current_posts_dir.glob("*.json"))
  42. if not json_files:
  43. print(f"警告: 当前帖子目录为空: {current_posts_dir}")
  44. return set()
  45. print(f"\n正在获取当前帖子ID...")
  46. print(f"找到 {len(json_files)} 个当前帖子")
  47. post_ids = set()
  48. for file_path in json_files:
  49. post_id = extract_post_id_from_filename(file_path.name)
  50. if post_id:
  51. post_ids.add(post_id)
  52. print(f"提取到 {len(post_ids)} 个帖子ID")
  53. return post_ids
  54. def get_earliest_publish_time(current_posts_dir: Path) -> Optional[str]:
  55. """
  56. 获取当前帖子目录中最早的发布时间
  57. Args:
  58. current_posts_dir: 当前帖子目录路径
  59. Returns:
  60. 最早的发布时间字符串,格式为 "YYYY-MM-DD HH:MM:SS"
  61. """
  62. if not current_posts_dir.exists():
  63. print(f"警告: 当前帖子目录不存在: {current_posts_dir}")
  64. return None
  65. json_files = list(current_posts_dir.glob("*.json"))
  66. if not json_files:
  67. print(f"警告: 当前帖子目录为空: {current_posts_dir}")
  68. return None
  69. print(f"\n正在获取当前帖子的发布时间...")
  70. print(f"找到 {len(json_files)} 个当前帖子")
  71. earliest_time = None
  72. for file_path in json_files:
  73. post_id = extract_post_id_from_filename(file_path.name)
  74. if not post_id:
  75. continue
  76. try:
  77. detail = get_post_detail(post_id)
  78. if detail and 'publish_time' in detail:
  79. publish_time = detail['publish_time']
  80. if earliest_time is None or publish_time < earliest_time:
  81. earliest_time = publish_time
  82. print(f" 更新最早时间: {publish_time} (帖子: {post_id})")
  83. except Exception as e:
  84. print(f" 警告: 获取帖子 {post_id} 发布时间失败: {e}")
  85. if earliest_time:
  86. print(f"\n当前帖子最早发布时间: {earliest_time}")
  87. else:
  88. print("\n警告: 未能获取到任何当前帖子的发布时间")
  89. return earliest_time
  90. def collect_all_post_ids(data: Dict) -> Set[str]:
  91. """
  92. 收集数据中的所有帖子ID
  93. Args:
  94. data: 聚合结果数据
  95. Returns:
  96. 帖子ID集合
  97. """
  98. post_ids = set()
  99. def traverse_node(node):
  100. if isinstance(node, dict):
  101. # 检查是否有帖子列表
  102. if "帖子列表" in node and isinstance(node["帖子列表"], list):
  103. post_ids.update(node["帖子列表"])
  104. # 检查是否有特征列表
  105. if "特征列表" in node and isinstance(node["特征列表"], list):
  106. for feature in node["特征列表"]:
  107. if "帖子id" in feature:
  108. post_ids.add(feature["帖子id"])
  109. # 递归遍历
  110. for key, value in node.items():
  111. if key not in ["_meta", "帖子数", "特征数", "帖子列表"]:
  112. traverse_node(value)
  113. elif isinstance(node, list):
  114. for item in node:
  115. traverse_node(item)
  116. for category in ["灵感点列表", "目的点", "关键点列表"]:
  117. if category in data:
  118. traverse_node(data[category])
  119. return post_ids
  120. def filter_data_by_post_ids(data: Dict, exclude_post_ids: Set[str]) -> tuple[Dict, Set[str]]:
  121. """
  122. 根据帖子ID过滤数据(新规则:排除当前帖子ID)
  123. Args:
  124. data: 原始聚合结果数据
  125. exclude_post_ids: 要排除的帖子ID集合
  126. Returns:
  127. (过滤后的数据, 被过滤掉的帖子ID集合)
  128. """
  129. # 收集所有帖子ID
  130. all_post_ids = collect_all_post_ids(data)
  131. print(f"\n数据中包含 {len(all_post_ids)} 个不同的帖子")
  132. # 过滤帖子
  133. print(f"\n正在应用帖子ID过滤,排除当前帖子目录中的 {len(exclude_post_ids)} 个帖子...")
  134. filtered_post_ids = all_post_ids & exclude_post_ids # 交集:需要过滤的
  135. valid_post_ids = all_post_ids - exclude_post_ids # 差集:保留的
  136. if filtered_post_ids:
  137. print(f" ⚠️ 过滤掉 {len(filtered_post_ids)} 个当前帖子:")
  138. for post_id in sorted(list(filtered_post_ids)[:10]): # 最多显示10个
  139. print(f" - {post_id}")
  140. if len(filtered_post_ids) > 10:
  141. print(f" ... 还有 {len(filtered_post_ids) - 10} 个")
  142. print(f"\n过滤统计: 过滤掉 {len(filtered_post_ids)} 个帖子,保留 {len(valid_post_ids)} 个帖子")
  143. # 过滤数据
  144. filtered_data = filter_node_by_post_ids(data, valid_post_ids)
  145. return filtered_data, filtered_post_ids
  146. def filter_data_by_time(data: Dict, time_filter: str) -> tuple[Dict, Set[str]]:
  147. """
  148. 根据发布时间过滤数据(旧规则:基于时间)
  149. Args:
  150. data: 原始聚合结果数据
  151. time_filter: 时间过滤阈值
  152. Returns:
  153. (过滤后的数据, 被过滤掉的帖子ID集合)
  154. """
  155. # 收集所有帖子ID
  156. all_post_ids = collect_all_post_ids(data)
  157. print(f"\n数据中包含 {len(all_post_ids)} 个不同的帖子")
  158. # 获取所有帖子的详情
  159. print("正在获取帖子详情...")
  160. post_details = {}
  161. for i, post_id in enumerate(all_post_ids, 1):
  162. print(f"[{i}/{len(all_post_ids)}] 获取帖子 {post_id} 的详情...")
  163. detail = get_post_detail(post_id)
  164. if detail:
  165. post_details[post_id] = detail
  166. # 根据时间过滤(过滤掉发布时间晚于等于阈值的帖子,避免穿越)
  167. print(f"\n正在应用时间过滤 (< {time_filter}),避免使用晚于当前帖子的数据...")
  168. filtered_post_ids = set()
  169. valid_post_ids = set()
  170. for post_id, detail in post_details.items():
  171. publish_time = detail.get('publish_time', '')
  172. if publish_time < time_filter:
  173. valid_post_ids.add(post_id)
  174. else:
  175. filtered_post_ids.add(post_id)
  176. print(f" ⚠️ 过滤掉帖子 {post_id} (发布时间: {publish_time},晚于阈值)")
  177. print(f"\n过滤统计: 过滤掉 {len(filtered_post_ids)} 个帖子(穿越),保留 {len(valid_post_ids)} 个帖子")
  178. # 过滤数据
  179. filtered_data = filter_node_by_post_ids(data, valid_post_ids)
  180. return filtered_data, filtered_post_ids
  181. def filter_node_by_post_ids(node: Any, valid_post_ids: Set[str]) -> Any:
  182. """
  183. 递归过滤节点,只保留有效帖子的数据
  184. Args:
  185. node: 当前节点
  186. valid_post_ids: 有效的帖子ID集合
  187. Returns:
  188. 过滤后的节点
  189. """
  190. if isinstance(node, dict):
  191. filtered_node = {}
  192. # 处理特征列表
  193. if "特征列表" in node:
  194. filtered_features = []
  195. for feature in node["特征列表"]:
  196. if "帖子id" in feature and feature["帖子id"] in valid_post_ids:
  197. filtered_features.append(feature)
  198. if filtered_features:
  199. filtered_node["特征列表"] = filtered_features
  200. # 更新元数据
  201. if "_meta" in node:
  202. filtered_node["_meta"] = node["_meta"].copy()
  203. filtered_node["帖子数"] = len(set(f["帖子id"] for f in filtered_features if "帖子id" in f))
  204. filtered_node["特征数"] = len(filtered_features)
  205. # 更新帖子列表
  206. filtered_node["帖子列表"] = list(set(f["帖子id"] for f in filtered_features if "帖子id" in f))
  207. # 递归处理子节点
  208. for key, value in node.items():
  209. if key in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
  210. continue
  211. filtered_child = filter_node_by_post_ids(value, valid_post_ids)
  212. if filtered_child: # 只添加非空的子节点
  213. filtered_node[key] = filtered_child
  214. return filtered_node if filtered_node else None
  215. elif isinstance(node, list):
  216. return [filter_node_by_post_ids(item, valid_post_ids) for item in node]
  217. else:
  218. return node
  219. def extract_categories_from_node(node: Dict, current_path: List[str], result: Dict[str, Dict]):
  220. """
  221. 递归遍历树形结构,提取特征名称及其分类路径
  222. Args:
  223. node: 当前节点
  224. current_path: 当前分类路径(从下到上)
  225. result: 结果字典,用于存储特征名称到分类的映射
  226. """
  227. # 如果当前节点包含"特征列表"
  228. if "特征列表" in node:
  229. for feature in node["特征列表"]:
  230. feature_name = feature.get("特征名称")
  231. if feature_name:
  232. # 将分类路径存储到结果中
  233. result[feature_name] = {
  234. "所属分类": current_path.copy()
  235. }
  236. # 递归处理子节点
  237. for key, value in node.items():
  238. # 跳过特殊字段
  239. if key in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
  240. continue
  241. # 如果值是字典,继续递归
  242. if isinstance(value, dict):
  243. # 将当前key添加到路径中
  244. new_path = [key] + current_path
  245. extract_categories_from_node(value, new_path, result)
  246. def process_category(category_data: Dict, category_key: str) -> Dict[str, Dict]:
  247. """
  248. 处理单个分类(灵感点列表/目的点/关键点列表)
  249. Args:
  250. category_data: 分类数据
  251. category_key: 分类键名
  252. Returns:
  253. 特征名称到分类的映射字典
  254. """
  255. result = {}
  256. if isinstance(category_data, dict):
  257. extract_categories_from_node(category_data, [], result)
  258. return result
  259. def build_category_hierarchy_from_node(
  260. node: Dict,
  261. category_hierarchy: Dict[str, Dict],
  262. current_level: int = 1,
  263. parent_categories: List[str] = None
  264. ):
  265. """
  266. 递归构建分类层级结构
  267. Args:
  268. node: 当前节点
  269. category_hierarchy: 分类层级字典
  270. current_level: 当前层级(从1开始)
  271. parent_categories: 父级分类列表(从顶到下)
  272. """
  273. if parent_categories is None:
  274. parent_categories = []
  275. # 遍历当前节点的所有键
  276. for key, value in node.items():
  277. # 跳过特殊字段
  278. if key in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
  279. continue
  280. if isinstance(value, dict):
  281. # 初始化当前分类的信息
  282. if key not in category_hierarchy:
  283. category_hierarchy[key] = {
  284. "几级分类": current_level,
  285. "是否是叶子分类": False,
  286. "下一级": []
  287. }
  288. # 收集下一级的分类名称和特征名称
  289. next_level_items = []
  290. # 检查是否有子分类
  291. has_sub_categories = False
  292. for sub_key, sub_value in value.items():
  293. if sub_key not in ["特征列表", "_meta", "帖子数", "特征数", "帖子列表"]:
  294. if isinstance(sub_value, dict):
  295. has_sub_categories = True
  296. next_level_items.append({
  297. "节点类型": "分类",
  298. "节点名称": sub_key
  299. })
  300. # 如果有特征列表,添加特征名称
  301. if "特征列表" in value:
  302. for feature in value["特征列表"]:
  303. feature_name = feature.get("特征名称")
  304. if feature_name:
  305. next_level_items.append({
  306. "节点类型": "特征",
  307. "节点名称": feature_name
  308. })
  309. # 更新下一级列表
  310. category_hierarchy[key]["下一级"] = next_level_items
  311. # 如果没有子分类,标记为叶子分类
  312. if not has_sub_categories:
  313. category_hierarchy[key]["是否是叶子分类"] = True
  314. # 递归处理子节点
  315. new_parent_categories = parent_categories + [key]
  316. build_category_hierarchy_from_node(
  317. value,
  318. category_hierarchy,
  319. current_level + 1,
  320. new_parent_categories
  321. )
  322. def build_category_hierarchy(category_data: Dict) -> Dict[str, Dict]:
  323. """
  324. 构建分类名称到下一级的映射关系
  325. Args:
  326. category_data: 分类数据
  327. Returns:
  328. 分类层级映射字典
  329. """
  330. category_hierarchy = {}
  331. if isinstance(category_data, dict):
  332. build_category_hierarchy_from_node(category_data, category_hierarchy)
  333. return category_hierarchy
  334. def main():
  335. # 使用路径配置
  336. config = PathConfig()
  337. # 确保输出目录存在
  338. config.ensure_dirs()
  339. # 获取路径
  340. input_file = config.pattern_cluster_file
  341. current_posts_dir = config.current_posts_dir
  342. output_file_1 = config.feature_category_mapping_file
  343. output_file_2 = config.category_hierarchy_file
  344. print(f"账号: {config.account_name}")
  345. print(f"过滤模式: {config.filter_mode}")
  346. print(f"输入文件: {input_file}")
  347. print(f"当前帖子目录: {current_posts_dir}")
  348. print(f"输出文件1: {output_file_1}")
  349. print(f"输出文件2: {output_file_2}")
  350. print()
  351. # 读取输入文件
  352. print(f"\n正在读取文件: {input_file}")
  353. with open(input_file, "r", encoding="utf-8") as f:
  354. data = json.load(f)
  355. # 根据配置的过滤模式应用过滤
  356. filtered_post_ids = set()
  357. filter_mode = config.filter_mode
  358. if filter_mode == "exclude_current_posts":
  359. # 新规则:排除当前帖子目录中的帖子ID
  360. print("\n" + "="*60)
  361. print("应用过滤规则: 排除当前帖子ID")
  362. current_post_ids = get_current_post_ids(current_posts_dir)
  363. if current_post_ids:
  364. data, filtered_post_ids = filter_data_by_post_ids(data, current_post_ids)
  365. else:
  366. print("\n未找到当前帖子ID,跳过过滤")
  367. elif filter_mode == "time_based":
  368. # 旧规则:基于发布时间过滤
  369. print("\n" + "="*60)
  370. print("应用过滤规则: 基于发布时间")
  371. earliest_time = get_earliest_publish_time(current_posts_dir)
  372. if earliest_time:
  373. data, filtered_post_ids = filter_data_by_time(data, earliest_time)
  374. else:
  375. print("\n未能获取时间信息,跳过过滤")
  376. elif filter_mode == "none":
  377. print("\n过滤模式: none,不应用任何过滤")
  378. else:
  379. print(f"\n警告: 未知的过滤模式 '{filter_mode}',不应用过滤")
  380. # 处理结果1: 特征名称到分类的映射
  381. output_1 = {}
  382. # 处理灵感点列表
  383. if "灵感点列表" in data:
  384. print("正在处理: 灵感点列表 (特征名称映射)")
  385. output_1["灵感点"] = process_category(data["灵感点列表"], "灵感点列表")
  386. print(f" 提取了 {len(output_1['灵感点'])} 个特征")
  387. # 处理目的点
  388. if "目的点" in data:
  389. print("正在处理: 目的点 (特征名称映射)")
  390. output_1["目的点"] = process_category(data["目的点"], "目的点")
  391. print(f" 提取了 {len(output_1['目的点'])} 个特征")
  392. # 处理关键点列表
  393. if "关键点列表" in data:
  394. print("正在处理: 关键点列表 (特征名称映射)")
  395. output_1["关键点"] = process_category(data["关键点列表"], "关键点列表")
  396. print(f" 提取了 {len(output_1['关键点'])} 个特征")
  397. # 保存结果1
  398. print(f"\n正在保存结果到: {output_file_1}")
  399. with open(output_file_1, "w", encoding="utf-8") as f:
  400. json.dump(output_1, f, ensure_ascii=False, indent=4)
  401. print("完成!")
  402. if filtered_post_ids:
  403. print(f"\n总计 (特征名称映射,已过滤掉 {len(filtered_post_ids)} 个帖子):")
  404. else:
  405. print(f"\n总计 (特征名称映射):")
  406. for category, features in output_1.items():
  407. print(f" {category}: {len(features)} 个特征")
  408. # 处理结果2: 分类层级映射
  409. print("\n" + "="*60)
  410. print("开始生成分类层级映射...")
  411. output_2 = {}
  412. # 处理灵感点列表
  413. if "灵感点列表" in data:
  414. print("正在处理: 灵感点列表 (分类层级)")
  415. output_2["灵感点"] = build_category_hierarchy(data["灵感点列表"])
  416. print(f" 提取了 {len(output_2['灵感点'])} 个分类")
  417. # 处理目的点
  418. if "目的点" in data:
  419. print("正在处理: 目的点 (分类层级)")
  420. output_2["目的点"] = build_category_hierarchy(data["目的点"])
  421. print(f" 提取了 {len(output_2['目的点'])} 个分类")
  422. # 处理关键点列表
  423. if "关键点列表" in data:
  424. print("正在处理: 关键点列表 (分类层级)")
  425. output_2["关键点"] = build_category_hierarchy(data["关键点列表"])
  426. print(f" 提取了 {len(output_2['关键点'])} 个分类")
  427. # 保存结果2
  428. print(f"\n正在保存结果到: {output_file_2}")
  429. with open(output_file_2, "w", encoding="utf-8") as f:
  430. json.dump(output_2, f, ensure_ascii=False, indent=4)
  431. print("完成!")
  432. if filtered_post_ids:
  433. print(f"\n总计 (分类层级映射,已过滤掉 {len(filtered_post_ids)} 个帖子):")
  434. else:
  435. print(f"\n总计 (分类层级映射):")
  436. for category, hierarchies in output_2.items():
  437. print(f" {category}: {len(hierarchies)} 个分类")
  438. if __name__ == "__main__":
  439. main()