extract_features_from_posts.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 从过去帖子_what解构结果目录中提取特征名称及其来源信息
  5. """
  6. import json
  7. from pathlib import Path
  8. from typing import Dict, List, Optional, Set
  9. import re
  10. import sys
  11. # 添加项目根目录到路径
  12. project_root = Path(__file__).parent.parent.parent
  13. sys.path.insert(0, str(project_root))
  14. from script.detail import get_xiaohongshu_detail
  15. from script.data_processing.path_config import PathConfig
  16. def extract_post_id_from_filename(filename: str) -> str:
  17. """从文件名中提取帖子ID"""
  18. match = re.match(r'^([^_]+)_', filename)
  19. if match:
  20. return match.group(1)
  21. return ""
  22. def get_post_detail(post_id: str) -> Optional[Dict]:
  23. """
  24. 获取帖子详情
  25. Args:
  26. post_id: 帖子ID
  27. Returns:
  28. 帖子详情字典,如果获取失败则返回None
  29. """
  30. try:
  31. detail = get_xiaohongshu_detail(post_id)
  32. return detail
  33. except Exception as e:
  34. print(f" 警告: 获取帖子 {post_id} 详情失败: {e}")
  35. return None
  36. def extract_features_from_point(point_data: Dict, post_id: str, point_name: str, point_description: str) -> List[Dict]:
  37. """
  38. 从单个点(灵感点/目的点/关键点)中提取特征信息
  39. Args:
  40. point_data: 点的数据
  41. post_id: 帖子ID
  42. point_name: 点的名称
  43. point_description: 点的描述
  44. Returns:
  45. 特征列表
  46. """
  47. features = []
  48. # 检查是否有"提取的特征"字段
  49. if "提取的特征" in point_data and isinstance(point_data["提取的特征"], list):
  50. for feature in point_data["提取的特征"]:
  51. if "特征名称" in feature:
  52. features.append({
  53. "特征名称": feature["特征名称"],
  54. "点的名称": point_name,
  55. "点的描述": point_description,
  56. "帖子id": post_id
  57. })
  58. return features
  59. def process_single_file(file_path: Path) -> Dict[str, Dict[str, List[Dict]]]:
  60. """
  61. 处理单个JSON文件,提取所有特征信息
  62. Args:
  63. file_path: JSON文件路径
  64. Returns:
  65. 包含灵感点、目的点、关键点的特征字典
  66. """
  67. result = {
  68. "灵感点": {},
  69. "目的点": {},
  70. "关键点": {}
  71. }
  72. # 从文件名提取帖子ID
  73. post_id = extract_post_id_from_filename(file_path.name)
  74. try:
  75. with open(file_path, "r", encoding="utf-8") as f:
  76. data = json.load(f)
  77. # 提取三点解构数据
  78. if "三点解构" not in data:
  79. return result
  80. three_points = data["三点解构"]
  81. # 处理灵感点
  82. if "灵感点" in three_points:
  83. inspiration = three_points["灵感点"]
  84. # 处理全新内容
  85. if "全新内容" in inspiration and isinstance(inspiration["全新内容"], list):
  86. for item in inspiration["全新内容"]:
  87. point_name = item.get("灵感点", "")
  88. point_desc = item.get("描述", "")
  89. features = extract_features_from_point(item, post_id, point_name, point_desc)
  90. for feature in features:
  91. feature_name = feature["特征名称"]
  92. if feature_name not in result["灵感点"]:
  93. result["灵感点"][feature_name] = []
  94. result["灵感点"][feature_name].append({
  95. "点的名称": feature["点的名称"],
  96. "点的描述": feature["点的描述"],
  97. "帖子id": feature["帖子id"]
  98. })
  99. # 处理共性差异
  100. if "共性差异" in inspiration and isinstance(inspiration["共性差异"], list):
  101. for item in inspiration["共性差异"]:
  102. point_name = item.get("灵感点", "")
  103. point_desc = item.get("描述", "")
  104. features = extract_features_from_point(item, post_id, point_name, point_desc)
  105. for feature in features:
  106. feature_name = feature["特征名称"]
  107. if feature_name not in result["灵感点"]:
  108. result["灵感点"][feature_name] = []
  109. result["灵感点"][feature_name].append({
  110. "点的名称": feature["点的名称"],
  111. "点的描述": feature["点的描述"],
  112. "帖子id": feature["帖子id"]
  113. })
  114. # 处理共性内容
  115. if "共性内容" in inspiration and isinstance(inspiration["共性内容"], list):
  116. for item in inspiration["共性内容"]:
  117. point_name = item.get("灵感点", "")
  118. point_desc = item.get("描述", "")
  119. features = extract_features_from_point(item, post_id, point_name, point_desc)
  120. for feature in features:
  121. feature_name = feature["特征名称"]
  122. if feature_name not in result["灵感点"]:
  123. result["灵感点"][feature_name] = []
  124. result["灵感点"][feature_name].append({
  125. "点的名称": feature["点的名称"],
  126. "点的描述": feature["点的描述"],
  127. "帖子id": feature["帖子id"]
  128. })
  129. # 处理目的点
  130. if "目的点" in three_points:
  131. purpose = three_points["目的点"]
  132. if "purposes" in purpose and isinstance(purpose["purposes"], list):
  133. for item in purpose["purposes"]:
  134. point_name = item.get("目的点", "")
  135. point_desc = item.get("描述", "")
  136. features = extract_features_from_point(item, post_id, point_name, point_desc)
  137. for feature in features:
  138. feature_name = feature["特征名称"]
  139. if feature_name not in result["目的点"]:
  140. result["目的点"][feature_name] = []
  141. result["目的点"][feature_name].append({
  142. "点的名称": feature["点的名称"],
  143. "点的描述": feature["点的描述"],
  144. "帖子id": feature["帖子id"]
  145. })
  146. # 处理关键点
  147. if "关键点" in three_points:
  148. key_points = three_points["关键点"]
  149. if "key_points" in key_points and isinstance(key_points["key_points"], list):
  150. for item in key_points["key_points"]:
  151. point_name = item.get("关键点", "")
  152. point_desc = item.get("描述", "")
  153. features = extract_features_from_point(item, post_id, point_name, point_desc)
  154. for feature in features:
  155. feature_name = feature["特征名称"]
  156. if feature_name not in result["关键点"]:
  157. result["关键点"][feature_name] = []
  158. result["关键点"][feature_name].append({
  159. "点的名称": feature["点的名称"],
  160. "点的描述": feature["点的描述"],
  161. "帖子id": feature["帖子id"]
  162. })
  163. except Exception as e:
  164. print(f"处理文件 {file_path.name} 时出错: {e}")
  165. return result
  166. def merge_results(all_results: List[Dict]) -> Dict:
  167. """
  168. 合并所有文件的提取结果
  169. Args:
  170. all_results: 所有文件的结果列表
  171. Returns:
  172. 合并后的结果
  173. """
  174. merged = {
  175. "灵感点": {},
  176. "目的点": {},
  177. "关键点": {}
  178. }
  179. for result in all_results:
  180. for category in ["灵感点", "目的点", "关键点"]:
  181. for feature_name, sources in result[category].items():
  182. if feature_name not in merged[category]:
  183. merged[category][feature_name] = {"来源": []}
  184. merged[category][feature_name]["来源"].extend(sources)
  185. return merged
  186. def convert_to_array_format(
  187. merged_dict: Dict,
  188. fetch_details: bool = True,
  189. time_filter: Optional[str] = None,
  190. exclude_post_ids: Optional[Set[str]] = None
  191. ) -> Dict:
  192. """
  193. 将字典格式转换为数组格式,并添加帖子详情
  194. Args:
  195. merged_dict: 字典格式的结果
  196. fetch_details: 是否获取帖子详情,默认为True
  197. time_filter: 时间过滤阈值,只保留发布时间<该时间的帖子,格式为 "YYYY-MM-DD HH:MM:SS"
  198. exclude_post_ids: 要排除的帖子ID集合
  199. Returns:
  200. 数组格式的结果
  201. """
  202. result = {
  203. "灵感点": [],
  204. "目的点": [],
  205. "关键点": []
  206. }
  207. # 收集所有需要获取详情的帖子ID
  208. post_ids = set()
  209. if fetch_details:
  210. for category in ["灵感点", "目的点", "关键点"]:
  211. for feature_name, data in merged_dict[category].items():
  212. for source in data["来源"]:
  213. post_ids.add(source["帖子id"])
  214. # 批量获取帖子详情
  215. print(f"\n正在获取 {len(post_ids)} 个帖子的详情...")
  216. post_details = {}
  217. for i, post_id in enumerate(post_ids, 1):
  218. print(f"[{i}/{len(post_ids)}] 获取帖子 {post_id} 的详情...")
  219. detail = get_post_detail(post_id)
  220. if detail:
  221. post_details[post_id] = detail
  222. print(f"成功获取 {len(post_details)} 个帖子详情")
  223. # 应用过滤规则
  224. filtered_count = 0
  225. # 1. 如果启用帖子ID过滤
  226. if exclude_post_ids:
  227. print(f"\n正在应用帖子ID过滤,排除 {len(exclude_post_ids)} 个当前帖子...")
  228. before_count = len(post_details)
  229. post_details = {pid: detail for pid, detail in post_details.items() if pid not in exclude_post_ids}
  230. filtered_count = before_count - len(post_details)
  231. if filtered_count > 0:
  232. print(f" ⚠️ 过滤掉 {filtered_count} 个当前帖子")
  233. print(f"保留 {len(post_details)} 个帖子")
  234. # 2. 如果启用时间过滤(过滤掉发布时间晚于等于阈值的帖子,避免穿越)
  235. elif time_filter:
  236. print(f"\n正在应用时间过滤 (< {time_filter}),避免使用晚于当前帖子的数据...")
  237. filtered_post_ids = set()
  238. for post_id, detail in post_details.items():
  239. publish_time = detail.get('publish_time', '')
  240. if publish_time < time_filter:
  241. filtered_post_ids.add(post_id)
  242. else:
  243. filtered_count += 1
  244. print(f" ⚠️ 过滤掉帖子 {post_id} (发布时间: {publish_time},晚于阈值)")
  245. print(f"过滤掉 {filtered_count} 个帖子(穿越),保留 {len(filtered_post_ids)} 个帖子")
  246. # 更新post_details,只保留符合时间条件的
  247. post_details = {pid: detail for pid, detail in post_details.items() if pid in filtered_post_ids}
  248. # 转换为数组格式并添加帖子详情
  249. for category in ["灵感点", "目的点", "关键点"]:
  250. for feature_name, data in merged_dict[category].items():
  251. # 为每个来源添加帖子详情
  252. enhanced_sources = []
  253. for source in data["来源"]:
  254. # 如果启用过滤,跳过不符合条件的帖子
  255. if fetch_details and (time_filter or exclude_post_ids) and source["帖子id"] not in post_details:
  256. continue
  257. enhanced_source = source.copy()
  258. if fetch_details and source["帖子id"] in post_details:
  259. enhanced_source["帖子详情"] = post_details[source["帖子id"]]
  260. enhanced_sources.append(enhanced_source)
  261. # 只添加有来源的特征
  262. if enhanced_sources:
  263. result[category].append({
  264. "特征名称": feature_name,
  265. "特征来源": enhanced_sources
  266. })
  267. return result
  268. def get_current_post_ids(current_posts_dir: Path) -> Set[str]:
  269. """
  270. 获取当前帖子目录中的所有帖子ID
  271. Args:
  272. current_posts_dir: 当前帖子目录路径
  273. Returns:
  274. 当前帖子ID集合
  275. """
  276. if not current_posts_dir.exists():
  277. print(f"警告: 当前帖子目录不存在: {current_posts_dir}")
  278. return set()
  279. json_files = list(current_posts_dir.glob("*.json"))
  280. if not json_files:
  281. print(f"警告: 当前帖子目录为空: {current_posts_dir}")
  282. return set()
  283. print(f"\n正在获取当前帖子ID...")
  284. print(f"找到 {len(json_files)} 个当前帖子")
  285. post_ids = set()
  286. for file_path in json_files:
  287. post_id = extract_post_id_from_filename(file_path.name)
  288. if post_id:
  289. post_ids.add(post_id)
  290. print(f"提取到 {len(post_ids)} 个帖子ID")
  291. return post_ids
  292. def get_earliest_publish_time(current_posts_dir: Path) -> Optional[str]:
  293. """
  294. 获取当前帖子目录中最早的发布时间
  295. Args:
  296. current_posts_dir: 当前帖子目录路径
  297. Returns:
  298. 最早的发布时间字符串,格式为 "YYYY-MM-DD HH:MM:SS"
  299. """
  300. if not current_posts_dir.exists():
  301. print(f"警告: 当前帖子目录不存在: {current_posts_dir}")
  302. return None
  303. json_files = list(current_posts_dir.glob("*.json"))
  304. if not json_files:
  305. print(f"警告: 当前帖子目录为空: {current_posts_dir}")
  306. return None
  307. print(f"\n正在获取当前帖子的发布时间...")
  308. print(f"找到 {len(json_files)} 个当前帖子")
  309. earliest_time = None
  310. for file_path in json_files:
  311. post_id = extract_post_id_from_filename(file_path.name)
  312. if not post_id:
  313. continue
  314. try:
  315. detail = get_post_detail(post_id)
  316. if detail and 'publish_time' in detail:
  317. publish_time = detail['publish_time']
  318. if earliest_time is None or publish_time < earliest_time:
  319. earliest_time = publish_time
  320. print(f" 更新最早时间: {publish_time} (帖子: {post_id})")
  321. except Exception as e:
  322. print(f" 警告: 获取帖子 {post_id} 发布时间失败: {e}")
  323. if earliest_time:
  324. print(f"\n当前帖子最早发布时间: {earliest_time}")
  325. else:
  326. print("\n警告: 未能获取到任何当前帖子的发布时间")
  327. return earliest_time
  328. def main():
  329. # 使用路径配置
  330. config = PathConfig()
  331. # 确保输出目录存在
  332. config.ensure_dirs()
  333. # 获取路径
  334. input_dir = config.historical_posts_dir
  335. current_posts_dir = config.current_posts_dir
  336. output_file = config.feature_source_mapping_file
  337. print(f"账号: {config.account_name}")
  338. print(f"过滤模式: {config.filter_mode}")
  339. print(f"过去帖子目录: {input_dir}")
  340. print(f"当前帖子目录: {current_posts_dir}")
  341. print(f"输出文件: {output_file}")
  342. print()
  343. print(f"\n正在扫描目录: {input_dir}")
  344. # 获取所有JSON文件
  345. json_files = list(input_dir.glob("*.json"))
  346. print(f"找到 {len(json_files)} 个JSON文件")
  347. # 处理所有文件
  348. all_results = []
  349. for i, file_path in enumerate(json_files, 1):
  350. print(f"处理文件 [{i}/{len(json_files)}]: {file_path.name}")
  351. result = process_single_file(file_path)
  352. all_results.append(result)
  353. # 合并结果
  354. print("\n正在合并结果...")
  355. merged_result = merge_results(all_results)
  356. # 根据配置的过滤模式应用过滤
  357. filter_mode = config.filter_mode
  358. time_filter = None
  359. exclude_post_ids = None
  360. if filter_mode == "exclude_current_posts":
  361. # 新规则:排除当前帖子ID
  362. print("\n应用过滤规则: 排除当前帖子ID")
  363. exclude_post_ids = get_current_post_ids(current_posts_dir)
  364. elif filter_mode == "time_based":
  365. # 旧规则:基于发布时间
  366. print("\n应用过滤规则: 基于发布时间")
  367. time_filter = get_earliest_publish_time(current_posts_dir)
  368. elif filter_mode == "none":
  369. print("\n过滤模式: none,不应用任何过滤")
  370. else:
  371. print(f"\n警告: 未知的过滤模式 '{filter_mode}',不应用过滤")
  372. # 转换为数组格式(带过滤)
  373. print("正在转换为数组格式...")
  374. final_result = convert_to_array_format(
  375. merged_result,
  376. fetch_details=True,
  377. time_filter=time_filter,
  378. exclude_post_ids=exclude_post_ids
  379. )
  380. # 统计信息
  381. print(f"\n提取统计:")
  382. for category in ["灵感点", "目的点", "关键点"]:
  383. feature_count = len(final_result[category])
  384. source_count = sum(len(item["特征来源"]) for item in final_result[category])
  385. print(f" {category}: {feature_count} 个特征, {source_count} 个来源")
  386. # 保存结果
  387. print(f"\n正在保存结果到: {output_file}")
  388. with open(output_file, "w", encoding="utf-8") as f:
  389. json.dump(final_result, f, ensure_ascii=False, indent=4)
  390. print("完成!")
  391. if __name__ == "__main__":
  392. main()