run_pipeline.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. """
  2. Pipeline: Research → Source → Generate-Case → Decode-Workflow [→ Apply-Grounding]
  3. ================================================================================
  4. CLI 速查
  5. ────────────────────────────────────────────────────────────────────────────────
  6. 必填:
  7. --index N 需求索引 (0-based)。输出目录 output/{(index+1):03d}/。
  8. 步骤拓扑(线性 5 步):
  9. research → source → generate-case → decode-workflow → apply-grounding
  10. ↑──────┘
  11. (phase-1 内部 research⇄source 循环,对 CLI 透明)
  12. 默认行为:
  13. 跑 research → source → generate-case → decode-workflow(4 步)。
  14. apply-grounding 仅手工触发(--only-step apply-grounding 或显式 --end-at)。
  15. 模式选择(互斥):
  16. 默认 跑 research..decode-workflow
  17. --only-step STEP 只跑单步
  18. --start-from / --end-at 区间跑(含两端),任填一端
  19. 可选参数:
  20. --case-index N 仅 decode-workflow / apply-grounding 支持
  21. --platforms xhs,zhihu,gzh,youtube,douyin,sph research 阶段平台过滤
  22. --skip-existing 仅在某 case 还没生成 decode 输出时才跑(增量模式)。
  23. 默认行为是全覆盖:每次跑都把所有 case 重新生成。
  24. 仅对 decode-workflow 批量模式生效;单 case 模式本身就总是重跑。
  25. --use-claude-sdk apply-grounding 走 Anthropic 官方 SDK
  26. --model {claude,gpt,gemini} apply-grounding 走 OpenRouter 后端
  27. 注:--use-claude-sdk 与 --model 互斥,且只在 active_steps 含 apply-grounding 时生效
  28. 典型用法:
  29. # 默认跑完前 4 步
  30. python run_pipeline.py --index 107
  31. # 只重跑某 case 的 decode-workflow
  32. python run_pipeline.py --index 107 --only-step decode-workflow --case-index 3
  33. # 单独跑 apply-grounding(默认全 case)
  34. python run_pipeline.py --index 107 --only-step apply-grounding
  35. # 只采集 xhs / zhihu
  36. python run_pipeline.py --index 107 --platforms xhs,zhihu
  37. """
  38. import argparse
  39. import asyncio
  40. import json
  41. import os
  42. import sys
  43. import time
  44. from dataclasses import dataclass, field
  45. from datetime import datetime
  46. from pathlib import Path
  47. from typing import Any, Dict, List, Optional, Set
  48. PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
  49. os.chdir(PROJECT_ROOT)
  50. sys.path.insert(0, str(PROJECT_ROOT))
  51. from dotenv import load_dotenv
  52. load_dotenv()
  53. from agent.llm.prompts import SimplePrompt
  54. from agent.core.runner import AgentRunner, RunConfig
  55. from agent.tools.builtin.knowledge import KnowledgeConfig
  56. from agent.trace import FileSystemTraceStore, Trace, Message
  57. from agent.llm import create_qwen_llm_call
  58. from agent.utils import setup_logging
  59. from examples.process_research.config import (
  60. TRACE_STORE_PATH, SKILLS_DIR, LOG_LEVEL, LOG_FILE,
  61. )
  62. from examples.process_pipeline.script.case_history import set_run_id, snapshot_case_file
  63. from examples.process_pipeline.script.extract_sources import extract_sources_to_json
  64. from examples.process_pipeline.script.generate_case import generate_case_from_source
  65. from examples.process_pipeline.script.extract_decode_workflow import extract_decode_workflow
  66. from examples.process_pipeline.script.apply_to_grounding_agent import apply_grounding
  67. from examples.process_pipeline.script.validate_schema import validate_case
  68. from examples.process_pipeline.script.fix_json_quotes import try_fix_and_parse
  69. # ──── Topology / Constants ────────────────────────────────────────────────────
  70. STEPS: List[str] = [
  71. "research", "source", "generate-case", "decode-workflow", "apply-grounding",
  72. ]
  73. DEFAULT_END = "decode-workflow" # apply-grounding 不进默认 pipeline
  74. CASE_INDEX_STEPS = {"decode-workflow", "apply-grounding"}
  75. LLM_CONFIGURABLE_STEPS = {"apply-grounding"}
  76. TARGET_QUALIFIED_CASES = 15
  77. MAX_RESEARCH_ROUNDS = 50
  78. QWEN_MODEL = "qwen3.5-plus"
  79. # ──── Logging Tee ─────────────────────────────────────────────────────────────
  80. class _Tee:
  81. def __init__(self, *streams):
  82. self.streams = streams
  83. def write(self, s):
  84. for st in self.streams:
  85. try:
  86. st.write(s)
  87. except Exception:
  88. pass
  89. self.flush()
  90. def flush(self):
  91. for st in self.streams:
  92. try:
  93. st.flush()
  94. except Exception:
  95. pass
  96. def isatty(self):
  97. return False
  98. # ──── Pipeline Context ────────────────────────────────────────────────────────
  99. @dataclass
  100. class PipelineContext:
  101. args: argparse.Namespace
  102. requirement: str
  103. output_dir: Path
  104. raw_cases_dir: Path
  105. base_dir: Path
  106. runner_qwen: AgentRunner
  107. active_steps: Set[str]
  108. costs_breakdown: Dict[str, float] = field(default_factory=dict)
  109. errors: List[str] = field(default_factory=list)
  110. total_cost: float = 0.0
  111. phase1_trace_ids: Dict[str, Optional[str]] = field(default_factory=dict)
  112. def track(self, name: str, cost: float) -> None:
  113. self.total_cost += cost
  114. self.costs_breakdown[name] = round(self.costs_breakdown.get(name, 0.0) + cost, 4)
  115. def error(self, msg: str) -> None:
  116. print(f"⚠️ [Error] {msg}")
  117. self.errors.append(msg)
  118. # ──── researcher runner (the only LLM-driven step left) ──────────────────────
  119. async def run_agent_task(
  120. runner: AgentRunner,
  121. prompt_name: str,
  122. kwargs: dict,
  123. task_name: str,
  124. model_name: str,
  125. *,
  126. start_trace_id: Optional[str] = None,
  127. additional_messages: Optional[list] = None,
  128. ):
  129. """
  130. 跑 prompt → agent loop,监听 write_file/write_json → 立即 schema 校验 + auto-fix。
  131. 失败时最多 3 次:能续 trace 就续 trace(带错误反馈),否则重新启动。
  132. 返回 (cost, errors, trace_id)。
  133. """
  134. prompt_path = Path(__file__).parent / "prompts" / f"{prompt_name}.prompt"
  135. prompt = SimplePrompt(prompt_path)
  136. base_messages = prompt.build_messages(**kwargs)
  137. if additional_messages:
  138. base_messages = list(base_messages) + list(additional_messages)
  139. out_file = kwargs.get("output_file")
  140. knowledge = KnowledgeConfig(
  141. enable_completion_extraction=False, enable_extraction=False, enable_injection=False,
  142. )
  143. def _instant_validate() -> Optional[str]:
  144. """文件写入后立即校验,必要时自动修复 JSON 语法。返回错误描述或 None。"""
  145. if not out_file or not Path(out_file).exists():
  146. return None
  147. try:
  148. with open(out_file, "r", encoding="utf-8") as f:
  149. raw = f.read()
  150. try:
  151. data = json.loads(raw)
  152. except json.JSONDecodeError:
  153. ok, data, desc = try_fix_and_parse(raw)
  154. if not ok:
  155. return "JSON parse failed, auto-fix unsuccessful"
  156. with open(out_file, "w", encoding="utf-8") as f:
  157. json.dump(data, f, ensure_ascii=False, indent=2)
  158. print(f" 🔧 [Auto-Fix] {desc}")
  159. filename = Path(out_file).name
  160. if filename.startswith("case_"):
  161. err = validate_case(data)
  162. if err:
  163. print(f" ⚠️ [Validation] {err}")
  164. return err
  165. print(f" ✅ [Validation] {filename} OK")
  166. return None
  167. except Exception as e:
  168. return str(e)
  169. async def _run_attempt(messages: list, attempt_name: str, trace_id: Optional[str], temperature: float):
  170. cost = 0.0
  171. errs: List[str] = []
  172. last_tid = trace_id
  173. cfg = RunConfig(
  174. model=prompt.config.get("model") or model_name,
  175. temperature=temperature,
  176. name=attempt_name,
  177. agent_type=prompt_name,
  178. tool_groups=["core", "content"] if prompt_name == "researcher" else ["core"],
  179. trace_id=trace_id,
  180. knowledge=knowledge,
  181. )
  182. try:
  183. async for item in runner.run(messages=messages, config=cfg):
  184. if isinstance(item, Trace):
  185. last_tid = item.trace_id
  186. if item.status == "completed":
  187. cost += item.total_cost
  188. elif item.status == "failed":
  189. errs.append(f"{attempt_name} failed: {item.error_message}")
  190. elif isinstance(item, Message) and item.role == "tool":
  191. content = item.content if isinstance(item.content, dict) else {}
  192. if content.get("tool_name") in ("write_file", "write_json"):
  193. print(f" 💾 [Write] {task_name}")
  194. _instant_validate()
  195. except Exception as e:
  196. errs.append(f"{attempt_name} crashed: {type(e).__name__}: {e}")
  197. print(f"❌ [Exception] {attempt_name}: {e}")
  198. return cost, errs, last_tid
  199. total_cost = 0.0
  200. total_errors: List[str] = []
  201. last_trace_id = start_trace_id
  202. last_validation_error: Optional[str] = None
  203. default_temp = prompt.config.get("temperature") or 0.3
  204. for attempt in range(3):
  205. if attempt == 0:
  206. print(f"🚀 [Launch] {task_name}")
  207. cost, errs, last_trace_id = await _run_attempt(
  208. base_messages, f"{task_name}_A0", last_trace_id, default_temp,
  209. )
  210. elif last_trace_id and last_validation_error:
  211. # 接着上次 trace 跑,告诉它哪里错了
  212. print(f"🔄 [Continue {attempt}/2] {task_name}")
  213. fix_msg = [{
  214. "role": "user",
  215. "content": (
  216. f"【系统校验失败】你上一次写入的文件 `{out_file}` 未通过 schema 校验。\n"
  217. f"错误:{last_validation_error}\n\n"
  218. f"请重新读取该文件,根据错误修复后再次 write_json 到 `{out_file}`。"
  219. f"只改有问题的部分,不要丢弃已正确的内容。"
  220. ),
  221. }]
  222. cost, errs, last_trace_id = await _run_attempt(
  223. fix_msg, f"{task_name}_Fix{attempt}", last_trace_id, 0.1,
  224. )
  225. else:
  226. # 没 trace 可续,从头来
  227. print(f"🔄 [Retry {attempt}/2] {task_name}")
  228. if out_file and Path(out_file).exists():
  229. Path(out_file).unlink()
  230. cost, errs, last_trace_id = await _run_attempt(
  231. base_messages, f"{task_name}_A{attempt}", None, default_temp,
  232. )
  233. total_cost += cost
  234. total_errors.extend(errs)
  235. # Recovery:output 文件没写出来时强制让 agent 写一个
  236. if out_file and not Path(out_file).exists() and last_trace_id:
  237. print(f"⚠️ [Recovery] {task_name} missing output, forcing wrap-up")
  238. rec_msg = [{
  239. "role": "user",
  240. "content": (
  241. f"【系统强制指令】任务终止但未写入文件。请立刻调用 write_json,"
  242. f"将已搜集到的结构化内容作为 json_data 写入到 `{out_file}`。"
  243. f"即使空也要写入。"
  244. ),
  245. }]
  246. r_cost, r_errs, last_trace_id = await _run_attempt(
  247. rec_msg, f"{task_name}_Rec", last_trace_id, 0.1,
  248. )
  249. total_cost += r_cost
  250. total_errors.extend(r_errs)
  251. # 最终校验
  252. if not out_file or not str(out_file).endswith(".json"):
  253. return total_cost, total_errors, last_trace_id
  254. if not Path(out_file).exists():
  255. print(f"❌ [Missing] {task_name}: no output file after recovery")
  256. last_validation_error = None
  257. continue
  258. err = _instant_validate()
  259. if err is None:
  260. return total_cost, total_errors, last_trace_id
  261. last_validation_error = err
  262. total_errors.append(f"{task_name} validation: {err}")
  263. print(f"❌ [Retry Limit] {task_name} exhausted retries")
  264. return total_cost, total_errors, last_trace_id
  265. # ──── Step 1+2: Research ⇄ Source ─────────────────────────────────────────────
  266. def _build_research_feedback(
  267. platform: str,
  268. last_platform_count: Dict[str, int],
  269. last_src_stats: Dict[str, Any],
  270. ) -> List[Dict[str, str]]:
  271. p_count = last_platform_count.get(platform, 0)
  272. p_filtered = [
  273. d for d in last_src_stats.get("filtered_details", [])
  274. if d.get("platform") == platform
  275. ]
  276. reason_summary = last_src_stats.get("filtered_reasons", {})
  277. lines = [
  278. f"【系统反馈】你在上一轮提取的有效案例数量未达标。",
  279. f"当前 {platform.upper()} 合格案例:{p_count}/{TARGET_QUALIFIED_CASES}",
  280. ]
  281. if reason_summary:
  282. lines.append(f"过滤统计:{dict(reason_summary)}")
  283. if p_filtered:
  284. lines.append(f"\n以下是你提交的被过滤掉的帖子(共 {len(p_filtered)} 条):")
  285. for item in p_filtered[:10]:
  286. lines.append(
  287. f" - [{item['case_id']}] {item['title']} → 原因: {item['filter_reason']}"
  288. )
  289. if len(p_filtered) > 10:
  290. lines.append(f" ... 还有 {len(p_filtered) - 10} 条未列出")
  291. lines.append(
  292. "\n请继续搜索并提取更多**全新的、不同的**高质量案例,**追加**写入到原文件。"
  293. "不要重复之前已找过的案例!针对过滤原因,确保正文详实、评分准确。"
  294. )
  295. return [{"role": "user", "content": "\n".join(lines)}]
  296. async def _run_research_round(
  297. ctx: PipelineContext,
  298. active_platforms: List[str],
  299. round_idx: int,
  300. platform_traces: Dict[str, Optional[str]],
  301. last_src_stats: Optional[Dict[str, Any]],
  302. last_platform_count: Dict[str, int],
  303. ) -> None:
  304. """跑一轮 research,多平台并行,结果存进 platform_traces。"""
  305. tasks = []
  306. for p in active_platforms:
  307. task_desc = (
  308. f"渠道:{p.upper()}。核心需求:{ctx.requirement}。"
  309. f"目标:至少收集 {TARGET_QUALIFIED_CASES} 条高质量案例(评分>=70、正文充实)。"
  310. )
  311. out_file = str(ctx.raw_cases_dir / f"case_{p}.json")
  312. kwargs = {"task": task_desc, "output_file": out_file}
  313. additional_msgs = None
  314. if round_idx > 0 and last_src_stats:
  315. additional_msgs = _build_research_feedback(p, last_platform_count, last_src_stats)
  316. tasks.append(run_agent_task(
  317. ctx.runner_qwen, "researcher", kwargs,
  318. f"P1_Research_{p}_R{round_idx+1}", QWEN_MODEL,
  319. start_trace_id=platform_traces[p],
  320. additional_messages=additional_msgs,
  321. ))
  322. results = await asyncio.gather(*tasks)
  323. for (cost, errs, tid), p in zip(results, active_platforms):
  324. ctx.track(f"P1_Research_{p}", cost)
  325. platform_traces[p] = tid
  326. ctx.phase1_trace_ids[f"P1_Research_{p}"] = tid
  327. ctx.errors.extend(errs)
  328. if not (ctx.raw_cases_dir / f"case_{p}.json").exists():
  329. ctx.error(f"Missing case file for {p}; agent likely hit max_iterations")
  330. def _extract_sources(ctx: PipelineContext, trace_ids: Optional[List[str]]) -> Optional[Dict[str, Any]]:
  331. try:
  332. stats = extract_sources_to_json(ctx.raw_cases_dir, trace_ids=trace_ids)
  333. print(
  334. f"📎 [Source] matched={stats['total_matched']}, "
  335. f"filtered={stats['filtered_total']} → {ctx.raw_cases_dir / 'source.json'}"
  336. )
  337. for reason, cnt in stats.get("filtered_reasons", {}).items():
  338. print(f" - {reason}: {cnt}")
  339. return stats
  340. except Exception as e:
  341. ctx.error(f"Source extraction failed: {type(e).__name__}: {e}")
  342. return None
  343. async def run_research_source_loop(ctx: PipelineContext) -> None:
  344. """完整 phase-1:research ⇄ source 循环,每轮按平台合格数判断停止。"""
  345. platforms = [p.strip() for p in ctx.args.platforms.split(",") if p.strip()]
  346. print(f"\n--- Phase 1: Research ⇄ Source loop ({QWEN_MODEL}) ---")
  347. print(f"📡 Platforms: {platforms}")
  348. platform_traces: Dict[str, Optional[str]] = {p: None for p in platforms}
  349. active_platforms = list(platforms)
  350. last_src_stats: Optional[Dict[str, Any]] = None
  351. last_platform_count: Dict[str, int] = {}
  352. round_idx = 0
  353. while active_platforms and round_idx < MAX_RESEARCH_ROUNDS:
  354. print(f"\n >>> [Round {round_idx+1}] Active: {active_platforms}")
  355. await _run_research_round(
  356. ctx, active_platforms, round_idx, platform_traces,
  357. last_src_stats, last_platform_count,
  358. )
  359. trace_id_list = [tid for tid in ctx.phase1_trace_ids.values() if tid]
  360. last_src_stats = _extract_sources(ctx, trace_id_list)
  361. source_file = ctx.raw_cases_dir / "source.json"
  362. if not source_file.exists():
  363. print(" ⚠️ source.json not found, continuing loop")
  364. round_idx += 1
  365. continue
  366. with open(source_file, "r", encoding="utf-8") as f:
  367. source_data = json.load(f)
  368. platform_count: Dict[str, int] = {}
  369. for s in source_data.get("sources", []):
  370. p = s.get("platform")
  371. if p:
  372. platform_count[p] = platform_count.get(p, 0) + 1
  373. last_platform_count = platform_count
  374. print(f" 📊 Target: >={TARGET_QUALIFIED_CASES}/platform")
  375. next_active = []
  376. for p in platforms:
  377. count = platform_count.get(p, 0)
  378. if p in active_platforms:
  379. print(f" - {p}: {count}/{TARGET_QUALIFIED_CASES}")
  380. if count < TARGET_QUALIFIED_CASES:
  381. next_active.append(p)
  382. active_platforms = next_active
  383. if not active_platforms:
  384. print(f" ✅ All platforms reached target {TARGET_QUALIFIED_CASES}")
  385. break
  386. round_idx += 1
  387. if round_idx >= MAX_RESEARCH_ROUNDS and active_platforms:
  388. print(f" ⚠️ Max {MAX_RESEARCH_ROUNDS} rounds reached. Remaining: {active_platforms}")
  389. async def run_research_only(ctx: PipelineContext) -> None:
  390. """单步 research:跑一轮,所有 --platforms 都跑,不接 source 校验循环。"""
  391. platforms = [p.strip() for p in ctx.args.platforms.split(",") if p.strip()]
  392. if not platforms:
  393. print(" ❌ No platforms specified")
  394. sys.exit(1)
  395. print(f"\n--- Single Step: Research ({QWEN_MODEL}) ---")
  396. print(f"📡 Platforms: {platforms}")
  397. platform_traces: Dict[str, Optional[str]] = {p: None for p in platforms}
  398. await _run_research_round(ctx, platforms, 0, platform_traces, None, {})
  399. async def run_source_only(ctx: PipelineContext) -> None:
  400. """单步 source:从已有 case_*.json 提取 source.json。"""
  401. print(f"\n--- Single Step: Source Extraction ---")
  402. _extract_sources(ctx, trace_ids=None)
  403. # ──── Step 3: generate-case ──────────────────────────────────────────────────
  404. async def run_generate_case(ctx: PipelineContext) -> None:
  405. print(f"\n--- Phase 1.3: Generate case.json ---")
  406. source_file = ctx.raw_cases_dir / "source.json"
  407. if not source_file.exists():
  408. ctx.error("source.json not found; run research/source first")
  409. return
  410. try:
  411. result = await generate_case_from_source(ctx.raw_cases_dir)
  412. print(f"📦 [Generate Case] cases={result['total_cases']} → {result['output_file']}")
  413. except Exception as e:
  414. ctx.error(f"Generate case failed: {type(e).__name__}: {e}")
  415. # ──── Step 4: decode-workflow ────────────────────────────────────────────────
  416. async def run_decode_workflow(ctx: PipelineContext) -> None:
  417. print(f"\n--- Phase 2: Decode Workflow (Gemini + LangChain) ---")
  418. case_file = ctx.output_dir / "case.json"
  419. if not case_file.exists():
  420. ctx.error("case.json not found; run generate-case first")
  421. return
  422. try:
  423. result = await extract_decode_workflow(
  424. case_file=case_file,
  425. case_index=ctx.args.case_index, # None = all cases
  426. skip_existing=ctx.args.skip_existing,
  427. )
  428. ctx.track("Decode_Workflow", result.get("total_cost", 0.0))
  429. print(
  430. f" ✓ decode-workflow: succeeded={result['succeeded']}/{result['total']} "
  431. f"skipped={result['skipped']} failed={result['failed']} "
  432. f"merged={result.get('merged', 0)} cost=${result['total_cost']}"
  433. )
  434. print(f" output_dir: {result['output_dir']}")
  435. except ImportError as e:
  436. ctx.error(
  437. f"decode-workflow 依赖未安装: {e}; "
  438. f"需要 pip install langchain langchain-google-genai 以及 .env GOOGLE_API_KEY"
  439. )
  440. except Exception as e:
  441. ctx.error(f"Decode workflow failed: {type(e).__name__}: {e}")
  442. # ──── Step 5: apply-grounding ────────────────────────────────────────────────
  443. def _build_main_llm_call(args: argparse.Namespace) -> tuple:
  444. """根据 --use-claude-sdk / --model 创建 grounding 用的 llm_call。"""
  445. if args.use_claude_sdk:
  446. model = "claude-sonnet-4-6"
  447. from agent.llm.claude_code_oauth import create_claude_code_oauth_llm_call
  448. print(f"✅ apply-grounding via Claude Agent SDK (OAuth): {model}")
  449. return create_claude_code_oauth_llm_call(model=model), model
  450. model_map = {
  451. "claude": "claude-sonnet-4-6",
  452. "gpt": "gpt-5.4",
  453. "gemini": "~google/gemini-pro-latest",
  454. }
  455. model = model_map.get(args.model, "claude-sonnet-4-6")
  456. from agent.llm.openrouter import create_openrouter_llm_call
  457. print(f"✅ apply-grounding via OpenRouter: {model}")
  458. return create_openrouter_llm_call(model=model), model
  459. def _filter_case_to_single(case_file: Path, case_index: int) -> Path:
  460. """挑出指定 case 写入临时文件,返回临时文件路径。"""
  461. with open(case_file, "r", encoding="utf-8") as f:
  462. data = json.load(f)
  463. cases = data.get("cases", [])
  464. target = next((c for c in cases if c.get("index") == case_index), None)
  465. if not target:
  466. raise ValueError(f"Case with index {case_index} not found in {case_file.name}")
  467. data["cases"] = [target]
  468. temp_file = case_file.parent / f"case_temp_{case_index}.json"
  469. with open(temp_file, "w", encoding="utf-8") as f:
  470. json.dump(data, f, ensure_ascii=False, indent=2)
  471. print(f" [Target] case {case_index}: {target.get('title', 'untitled')[:30]}")
  472. return temp_file
  473. def _merge_single_case_back(case_file: Path, temp_file: Path, case_index: int) -> None:
  474. """把临时文件里改过的 case 合并回原 case.json,留快照后写回。"""
  475. with open(temp_file, "r", encoding="utf-8") as f:
  476. updated_case = json.load(f)["cases"][0]
  477. with open(case_file, "r", encoding="utf-8") as f:
  478. original = json.load(f)
  479. for i, c in enumerate(original["cases"]):
  480. if c.get("index") == case_index:
  481. original["cases"][i] = updated_case
  482. break
  483. snap = snapshot_case_file(case_file, step="grounding_merge")
  484. if snap:
  485. print(f" [snapshot] {snap.name}")
  486. with open(case_file, "w", encoding="utf-8") as f:
  487. json.dump(original, f, ensure_ascii=False, indent=2)
  488. temp_file.unlink()
  489. print(f" ✓ Merged case {case_index} back to case.json")
  490. async def run_apply_grounding(ctx: PipelineContext) -> None:
  491. print(f"\n--- Phase 2: Apply Grounding ---")
  492. case_file = ctx.output_dir / "case.json"
  493. if not case_file.exists():
  494. ctx.error("case.json not found; run generate-case first")
  495. return
  496. llm_call, model = _build_main_llm_call(ctx.args)
  497. if ctx.args.case_index is not None:
  498. try:
  499. target_file = _filter_case_to_single(case_file, ctx.args.case_index)
  500. except ValueError as e:
  501. ctx.error(str(e))
  502. return
  503. else:
  504. target_file = case_file
  505. try:
  506. result = await apply_grounding(
  507. target_file, llm_call, model=model, max_concurrent=3,
  508. )
  509. ctx.track("Apply_Grounding", result.get("total_cost", 0.0))
  510. print(
  511. f"🗺️ [Grounding] grounded={result['grounded']}/{result['total']} "
  512. f"cost=${result.get('total_cost', 0.0):.4f}"
  513. )
  514. except Exception as e:
  515. ctx.error(f"Apply grounding failed: {type(e).__name__}: {e}")
  516. if ctx.args.case_index is not None and target_file != case_file:
  517. target_file.unlink(missing_ok=True)
  518. return
  519. if ctx.args.case_index is not None:
  520. _merge_single_case_back(case_file, target_file, ctx.args.case_index)
  521. # ──── CLI parsing & validation ───────────────────────────────────────────────
  522. def _parse_args() -> argparse.Namespace:
  523. parser = argparse.ArgumentParser(description="AIGC Process Pipeline (5-step)")
  524. parser.add_argument("--index", type=int, required=True,
  525. help="Index of requirement in db_requirements.json (0-based)")
  526. parser.add_argument("--platforms", type=str, default="xhs,zhihu,gzh,youtube,douyin,sph",
  527. help="Comma-separated platforms for research step")
  528. parser.add_argument("--case-index", type=int, default=None,
  529. help="Re-run a single case in decode-workflow / apply-grounding")
  530. parser.add_argument("--skip-existing", action="store_true",
  531. help="Skip cases whose output already exists "
  532. "(default: re-run / overwrite everything; "
  533. "currently affects decode-workflow batch mode)")
  534. parser.add_argument("--only-step", type=str, choices=STEPS,
  535. help="Run only a single step (mutex with --start-from/--end-at)")
  536. parser.add_argument("--start-from", type=str, choices=STEPS,
  537. help="Start from this step (inclusive)")
  538. parser.add_argument("--end-at", type=str, choices=STEPS,
  539. help="End at this step (inclusive)")
  540. llm_group = parser.add_mutually_exclusive_group()
  541. llm_group.add_argument("--use-claude-sdk", action="store_true",
  542. help="apply-grounding via Anthropic SDK (mutex with --model)")
  543. llm_group.add_argument("--model", type=str, choices=["claude", "gpt", "gemini"],
  544. default="claude",
  545. help="apply-grounding LLM family via OpenRouter")
  546. args = parser.parse_args()
  547. if args.only_step and (args.start_from or args.end_at):
  548. parser.error("--only-step is mutually exclusive with --start-from / --end-at")
  549. return args
  550. def _resolve_active_steps(args: argparse.Namespace) -> Set[str]:
  551. if args.only_step:
  552. return {args.only_step}
  553. start = args.start_from or STEPS[0]
  554. end = args.end_at or DEFAULT_END
  555. start_idx = STEPS.index(start)
  556. end_idx = STEPS.index(end)
  557. if start_idx > end_idx:
  558. print(f"❌ --start-from '{start}' is after --end-at '{end}'")
  559. sys.exit(1)
  560. return set(STEPS[start_idx:end_idx + 1])
  561. def _validate_args(args: argparse.Namespace, active_steps: Set[str]) -> None:
  562. if args.case_index is not None and not (active_steps & CASE_INDEX_STEPS):
  563. print(
  564. f"❌ --case-index only applies to {sorted(CASE_INDEX_STEPS)}; "
  565. f"none of those are in active steps {sorted(active_steps)}"
  566. )
  567. sys.exit(1)
  568. # --use-claude-sdk / --model 仅在 apply-grounding active 时生效
  569. llm_flags_explicit = (
  570. args.use_claude_sdk
  571. or any(a == "--model" or a.startswith("--model=") for a in sys.argv[1:])
  572. )
  573. if llm_flags_explicit and not (active_steps & LLM_CONFIGURABLE_STEPS):
  574. print(
  575. f"⚠️ --use-claude-sdk / --model are only used by apply-grounding; "
  576. f"ignored because active steps {sorted(active_steps)} don't include it"
  577. )
  578. # ──── Bootstrap ──────────────────────────────────────────────────────────────
  579. def _setup_run_log(output_dir: Path) -> Path:
  580. """新建本次运行的 history/<run_id>/,开 Tee 把 stdout/stderr 写进 run.log。"""
  581. run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
  582. set_run_id(run_id)
  583. run_dir = output_dir / "history" / run_id
  584. run_dir.mkdir(parents=True, exist_ok=True)
  585. log_path = run_dir / "run.log"
  586. log_file = open(log_path, "w", encoding="utf-8")
  587. initial_case = output_dir / "case.json"
  588. if initial_case.exists():
  589. snapshot_case_file(initial_case, step="run_start")
  590. sys.stdout = _Tee(sys.__stdout__, log_file)
  591. sys.stderr = _Tee(sys.__stderr__, log_file)
  592. import atexit
  593. atexit.register(log_file.close)
  594. print(f"[run-log] tee active → {log_path}")
  595. return log_path
  596. def _load_requirement(base_dir: Path, index: int) -> str:
  597. req_path = base_dir / "db_requirements.json"
  598. with open(req_path, encoding="utf-8") as f:
  599. reqs = json.load(f)
  600. if index < 0 or index >= len(reqs):
  601. print(f"❌ Index {index} out of bounds (db has {len(reqs)} entries)")
  602. sys.exit(1)
  603. return reqs[index]
  604. def _save_metrics(ctx: PipelineContext, elapsed_sec: float) -> None:
  605. metrics_file = ctx.base_dir / "run_metrics.json"
  606. metrics_data: List[Any] = []
  607. if metrics_file.exists():
  608. try:
  609. with open(metrics_file, "r", encoding="utf-8") as f:
  610. metrics_data = json.load(f)
  611. except json.JSONDecodeError:
  612. pass
  613. metrics_data.append({
  614. "index": ctx.args.index,
  615. "requirement": ctx.requirement[:80] + "...",
  616. "duration_seconds": round(elapsed_sec, 2),
  617. "total_cost_usd": round(ctx.total_cost, 4),
  618. "costs_breakdown": ctx.costs_breakdown,
  619. "trace_ids": {k: v for k, v in ctx.phase1_trace_ids.items() if v},
  620. "errors": ctx.errors,
  621. "active_steps": sorted(ctx.active_steps),
  622. "timestamp": datetime.now().isoformat(),
  623. })
  624. with open(metrics_file, "w", encoding="utf-8") as f:
  625. json.dump(metrics_data, f, indent=2, ensure_ascii=False)
  626. # ──── Main dispatch ──────────────────────────────────────────────────────────
  627. async def main() -> None:
  628. args = _parse_args()
  629. active_steps = _resolve_active_steps(args)
  630. _validate_args(args, active_steps)
  631. base_dir = Path(__file__).parent
  632. requirement = _load_requirement(base_dir, args.index)
  633. output_dir = base_dir / "output" / f"{(args.index+1):03d}"
  634. raw_cases_dir = output_dir / "raw_cases"
  635. output_dir.mkdir(parents=True, exist_ok=True)
  636. raw_cases_dir.mkdir(parents=True, exist_ok=True)
  637. _setup_run_log(output_dir)
  638. setup_logging(level=LOG_LEVEL, file=LOG_FILE)
  639. print("=" * 60)
  640. print(f"Pipeline | Demand: [{args.index+1:03d}] {requirement[:40]}...")
  641. print(f"Active steps: {' → '.join(s for s in STEPS if s in active_steps)}")
  642. print("=" * 60)
  643. # Load agent presets if available
  644. presets_path = base_dir / "presets.json"
  645. if presets_path.exists():
  646. from agent.core.presets import load_presets_from_json
  647. load_presets_from_json(str(presets_path))
  648. print("✅ Loaded agent presets")
  649. store = FileSystemTraceStore(base_path=TRACE_STORE_PATH)
  650. runner_qwen = AgentRunner(
  651. trace_store=store,
  652. llm_call=create_qwen_llm_call(model=QWEN_MODEL),
  653. skills_dir=SKILLS_DIR,
  654. )
  655. ctx = PipelineContext(
  656. args=args,
  657. requirement=requirement,
  658. output_dir=output_dir,
  659. raw_cases_dir=raw_cases_dir,
  660. base_dir=base_dir,
  661. runner_qwen=runner_qwen,
  662. active_steps=active_steps,
  663. )
  664. start_time = time.time()
  665. try:
  666. # research ⇄ source 是耦合的:两者都 active 时走循环,否则各跑各的
  667. if "research" in active_steps and "source" in active_steps:
  668. await run_research_source_loop(ctx)
  669. elif "research" in active_steps:
  670. await run_research_only(ctx)
  671. elif "source" in active_steps:
  672. await run_source_only(ctx)
  673. if "generate-case" in active_steps:
  674. await run_generate_case(ctx)
  675. if "decode-workflow" in active_steps:
  676. await run_decode_workflow(ctx)
  677. if "apply-grounding" in active_steps:
  678. await run_apply_grounding(ctx)
  679. elapsed = time.time() - start_time
  680. _save_metrics(ctx, elapsed)
  681. print(f"\n📊 [Metrics] Completed in {elapsed:.1f}s. Total cost: ${ctx.total_cost:.4f}")
  682. if ctx.errors:
  683. print(f"⚠️ {len(ctx.errors)} error(s) encountered:")
  684. for e in ctx.errors[:10]:
  685. print(f" - {e}")
  686. finally:
  687. print("✅ Pipeline run finished.")
  688. if __name__ == "__main__":
  689. asyncio.run(main())