generate_case.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. #!/usr/bin/env python3
  2. """
  3. 从 raw_cases/source.json 生成标准化的 case.json
  4. 职责:
  5. 1. 读取 raw_cases/source.json(原始 source 格式)
  6. 2. 标准化字段格式(title, body, author, images, url, note)
  7. 3. 下载图片到本地 + 上传到 OSS
  8. 4. 输出到需求目录根下的 case.json
  9. 输出格式:
  10. index, category, user_kept, user_comment, description, method,
  11. cover, title, author, body, images, url, note, _raw, workflow, capability
  12. """
  13. import asyncio
  14. import hashlib
  15. import json
  16. from pathlib import Path
  17. from typing import Any, Dict, List, Optional
  18. # ── OSS 工具 ──────────────────────────────────────
  19. CDN_BASE = "https://res.cybertogether.net"
  20. def _is_oss_url(url: str) -> bool:
  21. return url.startswith(CDN_BASE)
  22. def _ext_from_path(path: str) -> str:
  23. """从 URL 猜测扩展名,默认 jpg"""
  24. p = path.split("?")[0].lower()
  25. for ext in ("png", "gif", "webp", "avif", "bmp", "svg", "jpg", "jpeg"):
  26. if p.endswith(f".{ext}"):
  27. return ext
  28. return "jpg"
  29. async def _upload_bytes(data: bytes, filename: str) -> str:
  30. """上传 bytes 到 OSS,返回 CDN URL"""
  31. from agent.tools.builtin.file.image_cdn import _upload_bytes_to_oss
  32. return await _upload_bytes_to_oss(data, filename)
  33. async def _upload_remote(url: str, cache: Dict[str, str]) -> str:
  34. """下载外链图片并上传到 OSS,返回 CDN URL"""
  35. key = hashlib.md5(url.encode()).hexdigest()[:12]
  36. if key in cache:
  37. return cache[key]
  38. from agent.tools.builtin.file.image_cdn import _download_image
  39. data = await _download_image(url)
  40. ext = _ext_from_path(url)
  41. cdn_url = await _upload_bytes(data, f"{key}.{ext}")
  42. cache[key] = cdn_url
  43. return cdn_url
  44. async def ensure_oss_url(url: str, cache: Dict[str, str]) -> str:
  45. """确保图片是 OSS CDN URL"""
  46. if _is_oss_url(url):
  47. return url
  48. if url.startswith("http"):
  49. return await _upload_remote(url, cache)
  50. raise ValueError(f"Invalid image URL: {url}")
  51. # ── 字段提取(各平台差异处理)────────────────────────────────────
  52. def _extract_author(post: Dict[str, Any], platform: str) -> str:
  53. """字段映射:author / channel_account_name / channel"""
  54. if platform == "x":
  55. return post.get("channel_account_name") or post.get("author") or ""
  56. if platform == "youtube":
  57. return post.get("channel") or post.get("author") or ""
  58. return post.get("author") or ""
  59. def _extract_url(post: Dict[str, Any], platform: str) -> str:
  60. """字段映射:url / link / content_link"""
  61. if platform == "youtube":
  62. return post.get("content_link") or post.get("url") or ""
  63. return post.get("url") or post.get("link") or ""
  64. def _extract_body(post: Dict[str, Any], platform: str) -> str:
  65. """字段映射:body_text / description;视频帖把 video_transcript 也并入。
  66. 抖音 / 视频号等视频平台 post.body_text 通常只是几个 hashtag(甚至为空),
  67. 而真正的内容在 video_transcript(Deepgram 转写)/ captions(YouTube 官方字幕)。
  68. 把两者拼起来让下游评分、过滤、agent prompt 都能看到完整内容。
  69. 幂等:如果 body_text 已经被 extract_sources 合并过(含 `[视频字幕]` 标记),
  70. 直接返回原 body,避免重复 append。
  71. """
  72. if platform == "youtube":
  73. body = post.get("description") or post.get("body_text") or ""
  74. else:
  75. body = post.get("body_text") or ""
  76. body = body.strip() if isinstance(body, str) else ""
  77. # 已经合并过 → 跳过,避免重复
  78. if body and "[视频字幕]" in body:
  79. return body
  80. transcript = post.get("video_transcript") or post.get("captions") or ""
  81. transcript = transcript.strip() if isinstance(transcript, str) else ""
  82. if transcript and body:
  83. return f"{body}\n\n[视频字幕]\n{transcript}"
  84. return transcript or body
  85. def _extract_raw_images(post: Dict[str, Any], platform: str) -> List[str]:
  86. """字段映射:images / image_url_list / cover_url"""
  87. # 优先 images 字段
  88. if post.get("images"):
  89. imgs = post["images"]
  90. if isinstance(imgs, list) and imgs:
  91. return [i for i in imgs if i]
  92. # 其次 image_url_list
  93. if post.get("image_url_list"):
  94. raw = post["image_url_list"]
  95. if isinstance(raw, list):
  96. result = []
  97. for item in raw:
  98. if isinstance(item, dict):
  99. result.append(item.get("image_url") or "")
  100. else:
  101. result.append(item or "")
  102. result = [u for u in result if u]
  103. if result:
  104. return result
  105. # 最后兜底 cover_url
  106. if post.get("cover_url"):
  107. return [post["cover_url"]]
  108. return []
  109. def _parse_published_at(timestamp_str: str) -> Optional[str]:
  110. """
  111. 解析 publish_timestamp 为 ISO 8601 格式(timestamptz)
  112. 支持格式:
  113. - "2026-01-09 15:53:00"
  114. - "2026-01-09T15:53:00"
  115. - 空字符串或 None 返回 None
  116. Returns:
  117. ISO 8601 格式字符串 (e.g., "2026-01-09T15:53:00+00:00") 或 None
  118. """
  119. if not timestamp_str or not isinstance(timestamp_str, str):
  120. return None
  121. timestamp_str = timestamp_str.strip()
  122. if not timestamp_str:
  123. return None
  124. try:
  125. from datetime import datetime, timezone
  126. # 尝试解析常见格式
  127. for fmt in [
  128. "%Y-%m-%d %H:%M:%S",
  129. "%Y-%m-%dT%H:%M:%S",
  130. "%Y-%m-%d %H:%M:%S.%f",
  131. "%Y-%m-%dT%H:%M:%S.%f",
  132. ]:
  133. try:
  134. dt = datetime.strptime(timestamp_str, fmt)
  135. # 假设输入是 UTC 时间,添加时区信息
  136. dt = dt.replace(tzinfo=timezone.utc)
  137. # 返回 ISO 8601 格式
  138. return dt.isoformat()
  139. except ValueError:
  140. continue
  141. # 如果都失败了,返回 None
  142. return None
  143. except Exception:
  144. return None
  145. # ── 单条记录标准化 ────────────────────────────────────────────────────────────
  146. async def normalize_source_item(
  147. source_item: Dict[str, Any],
  148. index: int,
  149. upload_cache: Dict[str, str],
  150. images_dir: Path,
  151. ) -> Dict[str, Any]:
  152. """
  153. 将单条 source item 转换为标准化的 case 格式
  154. """
  155. # 从 source item 提取字段
  156. platform = source_item.get("platform", "")
  157. post = source_item.get("post", {})
  158. case_id = source_item.get("case_id", f"{platform}_{source_item.get('channel_content_id', '')}")
  159. body = _extract_body(post, platform)
  160. title = post.get("title") or source_item.get("title") or post.get("desc") or (body[:30] + "..." if body else "") or case_id
  161. author = _extract_author(post, platform)
  162. url = _extract_url(post, platform) or source_item.get("source_url", "")
  163. # 收集反馈数据(兼容不同平台,没有的字段填 None)
  164. feedback = {
  165. "like_count": post.get("like_count") if post.get("like_count") is not None else None,
  166. "collect_count": post.get("collect_count") if post.get("collect_count") is not None else None,
  167. "comment_count": post.get("comment_count") if post.get("comment_count") is not None else None,
  168. "share_count": post.get("share_count") if post.get("share_count") is not None else None,
  169. }
  170. # 用于 note 字段的简化显示
  171. likes = feedback["like_count"] or 0
  172. comments = feedback["comment_count"] or 0
  173. # 解析发布时间
  174. publish_timestamp = post.get("publish_timestamp", "")
  175. published_at = _parse_published_at(publish_timestamp)
  176. # 处理图片:下载到本地 + 上传 OSS
  177. raw_images = _extract_raw_images(post, platform)
  178. images: List[str] = []
  179. case_dir = images_dir / case_id
  180. case_dir.mkdir(parents=True, exist_ok=True)
  181. for idx, img_url in enumerate(raw_images):
  182. ext = _ext_from_path(img_url)
  183. local_path = case_dir / f"{idx:02d}.{ext}"
  184. try:
  185. # 下载到本地
  186. if not local_path.exists():
  187. print(f" 📥 [{idx+1}/{len(raw_images)}] 下载图片...")
  188. from agent.tools.builtin.file.image_cdn import _download_image
  189. data = await _download_image(img_url)
  190. local_path.write_bytes(data)
  191. print(f" 📥 [{idx+1}/{len(raw_images)}] 已保存 {local_path.name} ({len(data)} bytes)")
  192. else:
  193. print(f" 📁 [{idx+1}/{len(raw_images)}] 本地已存在 {local_path.name}")
  194. # 上传到 OSS
  195. if _is_oss_url(img_url):
  196. images.append(img_url)
  197. print(f" ☁️ [{idx+1}/{len(raw_images)}] 已是 CDN URL")
  198. else:
  199. print(f" ☁️ [{idx+1}/{len(raw_images)}] 上传 OSS...")
  200. cdn_url = await ensure_oss_url(img_url, upload_cache)
  201. images.append(cdn_url)
  202. print(f" ☁️ [{idx+1}/{len(raw_images)}] 上传完成")
  203. except Exception as e:
  204. print(f" ⚠ [{idx+1}/{len(raw_images)}] 图片处理失败: {str(e)[:60]}")
  205. images.append(img_url)
  206. # 兜底:对 body 里的外链图片也替换为 CDN
  207. try:
  208. from agent.tools.builtin.file.image_cdn import replace_image_urls
  209. body = await replace_image_urls(body)
  210. except Exception:
  211. pass
  212. cover = images[0] if images else ""
  213. author_comments = source_item.get("comments") or []
  214. return {
  215. "index": index,
  216. "category": "",
  217. "user_kept": False,
  218. "user_comment": "",
  219. "description": "",
  220. "method": "",
  221. "cover": cover,
  222. "title": title,
  223. "author": author,
  224. "body": body,
  225. "comments": author_comments,
  226. "images": images,
  227. "url": url,
  228. "note": f"platform={platform} | likes={likes} | comments={comments}",
  229. "published_at": published_at, # bigint, nullable
  230. "feedback": feedback,
  231. "_raw": {
  232. "case_id": case_id,
  233. "platform": platform,
  234. "channel_content_id": source_item.get("channel_content_id", ""),
  235. },
  236. "workflow": None,
  237. "capability": None,
  238. }
  239. # ── 主入口 ────────────────────────────────
  240. async def generate_case_from_source(
  241. raw_cases_dir: Path,
  242. output_file: Optional[Path] = None,
  243. ) -> Dict[str, Any]:
  244. """
  245. 从 raw_cases/source.json 生成标准化的 case.json
  246. 如果 case.json 已存在,会保留已有的 workflow 和 capability
  247. """
  248. raw_cases_dir = Path(raw_cases_dir)
  249. source_file = raw_cases_dir / "source.json"
  250. if not source_file.exists():
  251. raise FileNotFoundError(f"source.json not found: {source_file}")
  252. # 读取 source.json
  253. with open(source_file, "r", encoding="utf-8") as f:
  254. source_data = json.load(f)
  255. sources = source_data.get("sources", [])
  256. print(f"Processing {len(sources)} sources...")
  257. # 读取已有的 case.json(如果存在)
  258. if output_file is None:
  259. output_file = raw_cases_dir.parent / "case.json"
  260. existing_cases = {}
  261. if output_file.exists():
  262. try:
  263. with open(output_file, "r", encoding="utf-8") as f:
  264. existing_data = json.load(f)
  265. existing_list = existing_data.get("cases", [])
  266. # 建立 case_id -> case 的映射
  267. for case in existing_list:
  268. case_id = case.get("_raw", {}).get("case_id")
  269. if case_id:
  270. existing_cases[case_id] = case
  271. print(f"Found {len(existing_cases)} existing cases, will preserve workflow and capability")
  272. except Exception as e:
  273. print(f"Warning: Failed to read existing case.json: {e}")
  274. # 准备图片目录
  275. images_dir = raw_cases_dir / "images"
  276. images_dir.mkdir(parents=True, exist_ok=True)
  277. # 标准化所有 source items
  278. cases: List[Dict[str, Any]] = []
  279. upload_cache: Dict[str, str] = {}
  280. for idx, source_item in enumerate(sources, 1):
  281. try:
  282. case = await normalize_source_item(
  283. source_item=source_item,
  284. index=idx,
  285. upload_cache=upload_cache,
  286. images_dir=images_dir,
  287. )
  288. # 如果已有该 case,保留其 workflow_groups
  289. case_id = case.get("_raw", {}).get("case_id")
  290. if case_id and case_id in existing_cases:
  291. existing = existing_cases[case_id]
  292. if existing.get("workflow_groups") is not None:
  293. case["workflow_groups"] = existing["workflow_groups"]
  294. print(f" [{idx}] {case['title'][:40]} (preserved workflow_groups)")
  295. else:
  296. print(f" [{idx}] {case['title'][:40]}")
  297. cases.append(case)
  298. except Exception as e:
  299. print(f" [{idx}] ✗ 失败: {e}")
  300. # 输出 case.json
  301. output_data = {
  302. "total": len(cases),
  303. "cases": cases,
  304. }
  305. output_file.parent.mkdir(parents=True, exist_ok=True)
  306. with open(output_file, "w", encoding="utf-8") as f:
  307. json.dump(output_data, f, ensure_ascii=False, indent=2)
  308. return {
  309. "total_cases": len(cases),
  310. "output_file": str(output_file),
  311. }
  312. if __name__ == "__main__":
  313. import sys
  314. if len(sys.argv) < 2:
  315. print("Usage: python generate_case.py <raw_cases_dir>")
  316. sys.exit(1)
  317. raw_cases_dir = Path(sys.argv[1])
  318. stats = asyncio.run(generate_case_from_source(raw_cases_dir))
  319. print(f"\n✓ Generated {stats['total_cases']} cases")
  320. print(f"→ {stats['output_file']}")