generate_case.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. #!/usr/bin/env python3
  2. """
  3. 从 raw_cases/source.json 生成标准化的 case.json
  4. 职责:
  5. 1. 读取 raw_cases/source.json(原始 source 格式)
  6. 2. 标准化字段格式(title, body, author, images, url, note)
  7. 3. 下载图片到本地 + 上传到 OSS
  8. 4. 输出到需求目录根下的 case.json
  9. 输出格式:
  10. index, category, user_kept, user_comment, description, method,
  11. cover, title, author, body, images, url, note, _raw, workflow, capabilities
  12. """
  13. import asyncio
  14. import hashlib
  15. import json
  16. from pathlib import Path
  17. from typing import Any, Dict, List, Optional
  18. # ── OSS 工具 ──────────────────────────────────────
  19. CDN_BASE = "https://res.cybertogether.net"
  20. def _is_oss_url(url: str) -> bool:
  21. return url.startswith(CDN_BASE)
  22. def _ext_from_path(path: str) -> str:
  23. """从 URL 猜测扩展名,默认 jpg"""
  24. p = path.split("?")[0].lower()
  25. for ext in ("png", "gif", "webp", "avif", "bmp", "svg", "jpg", "jpeg"):
  26. if p.endswith(f".{ext}"):
  27. return ext
  28. return "jpg"
  29. async def _upload_bytes(data: bytes, filename: str) -> str:
  30. """上传 bytes 到 OSS,返回 CDN URL"""
  31. from agent.tools.builtin.file.image_cdn import _upload_bytes_to_oss
  32. return await _upload_bytes_to_oss(data, filename)
  33. async def _upload_remote(url: str, cache: Dict[str, str]) -> str:
  34. """下载外链图片并上传到 OSS,返回 CDN URL"""
  35. key = hashlib.md5(url.encode()).hexdigest()[:12]
  36. if key in cache:
  37. return cache[key]
  38. from agent.tools.builtin.file.image_cdn import _download_image
  39. data = await _download_image(url)
  40. ext = _ext_from_path(url)
  41. cdn_url = await _upload_bytes(data, f"{key}.{ext}")
  42. cache[key] = cdn_url
  43. return cdn_url
  44. async def ensure_oss_url(url: str, cache: Dict[str, str]) -> str:
  45. """确保图片是 OSS CDN URL"""
  46. if _is_oss_url(url):
  47. return url
  48. if url.startswith("http"):
  49. return await _upload_remote(url, cache)
  50. raise ValueError(f"Invalid image URL: {url}")
  51. # ── 字段提取(各平台差异处理)────────────────────────────────────
  52. def _extract_author(post: Dict[str, Any], platform: str) -> str:
  53. """字段映射:author / channel_account_name / channel"""
  54. if platform == "x":
  55. return post.get("channel_account_name") or post.get("author") or ""
  56. if platform == "youtube":
  57. return post.get("channel") or post.get("author") or ""
  58. return post.get("author") or ""
  59. def _extract_url(post: Dict[str, Any], platform: str) -> str:
  60. """字段映射:url / link / content_link"""
  61. if platform == "youtube":
  62. return post.get("content_link") or post.get("url") or ""
  63. return post.get("url") or post.get("link") or ""
  64. def _extract_body(post: Dict[str, Any], platform: str) -> str:
  65. """字段映射:body_text / description"""
  66. if platform == "youtube":
  67. return post.get("description") or post.get("body_text") or ""
  68. return post.get("body_text") or ""
  69. def _extract_raw_images(post: Dict[str, Any], platform: str) -> List[str]:
  70. """字段映射:images / image_url_list / cover_url"""
  71. # 优先 images 字段
  72. if post.get("images"):
  73. imgs = post["images"]
  74. if isinstance(imgs, list) and imgs:
  75. return [i for i in imgs if i]
  76. # 其次 image_url_list
  77. if post.get("image_url_list"):
  78. raw = post["image_url_list"]
  79. if isinstance(raw, list):
  80. result = []
  81. for item in raw:
  82. if isinstance(item, dict):
  83. result.append(item.get("image_url") or "")
  84. else:
  85. result.append(item or "")
  86. result = [u for u in result if u]
  87. if result:
  88. return result
  89. # 最后兜底 cover_url
  90. if post.get("cover_url"):
  91. return [post["cover_url"]]
  92. return []
  93. def _parse_published_at(timestamp_str: str) -> Optional[str]:
  94. """
  95. 解析 publish_timestamp 为 ISO 8601 格式(timestamptz)
  96. 支持格式:
  97. - "2026-01-09 15:53:00"
  98. - "2026-01-09T15:53:00"
  99. - 空字符串或 None 返回 None
  100. Returns:
  101. ISO 8601 格式字符串 (e.g., "2026-01-09T15:53:00+00:00") 或 None
  102. """
  103. if not timestamp_str or not isinstance(timestamp_str, str):
  104. return None
  105. timestamp_str = timestamp_str.strip()
  106. if not timestamp_str:
  107. return None
  108. try:
  109. from datetime import datetime, timezone
  110. # 尝试解析常见格式
  111. for fmt in [
  112. "%Y-%m-%d %H:%M:%S",
  113. "%Y-%m-%dT%H:%M:%S",
  114. "%Y-%m-%d %H:%M:%S.%f",
  115. "%Y-%m-%dT%H:%M:%S.%f",
  116. ]:
  117. try:
  118. dt = datetime.strptime(timestamp_str, fmt)
  119. # 假设输入是 UTC 时间,添加时区信息
  120. dt = dt.replace(tzinfo=timezone.utc)
  121. # 返回 ISO 8601 格式
  122. return dt.isoformat()
  123. except ValueError:
  124. continue
  125. # 如果都失败了,返回 None
  126. return None
  127. except Exception:
  128. return None
  129. # ── 单条记录标准化 ────────────────────────────────────────────────────────────
  130. async def normalize_source_item(
  131. source_item: Dict[str, Any],
  132. index: int,
  133. upload_cache: Dict[str, str],
  134. images_dir: Path,
  135. ) -> Dict[str, Any]:
  136. """
  137. 将单条 source item 转换为标准化的 case 格式
  138. """
  139. # 从 source item 提取字段
  140. platform = source_item.get("platform", "")
  141. post = source_item.get("post", {})
  142. case_id = source_item.get("case_id", f"{platform}_{source_item.get('channel_content_id', '')}")
  143. title = post.get("title", "")
  144. author = _extract_author(post, platform)
  145. body = _extract_body(post, platform)
  146. url = _extract_url(post, platform) or source_item.get("source_url", "")
  147. # 收集反馈数据(兼容不同平台,没有的字段填 None)
  148. feedback = {
  149. "like_count": post.get("like_count") if post.get("like_count") is not None else None,
  150. "collect_count": post.get("collect_count") if post.get("collect_count") is not None else None,
  151. "comment_count": post.get("comment_count") if post.get("comment_count") is not None else None,
  152. "share_count": post.get("share_count") if post.get("share_count") is not None else None,
  153. }
  154. # 用于 note 字段的简化显示
  155. likes = feedback["like_count"] or 0
  156. comments = feedback["comment_count"] or 0
  157. # 解析发布时间
  158. publish_timestamp = post.get("publish_timestamp", "")
  159. published_at = _parse_published_at(publish_timestamp)
  160. # 处理图片:下载到本地 + 上传 OSS
  161. raw_images = _extract_raw_images(post, platform)
  162. images: List[str] = []
  163. case_dir = images_dir / case_id
  164. case_dir.mkdir(parents=True, exist_ok=True)
  165. for idx, img_url in enumerate(raw_images):
  166. ext = _ext_from_path(img_url)
  167. local_path = case_dir / f"{idx:02d}.{ext}"
  168. try:
  169. # 下载到本地
  170. if not local_path.exists():
  171. print(f" 📥 [{idx+1}/{len(raw_images)}] 下载图片...")
  172. from agent.tools.builtin.file.image_cdn import _download_image
  173. data = await _download_image(img_url)
  174. local_path.write_bytes(data)
  175. print(f" 📥 [{idx+1}/{len(raw_images)}] 已保存 {local_path.name} ({len(data)} bytes)")
  176. else:
  177. print(f" 📁 [{idx+1}/{len(raw_images)}] 本地已存在 {local_path.name}")
  178. # 上传到 OSS
  179. if _is_oss_url(img_url):
  180. images.append(img_url)
  181. print(f" ☁️ [{idx+1}/{len(raw_images)}] 已是 CDN URL")
  182. else:
  183. print(f" ☁️ [{idx+1}/{len(raw_images)}] 上传 OSS...")
  184. cdn_url = await ensure_oss_url(img_url, upload_cache)
  185. images.append(cdn_url)
  186. print(f" ☁️ [{idx+1}/{len(raw_images)}] 上传完成")
  187. except Exception as e:
  188. print(f" ⚠ [{idx+1}/{len(raw_images)}] 图片处理失败: {str(e)[:60]}")
  189. # 兜底:对 body 里的外链图片也替换为 CDN
  190. try:
  191. from agent.tools.builtin.file.image_cdn import replace_image_urls
  192. body = await replace_image_urls(body)
  193. except Exception:
  194. pass
  195. cover = images[0] if images else ""
  196. return {
  197. "index": index,
  198. "category": "",
  199. "user_kept": False,
  200. "user_comment": "",
  201. "description": "",
  202. "method": "",
  203. "cover": cover,
  204. "title": title,
  205. "author": author,
  206. "body": body,
  207. "images": images,
  208. "url": url,
  209. "note": f"platform={platform} | likes={likes} | comments={comments}",
  210. "published_at": published_at, # bigint, nullable
  211. "feedback": feedback,
  212. "_raw": {
  213. "case_id": case_id,
  214. "platform": platform,
  215. "channel_content_id": source_item.get("channel_content_id", ""),
  216. },
  217. "workflow": None,
  218. "capabilities": None,
  219. }
  220. # ── 主入口 ────────────────────────────────
  221. async def generate_case_from_source(
  222. raw_cases_dir: Path,
  223. output_file: Optional[Path] = None,
  224. ) -> Dict[str, Any]:
  225. """
  226. 从 raw_cases/source.json 生成标准化的 case.json
  227. 如果 case.json 已存在,会保留已有的 workflow 和 capabilities
  228. """
  229. raw_cases_dir = Path(raw_cases_dir)
  230. source_file = raw_cases_dir / "source.json"
  231. if not source_file.exists():
  232. raise FileNotFoundError(f"source.json not found: {source_file}")
  233. # 读取 source.json
  234. with open(source_file, "r", encoding="utf-8") as f:
  235. source_data = json.load(f)
  236. sources = source_data.get("sources", [])
  237. print(f"Processing {len(sources)} sources...")
  238. # 读取已有的 case.json(如果存在)
  239. if output_file is None:
  240. output_file = raw_cases_dir.parent / "case.json"
  241. existing_cases = {}
  242. if output_file.exists():
  243. try:
  244. with open(output_file, "r", encoding="utf-8") as f:
  245. existing_data = json.load(f)
  246. existing_list = existing_data.get("cases", [])
  247. # 建立 case_id -> case 的映射
  248. for case in existing_list:
  249. case_id = case.get("_raw", {}).get("case_id")
  250. if case_id:
  251. existing_cases[case_id] = case
  252. print(f"Found {len(existing_cases)} existing cases, will preserve workflow and capabilities")
  253. except Exception as e:
  254. print(f"Warning: Failed to read existing case.json: {e}")
  255. # 准备图片目录
  256. images_dir = raw_cases_dir / "images"
  257. images_dir.mkdir(parents=True, exist_ok=True)
  258. # 标准化所有 source items
  259. cases: List[Dict[str, Any]] = []
  260. upload_cache: Dict[str, str] = {}
  261. for idx, source_item in enumerate(sources, 1):
  262. try:
  263. case = await normalize_source_item(
  264. source_item=source_item,
  265. index=idx,
  266. upload_cache=upload_cache,
  267. images_dir=images_dir,
  268. )
  269. # 如果已有该 case,保留其 workflow 和 capabilities
  270. case_id = case.get("_raw", {}).get("case_id")
  271. if case_id and case_id in existing_cases:
  272. existing = existing_cases[case_id]
  273. if existing.get("workflow") is not None:
  274. case["workflow"] = existing["workflow"]
  275. if existing.get("capabilities") is not None:
  276. case["capabilities"] = existing["capabilities"]
  277. print(f" [{idx}] {case['title'][:40]} (preserved workflow & capabilities)")
  278. else:
  279. print(f" [{idx}] {case['title'][:40]}")
  280. cases.append(case)
  281. except Exception as e:
  282. print(f" [{idx}] ✗ 失败: {e}")
  283. # 输出 case.json
  284. output_data = {
  285. "total": len(cases),
  286. "cases": cases,
  287. }
  288. output_file.parent.mkdir(parents=True, exist_ok=True)
  289. with open(output_file, "w", encoding="utf-8") as f:
  290. json.dump(output_data, f, ensure_ascii=False, indent=2)
  291. return {
  292. "total_cases": len(cases),
  293. "output_file": str(output_file),
  294. }
  295. if __name__ == "__main__":
  296. import sys
  297. if len(sys.argv) < 2:
  298. print("Usage: python generate_case.py <raw_cases_dir>")
  299. sys.exit(1)
  300. raw_cases_dir = Path(sys.argv[1])
  301. stats = asyncio.run(generate_case_from_source(raw_cases_dir))
  302. print(f"\n✓ Generated {stats['total_cases']} cases")
  303. print(f"→ {stats['output_file']}")