extract_sources.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. """
  2. 从 raw_cases/case_*.json 中提取 source_url / 帖子链接,
  3. 解析 channel_content_id,再从 .cache/content_search 中查找对应的原始帖子数据。
  4. 主函数:extract_sources_to_json(raw_cases_dir)
  5. - 扫描该目录下所有 case_{platform}.json
  6. - 解析每个 "工序发现[].帖子链接"(新格式)或 "cases[].source_url"(旧格式)
  7. - 从项目根的 .cache/content_search/*.json 中匹配 channel_content_id
  8. - 把匹配到的完整 post 写入 {raw_cases_dir}/source.json
  9. - 同时下载图片到 {raw_cases_dir}/images/{case_id}/
  10. """
  11. import json
  12. import re
  13. from pathlib import Path
  14. from typing import Any, Dict, List, Optional, Tuple
  15. import asyncio
  16. import aiohttp
  17. from urllib.parse import urlparse, parse_qs, urlencode
  18. # ── URL → (platform, content_id) 解析 ────────────────────────────────
  19. _URL_PATTERNS = [
  20. # B站: https://www.bilibili.com/video/BV1xxx
  21. ("bili", re.compile(r"bilibili\.com/video/(BV[\w]+)")),
  22. # 小红书: https://www.xiaohongshu.com/explore/{id} 或 /discovery/item/{id}
  23. ("xhs", re.compile(r"xiaohongshu\.com/(?:explore|discovery/item)/([a-f0-9]+)")),
  24. # YouTube: https://www.youtube.com/watch?v={id} 或 https://youtu.be/{id}
  25. ("youtube", re.compile(r"(?:youtube\.com/watch\?v=|youtu\.be/)([\w-]+)")),
  26. # X/Twitter: https://x.com/{user}/status/{id} 或 twitter.com
  27. ("x", re.compile(r"(?:x\.com|twitter\.com)/[^/]+/status/(\d+)")),
  28. # 知乎: https://zhuanlan.zhihu.com/p/{id} 或 zhihu.com/question/{qid}/answer/{aid}
  29. ("zhihu", re.compile(r"zhuanlan\.zhihu\.com/p/(\d+)")),
  30. ("zhihu", re.compile(r"zhihu\.com/question/\d+/answer/(\d+)")),
  31. # 公众号: 通过 __biz 或整个 URL 作为 id(后备)
  32. ("gzh", re.compile(r"mp\.weixin\.qq\.com/s[/?]([^\s\"']+)")),
  33. ]
  34. def parse_url(url: str) -> Optional[Tuple[str, str]]:
  35. """从 URL 解析出 (platform, content_id)。返回 None 表示无法解析。"""
  36. if not url or not isinstance(url, str):
  37. return None
  38. for platform, pat in _URL_PATTERNS:
  39. m = pat.search(url)
  40. if m:
  41. return platform, m.group(1)
  42. return None
  43. # ── 从 case 文件中抽取所有条目(URL + evaluation) ────────────────────────────────
  44. def extract_entries_from_case(case_data: Any) -> List[Dict[str, Any]]:
  45. """
  46. 从 case 数据中抽取条目,每条包含 url 和 agent 的 evaluation。
  47. 返回: [{"url": str, "evaluation": dict | None, "title": str | None, "case_id": str | None}]
  48. """
  49. entries: List[Dict[str, Any]] = []
  50. if not isinstance(case_data, dict):
  51. return entries
  52. # 新格式:工序发现[].帖子链接
  53. for item in case_data.get("工序发现", []) or []:
  54. if isinstance(item, dict):
  55. link = item.get("帖子链接") or item.get("source_url")
  56. if link:
  57. entries.append({
  58. "url": link,
  59. "evaluation": item.get("evaluation"),
  60. "title": item.get("title"),
  61. "case_id": item.get("case_id"),
  62. })
  63. # 主格式:cases[]
  64. for item in case_data.get("cases", []) or []:
  65. if isinstance(item, dict):
  66. link = item.get("source_url") or item.get("帖子链接")
  67. if link:
  68. entries.append({
  69. "url": link,
  70. "evaluation": item.get("evaluation"),
  71. "title": item.get("title"),
  72. "case_id": item.get("case_id"),
  73. })
  74. return entries
  75. def extract_urls_from_case(case_data: Any) -> List[str]:
  76. """[Legacy] 旧接口,保留供可能的外部调用。内部改用 extract_entries_from_case。"""
  77. return [e["url"] for e in extract_entries_from_case(case_data)]
  78. # ── 过滤规则(统一入口) ────────────────────────────────
  79. # 默认阈值:body_text 最少字符数、agent 评分下限
  80. DEFAULT_MIN_BODY_LEN = 30
  81. DEFAULT_MIN_SCORE = 70.0
  82. DEFAULT_CUTOFF_DATE = (2025, 10, 1)
  83. def _is_before_cutoff(source: Dict[str, Any], cutoff_ts: int) -> bool:
  84. """判断帖子是否早于截止时间戳(秒级)
  85. 如果 timestamp 为 0 或不存在,返回 False(保留)
  86. """
  87. post = source.get("post", {})
  88. if not isinstance(post, dict):
  89. return False
  90. ts = post.get("publish_timestamp")
  91. # 没有 timestamp 或为 0,保留
  92. if not ts or ts == 0:
  93. return False
  94. try:
  95. ts = int(ts)
  96. # 毫秒级转秒级
  97. if ts > 1000000000000:
  98. ts = ts / 1000
  99. return ts < cutoff_ts
  100. except (ValueError, TypeError):
  101. pass
  102. # 尝试解析字符串格式 "2025-05-02 19:25:30"
  103. try:
  104. from datetime import datetime
  105. dt = datetime.strptime(str(ts), "%Y-%m-%d %H:%M:%S")
  106. return dt.timestamp() < cutoff_ts
  107. except Exception:
  108. # 解析失败,保留
  109. return False
  110. def _check_filters(
  111. source: Dict[str, Any],
  112. cutoff_ts: int,
  113. min_body_len: int,
  114. min_score: float,
  115. ) -> Optional[str]:
  116. """
  117. 对一条 source 逐条检查过滤规则。
  118. 返回:
  119. None —— 条目合格
  120. 非 None 字符串 —— 不合格的原因(写入 filter_reason)
  121. """
  122. post = source.get("post", {}) or {}
  123. # 1. body_text 完整性
  124. body = post.get("body_text") or post.get("desc") or ""
  125. if not isinstance(body, str) or len(body.strip()) < min_body_len:
  126. return f"missing_body_text:len={len(body.strip()) if isinstance(body, str) else 0}"
  127. # 2. agent 评分
  128. evaluation = source.get("evaluation")
  129. if not isinstance(evaluation, dict):
  130. return "missing_evaluation"
  131. quality = evaluation.get("quality")
  132. if not isinstance(quality, dict):
  133. return "missing_evaluation"
  134. score = quality.get("overall_score")
  135. if not isinstance(score, (int, float)):
  136. return "invalid_score"
  137. if score < min_score:
  138. return f"low_score:{score}"
  139. # 3. 过时(复用原 _is_before_cutoff 对时间戳的解析)
  140. if _is_before_cutoff(source, cutoff_ts):
  141. return "outdated"
  142. return None
  143. # ── 从 cache 中构建 (platform, content_id) → post 索引 ────────────────────────────────
  144. def _normalize_url(url: str) -> Optional[str]:
  145. """规范化 URL:排序 query 参数,去掉尾斜杠"""
  146. try:
  147. parsed = urlparse(url)
  148. if parsed.query:
  149. # 排序 query 参数
  150. params = parse_qs(parsed.query, keep_blank_values=True)
  151. sorted_query = urlencode(sorted(params.items()), doseq=True)
  152. normalized = f"{parsed.scheme}://{parsed.netloc}{parsed.path}?{sorted_query}"
  153. else:
  154. normalized = f"{parsed.scheme}://{parsed.netloc}{parsed.path}"
  155. return normalized.rstrip('/')
  156. except:
  157. return None
  158. def build_cache_index(cache_dir: Path, trace_ids: Optional[List[str]] = None) -> Dict[Tuple[str, str], Dict[str, Any]]:
  159. """
  160. 构建 (platform, channel_content_id) -> post 映射。
  161. Args:
  162. cache_dir: cache 目录路径
  163. trace_ids: 可选的 trace_id 列表。如果提供,只加载这些特定的 cache 文件;
  164. 否则扫描所有 cache 文件
  165. Returns:
  166. (platform, content_id) -> post 的映射字典
  167. """
  168. index: Dict[Tuple[str, str], Dict[str, Any]] = {}
  169. if not cache_dir.exists():
  170. return index
  171. # 如果提供了 trace_ids,只加载这些特定文件
  172. if trace_ids:
  173. cache_files = [cache_dir / f"{tid}.json" for tid in trace_ids if tid]
  174. cache_files = [f for f in cache_files if f.exists()]
  175. else:
  176. # 否则扫描所有 cache 文件
  177. cache_files = list(cache_dir.glob("*.json"))
  178. for cache_file in cache_files:
  179. try:
  180. with open(cache_file, "r", encoding="utf-8") as f:
  181. data = json.load(f)
  182. except Exception:
  183. continue
  184. for key, entry in data.items():
  185. if not key.startswith("search:"):
  186. continue
  187. platform = key.split(":", 1)[1]
  188. # 新格式:entry = {"history": [...], "latest_index": n}
  189. # 旧格式:entry = {"keyword": ..., "posts": [...]}
  190. if isinstance(entry, dict) and "history" in entry:
  191. post_lists = [h.get("posts", []) for h in entry.get("history", [])]
  192. elif isinstance(entry, dict) and "posts" in entry:
  193. post_lists = [entry.get("posts", [])]
  194. else:
  195. continue
  196. for posts in post_lists:
  197. for post in posts or []:
  198. if not isinstance(post, dict):
  199. continue
  200. cid = post.get("channel_content_id")
  201. # YouTube 平台用 video_id 而非 channel_content_id
  202. if not cid and post.get("video_id"):
  203. cid = post.get("video_id")
  204. post["channel_content_id"] = cid # 补全字段
  205. if cid:
  206. # 用 (platform, content_id) 作为索引键
  207. index[(platform, str(cid))] = post
  208. # 额外用 link / url 字段建立索引(用于 GZH 等平台)
  209. link = post.get("link") or post.get("url") or ""
  210. if link:
  211. index[("url", link)] = post
  212. # 规范化 URL 索引(排序 query 参数,去掉尾斜杠)
  213. norm = _normalize_url(link)
  214. if norm:
  215. index[("norm_url", norm)] = post
  216. return index
  217. # ── 主入口 ────────────────────────────────
  218. def extract_sources_to_json(
  219. raw_cases_dir: Path,
  220. cache_dir: Optional[Path] = None,
  221. output_name: str = "source.json",
  222. trace_ids: Optional[List[str]] = None,
  223. min_body_len: int = DEFAULT_MIN_BODY_LEN,
  224. min_score: float = DEFAULT_MIN_SCORE,
  225. cutoff_date: Tuple[int, int, int] = DEFAULT_CUTOFF_DATE,
  226. ) -> Dict[str, Any]:
  227. """
  228. 扫描 raw_cases_dir 下的 case_*.json,
  229. 从 cache 中找出原始帖子,并对结果执行统一过滤(body_text / 评分 / 时效)。
  230. 过滤规则(均可通过参数调整):
  231. - body_text 为空或少于 min_body_len 字符 → filter_reason=missing_body_text
  232. - agent evaluation 缺失或 weighted_score < min_score → filter_reason=missing_evaluation / invalid_score / low_score
  233. - publish_timestamp 早于 cutoff_date → filter_reason=outdated
  234. 参数:
  235. min_body_len: body_text 最少字符数,默认 30
  236. min_score: agent 评分下限,默认 70
  237. cutoff_date: 过时截止日期 (year, month, day),默认 (2025, 10, 1)
  238. 返回统计信息 dict。
  239. """
  240. raw_cases_dir = Path(raw_cases_dir)
  241. if cache_dir is None:
  242. # 项目根目录:script 文件往上三级
  243. project_root = Path(__file__).resolve().parent.parent.parent.parent
  244. cache_dir = project_root / ".cache" / "content_search"
  245. cache_dir = Path(cache_dir)
  246. # 1. 构建 cache 索引
  247. cache_index = build_cache_index(cache_dir, trace_ids=trace_ids)
  248. # 2. 加载已有的 source.json(如果存在)
  249. output_file = raw_cases_dir / output_name
  250. existing_sources = []
  251. existing_ids = set() # (platform, channel_content_id) 集合用于去重
  252. if output_file.exists():
  253. try:
  254. with open(output_file, "r", encoding="utf-8") as f:
  255. existing_data = json.load(f)
  256. existing_sources = existing_data.get("sources", [])
  257. # 构建已有的 ID 集合
  258. for src in existing_sources:
  259. key = (src.get("platform"), src.get("channel_content_id"))
  260. existing_ids.add(key)
  261. except Exception as e:
  262. print(f"Warning: Failed to load existing source.json: {e}")
  263. # 2.1 加载已有的 filtered_cases.json(如果存在)
  264. filtered_output_file = raw_cases_dir / "filtered_cases.json"
  265. existing_filtered_sources = []
  266. existing_filtered_ids = set()
  267. if filtered_output_file.exists():
  268. try:
  269. with open(filtered_output_file, "r", encoding="utf-8") as f:
  270. filtered_data = json.load(f)
  271. existing_filtered_sources = filtered_data.get("sources", [])
  272. for src in existing_filtered_sources:
  273. key = (src.get("platform"), src.get("channel_content_id"))
  274. existing_filtered_ids.add(key)
  275. except Exception as e:
  276. print(f"Warning: Failed to load existing filtered_cases.json: {e}")
  277. # 3. 扫描所有 case 文件
  278. matched: List[Dict[str, Any]] = []
  279. unmatched: List[Dict[str, Any]] = []
  280. seen_keys: set = set(existing_ids) # 从已有的 ID 开始
  281. for case_file in sorted(raw_cases_dir.glob("case_*.json")):
  282. # 跳过自己(如果 source.json 误被命名成 case_*)
  283. if case_file.name == output_name:
  284. continue
  285. try:
  286. with open(case_file, "r", encoding="utf-8") as f:
  287. case_data = json.load(f)
  288. except Exception as e:
  289. # 尝试自动修复 JSON 格式错误
  290. try:
  291. from examples.process_pipeline.script.fix_json_quotes import try_fix_and_parse
  292. with open(case_file, "r", encoding="utf-8") as f:
  293. raw_content = f.read()
  294. success, case_data, fix_desc = try_fix_and_parse(raw_content)
  295. if success:
  296. # 修复成功,写回文件
  297. with open(case_file, "w", encoding="utf-8") as f:
  298. json.dump(case_data, f, ensure_ascii=False, indent=2)
  299. print(f" 🔧 [Auto-Fix] Fixed {case_file.name}: {fix_desc}")
  300. else:
  301. unmatched.append({"case_file": case_file.name, "error": str(e)})
  302. continue
  303. except Exception:
  304. unmatched.append({"case_file": case_file.name, "error": str(e)})
  305. continue
  306. entries = extract_entries_from_case(case_data)
  307. for entry in entries:
  308. url = entry["url"]
  309. evaluation = entry.get("evaluation")
  310. # 解析 URL 得到 platform 和 content_id
  311. parsed = parse_url(url)
  312. if not parsed:
  313. unmatched.append({
  314. "case_file": case_file.name,
  315. "url": url,
  316. "reason": "url_parse_failed",
  317. })
  318. continue
  319. platform, cid = parsed
  320. key = (platform, cid)
  321. if key in seen_keys:
  322. continue
  323. seen_keys.add(key)
  324. # 多级匹配:
  325. # 1. (platform, content_id) 精确匹配
  326. # 2. 完整 URL 匹配
  327. # 3. 规范化 URL 匹配
  328. post = cache_index.get(key)
  329. if not post:
  330. # 2. 完整 URL
  331. post = cache_index.get(("url", url))
  332. if not post:
  333. # 3. 规范化 URL
  334. norm = _normalize_url(url)
  335. if norm:
  336. post = cache_index.get(("norm_url", norm))
  337. if post:
  338. # 统一用 cache 中的 channel_content_id 生成 case_id
  339. # 这样保证 case_id 和 cache 中的 ID 一致
  340. actual_cid = post.get("channel_content_id") or post.get("video_id") or cid
  341. actual_case_id = f"{platform}_{actual_cid}"
  342. matched.append({
  343. "case_id": actual_case_id,
  344. "case_file": case_file.name,
  345. "platform": platform,
  346. "channel_content_id": str(actual_cid),
  347. "source_url": url,
  348. "evaluation": evaluation,
  349. "post": post,
  350. })
  351. else:
  352. unmatched.append({
  353. "case_id": f"{platform}_{cid}", # 统一格式的 ID
  354. "case_file": case_file.name,
  355. "platform": platform,
  356. "channel_content_id": cid,
  357. "source_url": url,
  358. "reason": "not_in_cache",
  359. })
  360. # 4. 合并已有数据和新匹配的数据
  361. all_sources = existing_sources + matched
  362. # 5. 统一过滤:body_text 完整性 / agent 评分 / 时效
  363. from datetime import datetime as _dt
  364. cutoff_ts = int(_dt(*cutoff_date).timestamp())
  365. kept_sources: List[Dict[str, Any]] = []
  366. filtered_sources: List[Dict[str, Any]] = []
  367. reason_counts: Dict[str, int] = {}
  368. for s in all_sources:
  369. reason = _check_filters(s, cutoff_ts, min_body_len, min_score)
  370. if reason is None:
  371. kept_sources.append(s)
  372. else:
  373. # 记录过滤原因
  374. s_copy = dict(s)
  375. s_copy["filter_reason"] = reason
  376. filtered_sources.append(s_copy)
  377. # 统计原因类型(只取冒号前的类别)
  378. category = reason.split(":", 1)[0]
  379. reason_counts[category] = reason_counts.get(category, 0) + 1
  380. all_sources = kept_sources
  381. filtered_count = len(filtered_sources)
  382. # 6. 转换 timestamp 为可读格式
  383. _convert_timestamps(all_sources)
  384. _convert_timestamps(filtered_sources)
  385. # 7. 写 source.json
  386. output = {
  387. "total": len(all_sources),
  388. "cache_dir": str(cache_dir),
  389. "sources": all_sources,
  390. }
  391. output_file.parent.mkdir(parents=True, exist_ok=True)
  392. with open(output_file, "w", encoding="utf-8") as f:
  393. json.dump(output, f, ensure_ascii=False, indent=2)
  394. # 8. 写 filtered_cases.json(被过滤掉的帖子,按原因分组)
  395. if filtered_sources:
  396. for fs in filtered_sources:
  397. key = (fs.get("platform"), fs.get("channel_content_id"))
  398. if key not in existing_filtered_ids:
  399. existing_filtered_sources.append(fs)
  400. existing_filtered_ids.add(key)
  401. # 按原因类别分组
  402. by_reason: Dict[str, List[Dict[str, Any]]] = {}
  403. for fs in existing_filtered_sources:
  404. reason = fs.get("filter_reason", "unknown")
  405. category = reason.split(":", 1)[0]
  406. by_reason.setdefault(category, []).append(fs)
  407. filtered_output = {
  408. "total": len(existing_filtered_sources),
  409. "by_reason": {
  410. category: {
  411. "count": len(items),
  412. "sources": items,
  413. }
  414. for category, items in by_reason.items()
  415. },
  416. }
  417. with open(filtered_output_file, "w", encoding="utf-8") as f:
  418. json.dump(filtered_output, f, ensure_ascii=False, indent=2)
  419. # 9. 下载图片到 raw_cases/images/{case_id}/
  420. images_downloaded = download_images_for_sources(matched, raw_cases_dir)
  421. # 10. 构建被过滤条目的摘要(供续跑 feedback 使用)
  422. filtered_details: List[Dict[str, Any]] = []
  423. for fs in filtered_sources:
  424. post = fs.get("post", {}) or {}
  425. title = post.get("title") or fs.get("source_url", "")
  426. filtered_details.append({
  427. "case_id": fs.get("case_id", ""),
  428. "platform": fs.get("platform", ""),
  429. "title": title[:60] if title else "",
  430. "filter_reason": fs.get("filter_reason", ""),
  431. })
  432. # 返回统计信息
  433. return {
  434. "total_matched": len(matched),
  435. "total_existing": len(existing_sources),
  436. "total_unmatched": len(unmatched),
  437. "filtered_total": filtered_count,
  438. "filtered_reasons": reason_counts,
  439. "filtered_details": filtered_details,
  440. "images_downloaded": images_downloaded,
  441. "output_file": str(output_file),
  442. }
  443. # ── 图片下载 ────────────────────────────────
  444. def _get_image_urls_from_post(post: Dict[str, Any]) -> List[str]:
  445. """从 post 中提取所有图片 URL"""
  446. urls = []
  447. images = post.get("images", [])
  448. if isinstance(images, list):
  449. for img in images:
  450. if isinstance(img, str) and img.startswith("http"):
  451. urls.append(img)
  452. elif isinstance(img, dict) and "url" in img:
  453. urls.append(img["url"])
  454. image_url_list = post.get("image_url_list", [])
  455. if isinstance(image_url_list, list):
  456. for img_obj in image_url_list:
  457. if isinstance(img_obj, dict) and "image_url" in img_obj:
  458. urls.append(img_obj["image_url"])
  459. return urls
  460. def _format_timestamp(ts: Any) -> Optional[str]:
  461. """将时间戳(秒/毫秒)转换为可读格式"""
  462. from datetime import datetime
  463. if ts is None or ts == 0 or ts == "":
  464. return None
  465. try:
  466. ts = int(ts)
  467. # 毫秒级时间戳
  468. if ts > 1000000000000:
  469. ts = ts / 1000
  470. dt = datetime.fromtimestamp(ts)
  471. return dt.strftime("%Y-%m-%d %H:%M:%S")
  472. except Exception:
  473. return None
  474. def _convert_timestamps(sources: List[Dict[str, Any]]) -> None:
  475. """将 source 列表中 post 的时间戳字段替换为可读格式"""
  476. timestamp_fields = ["publish_timestamp", "modify_timestamp", "update_timestamp"]
  477. for src in sources:
  478. post = src.get("post", {})
  479. if not isinstance(post, dict):
  480. continue
  481. for field in timestamp_fields:
  482. if field in post:
  483. readable = _format_timestamp(post.get(field))
  484. if readable:
  485. post[field] = readable
  486. def download_images_for_sources(sources: List[Dict[str, Any]], raw_cases_dir: Path) -> int:
  487. """
  488. 为新匹配的 sources 下载图片到 raw_cases/images/{case_id}/
  489. Returns:
  490. 下载成功的图片总数
  491. """
  492. import urllib.request
  493. import urllib.error
  494. images_base = raw_cases_dir / "images"
  495. total_downloaded = 0
  496. # 设置 headers 避免被拒(X/Twitter 需要 User-Agent)
  497. opener = urllib.request.build_opener()
  498. opener.addheaders = [("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")]
  499. urllib.request.install_opener(opener)
  500. for src in sources:
  501. case_id = src.get("case_id", "")
  502. post = src.get("post", {})
  503. image_urls = _get_image_urls_from_post(post)
  504. if not image_urls:
  505. continue
  506. case_dir = images_base / case_id
  507. case_dir.mkdir(parents=True, exist_ok=True)
  508. for idx, url in enumerate(image_urls):
  509. # 确定文件扩展名
  510. ext = ".jpg"
  511. if ".png" in url.lower():
  512. ext = ".png"
  513. elif ".webp" in url.lower():
  514. ext = ".webp"
  515. save_path = case_dir / f"{idx:02d}{ext}"
  516. if save_path.exists():
  517. total_downloaded += 1
  518. continue
  519. try:
  520. urllib.request.urlretrieve(url, str(save_path))
  521. total_downloaded += 1
  522. except Exception:
  523. # 下载失败就跳过,不中断流程
  524. pass
  525. return total_downloaded
  526. if __name__ == "__main__":
  527. # CLI:python extract_sources.py <raw_cases_dir> [cache_dir]
  528. import sys
  529. if len(sys.argv) < 2:
  530. print("Usage: python extract_sources.py <raw_cases_dir> [cache_dir]")
  531. sys.exit(1)
  532. raw_cases_dir = Path(sys.argv[1])
  533. cache_dir = Path(sys.argv[2]) if len(sys.argv) > 2 else None
  534. result = extract_sources_to_json(raw_cases_dir, cache_dir=cache_dir)
  535. print(f"[OK] Matched: {result['total_matched']}, Unmatched: {result['total_unmatched']}")
  536. print(f" Output: {raw_cases_dir / 'source.json'}")