extract_sources.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. """
  2. 从 raw_cases/case_*.json 中提取 source_url / 帖子链接,
  3. 解析 channel_content_id,再从 .cache/content_search 中查找对应的原始帖子数据。
  4. 主函数:extract_sources_to_json(raw_cases_dir)
  5. - 扫描该目录下所有 case_{platform}.json
  6. - 解析每个 "工序发现[].帖子链接"(新格式)或 "cases[].source_url"(旧格式)
  7. - 从项目根的 .cache/content_search/*.json 中匹配 channel_content_id
  8. - 把匹配到的完整 post 写入 {raw_cases_dir}/source.json
  9. - 同时下载图片到 {raw_cases_dir}/images/{case_id}/
  10. """
  11. import json
  12. import re
  13. from pathlib import Path
  14. from typing import Any, Dict, List, Optional, Tuple
  15. import asyncio
  16. import aiohttp
  17. from urllib.parse import urlparse, parse_qs, urlencode
  18. # ── URL → (platform, content_id) 解析 ────────────────────────────────
  19. _URL_PATTERNS = [
  20. # B站: https://www.bilibili.com/video/BV1xxx
  21. ("bili", re.compile(r"bilibili\.com/video/(BV[\w]+)")),
  22. # 小红书: https://www.xiaohongshu.com/explore/{id} 或 /discovery/item/{id}
  23. ("xhs", re.compile(r"xiaohongshu\.com/(?:explore|discovery/item)/([a-f0-9]+)")),
  24. # YouTube: https://www.youtube.com/watch?v={id} 或 https://youtu.be/{id}
  25. ("youtube", re.compile(r"(?:youtube\.com/watch\?v=|youtu\.be/)([\w-]+)")),
  26. # X/Twitter: https://x.com/{user}/status/{id} 或 twitter.com
  27. ("x", re.compile(r"(?:x\.com|twitter\.com)/[^/]+/status/(\d+)")),
  28. # 知乎: https://zhuanlan.zhihu.com/p/{id} 或 zhihu.com/question/{qid}/answer/{aid}
  29. ("zhihu", re.compile(r"zhuanlan\.zhihu\.com/p/(\d+)")),
  30. ("zhihu", re.compile(r"zhihu\.com/question/\d+/answer/(\d+)")),
  31. # 公众号: 通过 __biz 或整个 URL 作为 id(后备)
  32. ("gzh", re.compile(r"mp\.weixin\.qq\.com/s[/?]([^\s\"']+)")),
  33. ]
  34. def parse_url(url: str) -> Optional[Tuple[str, str]]:
  35. """从 URL 解析出 (platform, content_id)。返回 None 表示无法解析。"""
  36. if not url or not isinstance(url, str):
  37. return None
  38. for platform, pat in _URL_PATTERNS:
  39. m = pat.search(url)
  40. if m:
  41. return platform, m.group(1)
  42. return None
  43. # ── 从 case 文件中抽取所有链接 ────────────────────────────────
  44. def extract_urls_from_case(case_data: Any) -> List[str]:
  45. """兼容新旧两种格式,返回 case 文件里出现的所有 URL。"""
  46. urls: List[str] = []
  47. if not isinstance(case_data, dict):
  48. return urls
  49. # 新格式:工序发现[].帖子链接
  50. for item in case_data.get("工序发现", []) or []:
  51. if isinstance(item, dict):
  52. link = item.get("帖子链接") or item.get("source_url")
  53. if link:
  54. urls.append(link)
  55. # 旧格式:cases[].source_url
  56. for item in case_data.get("cases", []) or []:
  57. if isinstance(item, dict):
  58. link = item.get("source_url") or item.get("帖子链接")
  59. if link:
  60. urls.append(link)
  61. return urls
  62. # ── 从 cache 中构建 (platform, content_id) → post 索引 ────────────────────────────────
  63. def _normalize_url(url: str) -> Optional[str]:
  64. """规范化 URL:排序 query 参数,去掉尾斜杠"""
  65. try:
  66. parsed = urlparse(url)
  67. if parsed.query:
  68. # 排序 query 参数
  69. params = parse_qs(parsed.query, keep_blank_values=True)
  70. sorted_query = urlencode(sorted(params.items()), doseq=True)
  71. normalized = f"{parsed.scheme}://{parsed.netloc}{parsed.path}?{sorted_query}"
  72. else:
  73. normalized = f"{parsed.scheme}://{parsed.netloc}{parsed.path}"
  74. return normalized.rstrip('/')
  75. except:
  76. return None
  77. def build_cache_index(cache_dir: Path, trace_ids: Optional[List[str]] = None) -> Dict[Tuple[str, str], Dict[str, Any]]:
  78. """
  79. 构建 (platform, channel_content_id) -> post 映射。
  80. Args:
  81. cache_dir: cache 目录路径
  82. trace_ids: 可选的 trace_id 列表。如果提供,只加载这些特定的 cache 文件;
  83. 否则扫描所有 cache 文件
  84. Returns:
  85. (platform, content_id) -> post 的映射字典
  86. """
  87. index: Dict[Tuple[str, str], Dict[str, Any]] = {}
  88. if not cache_dir.exists():
  89. return index
  90. # 如果提供了 trace_ids,只加载这些特定文件
  91. if trace_ids:
  92. cache_files = [cache_dir / f"{tid}.json" for tid in trace_ids if tid]
  93. cache_files = [f for f in cache_files if f.exists()]
  94. else:
  95. # 否则扫描所有 cache 文件
  96. cache_files = list(cache_dir.glob("*.json"))
  97. for cache_file in cache_files:
  98. try:
  99. with open(cache_file, "r", encoding="utf-8") as f:
  100. data = json.load(f)
  101. except Exception:
  102. continue
  103. for key, entry in data.items():
  104. if not key.startswith("search:"):
  105. continue
  106. platform = key.split(":", 1)[1]
  107. # 新格式:entry = {"history": [...], "latest_index": n}
  108. # 旧格式:entry = {"keyword": ..., "posts": [...]}
  109. if isinstance(entry, dict) and "history" in entry:
  110. post_lists = [h.get("posts", []) for h in entry.get("history", [])]
  111. elif isinstance(entry, dict) and "posts" in entry:
  112. post_lists = [entry.get("posts", [])]
  113. else:
  114. continue
  115. for posts in post_lists:
  116. for post in posts or []:
  117. if not isinstance(post, dict):
  118. continue
  119. cid = post.get("channel_content_id")
  120. # YouTube 平台用 video_id 而非 channel_content_id
  121. if not cid and post.get("video_id"):
  122. cid = post.get("video_id")
  123. post["channel_content_id"] = cid # 补全字段
  124. if cid:
  125. # 用 (platform, content_id) 作为索引键
  126. index[(platform, str(cid))] = post
  127. # 额外用 link / url 字段建立索引(用于 GZH 等平台)
  128. link = post.get("link") or post.get("url") or ""
  129. if link:
  130. index[("url", link)] = post
  131. # 规范化 URL 索引(排序 query 参数,去掉尾斜杠)
  132. norm = _normalize_url(link)
  133. if norm:
  134. index[("norm_url", norm)] = post
  135. return index
  136. # ── 主入口 ────────────────────────────────
  137. def extract_sources_to_json(
  138. raw_cases_dir: Path,
  139. cache_dir: Optional[Path] = None,
  140. output_name: str = "source.json",
  141. trace_ids: Optional[List[str]] = None,
  142. ) -> Dict[str, Any]:
  143. """
  144. 扫描 raw_cases_dir 下的 case_*.json,
  145. 从 cache 中找出原始帖子,输出到 {raw_cases_dir}/{output_name}。
  146. 返回统计信息 dict。
  147. """
  148. raw_cases_dir = Path(raw_cases_dir)
  149. if cache_dir is None:
  150. # 项目根目录:script 文件往上三级
  151. project_root = Path(__file__).resolve().parent.parent.parent.parent
  152. cache_dir = project_root / ".cache" / "content_search"
  153. cache_dir = Path(cache_dir)
  154. # 1. 构建 cache 索引
  155. cache_index = build_cache_index(cache_dir, trace_ids=trace_ids)
  156. # 2. 加载已有的 source.json(如果存在)
  157. output_file = raw_cases_dir / output_name
  158. existing_sources = []
  159. existing_ids = set() # (platform, channel_content_id) 集合用于去重
  160. if output_file.exists():
  161. try:
  162. with open(output_file, "r", encoding="utf-8") as f:
  163. existing_data = json.load(f)
  164. existing_sources = existing_data.get("sources", [])
  165. # 构建已有的 ID 集合
  166. for src in existing_sources:
  167. key = (src.get("platform"), src.get("channel_content_id"))
  168. existing_ids.add(key)
  169. except Exception as e:
  170. print(f"Warning: Failed to load existing source.json: {e}")
  171. # 2.1 加载已有的 filtered_cases.json(如果存在)
  172. filtered_output_file = raw_cases_dir / "filtered_cases.json"
  173. existing_filtered_sources = []
  174. existing_filtered_ids = set()
  175. if filtered_output_file.exists():
  176. try:
  177. with open(filtered_output_file, "r", encoding="utf-8") as f:
  178. filtered_data = json.load(f)
  179. existing_filtered_sources = filtered_data.get("sources", [])
  180. for src in existing_filtered_sources:
  181. key = (src.get("platform"), src.get("channel_content_id"))
  182. existing_filtered_ids.add(key)
  183. except Exception as e:
  184. print(f"Warning: Failed to load existing filtered_cases.json: {e}")
  185. # 3. 扫描所有 case 文件
  186. matched: List[Dict[str, Any]] = []
  187. unmatched: List[Dict[str, Any]] = []
  188. seen_keys: set = set(existing_ids) # 从已有的 ID 开始
  189. for case_file in sorted(raw_cases_dir.glob("case_*.json")):
  190. # 跳过自己(如果 source.json 误被命名成 case_*)
  191. if case_file.name == output_name:
  192. continue
  193. try:
  194. with open(case_file, "r", encoding="utf-8") as f:
  195. case_data = json.load(f)
  196. except Exception as e:
  197. # 尝试自动修复 JSON 格式错误
  198. try:
  199. from examples.process_pipeline.script.fix_json_quotes import try_fix_and_parse
  200. with open(case_file, "r", encoding="utf-8") as f:
  201. raw_content = f.read()
  202. success, case_data, fix_desc = try_fix_and_parse(raw_content)
  203. if success:
  204. # 修复成功,写回文件
  205. with open(case_file, "w", encoding="utf-8") as f:
  206. json.dump(case_data, f, ensure_ascii=False, indent=2)
  207. print(f" 🔧 [Auto-Fix] Fixed {case_file.name}: {fix_desc}")
  208. else:
  209. unmatched.append({"case_file": case_file.name, "error": str(e)})
  210. continue
  211. except Exception:
  212. unmatched.append({"case_file": case_file.name, "error": str(e)})
  213. continue
  214. urls = extract_urls_from_case(case_data)
  215. for url in urls:
  216. # 解析 URL 得到 platform 和 content_id
  217. parsed = parse_url(url)
  218. if not parsed:
  219. unmatched.append({
  220. "case_file": case_file.name,
  221. "url": url,
  222. "reason": "url_parse_failed",
  223. })
  224. continue
  225. platform, cid = parsed
  226. key = (platform, cid)
  227. if key in seen_keys:
  228. continue
  229. seen_keys.add(key)
  230. # 多级匹配:
  231. # 1. (platform, content_id) 精确匹配
  232. # 2. 完整 URL 匹配
  233. # 3. 规范化 URL 匹配
  234. post = cache_index.get(key)
  235. if not post:
  236. # 2. 完整 URL
  237. post = cache_index.get(("url", url))
  238. if not post:
  239. # 3. 规范化 URL
  240. norm = _normalize_url(url)
  241. if norm:
  242. post = cache_index.get(("norm_url", norm))
  243. if post:
  244. # 统一用 cache 中的 channel_content_id 生成 case_id
  245. # 这样保证 case_id 和 cache 中的 ID 一致
  246. actual_cid = post.get("channel_content_id") or post.get("video_id") or cid
  247. actual_case_id = f"{platform}_{actual_cid}"
  248. matched.append({
  249. "case_id": actual_case_id,
  250. "case_file": case_file.name,
  251. "platform": platform,
  252. "channel_content_id": str(actual_cid),
  253. "source_url": url,
  254. "post": post,
  255. })
  256. else:
  257. unmatched.append({
  258. "case_id": f"{platform}_{cid}", # 统一格式的 ID
  259. "case_file": case_file.name,
  260. "platform": platform,
  261. "channel_content_id": cid,
  262. "source_url": url,
  263. "reason": "not_in_cache",
  264. })
  265. # 4. 合并已有数据和新匹配的数据
  266. all_sources = existing_sources + matched
  267. # 5. 过滤掉 2025-10 之前的过时帖子
  268. from datetime import datetime as _dt
  269. cutoff_ts = int(_dt(2025, 10, 1).timestamp()) # 本地时区的 2025-10-01
  270. before_filter = len(all_sources)
  271. filtered_sources = [s for s in all_sources if _is_before_cutoff(s, cutoff_ts)]
  272. all_sources = [s for s in all_sources if not _is_before_cutoff(s, cutoff_ts)]
  273. filtered_count = before_filter - len(all_sources)
  274. # 6. 转换 timestamp 为可读格式
  275. _convert_timestamps(all_sources)
  276. _convert_timestamps(filtered_sources)
  277. # 7. 写 source.json
  278. output = {
  279. "total": len(all_sources),
  280. "cache_dir": str(cache_dir),
  281. "sources": all_sources,
  282. }
  283. output_file.parent.mkdir(parents=True, exist_ok=True)
  284. with open(output_file, "w", encoding="utf-8") as f:
  285. json.dump(output, f, ensure_ascii=False, indent=2)
  286. # 8. 写 filtered_cases.json(被过滤掉的帖子,去重后追加)
  287. if filtered_sources:
  288. for fs in filtered_sources:
  289. key = (fs.get("platform"), fs.get("channel_content_id"))
  290. if key not in existing_filtered_ids:
  291. existing_filtered_sources.append(fs)
  292. existing_filtered_ids.add(key)
  293. filtered_output = {
  294. "total": len(existing_filtered_sources),
  295. "reason": "publish_timestamp before 2025-10-01",
  296. "sources": existing_filtered_sources,
  297. }
  298. with open(filtered_output_file, "w", encoding="utf-8") as f:
  299. json.dump(filtered_output, f, ensure_ascii=False, indent=2)
  300. # 9. 下载图片到 raw_cases/images/{case_id}/
  301. images_downloaded = download_images_for_sources(matched, raw_cases_dir)
  302. # 返回统计信息(包含 unmatched 用于日志输出)
  303. return {
  304. "total_matched": len(matched),
  305. "total_existing": len(existing_sources),
  306. "total_unmatched": len(unmatched),
  307. "filtered_outdated": filtered_count,
  308. "images_downloaded": images_downloaded,
  309. "output_file": str(output_file),
  310. }
  311. # ── 图片下载 ────────────────────────────────
  312. def _get_image_urls_from_post(post: Dict[str, Any]) -> List[str]:
  313. """从 post 中提取所有图片 URL"""
  314. urls = []
  315. images = post.get("images", [])
  316. if isinstance(images, list):
  317. for img in images:
  318. if isinstance(img, str) and img.startswith("http"):
  319. urls.append(img)
  320. elif isinstance(img, dict) and "url" in img:
  321. urls.append(img["url"])
  322. image_url_list = post.get("image_url_list", [])
  323. if isinstance(image_url_list, list):
  324. for img_obj in image_url_list:
  325. if isinstance(img_obj, dict) and "image_url" in img_obj:
  326. urls.append(img_obj["image_url"])
  327. return urls
  328. def _is_before_cutoff(source: Dict[str, Any], cutoff_ts: int) -> bool:
  329. """判断帖子是否早于截止时间戳(秒级)
  330. 如果 timestamp 为 0 或不存在,返回 False(保留)
  331. """
  332. post = source.get("post", {})
  333. if not isinstance(post, dict):
  334. return False
  335. ts = post.get("publish_timestamp")
  336. # 没有 timestamp 或为 0,保留
  337. if not ts or ts == 0:
  338. return False
  339. try:
  340. ts = int(ts)
  341. # 毫秒级转秒级
  342. if ts > 1000000000000:
  343. ts = ts / 1000
  344. return ts < cutoff_ts
  345. except (ValueError, TypeError):
  346. pass
  347. # 尝试解析字符串格式 "2025-05-02 19:25:30"
  348. try:
  349. from datetime import datetime
  350. dt = datetime.strptime(str(ts), "%Y-%m-%d %H:%M:%S")
  351. return dt.timestamp() < cutoff_ts
  352. except Exception:
  353. # 解析失败,保留
  354. return False
  355. def _format_timestamp(ts: Any) -> Optional[str]:
  356. """将时间戳(秒/毫秒)转换为可读格式"""
  357. from datetime import datetime
  358. if ts is None or ts == 0 or ts == "":
  359. return None
  360. try:
  361. ts = int(ts)
  362. # 毫秒级时间戳
  363. if ts > 1000000000000:
  364. ts = ts / 1000
  365. dt = datetime.fromtimestamp(ts)
  366. return dt.strftime("%Y-%m-%d %H:%M:%S")
  367. except Exception:
  368. return None
  369. def _convert_timestamps(sources: List[Dict[str, Any]]) -> None:
  370. """将 source 列表中 post 的时间戳字段替换为可读格式"""
  371. timestamp_fields = ["publish_timestamp", "modify_timestamp", "update_timestamp"]
  372. for src in sources:
  373. post = src.get("post", {})
  374. if not isinstance(post, dict):
  375. continue
  376. for field in timestamp_fields:
  377. if field in post:
  378. readable = _format_timestamp(post.get(field))
  379. if readable:
  380. post[field] = readable
  381. def download_images_for_sources(sources: List[Dict[str, Any]], raw_cases_dir: Path) -> int:
  382. """
  383. 为新匹配的 sources 下载图片到 raw_cases/images/{case_id}/
  384. Returns:
  385. 下载成功的图片总数
  386. """
  387. import urllib.request
  388. import urllib.error
  389. images_base = raw_cases_dir / "images"
  390. total_downloaded = 0
  391. # 设置 headers 避免被拒(X/Twitter 需要 User-Agent)
  392. opener = urllib.request.build_opener()
  393. opener.addheaders = [("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")]
  394. urllib.request.install_opener(opener)
  395. for src in sources:
  396. case_id = src.get("case_id", "")
  397. post = src.get("post", {})
  398. image_urls = _get_image_urls_from_post(post)
  399. if not image_urls:
  400. continue
  401. case_dir = images_base / case_id
  402. case_dir.mkdir(parents=True, exist_ok=True)
  403. for idx, url in enumerate(image_urls):
  404. # 确定文件扩展名
  405. ext = ".jpg"
  406. if ".png" in url.lower():
  407. ext = ".png"
  408. elif ".webp" in url.lower():
  409. ext = ".webp"
  410. save_path = case_dir / f"{idx:02d}{ext}"
  411. if save_path.exists():
  412. total_downloaded += 1
  413. continue
  414. try:
  415. urllib.request.urlretrieve(url, str(save_path))
  416. total_downloaded += 1
  417. except Exception:
  418. # 下载失败就跳过,不中断流程
  419. pass
  420. return total_downloaded
  421. if __name__ == "__main__":
  422. # CLI:python extract_sources.py <raw_cases_dir> [cache_dir]
  423. import sys
  424. if len(sys.argv) < 2:
  425. print("Usage: python extract_sources.py <raw_cases_dir> [cache_dir]")
  426. sys.exit(1)
  427. raw_cases_dir = Path(sys.argv[1])
  428. cache_dir = Path(sys.argv[2]) if len(sys.argv) > 2 else None
  429. result = extract_sources_to_json(raw_cases_dir, cache_dir=cache_dir)
  430. print(f"[OK] Matched: {result['total_matched']}, Unmatched: {result['total_unmatched']}")
  431. print(f" Output: {raw_cases_dir / 'source.json'}")