apply_to_grounding_agent.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756
  1. """
  2. Stage 2 (Agent 版):将 capability 对接到内容分类库
  3. 替换原本 step1 (extract-draft-query) + step2 (search API) + step3 (grounding LLM)
  4. 的三段式流程为单次 agent 调用 —— 用 Claude Agent SDK 启动一个带 Read/Grep/Glob
  5. 工具的 agent,让它自己读两份分类库 JSON 并在最终输出中包含 JSON 代码块;
  6. Python 端从 agent 的文本输出里提取 JSON 并写回 case.json。
  7. 输入(每个 capability):
  8. {body, inputs, outputs, ...}(不再依赖 apply_to_draft,agent 自己看 body)
  9. Agent 输出的 JSON:
  10. {
  11. "matched": [{ category_path, category_type: "实质"|"形式", action, ability_type,
  12. matched_elements, structured_content }, ...],
  13. "suggested_additions": [{ category_type: "实质", parent_path, suggested_level, ...}, ...]
  14. }
  15. 写回 case.json(直接平铺,跟现有 grounded cap 结构对齐):
  16. capability.apply_to = matched 数组(每项自带 category_type 字段区分实质/形式)
  17. capability.suggest_apply_to = suggested_additions 数组(每项 category_type 必为"实质")
  18. """
  19. import asyncio
  20. import json
  21. import os
  22. import re
  23. from pathlib import Path
  24. from typing import Any, Dict, List, Optional, Tuple
  25. from dotenv import load_dotenv
  26. # Project root + .env
  27. PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent
  28. load_dotenv(PROJECT_ROOT / ".env")
  29. # 分类库 JSON 路径(由 inputs/ 移到 script/resource/ 下作为引用)
  30. RESOURCE_DIR = Path(__file__).resolve().parent / "resource"
  31. SHI_JSON_PATH = RESOURCE_DIR / "分类库导出_实质_20260512_132218.json"
  32. XING_JSON_PATH = RESOURCE_DIR / "分类库导出_形式_20260512_170623.json"
  33. # 并发上限(agent 调用比一次性 LLM 慢且贵,控制并发避免 OAuth 配额爆炸)
  34. DEFAULT_MAX_CONCURRENT = int(os.getenv("GROUNDING_AGENT_MAX_CONCURRENT", "2"))
  35. # 单 capability agent 最多对话轮数
  36. DEFAULT_MAX_TURNS = int(os.getenv("GROUNDING_AGENT_MAX_TURNS", "100"))
  37. # 单 cap 失败时的重试次数(每次重试启动一个新 ClaudeSDKClient,拿到 fresh context)
  38. DEFAULT_MAX_RETRIES = int(os.getenv("GROUNDING_AGENT_MAX_RETRIES", "2"))
  39. # 默认模型(走 OAuth/Max 订阅)
  40. DEFAULT_MODEL = os.getenv("GROUNDING_AGENT_MODEL", "claude-sonnet-4-6")
  41. def load_prompt_template(name: str = "apply_to_grounding_agent") -> str:
  42. """加载 prompts/<name>.prompt 模板(agent 或 oneshot 通用)"""
  43. prompt_path = Path(__file__).resolve().parent.parent / "prompts" / f"{name}.prompt"
  44. with open(prompt_path, "r", encoding="utf-8") as f:
  45. content = f.read()
  46. if content.startswith("---"):
  47. parts = content.split("---", 2)
  48. if len(parts) >= 3:
  49. content = parts[2]
  50. return content.strip()
  51. def render_oneshot_prompt(template: str, capability: Dict[str, Any], shi_json: str, xing_json: str) -> str:
  52. """渲染 oneshot prompt:把 body + 两个分类库全文注入"""
  53. body = capability.get("body") or ""
  54. return (
  55. template
  56. .replace("{capability的body信息}", body)
  57. .replace("{shi_json}", shi_json)
  58. .replace("{xing_json}", xing_json)
  59. )
  60. async def ground_single_capability_oneshot(
  61. capability: Dict[str, Any],
  62. template: str,
  63. llm_call: Any,
  64. model: str,
  65. shi_json: str,
  66. xing_json: str,
  67. task_label: str = "",
  68. max_retries: int = DEFAULT_MAX_RETRIES,
  69. ) -> Tuple[Optional[Dict[str, Any]], float, int]:
  70. """
  71. Oneshot 模式:一次性把"分类库 + capability body"塞进 prompt,调 LLM 拿 JSON。
  72. 跟 agent 模式共享 retry / recover / schema 校验 / 日志格式。
  73. Returns:
  74. (parsed_output, cost, attempts_used)
  75. """
  76. cap_id = capability.get("capability_id") or "?"
  77. label = task_label or cap_id
  78. full_prompt = render_oneshot_prompt(template, capability, shi_json, xing_json)
  79. last_reason: Optional[str] = None
  80. last_in_tok = 0
  81. last_out_tok = 0
  82. last_cache_read = 0
  83. print(
  84. f" [grounding-oneshot] cap={label} → starting (input ~{len(full_prompt):,} chars, model={model})",
  85. flush=True,
  86. )
  87. for attempt in range(max_retries + 1):
  88. if attempt > 0:
  89. print(
  90. f" [grounding-oneshot] cap={label} → RETRY {attempt}/{max_retries} "
  91. f"(prev: {last_reason})",
  92. flush=True,
  93. )
  94. try:
  95. response = await llm_call(
  96. messages=[{"role": "user", "content": full_prompt}],
  97. model=model,
  98. temperature=0.1,
  99. max_tokens=16384,
  100. )
  101. except Exception as e:
  102. last_reason = f"LLM call failed: {type(e).__name__}: {e}"
  103. print(f" [grounding-oneshot] cap={label} → {last_reason}", flush=True)
  104. # 服务端 body
  105. for attr in ("response", "body"):
  106. obj = getattr(e, attr, None)
  107. if obj is not None:
  108. try:
  109. text = obj.text if hasattr(obj, "text") else str(obj)
  110. print(f" server body: {text[:500]}", flush=True)
  111. except Exception:
  112. pass
  113. continue
  114. # 提取 content
  115. content = response.get("content", "")
  116. if isinstance(content, list):
  117. content = "".join(
  118. (b.get("text") or "") if isinstance(b, dict) else str(b) for b in content
  119. )
  120. elif not isinstance(content, str):
  121. content = str(content)
  122. # Usage
  123. usage = response.get("usage") or {}
  124. if hasattr(usage, "__dict__") and not isinstance(usage, dict):
  125. usage = {k: getattr(usage, k) for k in dir(usage)
  126. if not k.startswith("_") and not callable(getattr(usage, k))}
  127. last_in_tok = usage.get("input_tokens") or usage.get("prompt_tokens") or 0
  128. last_out_tok = usage.get("output_tokens") or usage.get("completion_tokens") or 0
  129. last_cache_read = usage.get("cache_read_input_tokens") or usage.get("cached_tokens") or 0
  130. # 提取 JSON
  131. parsed = extract_json_from_response(content)
  132. if parsed is None:
  133. last_reason = "JSON parse FAILED (no valid JSON block found)"
  134. print(f" [grounding-oneshot] cap={label} → {last_reason}", flush=True)
  135. print(f" raw output total: {len(content)} chars", flush=True)
  136. print(f" raw output head: {content[:1000]}", flush=True)
  137. if len(content) > 1000:
  138. print(f" raw output tail: {content[-500:]}", flush=True)
  139. continue
  140. # Schema 校验
  141. schema_err = _validate_against_schema(parsed)
  142. schema_warn = f" ⚠ schema violation: {schema_err}" if schema_err else ""
  143. matched_all = parsed.get("matched", []) or []
  144. shi_n = sum(1 for m in matched_all if isinstance(m, dict) and m.get("category_type") == "实质")
  145. xing_n = sum(1 for m in matched_all if isinstance(m, dict) and m.get("category_type") == "形式")
  146. suggest_n = len(parsed.get("suggested_additions", []) or [])
  147. cache_note = f" ({last_cache_read:,} cached)" if last_cache_read else ""
  148. print(
  149. f" [grounding-oneshot] cap={label} → done: in={last_in_tok:,}{cache_note} out={last_out_tok:,} "
  150. f"matched={len(matched_all)} (substance={shi_n} form={xing_n}) suggested={suggest_n}{schema_warn}",
  151. flush=True,
  152. )
  153. return parsed, 0.0, attempt + 1
  154. # 所有 retry 用完
  155. print(
  156. f" [grounding-oneshot] cap={label} → ✗ FAILED after {max_retries + 1} attempts: {last_reason}",
  157. flush=True,
  158. )
  159. return None, 0.0, max_retries + 1
  160. def render_prompt(template: str, capability: Dict[str, Any]) -> str:
  161. """渲染 prompt:把 capability.body 填进「做法描述」占位符 + 两个分类库的绝对路径。
  162. 新版 prompt 只用 body 字段(不再喂整个 capability JSON)。inputs/outputs 不传给 agent,
  163. 保持 prompt 简洁,跟用户提供的 prompt_input.txt 模板一致。
  164. """
  165. body = capability.get("body") or ""
  166. return (
  167. template
  168. # 新占位符(用户最新 prompt 模板用的)
  169. .replace("{capability的body信息}", body)
  170. # 旧占位符(兼容,万一未来需要换回完整 JSON)
  171. .replace("{capability_json}", body)
  172. .replace("{shi_path}", str(SHI_JSON_PATH.resolve()))
  173. .replace("{xing_path}", str(XING_JSON_PATH.resolve()))
  174. )
  175. def _find_balanced_json_object(text: str, start_idx: int = 0) -> Optional[str]:
  176. """
  177. 从 text[start_idx:] 找一个花括号配平的 JSON 对象,返回原始片段(含两端 {})。
  178. 用字符串状态机正确处理 JSON 字符串内的 { } 不算 brace。
  179. 找不到返回 None。
  180. """
  181. start = text.find("{", start_idx)
  182. if start < 0:
  183. return None
  184. depth = 0
  185. in_string = False
  186. escape_next = False
  187. for i in range(start, len(text)):
  188. c = text[i]
  189. if escape_next:
  190. escape_next = False
  191. continue
  192. if c == "\\":
  193. escape_next = True
  194. continue
  195. if c == '"':
  196. in_string = not in_string
  197. continue
  198. if in_string:
  199. continue
  200. if c == "{":
  201. depth += 1
  202. elif c == "}":
  203. depth -= 1
  204. if depth == 0:
  205. return text[start:i + 1]
  206. return None # 没配平
  207. def extract_json_from_response(text: str) -> Optional[Dict[str, Any]]:
  208. """
  209. 从 agent 的最终输出文本中提取 JSON。多策略,尽量宽容:
  210. 1. ```json ... ``` 代码块(findall 拿所有,倒序试,取第一个 valid 的 —— agent 倾向"先预览后正式")
  211. 2. 花括号配平 —— 不依赖 ``` 包裹,从文本任意位置找完整 {…} 对象(多个候选倒序试,取后面的优先)
  212. 3. 贪婪 regex 兜底 —— 取整段文本里 { 到最后一个 } 之间的内容
  213. """
  214. # 策略 1:``` 代码块(允许 ```json 或纯 ```)
  215. blocks = re.findall(r"```(?:json|JSON)?\s*\n?(\{[\s\S]*?\})\s*```", text)
  216. for raw in reversed(blocks):
  217. try:
  218. return json.loads(raw)
  219. except json.JSONDecodeError:
  220. continue
  221. # 策略 2:花括号配平 — 扫描所有可能的起点,倒序试(最终结果通常在文本末尾)
  222. candidates: List[str] = []
  223. cursor = 0
  224. while cursor < len(text):
  225. nxt = text.find("{", cursor)
  226. if nxt < 0:
  227. break
  228. obj = _find_balanced_json_object(text, nxt)
  229. if obj is None:
  230. break
  231. candidates.append(obj)
  232. cursor = nxt + len(obj)
  233. for raw in reversed(candidates):
  234. try:
  235. return json.loads(raw)
  236. except json.JSONDecodeError:
  237. continue
  238. # 策略 3:贪婪兜底(最长 { ... } 块,可能解析失败但试一下)
  239. greedy = re.search(r"\{[\s\S]*\}", text)
  240. if greedy:
  241. try:
  242. return json.loads(greedy.group(0))
  243. except json.JSONDecodeError:
  244. pass
  245. return None
  246. async def _ground_single_capability_attempt(
  247. capability: Dict[str, Any],
  248. template: str,
  249. model: str,
  250. max_turns: int,
  251. label: str,
  252. attempt_no: int, # 0=首次,>=1=retry
  253. ) -> Tuple[Optional[Dict[str, Any]], int, Optional[str]]:
  254. """
  255. 单次 agent 调用尝试。
  256. Returns:
  257. (parsed_output, turns_used, failure_reason)
  258. - parsed != None & failure_reason == None: 成功(或 recover)
  259. - parsed == None & failure_reason != None: 失败,可 retry
  260. """
  261. from claude_agent_sdk import (
  262. AssistantMessage,
  263. ClaudeAgentOptions,
  264. ClaudeSDKClient,
  265. ClaudeSDKError,
  266. ResultMessage,
  267. TextBlock,
  268. )
  269. # 抹掉父进程 API key,强制 SDK 子进程走 OAuth
  270. override_env = {
  271. "ANTHROPIC_API_KEY": "",
  272. "ANTHROPIC_BASE_URL": "",
  273. "ANTHROPIC_AUTH_TOKEN": "",
  274. }
  275. full_prompt = render_prompt(template, capability)
  276. stderr_lines: List[str] = []
  277. def _capture_stderr(line: str) -> None:
  278. if line:
  279. stderr_lines.append(line)
  280. options = ClaudeAgentOptions(
  281. model=model,
  282. allowed_tools=["Read", "Grep", "Glob"],
  283. max_turns=max_turns,
  284. cwd=str(RESOURCE_DIR),
  285. env=override_env,
  286. stderr=_capture_stderr,
  287. setting_sources=[],
  288. )
  289. text_parts: List[str] = []
  290. is_error = False
  291. turns_used = 0
  292. tool_use_count = 0
  293. attempt_label = f" attempt#{attempt_no+1}" if attempt_no > 0 else ""
  294. print(f" [grounding-agent] cap={label}{attempt_label} → starting (max_turns={max_turns})", flush=True)
  295. try:
  296. async with ClaudeSDKClient(options=options) as client:
  297. await client.query(full_prompt)
  298. async for msg in client.receive_response():
  299. if isinstance(msg, AssistantMessage):
  300. for block in msg.content:
  301. if hasattr(block, "thinking"):
  302. continue
  303. elif isinstance(block, TextBlock):
  304. text_parts.append(block.text)
  305. preview = block.text.replace("\n", " ").strip()[:160]
  306. if preview:
  307. print(f" [cap={label}] text: {preview}", flush=True)
  308. elif hasattr(block, "name") and hasattr(block, "input"):
  309. tool_use_count += 1
  310. tool_input_str = json.dumps(block.input, ensure_ascii=False)[:160]
  311. print(
  312. f" [cap={label}] tool#{tool_use_count} {block.name}({tool_input_str})",
  313. flush=True,
  314. )
  315. elif isinstance(msg, ResultMessage):
  316. is_error = msg.is_error
  317. turns_used = msg.num_turns
  318. except ClaudeSDKError as e:
  319. reason = f"SDK ERROR {type(e).__name__}: {e}"
  320. print(f" [grounding-agent] cap={label} → {reason}", flush=True)
  321. if stderr_lines:
  322. print(f" CLI stderr tail: {stderr_lines[-3:]}", flush=True)
  323. return None, turns_used, reason
  324. # Recover 策略:is_error=True 也尝试解析 JSON
  325. content = "".join(text_parts).strip()
  326. parsed = extract_json_from_response(content)
  327. if parsed is None:
  328. reason = (
  329. "is_error=True (likely max_turns) AND no parseable JSON in output"
  330. if is_error
  331. else "JSON parse FAILED (model output didn't contain valid JSON block)"
  332. )
  333. print(f" [grounding-agent] cap={label} → {reason}", flush=True)
  334. print(f" raw output total: {len(content)} chars", flush=True)
  335. print(f" raw output head: {content[:1000]}", flush=True)
  336. if len(content) > 1000:
  337. print(f" raw output tail: {content[-500:]}", flush=True)
  338. return None, turns_used, reason
  339. # Schema 校验
  340. schema_err = _validate_against_schema(parsed)
  341. schema_warn = f" ⚠ schema violation: {schema_err}" if schema_err else ""
  342. matched_all = parsed.get("matched", []) or []
  343. shi_n = sum(1 for m in matched_all if isinstance(m, dict) and m.get("category_type") == "实质")
  344. xing_n = sum(1 for m in matched_all if isinstance(m, dict) and m.get("category_type") == "形式")
  345. suggest_n = len(parsed.get("suggested_additions", []) or [])
  346. if is_error:
  347. print(
  348. f" [grounding-agent] cap={label} → ⚠ RECOVERED: turns={turns_used} tools={tool_use_count} "
  349. f"matched={len(matched_all)} (substance={shi_n} form={xing_n}) suggested={suggest_n} "
  350. f"(is_error=True, but JSON captured before truncation){schema_warn}",
  351. flush=True,
  352. )
  353. else:
  354. print(
  355. f" [grounding-agent] cap={label} → done: turns={turns_used} tools={tool_use_count} "
  356. f"matched={len(matched_all)} (substance={shi_n} form={xing_n}) suggested={suggest_n}{schema_warn}",
  357. flush=True,
  358. )
  359. return parsed, turns_used, None # 成功
  360. async def ground_single_capability(
  361. capability: Dict[str, Any],
  362. template: str,
  363. model: str = DEFAULT_MODEL,
  364. max_turns: int = DEFAULT_MAX_TURNS,
  365. task_label: str = "",
  366. max_retries: int = DEFAULT_MAX_RETRIES,
  367. ) -> Tuple[Optional[Dict[str, Any]], float, int]:
  368. """
  369. 对单个 capability 跑 agent grounding,失败时自动 retry。
  370. Returns:
  371. (parsed_output, cost, turns_used)
  372. parsed_output 为 None 表示解析失败或 agent 出错且重试用完
  373. cost 当前总为 0.0(OAuth 模式不计费),保留接口便于后续切换
  374. """
  375. try:
  376. import claude_agent_sdk # noqa: F401 — 早失败,避免 retry 循环里反复 ImportError
  377. except ImportError as e:
  378. raise RuntimeError(f"claude_agent_sdk not installed: {e}") from e
  379. cap_id = capability.get("capability_id") or "?"
  380. label = task_label or cap_id
  381. last_reason: Optional[str] = None
  382. last_turns_used = 0
  383. for attempt in range(max_retries + 1):
  384. if attempt > 0:
  385. print(
  386. f" [grounding-agent] cap={label} → RETRY {attempt}/{max_retries} "
  387. f"(prev failure: {last_reason})",
  388. flush=True,
  389. )
  390. parsed, turns_used, failure_reason = await _ground_single_capability_attempt(
  391. capability, template, model, max_turns, label, attempt
  392. )
  393. last_turns_used = turns_used
  394. if failure_reason is None:
  395. # 成功(包括 recover 路径)
  396. if attempt > 0:
  397. print(
  398. f" [grounding-agent] cap={label} → ✓ recovered after retry#{attempt}",
  399. flush=True,
  400. )
  401. return parsed, 0.0, turns_used
  402. # 失败:记录原因,进入下一次重试
  403. last_reason = failure_reason
  404. # 所有 retry 用完
  405. print(
  406. f" [grounding-agent] cap={label} → ✗ FAILED after {max_retries + 1} attempts: {last_reason}",
  407. flush=True,
  408. )
  409. return None, 0.0, last_turns_used
  410. def _validate_against_schema(parsed: Dict[str, Any]) -> Optional[str]:
  411. """
  412. 用 apply_to_grounding_agent.schema.json 校验 agent 输出。
  413. 返回 None 表示通过,否则返回错误字符串。
  414. 任何加载/调用异常都吞掉(返回 None) — schema 校验是 nice-to-have,不阻塞主流程。
  415. """
  416. try:
  417. from examples.process_pipeline.script.schema_manager import validate_with_schema
  418. return validate_with_schema(parsed, "apply_to_grounding_agent")
  419. except Exception as e:
  420. # schema 文件不存在 / schema_manager 异常 — 静默跳过,不影响主流程
  421. return None
  422. async def ground_single_case(
  423. case_item: Dict[str, Any],
  424. template: str,
  425. model: str = DEFAULT_MODEL,
  426. max_turns: int = DEFAULT_MAX_TURNS,
  427. max_concurrent: int = DEFAULT_MAX_CONCURRENT,
  428. on_cap_done: Optional[Any] = None, # async callable, called after each cap completes (for incremental persist)
  429. backend: str = "agent", # "agent" | "oneshot"
  430. llm_call: Any = None, # oneshot 必需
  431. oneshot_template: Optional[str] = None, # oneshot 必需
  432. shi_json: Optional[str] = None, # oneshot 必需(预读,避免每个 cap 重读)
  433. xing_json: Optional[str] = None, # oneshot 必需
  434. ) -> Tuple[Dict[str, Any], float]:
  435. """
  436. 对单个 case 的所有 capability 跑 agent grounding。
  437. **直接 in-place 修改 case_item**(不再创建副本)— 每个 cap 完成时,对应的
  438. capability dict 立刻被更新 apply_to / suggest_apply_to 字段。
  439. on_cap_done: 可选 async 回调,每个 cap 完成后调用(无参)。caller 可以在回调里
  440. 持久化 case_item 所属的 case.json,实现 cap 级增量写回。
  441. """
  442. total_cost = 0.0
  443. case_index = case_item.get("index", "?")
  444. title = (case_item.get("title") or "")[:20] or "untitled"
  445. workflow_groups = case_item.get("workflow_groups")
  446. if not isinstance(workflow_groups, list) or not workflow_groups:
  447. return case_item, total_cost
  448. # 跨 group 收集所有 capability — 直接持有原 case_item 内部的 dict 引用。
  449. # 跳过被 workflow.steps[*].phase == "非制作" 标注的 cap(meta-task,没有真实产出物)。
  450. flat_pairs: List[tuple] = []
  451. skipped_non_production = 0
  452. for group_idx, group in enumerate(workflow_groups):
  453. if not isinstance(group, dict):
  454. continue
  455. workflow_id = group.get("workflow_id") or f"g{group_idx + 1}"
  456. # 构造 step_id → phase 映射
  457. wf = group.get("workflow") or {}
  458. step_phase = {
  459. s.get("step_id"): s.get("phase")
  460. for s in (wf.get("steps") or []) if isinstance(s, dict)
  461. }
  462. capability_items = group.get("capability")
  463. if not isinstance(capability_items, list) or not capability_items:
  464. continue
  465. for cap_idx, cap in enumerate(capability_items):
  466. if not (isinstance(cap, dict) and cap.get("capability_id")):
  467. continue
  468. cap_id = cap.get("capability_id")
  469. # 通过 workflow_step_ref 反查 phase
  470. step_ref = cap.get("workflow_step_ref") or {}
  471. step_id = step_ref.get("step_id")
  472. phase = step_phase.get(step_id)
  473. if phase == "非制作":
  474. print(
  475. f" [grounding-agent] skip cap={case_index}/{workflow_id}/{cap_id} "
  476. f"(step={step_id} phase=非制作, no production output to ground)",
  477. flush=True,
  478. )
  479. skipped_non_production += 1
  480. continue
  481. flat_pairs.append((group_idx, cap_idx, cap, workflow_id))
  482. if not flat_pairs:
  483. return case_item, total_cost
  484. print(f"\n[{case_index}] === case '{title}' grounding {len(flat_pairs)} capabilities ===", flush=True)
  485. sem = asyncio.Semaphore(max_concurrent)
  486. # cap 级写锁:max_concurrent>1 时同 case 内多个 cap 并发,写回 case_item + 触发 persist 需要串行
  487. write_lock = asyncio.Lock()
  488. grounded_count = 0
  489. async def _run_one(group_idx: int, cap_idx: int, cap: Dict, workflow_id: str):
  490. nonlocal grounded_count, total_cost
  491. async with sem:
  492. cap_id = cap.get("capability_id") or f"idx{cap_idx}"
  493. label = f"{case_index}/{workflow_id}/{cap_id}"
  494. if backend == "oneshot":
  495. parsed, cost, _ = await ground_single_capability_oneshot(
  496. cap, oneshot_template, llm_call, model,
  497. shi_json=shi_json, xing_json=xing_json, task_label=label,
  498. )
  499. else:
  500. parsed, cost, _turns = await ground_single_capability(
  501. cap, template, model=model, max_turns=max_turns, task_label=label,
  502. )
  503. # in-place 写回到 case_item 内的 capability dict(直接修改原对象)
  504. # write_lock 保证:多个并发 cap 完成时,更新 + 触发 persist 是串行的
  505. async with write_lock:
  506. total_cost += cost
  507. if parsed is not None:
  508. target_cap = case_item["workflow_groups"][group_idx]["capability"][cap_idx]
  509. if isinstance(target_cap, dict):
  510. target_cap["apply_to"] = parsed.get("matched", []) or []
  511. target_cap["suggest_apply_to"] = parsed.get("suggested_additions", []) or []
  512. grounded_count += 1
  513. # 通知 caller 持久化(即使 parsed=None 也通知 — failure 也是个"完成",
  514. # caller 可以选择是否在失败时也持久化部分进度)
  515. if on_cap_done is not None:
  516. try:
  517. await on_cap_done()
  518. except Exception as e:
  519. print(f" [cap={label}] persist callback failed: {type(e).__name__}: {e}", flush=True)
  520. return group_idx, cap_idx, parsed, cost
  521. tasks = [_run_one(gi, ci, cap, wid) for (gi, ci, cap, wid) in flat_pairs]
  522. await asyncio.gather(*tasks)
  523. print(
  524. f"[{case_index}] >>> case done: {grounded_count}/{len(flat_pairs)} caps grounded, "
  525. f"{skipped_non_production} skipped (phase=非制作) <<<",
  526. flush=True,
  527. )
  528. return case_item, total_cost
  529. async def apply_grounding(
  530. case_file: Path,
  531. llm_call: Any = None,
  532. model: str = DEFAULT_MODEL,
  533. max_concurrent: int = DEFAULT_MAX_CONCURRENT,
  534. backend: Optional[str] = None, # "agent" | "oneshot",None 时读环境变量
  535. ) -> Dict[str, Any]:
  536. """
  537. 顶层入口:遍历 case.json,对每个 case 的 capability 跑 grounding。
  538. backend 选择:
  539. - "agent" :用 Claude Agent SDK + OAuth/Max 订阅,按需 Read/Grep 分类库
  540. - "oneshot" :用一次性 LLM 调用(通过 llm_call,通常是 OpenRouter),把分类库全文塞 prompt
  541. - None :从环境变量 GROUNDING_BACKEND 读,默认 "agent"
  542. 保持与旧 apply_grounding 一致的签名,便于 run_pipeline.py 平滑切换。
  543. """
  544. # 确定 backend
  545. if backend is None:
  546. backend = os.getenv("GROUNDING_BACKEND", "agent").strip().lower()
  547. if backend not in ("agent", "oneshot"):
  548. raise ValueError(f"Invalid backend: {backend!r}. Choose 'agent' or 'oneshot'.")
  549. if backend == "oneshot" and llm_call is None:
  550. raise ValueError("backend='oneshot' requires llm_call (e.g. OpenRouter llm_call from run_pipeline.py)")
  551. if not SHI_JSON_PATH.exists():
  552. raise FileNotFoundError(f"实质分类库不存在: {SHI_JSON_PATH}")
  553. if not XING_JSON_PATH.exists():
  554. raise FileNotFoundError(f"形式分类库不存在: {XING_JSON_PATH}")
  555. with open(case_file, "r", encoding="utf-8") as f:
  556. case_data = json.load(f)
  557. cases = case_data.get("cases", [])
  558. template = load_prompt_template("apply_to_grounding_agent")
  559. # Oneshot 模式:预读两个分类库 JSON(每个 cap 共享,不重复读)
  560. oneshot_template: Optional[str] = None
  561. shi_json: Optional[str] = None
  562. xing_json: Optional[str] = None
  563. if backend == "oneshot":
  564. oneshot_template = load_prompt_template("apply_to_grounding_oneshot")
  565. shi_json = SHI_JSON_PATH.read_text(encoding="utf-8")
  566. xing_json = XING_JSON_PATH.read_text(encoding="utf-8")
  567. print(
  568. f"[apply-grounding] backend=oneshot model={model} "
  569. f"shi_json={len(shi_json):,} chars xing_json={len(xing_json):,} chars",
  570. flush=True,
  571. )
  572. else:
  573. print(f"[apply-grounding] backend=agent model={model}", flush=True)
  574. # 选出有 capability 的 case
  575. needs = [
  576. c for c in cases
  577. if isinstance(c.get("workflow_groups"), list) and any(
  578. isinstance(g, dict) and isinstance(g.get("capability"), list) and any(
  579. isinstance(cap, dict) and cap.get("capability_id")
  580. for cap in g.get("capability", [])
  581. )
  582. for g in c["workflow_groups"]
  583. )
  584. ]
  585. print(f"Grounding (agent) for {len(needs)}/{len(cases)} cases (max_concurrent={max_concurrent})", flush=True)
  586. if not needs:
  587. return {
  588. "total": len(cases),
  589. "grounded": 0,
  590. "total_cost": 0.0,
  591. "output_file": str(case_file),
  592. }
  593. # 顶层 case 之间限制并发(agent 单次就贵,case 间不要再并发)
  594. # 默认每次只跑 1 个 case,case 内部按 max_concurrent 并发 cap
  595. # 注意:case_sem=1 同时保证了"case 间串行" — 增量写回 case.json 不会有并发冲突。
  596. case_sem = asyncio.Semaphore(1)
  597. # 跑前先做一次快照(保留改动前的 case.json)
  598. try:
  599. from examples.process_pipeline.script.case_history import snapshot_case_file
  600. snap = snapshot_case_file(case_file, step="apply_grounding_agent")
  601. if snap:
  602. print(f" [snapshot] {snap.name}", flush=True)
  603. except Exception as e:
  604. print(f" [snapshot] skipped: {e}", flush=True)
  605. # 全局写锁:跨 case 共享一个 — 即使 case_sem=1 的设定将来放宽,I/O 也安全
  606. persist_lock = asyncio.Lock()
  607. async def _persist_case_data() -> None:
  608. """异步串行 dump case_data 到 case.json(cap 级增量写回 + finally 兜底都用这个)"""
  609. async with persist_lock:
  610. with open(case_file, "w", encoding="utf-8") as f:
  611. json.dump(case_data, f, ensure_ascii=False, indent=2)
  612. async def _process_case(case_item):
  613. async with case_sem:
  614. idx = case_item.get("index", 0)
  615. cid = (case_item.get("_raw") or {}).get("case_id", "unknown")
  616. print(f" -> [{idx}] [{cid}] agent grounding starts", flush=True)
  617. # cap 级增量持久化:ground_single_case 内部每个 cap 完成时调一次
  618. # case_item 已被 in-place 修改,case_data["cases"] 里的引用就是 case_item,
  619. # 所以直接 dump case_data 就能反映最新进度。
  620. async def _on_cap_done():
  621. try:
  622. await _persist_case_data()
  623. except Exception as e:
  624. print(f" [case {idx}] persist failed: {type(e).__name__}: {e}", flush=True)
  625. grounded, cost = await ground_single_case(
  626. case_item, template, model=model, max_concurrent=max_concurrent,
  627. on_cap_done=_on_cap_done,
  628. backend=backend, llm_call=llm_call,
  629. oneshot_template=oneshot_template,
  630. shi_json=shi_json, xing_json=xing_json,
  631. )
  632. print(f" <- [{idx}] [{cid}] agent grounding done (case.json saved incrementally at each cap)", flush=True)
  633. return case_item, grounded, cost
  634. total_cost = 0.0
  635. try:
  636. tasks = [_process_case(c) for c in needs]
  637. results = await asyncio.gather(*tasks)
  638. total_cost = sum(c for _, _, c in results)
  639. except (KeyboardInterrupt, asyncio.CancelledError) as e:
  640. # 兜底:用户 Ctrl+C 或上游 cancel 时,把已写入的 case_data 落盘
  641. # (cap 级增量写回已经在每个 cap 完成时写过,这里再保险一次)
  642. print(f"\n[grounding-agent] interrupted ({type(e).__name__}) — flushing partial results to {case_file}", flush=True)
  643. try:
  644. await _persist_case_data()
  645. except Exception as flush_err:
  646. print(f" [grounding-agent] final flush failed: {flush_err}", flush=True)
  647. raise
  648. # 最终 dump 一次(兜底,确保最终状态一致)
  649. await _persist_case_data()
  650. return {
  651. "total": len(cases),
  652. "grounded": len(needs),
  653. "total_cost": total_cost,
  654. "output_file": str(case_file),
  655. }
  656. if __name__ == "__main__":
  657. import sys
  658. print("Please use this module through run_pipeline.py")
  659. sys.exit(1)