#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 工序拆解 Agent(LangChain) 输入:单个小红书内容 JSON 文件路径(结构见 aigc_data/化妆师川川/*.json)。 输出:把还原的「完整工序序列」逐步写入 decode_process/output/.json。 实现结构对齐 aiddit/pattern_global/MergeAgentLangChain.py,但只针对单条输入、 纯文件持久化、不依赖 BaseMergeAgent / 数据库。 """ import asyncio import json import os from pathlib import Path from typing import Any, Dict, List, Tuple, Type from dotenv import load_dotenv load_dotenv(Path(__file__).resolve().parents[2] / ".env") from langchain.agents import create_agent from langchain.chat_models import init_chat_model from agent_tools import ( think_and_plan, add_step, add_step_input, add_step_output, update_step, update_step_input, update_step_output, delete_step, delete_step_input, delete_step_output, get_current_workflow, finalize_workflow, ) from workflow_store import WorkflowContext from visualize_workflow import render_html # ============================================================================ # 模型定价配置(每百万 token 美元价格) # ============================================================================ MODEL_PRICING = { "google_genai:gemini-3-flash-preview": {"input": 0.50, "output": 3.00}, } # ============================================================================ # Token 统计 # ============================================================================ def count_token_usage(result: dict) -> dict: """从 Agent 执行结果中统计 token 消耗 + Gemini cache 命中诊断。 同时 dump 每轮的 input_token_details,便于观察 cache_read / cached_content 等字段(不同 langchain_google_genai 版本字段名不同,dump 整个 dict 最稳)。 """ from langchain_core.messages import AIMessage total_input = 0 total_output = 0 total_cached = 0 turns = [] for idx, msg in enumerate(result["messages"]): if isinstance(msg, AIMessage) and getattr(msg, "usage_metadata", None): um = msg.usage_metadata it = um.get("input_tokens", 0) or 0 ot = um.get("output_tokens", 0) or 0 details = um.get("input_token_details", {}) or {} cached = ( details.get("cache_read", 0) or details.get("cached_content", 0) or details.get("cached", 0) or 0 ) total_input += it total_output += ot total_cached += cached turns.append((idx, it, ot, cached, details)) if turns: print("─" * 80) print( f"{'msg_idx':<8}{'input':>10}{'output':>10}{'cached':>10}{'hit_rate':>10} details" ) print("─" * 80) for idx, it, ot, cached, details in turns: hit = (cached / it * 100) if it else 0.0 print( f"{idx:<8}{it:>10}{ot:>10}{cached:>10}{hit:>9.1f}% {details}" ) overall_hit = (total_cached / total_input * 100) if total_input else 0.0 print("─" * 80) print( f"TOTAL input={total_input} output={total_output} " f"cached={total_cached} (overall_hit={overall_hit:.1f}%)" ) print("─" * 80) else: print(f"total_input={total_input}, total_output={total_output} (no usage_metadata found)") return { "input_tokens": total_input, "output_tokens": total_output, "cached_tokens": total_cached, } def calculate_cost(model_name: str, input_tokens: int, output_tokens: int) -> float: pricing = MODEL_PRICING.get(model_name) if not pricing: return 0.0 cost = ( input_tokens * pricing["input"] / 1_000_000 + output_tokens * pricing["output"] / 1_000_000 ) return round(cost, 6) # ============================================================================ # DecodeProcessAgent # ============================================================================ _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) _PROMPT_FILE = os.path.join(_CURRENT_DIR, "decode_process_prompt.md") def _get_output_dir() -> str: """支持环境变量 DECODE_OUTPUT_DIR 覆盖 — 每次 run() 时重新读,方便 adapter 按 case 切换目录。""" return os.environ.get("DECODE_OUTPUT_DIR") or os.path.join(_CURRENT_DIR, "output") # 向后兼容(模块级常量仍可被原版 main 引用) _OUTPUT_DIR = _get_output_dir() def _load_system_prompt() -> str: with open(_PROMPT_FILE, "r", encoding="utf-8") as f: return f.read() def _build_user_content(title: str, body_text: str, images: List[str]) -> List[Dict[str, Any]]: """构造多模态 user message 的 content 数组。""" instruction = ( "请根据下面这条内容的标题、正文与图片,按系统提示中的规则拆解完整工序。\n" "工作流:think_and_plan -> 多轮 add_step / add_step_input / add_step_output " "-> finalize_workflow。\n\n" f"标题: {title}\n" f"正文:\n{body_text}\n\n" f"共 {len(images)} 张图片,下面按顺序给出(依次对应 图1、图2 ...)。" ) content: List[Dict[str, Any]] = [{"type": "text", "text": instruction}] for url in images: content.append({"type": "image_url", "image_url": url}) return content def _transient_error_types() -> Tuple[Type[BaseException], ...]: """构造瞬时网络错误集合(用于 run_batch 内部 retry 判定)。 各家 HTTP 库都自定义了一套异常体系,类名常和 builtins 重名但不是同一个类 (比如 `requests.exceptions.ConnectionError` !== `builtins.ConnectionError`)。 所以这里要把 langchain-google-genai 路径上可能出现的库都各 import 一遍。 全部用 try-import 容错——某个库没装时不影响其它判定。 """ excs: List[Type[BaseException]] = [ ConnectionError, ConnectionResetError, ConnectionAbortedError, TimeoutError, ] try: import httpx excs.extend([ httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadError, httpx.WriteError, httpx.ConnectTimeout, httpx.ReadTimeout, httpx.NetworkError, httpx.PoolTimeout, ]) except ImportError: pass try: # google-genai SDK 内部某些路径还在用 requests/urllib3(OAuth 流程、metadata 探测等) import requests excs.extend([ requests.exceptions.ConnectionError, requests.exceptions.ChunkedEncodingError, requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout, ]) except ImportError: pass try: import urllib3 excs.extend([ urllib3.exceptions.ProtocolError, urllib3.exceptions.NewConnectionError, urllib3.exceptions.ReadTimeoutError, ]) except ImportError: pass try: from google.api_core import exceptions as gae excs.extend([gae.ServiceUnavailable, gae.DeadlineExceeded, gae.RetryError]) except ImportError: pass return tuple(excs) class DecodeProcessAgent: """LangChain 工序拆解 Agent。""" def __init__(self, model_name: str = "google_genai:gemini-3-flash-preview"): self.model_name = model_name async def run_batch( self, input_dir: str, skip_existing: bool = True, concurrency: int = 1, max_retries: int = 3, ) -> Dict[str, Any]: """批量处理目录下所有 *.json 文件,**顺序执行**(同主进程内)。 历史背景:旧版用 ProcessPoolExecutor 多进程并发,目的是规避 WorkflowContext 类级单例同进程并发污染问题。但 ProcessPoolExecutor 在 Windows 上有三大副作用: ① 子进程 stdout 不进父进程 Tee/run.log; ② 多并发同时发大请求易触发瞬时网络错(RemoteProtocolError / ConnectionReset); ③ Ctrl+C 不传给子进程,停不下来。 现改回主进程顺序执行:agent.run() 末尾会 WorkflowContext.clear(),所以串行多 case 没有污染问题;网络抖动通过 max_retries 重试覆盖;Ctrl+C 走 asyncio CancelledError 立即生效。 concurrency 参数保留但 > 1 时打 warning 并强制为 1(语义保留以便未来若 WorkflowContext 改成 ContextVar 隔离后能恢复并发)。 Args: input_dir: 输入目录,每个 .json 文件是一条小红书内容。 skip_existing: 若 output/.json 已存在则跳过。 concurrency: 历史参数;当前实现强制为 1。 max_retries: 单 case 遇到瞬时网络错时的最大重试次数(默认 3)。 Returns: {"total", "succeeded": [...], "skipped": [...], "failed": [...], "total_input_tokens", "total_output_tokens", "total_cost_usd"} """ input_dir_path = Path(input_dir) if not input_dir_path.is_dir(): raise ValueError(f"输入路径不是目录: {input_dir}") input_files = sorted(input_dir_path.glob("*.json")) if not input_files: raise ValueError(f"目录 {input_dir} 下没有 .json 文件") pending: List[Path] = [] skipped: List[str] = [] for fp in input_files: if skip_existing and os.path.exists( os.path.join(_get_output_dir(), f"{fp.stem}.json") ): print(f"⏭ {fp.name}:output 已存在,跳过") skipped.append(str(fp)) else: pending.append(fp) if not pending: print("\n没有待处理文件。") return { "total": len(input_files), "succeeded": [], "skipped": skipped, "failed": [], "total_input_tokens": 0, "total_output_tokens": 0, "total_cost_usd": 0.0, } if concurrency != 1: print( f"⚠ concurrency={concurrency} 被忽略:WorkflowContext 是类级单例," f"同进程并发会互相污染。已强制顺序执行(concurrency=1)。" ) print( f"\n待处理: {len(pending)} 个文件 | 并发: 1 (sequential) " f"| 跳过: {len(skipped)} | retry: {max_retries} 次/case" ) succeeded: List[Dict[str, Any]] = [] failed: List[Dict[str, Any]] = [] total_in = total_out = 0 total_cost = 0.0 transient_excs = _transient_error_types() for done_count, fp in enumerate(pending, 1): last_err: BaseException | None = None for attempt in range(1, max_retries + 1): # 跑前主动清一下单例:上一次 agent.run() 末尾应该已经 clear, # 但若中途异常退出可能没 clear,这里防御性再 clear 一次。 WorkflowContext.clear() try: result = await self.run(str(fp)) except asyncio.CancelledError: # Ctrl+C / 任务取消:立即向上传播,跳出整个 batch print(f"\n⏸ cancelled by user before completing {fp.name}") raise except transient_excs as e: last_err = e if attempt < max_retries: wait = 2 ** attempt # 2s, 4s, 8s ... print( f"⚠ [{done_count}/{len(pending)}] {fp.name} attempt {attempt}/{max_retries} " f"hit transient {type(e).__name__}: {e}; retrying in {wait}s..." ) try: await asyncio.sleep(wait) except asyncio.CancelledError: print(f"\n⏸ cancelled by user during retry backoff") raise continue # 用完次数还是失败,跳出 retry 循环走 failed 分支 break except Exception as e: # 非瞬时错误不重试 last_err = e break else: # 成功 total_in += result["input_tokens"] total_out += result["output_tokens"] total_cost += result["cost_usd"] succeeded.append({ "input": str(fp), "output_path": result["output_path"], "html_path": result.get("html_path"), "input_tokens": result["input_tokens"], "output_tokens": result["output_tokens"], "cost_usd": result["cost_usd"], "step_count": len(result["workflow"]["steps"]), }) retry_note = f" (attempt {attempt})" if attempt > 1 else "" print( f"✅ [{done_count}/{len(pending)}] {fp.name}: " f"steps={len(result['workflow']['steps'])} " f"tokens(in={result['input_tokens']}/out={result['output_tokens']}) " f"cost=${result['cost_usd']}{retry_note}" ) last_err = None break if last_err is not None: failed.append({"input": str(fp), "error": f"{type(last_err).__name__}: {last_err}"}) print( f"❌ [{done_count}/{len(pending)}] {fp.name} 失败: " f"{type(last_err).__name__}: {last_err}" ) summary = { "total": len(input_files), "succeeded": succeeded, "skipped": skipped, "failed": failed, "total_input_tokens": total_in, "total_output_tokens": total_out, "total_cost_usd": round(total_cost, 6), } print(f"\n========== 批量运行汇总 ==========") print( f"总计: {summary['total']} | 成功: {len(succeeded)} " f"| 跳过: {len(skipped)} | 失败: {len(failed)}" ) print( f"总 tokens: in={total_in}, out={total_out}, " f"cost=${summary['total_cost_usd']}" ) if failed: print("失败详情:") for item in failed: print(f" - {item['input']}: {item['error']}") return summary async def run(self, input_json_path: str) -> Dict[str, Any]: with open(input_json_path, "r", encoding="utf-8") as f: payload = json.load(f) channel_content_id = payload["channel_content_id"] title = payload.get("title", "") body_text = payload.get("body_text", "") images = payload.get("images", []) or [] if not images: raise ValueError(f"输入 {input_json_path} 没有 images,无法做多模态工序拆解") # source 里只存"图片占位",不存 base64 原文 — 避免 decode 输出文件膨胀到 MB 级 source_images = [] for i, img in enumerate(images): if isinstance(img, str) and img.startswith("data:"): source_images.append(f"") else: source_images.append(img) input_stem = Path(input_json_path).stem output_path = os.path.join(_get_output_dir(), f"{input_stem}.json") WorkflowContext.init( output_path=output_path, source_meta={ "channel_content_id": channel_content_id, "title": title, "body_text": body_text, "images": source_images, # 占位,原 base64 留在 images 局部变量里给 LangChain 用 }, ) system_prompt = _load_system_prompt() user_content = _build_user_content(title, body_text, images) model = init_chat_model(self.model_name) tools = [ think_and_plan, add_step, add_step_input, add_step_output, update_step, update_step_input, update_step_output, delete_step, delete_step_input, delete_step_output, get_current_workflow, finalize_workflow, ] agent = create_agent(model=model, tools=tools, system_prompt=system_prompt) result = await asyncio.to_thread( agent.invoke, {"messages": [{"role": "user", "content": user_content}]}, ) usage = count_token_usage(result) cost = calculate_cost(self.model_name, usage["input_tokens"], usage["output_tokens"]) final_workflow = WorkflowContext.get() WorkflowContext.clear() html_path = os.path.splitext(output_path)[0] + ".html" try: with open(html_path, "w", encoding="utf-8") as f: f.write(render_html(final_workflow)) print(f"[visualize] HTML 已生成 -> {html_path}") except Exception as e: html_path = None print(f"[visualize] HTML 生成失败(不影响工序结果): {type(e).__name__}: {e}") return { "output_path": output_path, "html_path": html_path, "input_tokens": usage["input_tokens"], "output_tokens": usage["output_tokens"], "cost_usd": cost, "workflow": final_workflow, } # ============================================================================ # 主程序 # ============================================================================ if __name__ == "__main__": import argparse import logging import sys for _stream in (sys.stdout, sys.stderr): if hasattr(_stream, "reconfigure"): _stream.reconfigure(encoding="utf-8", errors="replace") logging.getLogger("langchain").setLevel(logging.INFO) DEFAULT_INPUT = os.path.join(_CURRENT_DIR, "input") DEFAULT_MODEL = "google_genai:gemini-3-flash-preview" parser = argparse.ArgumentParser( description="工序拆解 Agent:支持单文件或目录批量处理", ) parser.add_argument( "--input", "-i", default=DEFAULT_INPUT, help=f"输入路径,可以是单个 .json 文件或包含多个 .json 的目录(默认: {DEFAULT_INPUT})", ) parser.add_argument( "--model", "-m", default=DEFAULT_MODEL, help=f"模型名(默认: {DEFAULT_MODEL})", ) parser.add_argument( "--no-skip-existing", action="store_true", help="批量模式下,即使 output 已存在也重新处理(默认会跳过已有的)", ) parser.add_argument( "--concurrency", "-c", type=int, default=3, help="批量模式的并发子进程数(默认: 3)。每个子进程独立跑一条 case,互不干扰", ) args = parser.parse_args() target = args.input agent = DecodeProcessAgent(model_name=args.model) if os.path.isdir(target): asyncio.run( agent.run_batch( target, skip_existing=not args.no_skip_existing, concurrency=args.concurrency, ) ) elif os.path.isfile(target): result = asyncio.run(agent.run(target)) print("\n===== 运行完成 =====") print(f"输出文件: {result['output_path']}") if result.get("html_path"): print(f"HTML 可视化: {result['html_path']}") print(f"步骤数: {len(result['workflow']['steps'])}") print(f"status: {result['workflow']['status']}") print( f"tokens: in={result['input_tokens']} out={result['output_tokens']} " f"cost=${result['cost_usd']}" ) else: sys.exit(f"输入路径不存在: {target}")