extract_decode_workflow.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. """
  2. Decode-Workflow Step(旁路):用 DecodeProcessAgent (LangChain + Gemini) 提取工序
  3. 跟现有 `extract_workflow.py` 并存:
  4. - 不替换 case.json 里的 workflow_groups(那个仍由 extract_workflow.py 产出)
  5. - 输出落到 output/<NNN>/decode_workflows/<case_id>.json + <case_id>.html
  6. - 旁路存储 — 让你能对比 DecodeProcessAgent vs 当前 extract_workflow 的产出质量
  7. 使用方式(通过 run_pipeline.py):
  8. python run_pipeline.py --index 107 --only-step decode-workflow
  9. 依赖(首次用需要装):
  10. pip install langchain langchain-google-genai
  11. .env 中加 GOOGLE_API_KEY=...
  12. """
  13. import asyncio
  14. import json
  15. import os
  16. import sys
  17. from pathlib import Path
  18. from typing import Any, Dict, List, Optional
  19. # 默认模型(DecodeProcessAgent 的初始配置)
  20. DEFAULT_DECODE_MODEL = os.getenv("DECODE_WORKFLOW_MODEL", "google_genai:gemini-3-flash-preview")
  21. # 并发 case 数(注意:新版 DecodeProcessAgent.run_batch 已改为同进程顺序执行,
  22. # 值 >1 会被忽略并打 warning。env var 保留为了向后兼容,默认值改为 1。
  23. # 想恢复真并发需先把 WorkflowContext 改成 ContextVar 隔离)
  24. DEFAULT_DECODE_CONCURRENCY = int(os.getenv("DECODE_WORKFLOW_CONCURRENCY", "1"))
  25. # 是否把 decode 输出合并回 case.json 的 workflow_groups[0].workflow.steps(默认开 — 用户的要求)
  26. # 0 时只走旁路,不动 case.json
  27. DEFAULT_MERGE_TO_CASE = os.getenv("DECODE_MERGE_TO_CASE", "1") not in ("0", "false", "False", "")
  28. # 每个 case 最多传给 agent 的图片张数 — 只对 base64 模式生效(防 Gemini 1M context 爆炸);
  29. # URL 模式不限(LangChain → Google File API 走 image tokens 不会爆)
  30. DEFAULT_MAX_IMAGES_PER_CASE = int(os.getenv("DECODE_MAX_IMAGES_PER_CASE", "50"))
  31. def _resolve_local_images(case: Dict[str, Any], case_file: Path, max_images: int = DEFAULT_MAX_IMAGES_PER_CASE) -> Optional[List[str]]:
  32. """
  33. 可选:用本地图片(raw_cases/images/<case_id>/)转 base64 data URI 替代远端 URL。
  34. 通过环境变量 DECODE_USE_LOCAL_IMAGES=1 opt-in 启用(默认走 URL,因为 base64 会让
  35. agent 每轮 invoke 重传 user message 累积 → 撞 Gemini 1M context)。
  36. 返回 None 表示本地没有图片目录或目录为空,调用方应 fallback 到 case.images URL 列表。
  37. """
  38. raw = case.get("_raw") or {}
  39. case_id = raw.get("case_id") or raw.get("channel_content_id")
  40. if not case_id:
  41. return None
  42. images_dir = case_file.parent / "raw_cases" / "images" / case_id
  43. if not images_dir.is_dir():
  44. return None
  45. img_exts = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
  46. all_img_files = sorted(p for p in images_dir.iterdir() if p.suffix.lower() in img_exts)
  47. if not all_img_files:
  48. return None
  49. truncated = len(all_img_files) > max_images
  50. img_files = all_img_files[:max_images]
  51. import base64
  52. import mimetypes
  53. data_uris: List[str] = []
  54. for p in img_files:
  55. try:
  56. mime = mimetypes.guess_type(p.name)[0] or "image/jpeg"
  57. b64 = base64.b64encode(p.read_bytes()).decode("ascii")
  58. data_uris.append(f"data:{mime};base64,{b64}")
  59. except Exception as e:
  60. print(f" [decode-workflow] WARN: failed to read local image {p.name}: {e}", flush=True)
  61. if truncated:
  62. print(
  63. f" [decode-workflow] truncated images: kept first {len(img_files)}/{len(all_img_files)} "
  64. f"(防 Gemini 1M context 爆炸;通过环境变量 DECODE_MAX_IMAGES_PER_CASE 调整)",
  65. flush=True,
  66. )
  67. return data_uris or None
  68. def _build_decode_input(case: Dict[str, Any], case_file: Optional[Path] = None) -> Dict[str, Any]:
  69. """
  70. 从 case_item 构造 DecodeProcessAgent 期望的输入 JSON。
  71. images 字段**默认用 case.images 里的远端 URL** —— LangChain 收到 URL 后会
  72. download → bytes → 上传给 Gemini File API,按 image tokens 算(一张 ~258 tokens),
  73. 不会撞 1M context。
  74. **本地 base64 模式是 opt-in**(DECODE_USE_LOCAL_IMAGES=1),因为 base64 data URI 在
  75. LangChain 内部可能当 text 塞进 history,每轮 invoke 累积会撞 Gemini 1M context(实测
  76. 14 张图 9 个 step 必爆)。
  77. """
  78. raw = case.get("_raw") or {}
  79. images = case.get("images") or []
  80. use_local = os.getenv("DECODE_USE_LOCAL_IMAGES", "0") not in ("0", "false", "False", "")
  81. if use_local and case_file is not None:
  82. local = _resolve_local_images(case, case_file)
  83. if local:
  84. print(
  85. f" [decode-workflow] case index={case.get('index')}: "
  86. f"using {len(local)} local images (base64, opt-in via DECODE_USE_LOCAL_IMAGES)",
  87. flush=True,
  88. )
  89. images = local
  90. return {
  91. "channel_content_id": raw.get("channel_content_id") or raw.get("case_id") or f"case_{case.get('index', 0)}",
  92. "title": case.get("title") or "",
  93. "body_text": case.get("body") or case.get("body_text") or "",
  94. "images": images,
  95. }
  96. def _ensure_decode_agent_importable() -> Path:
  97. """确保 decode_workflow_agent 目录在 sys.path 内(它的 import 是相对的)。"""
  98. decode_dir = Path(__file__).resolve().parent / "decode_workflow_agent"
  99. if not decode_dir.exists():
  100. raise FileNotFoundError(f"decode_workflow_agent 目录不存在: {decode_dir}")
  101. if str(decode_dir) not in sys.path:
  102. sys.path.insert(0, str(decode_dir))
  103. return decode_dir
  104. def _setup_gemini_env() -> None:
  105. """
  106. 兼容多种 Gemini API key 命名,把它们 alias 到 langchain-google-genai 期望的 GOOGLE_API_KEY。
  107. langchain-google-genai 默认只读 GOOGLE_API_KEY;用户 .env 里可能是 GEMINI_API_KEY 等命名。
  108. 同时检测 GEMINI_API_BASE — 若设了,警告"langchain-google-genai 不支持 base URL 覆盖"。
  109. """
  110. if not os.environ.get("GOOGLE_API_KEY"):
  111. for alt in ("GEMINI_API_KEY", "gemini_api_key", "GOOGLE_GENAI_API_KEY"):
  112. val = os.environ.get(alt)
  113. if val:
  114. os.environ["GOOGLE_API_KEY"] = val
  115. print(f"[decode-workflow] aliased {alt} → GOOGLE_API_KEY", flush=True)
  116. break
  117. else:
  118. print(
  119. "[decode-workflow] ⚠ no Gemini API key found in env "
  120. "(tried GOOGLE_API_KEY / GEMINI_API_KEY / GOOGLE_GENAI_API_KEY). "
  121. "Gemini calls will likely fail. Set GOOGLE_API_KEY=<your_key> in .env",
  122. flush=True,
  123. )
  124. if os.environ.get("GEMINI_API_BASE"):
  125. print(
  126. "[decode-workflow] ⚠⚠⚠ GEMINI_API_BASE is set — but langchain-google-genai uses\n"
  127. " Google's official SDK and does NOT honor base URL overrides; all calls go to\n"
  128. " generativelanguage.googleapis.com regardless.\n"
  129. " If your GEMINI_API_KEY is for a 3rd-party proxy (OneAPI / yescode / OpenRouter):\n"
  130. " → DecodeProcessAgent will get 401/404 from Google's real endpoint.\n"
  131. " → Consider using OpenRouter Gemini via the main pipeline's --model gemini instead.\n"
  132. " If your key is from Google AI Studio (https://aistudio.google.com): you can ignore this.",
  133. flush=True,
  134. )
  135. def _merge_decode_into_case(
  136. case_file: Path,
  137. target_case_index: int,
  138. decode_workflow: Dict[str, Any],
  139. ) -> bool:
  140. """
  141. 把 decode 输出原样存到 case.cases[i].decode_workflow 顶层字段(跟 workflow_groups 平级)。
  142. 设计选择:不动 workflow_groups / capability,让旧字段保留作对照;decode 输出作为新的
  143. 独立字段存在,下游需要的话单独读 case.decode_workflow.steps。
  144. Returns:
  145. True 成功合并,False 没找到目标 case(不报错,只警告)
  146. """
  147. with open(case_file, "r", encoding="utf-8") as f:
  148. case_data = json.load(f)
  149. target = None
  150. for case in case_data.get("cases", []):
  151. if case.get("index") == target_case_index:
  152. target = case
  153. break
  154. if not target:
  155. print(
  156. f" [decode-workflow] merge SKIPPED: case index={target_case_index} not in case.json",
  157. flush=True,
  158. )
  159. return False
  160. # 只塞 steps + 元信息(不塞 source — source.images 可能含 base64,几 MB 级,污染 case.json)
  161. target["decode_workflow"] = {
  162. "channel_content_id": (decode_workflow or {}).get("channel_content_id"),
  163. "steps": (decode_workflow or {}).get("steps", []) or [],
  164. "summary": (decode_workflow or {}).get("summary", ""),
  165. "status": (decode_workflow or {}).get("status", ""),
  166. }
  167. with open(case_file, "w", encoding="utf-8") as f:
  168. json.dump(case_data, f, ensure_ascii=False, indent=2)
  169. step_count = len(target["decode_workflow"]["steps"])
  170. status = target["decode_workflow"]["status"]
  171. print(
  172. f" [decode-workflow] merged into case.json: "
  173. f"case index={target_case_index} → cases[].decode_workflow "
  174. f"({step_count} steps, status={status!r}, source stripped)",
  175. flush=True,
  176. )
  177. return True
  178. # 跨 case 共享的 case.json 写锁(防多进程/并发场景下的损坏)
  179. _case_write_lock = asyncio.Lock()
  180. async def _merge_decode_into_case_async(case_file: Path, target_case_index: int, decode_workflow: Dict[str, Any]) -> bool:
  181. """异步包装 _merge_decode_into_case,串行化对 case.json 的写入。"""
  182. async with _case_write_lock:
  183. return _merge_decode_into_case(case_file, target_case_index, decode_workflow)
  184. async def extract_decode_workflow(
  185. case_file: Path,
  186. llm_call: Any = None, # 兼容旧签名,不使用(DecodeProcessAgent 走自己的 LangChain 链路)
  187. model: str = DEFAULT_DECODE_MODEL,
  188. max_concurrent: int = DEFAULT_DECODE_CONCURRENCY,
  189. case_index: Optional[int] = None, # 单 case 模式:只跑 case.index==case_index 的那一个
  190. merge_to_case: Optional[bool] = None, # 是否把输出合并回 case.json;None 时读 env DECODE_MERGE_TO_CASE
  191. skip_existing: bool = False, # True = decode_workflows/<case_<idx>>.json 已存在则跳过;默认全覆盖
  192. ) -> Dict[str, Any]:
  193. """
  194. 顶层入口:对 case.json 里的每个 case 跑 DecodeProcessAgent,输出到 case_dir/decode_workflows/
  195. case_index 传值时:只处理那一个 case;不传则处理全部。
  196. merge_to_case=True 时:跑完每个 case 后把 decode 输出的 workflow.steps 覆盖回
  197. case.json[cases[i]].workflow_groups[0].workflow.steps(不管引用一致性)。
  198. Returns:
  199. 统计字典:total, succeeded, skipped, failed, merged, total_cost
  200. """
  201. case_file = Path(case_file)
  202. if not case_file.exists():
  203. raise FileNotFoundError(f"case.json 不存在: {case_file}")
  204. if merge_to_case is None:
  205. merge_to_case = DEFAULT_MERGE_TO_CASE
  206. print(f"[decode-workflow] merge_to_case = {merge_to_case}", flush=True)
  207. # 跑前快照:如果要合并回 case.json,先做一份历史备份(万一不满意能回滚)
  208. if merge_to_case:
  209. try:
  210. from examples.process_pipeline.script.case_history import snapshot_case_file
  211. snap = snapshot_case_file(case_file, step="decode_workflow_merge")
  212. if snap:
  213. print(f"[decode-workflow] snapshot saved → {snap.name}", flush=True)
  214. except Exception as e:
  215. print(f"[decode-workflow] snapshot SKIPPED: {type(e).__name__}: {e}", flush=True)
  216. # 输出目录:output/<NNN>/decode_workflows/
  217. output_dir = case_file.parent / "decode_workflows"
  218. output_dir.mkdir(parents=True, exist_ok=True)
  219. inputs_dir = output_dir / "_inputs" # 保存喂给 agent 的 input 文件(audit + skip 检查靠这个)
  220. inputs_dir.mkdir(parents=True, exist_ok=True)
  221. # 给 DecodeProcessAgent 设环境变量 → 它会把每个输出落到这里
  222. os.environ["DECODE_OUTPUT_DIR"] = str(output_dir.resolve())
  223. print(f"[decode-workflow] output_dir = {output_dir}", flush=True)
  224. # 兼容 .env 里的 GEMINI_API_KEY 等命名 + 检测 GEMINI_API_BASE 用代理的情况
  225. _setup_gemini_env()
  226. # Lazy import — 不用这个 step 时不会 import langchain
  227. _ensure_decode_agent_importable()
  228. try:
  229. from DecodeProcessAgent import DecodeProcessAgent # noqa: I001
  230. except ImportError as e:
  231. raise RuntimeError(
  232. f"无法导入 DecodeProcessAgent。请确保安装了依赖:\n"
  233. f" pip install langchain langchain-google-genai\n"
  234. f"以及 .env 中配置 GOOGLE_API_KEY=...\n"
  235. f"原始错误: {type(e).__name__}: {e}"
  236. ) from e
  237. # 读 case.json
  238. with open(case_file, "r", encoding="utf-8") as f:
  239. case_data = json.load(f)
  240. cases = case_data.get("cases") or []
  241. if not cases:
  242. print("[decode-workflow] no cases to process", flush=True)
  243. return {"total": 0, "succeeded": 0, "skipped": 0, "failed": 0, "total_cost": 0.0}
  244. # 单 case 模式过滤
  245. if case_index is not None:
  246. target = [c for c in cases if c.get("index") == case_index]
  247. if not target:
  248. available = [c.get("index") for c in cases]
  249. raise ValueError(
  250. f"case index={case_index} not found in {case_file}. Available indices: {available}"
  251. )
  252. print(f"[decode-workflow] filtered to single case index={case_index}", flush=True)
  253. cases = target
  254. # 把每个 case 写成 DecodeProcessAgent 的 input 文件(落在 _inputs/ 下)
  255. # 文件名用 case_<index>.json — DecodeProcessAgent 用 input_stem 命名 output,
  256. # 所以最终 decode_workflows/<case_<idx>>.json + .html 也跟着用这个命名。
  257. print(f"[decode-workflow] preparing {len(cases)} input files in {inputs_dir}", flush=True)
  258. prepared_paths: List[tuple] = [] # [(input_path, case_index), ...]
  259. for case in cases:
  260. decode_input = _build_decode_input(case, case_file=case_file)
  261. if not decode_input["images"]:
  262. # 没图片就跳过 — DecodeProcessAgent 会 raise
  263. print(f" [decode-workflow] skip case index={case.get('index')}: no images", flush=True)
  264. continue
  265. idx = case.get("index", 0)
  266. input_path = inputs_dir / f"case_{idx}.json"
  267. input_path.write_text(json.dumps(decode_input, ensure_ascii=False, indent=2), encoding="utf-8")
  268. prepared_paths.append((input_path, idx))
  269. if not prepared_paths:
  270. print("[decode-workflow] no eligible cases (all skipped — no images)", flush=True)
  271. return {"total": len(cases), "succeeded": 0, "skipped": len(cases), "failed": 0, "total_cost": 0.0}
  272. agent = DecodeProcessAgent(model_name=model)
  273. merged_count = 0
  274. # 单 case 模式:直接 agent.run(),不走 run_batch(避免跑 _inputs/ 里的残留旧文件)
  275. if case_index is not None or len(prepared_paths) == 1:
  276. target_path, target_idx = prepared_paths[0]
  277. print(f"[decode-workflow] single-case mode: run {target_path.name}", flush=True)
  278. try:
  279. result = await agent.run(str(target_path))
  280. # 合并回 case.json
  281. if merge_to_case and result.get("workflow"):
  282. try:
  283. ok = await _merge_decode_into_case_async(case_file, target_idx, result["workflow"])
  284. if ok:
  285. merged_count = 1
  286. except Exception as e:
  287. print(f" [decode-workflow] merge FAILED: {type(e).__name__}: {e}", flush=True)
  288. return {
  289. "total": 1,
  290. "succeeded": 1,
  291. "skipped": 0,
  292. "failed": 0,
  293. "merged": merged_count,
  294. "total_input_tokens": result.get("input_tokens", 0),
  295. "total_output_tokens": result.get("output_tokens", 0),
  296. "total_cost": result.get("cost_usd", 0.0),
  297. "output_dir": str(output_dir),
  298. }
  299. except Exception as e:
  300. print(f" [decode-workflow] FAILED: {type(e).__name__}: {e}", flush=True)
  301. # Fallback:agent 异常但 WorkflowContext 可能已经增量写过 step 到磁盘 —
  302. # 尝试读磁盘上的 case_<idx>.json 做 partial merge(即使 status='in_progress' 也合并)
  303. partial_merged = 0
  304. if merge_to_case:
  305. disk_path = output_dir / f"case_{target_idx}.json"
  306. if disk_path.exists():
  307. try:
  308. with open(disk_path, "r", encoding="utf-8") as f:
  309. partial_workflow = json.load(f)
  310. ok = await _merge_decode_into_case_async(case_file, target_idx, partial_workflow)
  311. if ok:
  312. partial_merged = 1
  313. print(
  314. f" [decode-workflow] ⚠ partial merge: case index={target_idx} status="
  315. f"{partial_workflow.get('status')!r} (agent 异常前的部分进度已存)",
  316. flush=True,
  317. )
  318. except Exception as me:
  319. print(
  320. f" [decode-workflow] partial merge attempt FAILED: {type(me).__name__}: {me}",
  321. flush=True,
  322. )
  323. return {
  324. "total": 1, "succeeded": 0, "skipped": 0, "failed": 1, "merged": partial_merged,
  325. "total_input_tokens": 0, "total_output_tokens": 0,
  326. "total_cost": 0.0, "output_dir": str(output_dir),
  327. "error": f"{type(e).__name__}: {e}",
  328. }
  329. # 全量模式:run_batch 处理整个 _inputs/ 目录(同主进程顺序执行 + 自带 retry + skip_existing)
  330. # 加 try/finally:即使 run_batch 中途因网络问题/SDK bug 崩了,已生成的旁路文件也要 merge
  331. summary = {"succeeded": [], "skipped": [], "failed": [], "total_input_tokens": 0, "total_output_tokens": 0, "total_cost_usd": 0.0}
  332. run_batch_error: Optional[Exception] = None
  333. try:
  334. summary = await agent.run_batch(
  335. input_dir=str(inputs_dir),
  336. skip_existing=skip_existing,
  337. concurrency=max_concurrent,
  338. )
  339. except Exception as e:
  340. run_batch_error = e
  341. print(
  342. f"\n[decode-workflow] ⚠ run_batch 中途异常: {type(e).__name__}: {e}\n"
  343. f" → 仍会尝试 merge 已生成的旁路文件到 case.json",
  344. flush=True,
  345. )
  346. # 合并每个成功 case 的输出回 case.json — run_batch 返回的 summary 不含 workflow 数据,
  347. # 要从磁盘上的 output 文件读。这里**不依赖 summary**,直接扫已存在的 case_<idx>.json,
  348. # 这样 run_batch 中途崩了也能救回已完成的部分。
  349. if merge_to_case:
  350. for path, idx in prepared_paths:
  351. decode_out_path = output_dir / f"case_{idx}.json"
  352. if not decode_out_path.exists():
  353. continue # 该 case 跑失败了或还没到
  354. try:
  355. with open(decode_out_path, "r", encoding="utf-8") as f:
  356. decode_workflow = json.load(f)
  357. ok = await _merge_decode_into_case_async(case_file, idx, decode_workflow)
  358. if ok:
  359. merged_count += 1
  360. except Exception as e:
  361. print(
  362. f" [decode-workflow] merge case index={idx} FAILED: {type(e).__name__}: {e}",
  363. flush=True,
  364. )
  365. result = {
  366. "total": len(cases),
  367. "succeeded": len(summary.get("succeeded", [])),
  368. "skipped": len(summary.get("skipped", [])),
  369. "failed": len(summary.get("failed", [])),
  370. "merged": merged_count,
  371. "total_input_tokens": summary.get("total_input_tokens", 0),
  372. "total_output_tokens": summary.get("total_output_tokens", 0),
  373. "total_cost": summary.get("total_cost_usd", 0.0),
  374. "output_dir": str(output_dir),
  375. }
  376. if run_batch_error is not None:
  377. result["run_batch_error"] = f"{type(run_batch_error).__name__}: {run_batch_error}"
  378. return result
  379. if __name__ == "__main__":
  380. import argparse
  381. p = argparse.ArgumentParser(description="DecodeProcessAgent 旁路 workflow 提取")
  382. p.add_argument("--case-file", required=True, help="case.json 路径")
  383. p.add_argument("--model", default=DEFAULT_DECODE_MODEL)
  384. p.add_argument("--concurrency", type=int, default=DEFAULT_DECODE_CONCURRENCY)
  385. args = p.parse_args()
  386. result = asyncio.run(extract_decode_workflow(
  387. case_file=Path(args.case_file),
  388. model=args.model,
  389. max_concurrent=args.concurrency,
  390. ))
  391. print(json.dumps(result, ensure_ascii=False, indent=2))