| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 工序拆解 Agent(LangChain)
- 输入:单个小红书内容 JSON 文件路径(结构见 aigc_data/化妆师川川/*.json)。
- 输出:把还原的「完整工序序列」逐步写入 decode_process/output/<channel_content_id>.json。
- 实现结构对齐 aiddit/pattern_global/MergeAgentLangChain.py,但只针对单条输入、
- 纯文件持久化、不依赖 BaseMergeAgent / 数据库。
- """
- import asyncio
- import json
- import os
- from pathlib import Path
- from typing import Any, Dict, List, Tuple, Type
- from dotenv import load_dotenv
- load_dotenv(Path(__file__).resolve().parents[2] / ".env")
- from langchain.agents import create_agent
- from langchain.chat_models import init_chat_model
- from agent_tools import (
- think_and_plan,
- add_step,
- add_step_input,
- add_step_output,
- update_step,
- update_step_input,
- update_step_output,
- delete_step,
- delete_step_input,
- delete_step_output,
- get_current_workflow,
- finalize_workflow,
- )
- from workflow_store import WorkflowContext
- from visualize_workflow import render_html
- # ============================================================================
- # 模型定价配置(每百万 token 美元价格)
- # ============================================================================
- MODEL_PRICING = {
- "google_genai:gemini-3-flash-preview": {"input": 0.50, "output": 3.00},
- }
- # ============================================================================
- # Token 统计
- # ============================================================================
- def count_token_usage(result: dict) -> dict:
- """从 Agent 执行结果中统计 token 消耗 + Gemini cache 命中诊断。
- 同时 dump 每轮的 input_token_details,便于观察 cache_read / cached_content
- 等字段(不同 langchain_google_genai 版本字段名不同,dump 整个 dict 最稳)。
- """
- from langchain_core.messages import AIMessage
- total_input = 0
- total_output = 0
- total_cached = 0
- turns = []
- for idx, msg in enumerate(result["messages"]):
- if isinstance(msg, AIMessage) and getattr(msg, "usage_metadata", None):
- um = msg.usage_metadata
- it = um.get("input_tokens", 0) or 0
- ot = um.get("output_tokens", 0) or 0
- details = um.get("input_token_details", {}) or {}
- cached = (
- details.get("cache_read", 0)
- or details.get("cached_content", 0)
- or details.get("cached", 0)
- or 0
- )
- total_input += it
- total_output += ot
- total_cached += cached
- turns.append((idx, it, ot, cached, details))
- if turns:
- print("─" * 80)
- print(
- f"{'msg_idx':<8}{'input':>10}{'output':>10}{'cached':>10}{'hit_rate':>10} details"
- )
- print("─" * 80)
- for idx, it, ot, cached, details in turns:
- hit = (cached / it * 100) if it else 0.0
- print(
- f"{idx:<8}{it:>10}{ot:>10}{cached:>10}{hit:>9.1f}% {details}"
- )
- overall_hit = (total_cached / total_input * 100) if total_input else 0.0
- print("─" * 80)
- print(
- f"TOTAL input={total_input} output={total_output} "
- f"cached={total_cached} (overall_hit={overall_hit:.1f}%)"
- )
- print("─" * 80)
- else:
- print(f"total_input={total_input}, total_output={total_output} (no usage_metadata found)")
- return {
- "input_tokens": total_input,
- "output_tokens": total_output,
- "cached_tokens": total_cached,
- }
- def calculate_cost(model_name: str, input_tokens: int, output_tokens: int) -> float:
- pricing = MODEL_PRICING.get(model_name)
- if not pricing:
- return 0.0
- cost = (
- input_tokens * pricing["input"] / 1_000_000
- + output_tokens * pricing["output"] / 1_000_000
- )
- return round(cost, 6)
- # ============================================================================
- # DecodeProcessAgent
- # ============================================================================
- _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
- _PROMPT_FILE = os.path.join(_CURRENT_DIR, "decode_process_prompt.md")
- def _get_output_dir() -> str:
- """支持环境变量 DECODE_OUTPUT_DIR 覆盖 — 每次 run() 时重新读,方便 adapter 按 case 切换目录。"""
- return os.environ.get("DECODE_OUTPUT_DIR") or os.path.join(_CURRENT_DIR, "output")
- # 向后兼容(模块级常量仍可被原版 main 引用)
- _OUTPUT_DIR = _get_output_dir()
- def _load_system_prompt() -> str:
- with open(_PROMPT_FILE, "r", encoding="utf-8") as f:
- return f.read()
- def _build_user_content(title: str, body_text: str, images: List[str]) -> List[Dict[str, Any]]:
- """构造多模态 user message 的 content 数组。"""
- instruction = (
- "请根据下面这条内容的标题、正文与图片,按系统提示中的规则拆解完整工序。\n"
- "工作流:think_and_plan -> 多轮 add_step / add_step_input / add_step_output "
- "-> finalize_workflow。\n\n"
- f"标题: {title}\n"
- f"正文:\n{body_text}\n\n"
- f"共 {len(images)} 张图片,下面按顺序给出(依次对应 图1、图2 ...)。"
- )
- content: List[Dict[str, Any]] = [{"type": "text", "text": instruction}]
- for url in images:
- content.append({"type": "image_url", "image_url": url})
- return content
- def _transient_error_types() -> Tuple[Type[BaseException], ...]:
- """构造瞬时网络错误集合(用于 run_batch 内部 retry 判定)。
- 各家 HTTP 库都自定义了一套异常体系,类名常和 builtins 重名但不是同一个类
- (比如 `requests.exceptions.ConnectionError` !== `builtins.ConnectionError`)。
- 所以这里要把 langchain-google-genai 路径上可能出现的库都各 import 一遍。
- 全部用 try-import 容错——某个库没装时不影响其它判定。
- """
- excs: List[Type[BaseException]] = [
- ConnectionError, ConnectionResetError, ConnectionAbortedError, TimeoutError,
- ]
- try:
- import httpx
- excs.extend([
- httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadError,
- httpx.WriteError, httpx.ConnectTimeout, httpx.ReadTimeout,
- httpx.NetworkError, httpx.PoolTimeout,
- ])
- except ImportError:
- pass
- try:
- # google-genai SDK 内部某些路径还在用 requests/urllib3(OAuth 流程、metadata 探测等)
- import requests
- excs.extend([
- requests.exceptions.ConnectionError,
- requests.exceptions.ChunkedEncodingError,
- requests.exceptions.ReadTimeout,
- requests.exceptions.ConnectTimeout,
- ])
- except ImportError:
- pass
- try:
- import urllib3
- excs.extend([
- urllib3.exceptions.ProtocolError,
- urllib3.exceptions.NewConnectionError,
- urllib3.exceptions.ReadTimeoutError,
- ])
- except ImportError:
- pass
- try:
- from google.api_core import exceptions as gae
- excs.extend([gae.ServiceUnavailable, gae.DeadlineExceeded, gae.RetryError])
- except ImportError:
- pass
- return tuple(excs)
- class DecodeProcessAgent:
- """LangChain 工序拆解 Agent。"""
- def __init__(self, model_name: str = "google_genai:gemini-3-flash-preview"):
- self.model_name = model_name
- async def run_batch(
- self,
- input_dir: str,
- skip_existing: bool = True,
- concurrency: int = 1,
- max_retries: int = 3,
- ) -> Dict[str, Any]:
- """批量处理目录下所有 *.json 文件,**顺序执行**(同主进程内)。
- 历史背景:旧版用 ProcessPoolExecutor 多进程并发,目的是规避 WorkflowContext
- 类级单例同进程并发污染问题。但 ProcessPoolExecutor 在 Windows 上有三大副作用:
- ① 子进程 stdout 不进父进程 Tee/run.log;
- ② 多并发同时发大请求易触发瞬时网络错(RemoteProtocolError / ConnectionReset);
- ③ Ctrl+C 不传给子进程,停不下来。
- 现改回主进程顺序执行:agent.run() 末尾会 WorkflowContext.clear(),所以串行多
- case 没有污染问题;网络抖动通过 max_retries 重试覆盖;Ctrl+C 走 asyncio
- CancelledError 立即生效。
- concurrency 参数保留但 > 1 时打 warning 并强制为 1(语义保留以便未来若
- WorkflowContext 改成 ContextVar 隔离后能恢复并发)。
- Args:
- input_dir: 输入目录,每个 .json 文件是一条小红书内容。
- skip_existing: 若 output/<input_stem>.json 已存在则跳过。
- concurrency: 历史参数;当前实现强制为 1。
- max_retries: 单 case 遇到瞬时网络错时的最大重试次数(默认 3)。
- Returns:
- {"total", "succeeded": [...], "skipped": [...], "failed": [...],
- "total_input_tokens", "total_output_tokens", "total_cost_usd"}
- """
- input_dir_path = Path(input_dir)
- if not input_dir_path.is_dir():
- raise ValueError(f"输入路径不是目录: {input_dir}")
- input_files = sorted(input_dir_path.glob("*.json"))
- if not input_files:
- raise ValueError(f"目录 {input_dir} 下没有 .json 文件")
- pending: List[Path] = []
- skipped: List[str] = []
- for fp in input_files:
- if skip_existing and os.path.exists(
- os.path.join(_get_output_dir(), f"{fp.stem}.json")
- ):
- print(f"⏭ {fp.name}:output 已存在,跳过")
- skipped.append(str(fp))
- else:
- pending.append(fp)
- if not pending:
- print("\n没有待处理文件。")
- return {
- "total": len(input_files),
- "succeeded": [],
- "skipped": skipped,
- "failed": [],
- "total_input_tokens": 0,
- "total_output_tokens": 0,
- "total_cost_usd": 0.0,
- }
- if concurrency != 1:
- print(
- f"⚠ concurrency={concurrency} 被忽略:WorkflowContext 是类级单例,"
- f"同进程并发会互相污染。已强制顺序执行(concurrency=1)。"
- )
- print(
- f"\n待处理: {len(pending)} 个文件 | 并发: 1 (sequential) "
- f"| 跳过: {len(skipped)} | retry: {max_retries} 次/case"
- )
- succeeded: List[Dict[str, Any]] = []
- failed: List[Dict[str, Any]] = []
- total_in = total_out = 0
- total_cost = 0.0
- transient_excs = _transient_error_types()
- for done_count, fp in enumerate(pending, 1):
- last_err: BaseException | None = None
- for attempt in range(1, max_retries + 1):
- # 跑前主动清一下单例:上一次 agent.run() 末尾应该已经 clear,
- # 但若中途异常退出可能没 clear,这里防御性再 clear 一次。
- WorkflowContext.clear()
- try:
- result = await self.run(str(fp))
- except asyncio.CancelledError:
- # Ctrl+C / 任务取消:立即向上传播,跳出整个 batch
- print(f"\n⏸ cancelled by user before completing {fp.name}")
- raise
- except transient_excs as e:
- last_err = e
- if attempt < max_retries:
- wait = 2 ** attempt # 2s, 4s, 8s ...
- print(
- f"⚠ [{done_count}/{len(pending)}] {fp.name} attempt {attempt}/{max_retries} "
- f"hit transient {type(e).__name__}: {e}; retrying in {wait}s..."
- )
- try:
- await asyncio.sleep(wait)
- except asyncio.CancelledError:
- print(f"\n⏸ cancelled by user during retry backoff")
- raise
- continue
- # 用完次数还是失败,跳出 retry 循环走 failed 分支
- break
- except Exception as e:
- # 非瞬时错误不重试
- last_err = e
- break
- else:
- # 成功
- total_in += result["input_tokens"]
- total_out += result["output_tokens"]
- total_cost += result["cost_usd"]
- succeeded.append({
- "input": str(fp),
- "output_path": result["output_path"],
- "html_path": result.get("html_path"),
- "input_tokens": result["input_tokens"],
- "output_tokens": result["output_tokens"],
- "cost_usd": result["cost_usd"],
- "step_count": len(result["workflow"]["steps"]),
- })
- retry_note = f" (attempt {attempt})" if attempt > 1 else ""
- print(
- f"✅ [{done_count}/{len(pending)}] {fp.name}: "
- f"steps={len(result['workflow']['steps'])} "
- f"tokens(in={result['input_tokens']}/out={result['output_tokens']}) "
- f"cost=${result['cost_usd']}{retry_note}"
- )
- last_err = None
- break
- if last_err is not None:
- failed.append({"input": str(fp), "error": f"{type(last_err).__name__}: {last_err}"})
- print(
- f"❌ [{done_count}/{len(pending)}] {fp.name} 失败: "
- f"{type(last_err).__name__}: {last_err}"
- )
- summary = {
- "total": len(input_files),
- "succeeded": succeeded,
- "skipped": skipped,
- "failed": failed,
- "total_input_tokens": total_in,
- "total_output_tokens": total_out,
- "total_cost_usd": round(total_cost, 6),
- }
- print(f"\n========== 批量运行汇总 ==========")
- print(
- f"总计: {summary['total']} | 成功: {len(succeeded)} "
- f"| 跳过: {len(skipped)} | 失败: {len(failed)}"
- )
- print(
- f"总 tokens: in={total_in}, out={total_out}, "
- f"cost=${summary['total_cost_usd']}"
- )
- if failed:
- print("失败详情:")
- for item in failed:
- print(f" - {item['input']}: {item['error']}")
- return summary
- async def run(self, input_json_path: str) -> Dict[str, Any]:
- with open(input_json_path, "r", encoding="utf-8") as f:
- payload = json.load(f)
- channel_content_id = payload["channel_content_id"]
- title = payload.get("title", "")
- body_text = payload.get("body_text", "")
- images = payload.get("images", []) or []
- if not images:
- raise ValueError(f"输入 {input_json_path} 没有 images,无法做多模态工序拆解")
- # source 里只存"图片占位",不存 base64 原文 — 避免 decode 输出文件膨胀到 MB 级
- source_images = []
- for i, img in enumerate(images):
- if isinstance(img, str) and img.startswith("data:"):
- source_images.append(f"<image_{i + 1} (base64, {len(img) // 1024}KB)>")
- else:
- source_images.append(img)
- input_stem = Path(input_json_path).stem
- output_path = os.path.join(_get_output_dir(), f"{input_stem}.json")
- WorkflowContext.init(
- output_path=output_path,
- source_meta={
- "channel_content_id": channel_content_id,
- "title": title,
- "body_text": body_text,
- "images": source_images, # 占位,原 base64 留在 images 局部变量里给 LangChain 用
- },
- )
- system_prompt = _load_system_prompt()
- user_content = _build_user_content(title, body_text, images)
- model = init_chat_model(self.model_name)
- tools = [
- think_and_plan,
- add_step,
- add_step_input,
- add_step_output,
- update_step,
- update_step_input,
- update_step_output,
- delete_step,
- delete_step_input,
- delete_step_output,
- get_current_workflow,
- finalize_workflow,
- ]
- agent = create_agent(model=model, tools=tools, system_prompt=system_prompt)
- result = await asyncio.to_thread(
- agent.invoke,
- {"messages": [{"role": "user", "content": user_content}]},
- )
- usage = count_token_usage(result)
- cost = calculate_cost(self.model_name, usage["input_tokens"], usage["output_tokens"])
- final_workflow = WorkflowContext.get()
- WorkflowContext.clear()
- html_path = os.path.splitext(output_path)[0] + ".html"
- try:
- with open(html_path, "w", encoding="utf-8") as f:
- f.write(render_html(final_workflow))
- print(f"[visualize] HTML 已生成 -> {html_path}")
- except Exception as e:
- html_path = None
- print(f"[visualize] HTML 生成失败(不影响工序结果): {type(e).__name__}: {e}")
- return {
- "output_path": output_path,
- "html_path": html_path,
- "input_tokens": usage["input_tokens"],
- "output_tokens": usage["output_tokens"],
- "cost_usd": cost,
- "workflow": final_workflow,
- }
- # ============================================================================
- # 主程序
- # ============================================================================
- if __name__ == "__main__":
- import argparse
- import logging
- import sys
- for _stream in (sys.stdout, sys.stderr):
- if hasattr(_stream, "reconfigure"):
- _stream.reconfigure(encoding="utf-8", errors="replace")
- logging.getLogger("langchain").setLevel(logging.INFO)
- DEFAULT_INPUT = os.path.join(_CURRENT_DIR, "input")
- DEFAULT_MODEL = "google_genai:gemini-3-flash-preview"
- parser = argparse.ArgumentParser(
- description="工序拆解 Agent:支持单文件或目录批量处理",
- )
- parser.add_argument(
- "--input",
- "-i",
- default=DEFAULT_INPUT,
- help=f"输入路径,可以是单个 .json 文件或包含多个 .json 的目录(默认: {DEFAULT_INPUT})",
- )
- parser.add_argument(
- "--model",
- "-m",
- default=DEFAULT_MODEL,
- help=f"模型名(默认: {DEFAULT_MODEL})",
- )
- parser.add_argument(
- "--no-skip-existing",
- action="store_true",
- help="批量模式下,即使 output 已存在也重新处理(默认会跳过已有的)",
- )
- parser.add_argument(
- "--concurrency",
- "-c",
- type=int,
- default=3,
- help="批量模式的并发子进程数(默认: 3)。每个子进程独立跑一条 case,互不干扰",
- )
- args = parser.parse_args()
- target = args.input
- agent = DecodeProcessAgent(model_name=args.model)
- if os.path.isdir(target):
- asyncio.run(
- agent.run_batch(
- target,
- skip_existing=not args.no_skip_existing,
- concurrency=args.concurrency,
- )
- )
- elif os.path.isfile(target):
- result = asyncio.run(agent.run(target))
- print("\n===== 运行完成 =====")
- print(f"输出文件: {result['output_path']}")
- if result.get("html_path"):
- print(f"HTML 可视化: {result['html_path']}")
- print(f"步骤数: {len(result['workflow']['steps'])}")
- print(f"status: {result['workflow']['status']}")
- print(
- f"tokens: in={result['input_tokens']} out={result['output_tokens']} "
- f"cost=${result['cost_usd']}"
- )
- else:
- sys.exit(f"输入路径不存在: {target}")
|