DecodeProcessAgent.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 工序拆解 Agent(LangChain)
  5. 输入:单个小红书内容 JSON 文件路径(结构见 aigc_data/化妆师川川/*.json)。
  6. 输出:把还原的「完整工序序列」逐步写入 decode_process/output/<channel_content_id>.json。
  7. 实现结构对齐 aiddit/pattern_global/MergeAgentLangChain.py,但只针对单条输入、
  8. 纯文件持久化、不依赖 BaseMergeAgent / 数据库。
  9. """
  10. import asyncio
  11. import json
  12. import os
  13. from pathlib import Path
  14. from typing import Any, Dict, List, Tuple, Type
  15. from dotenv import load_dotenv
  16. load_dotenv(Path(__file__).resolve().parents[2] / ".env")
  17. from langchain.agents import create_agent
  18. from langchain.chat_models import init_chat_model
  19. from agent_tools import (
  20. think_and_plan,
  21. add_step,
  22. add_step_input,
  23. add_step_output,
  24. update_step,
  25. update_step_input,
  26. update_step_output,
  27. delete_step,
  28. delete_step_input,
  29. delete_step_output,
  30. get_current_workflow,
  31. finalize_workflow,
  32. )
  33. from workflow_store import WorkflowContext
  34. from visualize_workflow import render_html
  35. # ============================================================================
  36. # 模型定价配置(每百万 token 美元价格)
  37. # ============================================================================
  38. MODEL_PRICING = {
  39. "google_genai:gemini-3-flash-preview": {"input": 0.50, "output": 3.00},
  40. }
  41. # ============================================================================
  42. # Token 统计
  43. # ============================================================================
  44. def count_token_usage(result: dict) -> dict:
  45. """从 Agent 执行结果中统计 token 消耗 + Gemini cache 命中诊断。
  46. 同时 dump 每轮的 input_token_details,便于观察 cache_read / cached_content
  47. 等字段(不同 langchain_google_genai 版本字段名不同,dump 整个 dict 最稳)。
  48. """
  49. from langchain_core.messages import AIMessage
  50. total_input = 0
  51. total_output = 0
  52. total_cached = 0
  53. turns = []
  54. for idx, msg in enumerate(result["messages"]):
  55. if isinstance(msg, AIMessage) and getattr(msg, "usage_metadata", None):
  56. um = msg.usage_metadata
  57. it = um.get("input_tokens", 0) or 0
  58. ot = um.get("output_tokens", 0) or 0
  59. details = um.get("input_token_details", {}) or {}
  60. cached = (
  61. details.get("cache_read", 0)
  62. or details.get("cached_content", 0)
  63. or details.get("cached", 0)
  64. or 0
  65. )
  66. total_input += it
  67. total_output += ot
  68. total_cached += cached
  69. turns.append((idx, it, ot, cached, details))
  70. if turns:
  71. print("─" * 80)
  72. print(
  73. f"{'msg_idx':<8}{'input':>10}{'output':>10}{'cached':>10}{'hit_rate':>10} details"
  74. )
  75. print("─" * 80)
  76. for idx, it, ot, cached, details in turns:
  77. hit = (cached / it * 100) if it else 0.0
  78. print(
  79. f"{idx:<8}{it:>10}{ot:>10}{cached:>10}{hit:>9.1f}% {details}"
  80. )
  81. overall_hit = (total_cached / total_input * 100) if total_input else 0.0
  82. print("─" * 80)
  83. print(
  84. f"TOTAL input={total_input} output={total_output} "
  85. f"cached={total_cached} (overall_hit={overall_hit:.1f}%)"
  86. )
  87. print("─" * 80)
  88. else:
  89. print(f"total_input={total_input}, total_output={total_output} (no usage_metadata found)")
  90. return {
  91. "input_tokens": total_input,
  92. "output_tokens": total_output,
  93. "cached_tokens": total_cached,
  94. }
  95. def calculate_cost(model_name: str, input_tokens: int, output_tokens: int) -> float:
  96. pricing = MODEL_PRICING.get(model_name)
  97. if not pricing:
  98. return 0.0
  99. cost = (
  100. input_tokens * pricing["input"] / 1_000_000
  101. + output_tokens * pricing["output"] / 1_000_000
  102. )
  103. return round(cost, 6)
  104. # ============================================================================
  105. # DecodeProcessAgent
  106. # ============================================================================
  107. _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
  108. _PROMPT_FILE = os.path.join(_CURRENT_DIR, "decode_process_prompt.md")
  109. def _get_output_dir() -> str:
  110. """支持环境变量 DECODE_OUTPUT_DIR 覆盖 — 每次 run() 时重新读,方便 adapter 按 case 切换目录。"""
  111. return os.environ.get("DECODE_OUTPUT_DIR") or os.path.join(_CURRENT_DIR, "output")
  112. # 向后兼容(模块级常量仍可被原版 main 引用)
  113. _OUTPUT_DIR = _get_output_dir()
  114. def _load_system_prompt() -> str:
  115. with open(_PROMPT_FILE, "r", encoding="utf-8") as f:
  116. return f.read()
  117. def _build_user_content(title: str, body_text: str, images: List[str]) -> List[Dict[str, Any]]:
  118. """构造多模态 user message 的 content 数组。"""
  119. instruction = (
  120. "请根据下面这条内容的标题、正文与图片,按系统提示中的规则拆解完整工序。\n"
  121. "工作流:think_and_plan -> 多轮 add_step / add_step_input / add_step_output "
  122. "-> finalize_workflow。\n\n"
  123. f"标题: {title}\n"
  124. f"正文:\n{body_text}\n\n"
  125. f"共 {len(images)} 张图片,下面按顺序给出(依次对应 图1、图2 ...)。"
  126. )
  127. content: List[Dict[str, Any]] = [{"type": "text", "text": instruction}]
  128. for url in images:
  129. content.append({"type": "image_url", "image_url": url})
  130. return content
  131. def _transient_error_types() -> Tuple[Type[BaseException], ...]:
  132. """构造瞬时网络错误集合(用于 run_batch 内部 retry 判定)。
  133. 各家 HTTP 库都自定义了一套异常体系,类名常和 builtins 重名但不是同一个类
  134. (比如 `requests.exceptions.ConnectionError` !== `builtins.ConnectionError`)。
  135. 所以这里要把 langchain-google-genai 路径上可能出现的库都各 import 一遍。
  136. 全部用 try-import 容错——某个库没装时不影响其它判定。
  137. """
  138. excs: List[Type[BaseException]] = [
  139. ConnectionError, ConnectionResetError, ConnectionAbortedError, TimeoutError,
  140. ]
  141. try:
  142. import httpx
  143. excs.extend([
  144. httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadError,
  145. httpx.WriteError, httpx.ConnectTimeout, httpx.ReadTimeout,
  146. httpx.NetworkError, httpx.PoolTimeout,
  147. ])
  148. except ImportError:
  149. pass
  150. try:
  151. # google-genai SDK 内部某些路径还在用 requests/urllib3(OAuth 流程、metadata 探测等)
  152. import requests
  153. excs.extend([
  154. requests.exceptions.ConnectionError,
  155. requests.exceptions.ChunkedEncodingError,
  156. requests.exceptions.ReadTimeout,
  157. requests.exceptions.ConnectTimeout,
  158. ])
  159. except ImportError:
  160. pass
  161. try:
  162. import urllib3
  163. excs.extend([
  164. urllib3.exceptions.ProtocolError,
  165. urllib3.exceptions.NewConnectionError,
  166. urllib3.exceptions.ReadTimeoutError,
  167. ])
  168. except ImportError:
  169. pass
  170. try:
  171. from google.api_core import exceptions as gae
  172. excs.extend([gae.ServiceUnavailable, gae.DeadlineExceeded, gae.RetryError])
  173. except ImportError:
  174. pass
  175. return tuple(excs)
  176. class DecodeProcessAgent:
  177. """LangChain 工序拆解 Agent。"""
  178. def __init__(self, model_name: str = "google_genai:gemini-3-flash-preview"):
  179. self.model_name = model_name
  180. async def run_batch(
  181. self,
  182. input_dir: str,
  183. skip_existing: bool = True,
  184. concurrency: int = 1,
  185. max_retries: int = 3,
  186. ) -> Dict[str, Any]:
  187. """批量处理目录下所有 *.json 文件,**顺序执行**(同主进程内)。
  188. 历史背景:旧版用 ProcessPoolExecutor 多进程并发,目的是规避 WorkflowContext
  189. 类级单例同进程并发污染问题。但 ProcessPoolExecutor 在 Windows 上有三大副作用:
  190. ① 子进程 stdout 不进父进程 Tee/run.log;
  191. ② 多并发同时发大请求易触发瞬时网络错(RemoteProtocolError / ConnectionReset);
  192. ③ Ctrl+C 不传给子进程,停不下来。
  193. 现改回主进程顺序执行:agent.run() 末尾会 WorkflowContext.clear(),所以串行多
  194. case 没有污染问题;网络抖动通过 max_retries 重试覆盖;Ctrl+C 走 asyncio
  195. CancelledError 立即生效。
  196. concurrency 参数保留但 > 1 时打 warning 并强制为 1(语义保留以便未来若
  197. WorkflowContext 改成 ContextVar 隔离后能恢复并发)。
  198. Args:
  199. input_dir: 输入目录,每个 .json 文件是一条小红书内容。
  200. skip_existing: 若 output/<input_stem>.json 已存在则跳过。
  201. concurrency: 历史参数;当前实现强制为 1。
  202. max_retries: 单 case 遇到瞬时网络错时的最大重试次数(默认 3)。
  203. Returns:
  204. {"total", "succeeded": [...], "skipped": [...], "failed": [...],
  205. "total_input_tokens", "total_output_tokens", "total_cost_usd"}
  206. """
  207. input_dir_path = Path(input_dir)
  208. if not input_dir_path.is_dir():
  209. raise ValueError(f"输入路径不是目录: {input_dir}")
  210. input_files = sorted(input_dir_path.glob("*.json"))
  211. if not input_files:
  212. raise ValueError(f"目录 {input_dir} 下没有 .json 文件")
  213. pending: List[Path] = []
  214. skipped: List[str] = []
  215. for fp in input_files:
  216. if skip_existing and os.path.exists(
  217. os.path.join(_get_output_dir(), f"{fp.stem}.json")
  218. ):
  219. print(f"⏭ {fp.name}:output 已存在,跳过")
  220. skipped.append(str(fp))
  221. else:
  222. pending.append(fp)
  223. if not pending:
  224. print("\n没有待处理文件。")
  225. return {
  226. "total": len(input_files),
  227. "succeeded": [],
  228. "skipped": skipped,
  229. "failed": [],
  230. "total_input_tokens": 0,
  231. "total_output_tokens": 0,
  232. "total_cost_usd": 0.0,
  233. }
  234. if concurrency != 1:
  235. print(
  236. f"⚠ concurrency={concurrency} 被忽略:WorkflowContext 是类级单例,"
  237. f"同进程并发会互相污染。已强制顺序执行(concurrency=1)。"
  238. )
  239. print(
  240. f"\n待处理: {len(pending)} 个文件 | 并发: 1 (sequential) "
  241. f"| 跳过: {len(skipped)} | retry: {max_retries} 次/case"
  242. )
  243. succeeded: List[Dict[str, Any]] = []
  244. failed: List[Dict[str, Any]] = []
  245. total_in = total_out = 0
  246. total_cost = 0.0
  247. transient_excs = _transient_error_types()
  248. for done_count, fp in enumerate(pending, 1):
  249. last_err: BaseException | None = None
  250. for attempt in range(1, max_retries + 1):
  251. # 跑前主动清一下单例:上一次 agent.run() 末尾应该已经 clear,
  252. # 但若中途异常退出可能没 clear,这里防御性再 clear 一次。
  253. WorkflowContext.clear()
  254. try:
  255. result = await self.run(str(fp))
  256. except asyncio.CancelledError:
  257. # Ctrl+C / 任务取消:立即向上传播,跳出整个 batch
  258. print(f"\n⏸ cancelled by user before completing {fp.name}")
  259. raise
  260. except transient_excs as e:
  261. last_err = e
  262. if attempt < max_retries:
  263. wait = 2 ** attempt # 2s, 4s, 8s ...
  264. print(
  265. f"⚠ [{done_count}/{len(pending)}] {fp.name} attempt {attempt}/{max_retries} "
  266. f"hit transient {type(e).__name__}: {e}; retrying in {wait}s..."
  267. )
  268. try:
  269. await asyncio.sleep(wait)
  270. except asyncio.CancelledError:
  271. print(f"\n⏸ cancelled by user during retry backoff")
  272. raise
  273. continue
  274. # 用完次数还是失败,跳出 retry 循环走 failed 分支
  275. break
  276. except Exception as e:
  277. # 非瞬时错误不重试
  278. last_err = e
  279. break
  280. else:
  281. # 成功
  282. total_in += result["input_tokens"]
  283. total_out += result["output_tokens"]
  284. total_cost += result["cost_usd"]
  285. succeeded.append({
  286. "input": str(fp),
  287. "output_path": result["output_path"],
  288. "html_path": result.get("html_path"),
  289. "input_tokens": result["input_tokens"],
  290. "output_tokens": result["output_tokens"],
  291. "cost_usd": result["cost_usd"],
  292. "step_count": len(result["workflow"]["steps"]),
  293. })
  294. retry_note = f" (attempt {attempt})" if attempt > 1 else ""
  295. print(
  296. f"✅ [{done_count}/{len(pending)}] {fp.name}: "
  297. f"steps={len(result['workflow']['steps'])} "
  298. f"tokens(in={result['input_tokens']}/out={result['output_tokens']}) "
  299. f"cost=${result['cost_usd']}{retry_note}"
  300. )
  301. last_err = None
  302. break
  303. if last_err is not None:
  304. failed.append({"input": str(fp), "error": f"{type(last_err).__name__}: {last_err}"})
  305. print(
  306. f"❌ [{done_count}/{len(pending)}] {fp.name} 失败: "
  307. f"{type(last_err).__name__}: {last_err}"
  308. )
  309. summary = {
  310. "total": len(input_files),
  311. "succeeded": succeeded,
  312. "skipped": skipped,
  313. "failed": failed,
  314. "total_input_tokens": total_in,
  315. "total_output_tokens": total_out,
  316. "total_cost_usd": round(total_cost, 6),
  317. }
  318. print(f"\n========== 批量运行汇总 ==========")
  319. print(
  320. f"总计: {summary['total']} | 成功: {len(succeeded)} "
  321. f"| 跳过: {len(skipped)} | 失败: {len(failed)}"
  322. )
  323. print(
  324. f"总 tokens: in={total_in}, out={total_out}, "
  325. f"cost=${summary['total_cost_usd']}"
  326. )
  327. if failed:
  328. print("失败详情:")
  329. for item in failed:
  330. print(f" - {item['input']}: {item['error']}")
  331. return summary
  332. async def run(self, input_json_path: str) -> Dict[str, Any]:
  333. with open(input_json_path, "r", encoding="utf-8") as f:
  334. payload = json.load(f)
  335. channel_content_id = payload["channel_content_id"]
  336. title = payload.get("title", "")
  337. body_text = payload.get("body_text", "")
  338. images = payload.get("images", []) or []
  339. if not images:
  340. raise ValueError(f"输入 {input_json_path} 没有 images,无法做多模态工序拆解")
  341. # source 里只存"图片占位",不存 base64 原文 — 避免 decode 输出文件膨胀到 MB 级
  342. source_images = []
  343. for i, img in enumerate(images):
  344. if isinstance(img, str) and img.startswith("data:"):
  345. source_images.append(f"<image_{i + 1} (base64, {len(img) // 1024}KB)>")
  346. else:
  347. source_images.append(img)
  348. input_stem = Path(input_json_path).stem
  349. output_path = os.path.join(_get_output_dir(), f"{input_stem}.json")
  350. WorkflowContext.init(
  351. output_path=output_path,
  352. source_meta={
  353. "channel_content_id": channel_content_id,
  354. "title": title,
  355. "body_text": body_text,
  356. "images": source_images, # 占位,原 base64 留在 images 局部变量里给 LangChain 用
  357. },
  358. )
  359. system_prompt = _load_system_prompt()
  360. user_content = _build_user_content(title, body_text, images)
  361. model = init_chat_model(self.model_name)
  362. tools = [
  363. think_and_plan,
  364. add_step,
  365. add_step_input,
  366. add_step_output,
  367. update_step,
  368. update_step_input,
  369. update_step_output,
  370. delete_step,
  371. delete_step_input,
  372. delete_step_output,
  373. get_current_workflow,
  374. finalize_workflow,
  375. ]
  376. agent = create_agent(model=model, tools=tools, system_prompt=system_prompt)
  377. result = await asyncio.to_thread(
  378. agent.invoke,
  379. {"messages": [{"role": "user", "content": user_content}]},
  380. )
  381. usage = count_token_usage(result)
  382. cost = calculate_cost(self.model_name, usage["input_tokens"], usage["output_tokens"])
  383. final_workflow = WorkflowContext.get()
  384. WorkflowContext.clear()
  385. html_path = os.path.splitext(output_path)[0] + ".html"
  386. try:
  387. with open(html_path, "w", encoding="utf-8") as f:
  388. f.write(render_html(final_workflow))
  389. print(f"[visualize] HTML 已生成 -> {html_path}")
  390. except Exception as e:
  391. html_path = None
  392. print(f"[visualize] HTML 生成失败(不影响工序结果): {type(e).__name__}: {e}")
  393. return {
  394. "output_path": output_path,
  395. "html_path": html_path,
  396. "input_tokens": usage["input_tokens"],
  397. "output_tokens": usage["output_tokens"],
  398. "cost_usd": cost,
  399. "workflow": final_workflow,
  400. }
  401. # ============================================================================
  402. # 主程序
  403. # ============================================================================
  404. if __name__ == "__main__":
  405. import argparse
  406. import logging
  407. import sys
  408. for _stream in (sys.stdout, sys.stderr):
  409. if hasattr(_stream, "reconfigure"):
  410. _stream.reconfigure(encoding="utf-8", errors="replace")
  411. logging.getLogger("langchain").setLevel(logging.INFO)
  412. DEFAULT_INPUT = os.path.join(_CURRENT_DIR, "input")
  413. DEFAULT_MODEL = "google_genai:gemini-3-flash-preview"
  414. parser = argparse.ArgumentParser(
  415. description="工序拆解 Agent:支持单文件或目录批量处理",
  416. )
  417. parser.add_argument(
  418. "--input",
  419. "-i",
  420. default=DEFAULT_INPUT,
  421. help=f"输入路径,可以是单个 .json 文件或包含多个 .json 的目录(默认: {DEFAULT_INPUT})",
  422. )
  423. parser.add_argument(
  424. "--model",
  425. "-m",
  426. default=DEFAULT_MODEL,
  427. help=f"模型名(默认: {DEFAULT_MODEL})",
  428. )
  429. parser.add_argument(
  430. "--no-skip-existing",
  431. action="store_true",
  432. help="批量模式下,即使 output 已存在也重新处理(默认会跳过已有的)",
  433. )
  434. parser.add_argument(
  435. "--concurrency",
  436. "-c",
  437. type=int,
  438. default=3,
  439. help="批量模式的并发子进程数(默认: 3)。每个子进程独立跑一条 case,互不干扰",
  440. )
  441. args = parser.parse_args()
  442. target = args.input
  443. agent = DecodeProcessAgent(model_name=args.model)
  444. if os.path.isdir(target):
  445. asyncio.run(
  446. agent.run_batch(
  447. target,
  448. skip_existing=not args.no_skip_existing,
  449. concurrency=args.concurrency,
  450. )
  451. )
  452. elif os.path.isfile(target):
  453. result = asyncio.run(agent.run(target))
  454. print("\n===== 运行完成 =====")
  455. print(f"输出文件: {result['output_path']}")
  456. if result.get("html_path"):
  457. print(f"HTML 可视化: {result['html_path']}")
  458. print(f"步骤数: {len(result['workflow']['steps'])}")
  459. print(f"status: {result['workflow']['status']}")
  460. print(
  461. f"tokens: in={result['input_tokens']} out={result['output_tokens']} "
  462. f"cost=${result['cost_usd']}"
  463. )
  464. else:
  465. sys.exit(f"输入路径不存在: {target}")