""" Stage 2 (Agent 版):将 capability 对接到内容分类库 替换原本 step1 (extract-draft-query) + step2 (search API) + step3 (grounding LLM) 的三段式流程为单次 agent 调用 —— 用 Claude Agent SDK 启动一个带 Read/Grep/Glob 工具的 agent,让它自己读两份分类库 JSON 并在最终输出中包含 JSON 代码块; Python 端从 agent 的文本输出里提取 JSON 并写回 case.json。 输入(每个 capability): {body, inputs, outputs, ...}(不再依赖 apply_to_draft,agent 自己看 body) Agent 输出的 JSON: { "matched": [{ category_path, category_type: "实质"|"形式", action, ability_type, matched_elements, structured_content }, ...], "suggested_additions": [{ category_type: "实质", parent_path, suggested_level, ...}, ...] } 写回 case.json(直接平铺,跟现有 grounded cap 结构对齐): capability.apply_to = matched 数组(每项自带 category_type 字段区分实质/形式) capability.suggest_apply_to = suggested_additions 数组(每项 category_type 必为"实质") """ import asyncio import json import os import re from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from dotenv import load_dotenv # Project root + .env PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent load_dotenv(PROJECT_ROOT / ".env") # 分类库 JSON 路径(由 inputs/ 移到 script/resource/ 下作为引用) RESOURCE_DIR = Path(__file__).resolve().parent / "resource" SHI_JSON_PATH = RESOURCE_DIR / "分类库导出_实质_20260512_132218.json" XING_JSON_PATH = RESOURCE_DIR / "分类库导出_形式_20260512_170623.json" # 并发上限(agent 调用比一次性 LLM 慢且贵,控制并发避免 OAuth 配额爆炸) DEFAULT_MAX_CONCURRENT = int(os.getenv("GROUNDING_AGENT_MAX_CONCURRENT", "2")) # 单 capability agent 最多对话轮数 DEFAULT_MAX_TURNS = int(os.getenv("GROUNDING_AGENT_MAX_TURNS", "100")) # 单 cap 失败时的重试次数(每次重试启动一个新 ClaudeSDKClient,拿到 fresh context) DEFAULT_MAX_RETRIES = int(os.getenv("GROUNDING_AGENT_MAX_RETRIES", "2")) # 默认模型(走 OAuth/Max 订阅) DEFAULT_MODEL = os.getenv("GROUNDING_AGENT_MODEL", "claude-sonnet-4-6") def load_prompt_template(name: str = "apply_to_grounding_agent") -> str: """加载 prompts/.prompt 模板(agent 或 oneshot 通用)""" prompt_path = Path(__file__).resolve().parent.parent / "prompts" / f"{name}.prompt" with open(prompt_path, "r", encoding="utf-8") as f: content = f.read() if content.startswith("---"): parts = content.split("---", 2) if len(parts) >= 3: content = parts[2] return content.strip() def render_oneshot_prompt(template: str, capability: Dict[str, Any], shi_json: str, xing_json: str) -> str: """渲染 oneshot prompt:把 body + 两个分类库全文注入""" body = capability.get("body") or "" return ( template .replace("{capability的body信息}", body) .replace("{shi_json}", shi_json) .replace("{xing_json}", xing_json) ) async def ground_single_capability_oneshot( capability: Dict[str, Any], template: str, llm_call: Any, model: str, shi_json: str, xing_json: str, task_label: str = "", max_retries: int = DEFAULT_MAX_RETRIES, ) -> Tuple[Optional[Dict[str, Any]], float, int]: """ Oneshot 模式:一次性把"分类库 + capability body"塞进 prompt,调 LLM 拿 JSON。 跟 agent 模式共享 retry / recover / schema 校验 / 日志格式。 Returns: (parsed_output, cost, attempts_used) """ cap_id = capability.get("capability_id") or "?" label = task_label or cap_id full_prompt = render_oneshot_prompt(template, capability, shi_json, xing_json) last_reason: Optional[str] = None last_in_tok = 0 last_out_tok = 0 last_cache_read = 0 print( f" [grounding-oneshot] cap={label} → starting (input ~{len(full_prompt):,} chars, model={model})", flush=True, ) for attempt in range(max_retries + 1): if attempt > 0: print( f" [grounding-oneshot] cap={label} → RETRY {attempt}/{max_retries} " f"(prev: {last_reason})", flush=True, ) try: response = await llm_call( messages=[{"role": "user", "content": full_prompt}], model=model, temperature=0.1, max_tokens=16384, ) except Exception as e: last_reason = f"LLM call failed: {type(e).__name__}: {e}" print(f" [grounding-oneshot] cap={label} → {last_reason}", flush=True) # 服务端 body for attr in ("response", "body"): obj = getattr(e, attr, None) if obj is not None: try: text = obj.text if hasattr(obj, "text") else str(obj) print(f" server body: {text[:500]}", flush=True) except Exception: pass continue # 提取 content content = response.get("content", "") if isinstance(content, list): content = "".join( (b.get("text") or "") if isinstance(b, dict) else str(b) for b in content ) elif not isinstance(content, str): content = str(content) # Usage usage = response.get("usage") or {} if hasattr(usage, "__dict__") and not isinstance(usage, dict): usage = {k: getattr(usage, k) for k in dir(usage) if not k.startswith("_") and not callable(getattr(usage, k))} last_in_tok = usage.get("input_tokens") or usage.get("prompt_tokens") or 0 last_out_tok = usage.get("output_tokens") or usage.get("completion_tokens") or 0 last_cache_read = usage.get("cache_read_input_tokens") or usage.get("cached_tokens") or 0 # 提取 JSON parsed = extract_json_from_response(content) if parsed is None: last_reason = "JSON parse FAILED (no valid JSON block found)" print(f" [grounding-oneshot] cap={label} → {last_reason}", flush=True) print(f" raw output total: {len(content)} chars", flush=True) print(f" raw output head: {content[:1000]}", flush=True) if len(content) > 1000: print(f" raw output tail: {content[-500:]}", flush=True) continue # Schema 校验 schema_err = _validate_against_schema(parsed) schema_warn = f" ⚠ schema violation: {schema_err}" if schema_err else "" matched_all = parsed.get("matched", []) or [] shi_n = sum(1 for m in matched_all if isinstance(m, dict) and m.get("category_type") == "实质") xing_n = sum(1 for m in matched_all if isinstance(m, dict) and m.get("category_type") == "形式") suggest_n = len(parsed.get("suggested_additions", []) or []) cache_note = f" ({last_cache_read:,} cached)" if last_cache_read else "" print( f" [grounding-oneshot] cap={label} → done: in={last_in_tok:,}{cache_note} out={last_out_tok:,} " f"matched={len(matched_all)} (substance={shi_n} form={xing_n}) suggested={suggest_n}{schema_warn}", flush=True, ) return parsed, 0.0, attempt + 1 # 所有 retry 用完 print( f" [grounding-oneshot] cap={label} → ✗ FAILED after {max_retries + 1} attempts: {last_reason}", flush=True, ) return None, 0.0, max_retries + 1 def render_prompt(template: str, capability: Dict[str, Any]) -> str: """渲染 prompt:把 capability.body 填进「做法描述」占位符 + 两个分类库的绝对路径。 新版 prompt 只用 body 字段(不再喂整个 capability JSON)。inputs/outputs 不传给 agent, 保持 prompt 简洁,跟用户提供的 prompt_input.txt 模板一致。 """ body = capability.get("body") or "" return ( template # 新占位符(用户最新 prompt 模板用的) .replace("{capability的body信息}", body) # 旧占位符(兼容,万一未来需要换回完整 JSON) .replace("{capability_json}", body) .replace("{shi_path}", str(SHI_JSON_PATH.resolve())) .replace("{xing_path}", str(XING_JSON_PATH.resolve())) ) def _find_balanced_json_object(text: str, start_idx: int = 0) -> Optional[str]: """ 从 text[start_idx:] 找一个花括号配平的 JSON 对象,返回原始片段(含两端 {})。 用字符串状态机正确处理 JSON 字符串内的 { } 不算 brace。 找不到返回 None。 """ start = text.find("{", start_idx) if start < 0: return None depth = 0 in_string = False escape_next = False for i in range(start, len(text)): c = text[i] if escape_next: escape_next = False continue if c == "\\": escape_next = True continue if c == '"': in_string = not in_string continue if in_string: continue if c == "{": depth += 1 elif c == "}": depth -= 1 if depth == 0: return text[start:i + 1] return None # 没配平 def extract_json_from_response(text: str) -> Optional[Dict[str, Any]]: """ 从 agent 的最终输出文本中提取 JSON。多策略,尽量宽容: 1. ```json ... ``` 代码块(findall 拿所有,倒序试,取第一个 valid 的 —— agent 倾向"先预览后正式") 2. 花括号配平 —— 不依赖 ``` 包裹,从文本任意位置找完整 {…} 对象(多个候选倒序试,取后面的优先) 3. 贪婪 regex 兜底 —— 取整段文本里 { 到最后一个 } 之间的内容 """ # 策略 1:``` 代码块(允许 ```json 或纯 ```) blocks = re.findall(r"```(?:json|JSON)?\s*\n?(\{[\s\S]*?\})\s*```", text) for raw in reversed(blocks): try: return json.loads(raw) except json.JSONDecodeError: continue # 策略 2:花括号配平 — 扫描所有可能的起点,倒序试(最终结果通常在文本末尾) candidates: List[str] = [] cursor = 0 while cursor < len(text): nxt = text.find("{", cursor) if nxt < 0: break obj = _find_balanced_json_object(text, nxt) if obj is None: break candidates.append(obj) cursor = nxt + len(obj) for raw in reversed(candidates): try: return json.loads(raw) except json.JSONDecodeError: continue # 策略 3:贪婪兜底(最长 { ... } 块,可能解析失败但试一下) greedy = re.search(r"\{[\s\S]*\}", text) if greedy: try: return json.loads(greedy.group(0)) except json.JSONDecodeError: pass return None async def _ground_single_capability_attempt( capability: Dict[str, Any], template: str, model: str, max_turns: int, label: str, attempt_no: int, # 0=首次,>=1=retry ) -> Tuple[Optional[Dict[str, Any]], int, Optional[str]]: """ 单次 agent 调用尝试。 Returns: (parsed_output, turns_used, failure_reason) - parsed != None & failure_reason == None: 成功(或 recover) - parsed == None & failure_reason != None: 失败,可 retry """ from claude_agent_sdk import ( AssistantMessage, ClaudeAgentOptions, ClaudeSDKClient, ClaudeSDKError, ResultMessage, TextBlock, ) # 抹掉父进程 API key,强制 SDK 子进程走 OAuth override_env = { "ANTHROPIC_API_KEY": "", "ANTHROPIC_BASE_URL": "", "ANTHROPIC_AUTH_TOKEN": "", } full_prompt = render_prompt(template, capability) stderr_lines: List[str] = [] def _capture_stderr(line: str) -> None: if line: stderr_lines.append(line) options = ClaudeAgentOptions( model=model, allowed_tools=["Read", "Grep", "Glob"], max_turns=max_turns, cwd=str(RESOURCE_DIR), env=override_env, stderr=_capture_stderr, setting_sources=[], ) text_parts: List[str] = [] is_error = False turns_used = 0 tool_use_count = 0 attempt_label = f" attempt#{attempt_no+1}" if attempt_no > 0 else "" print(f" [grounding-agent] cap={label}{attempt_label} → starting (max_turns={max_turns})", flush=True) try: async with ClaudeSDKClient(options=options) as client: await client.query(full_prompt) async for msg in client.receive_response(): if isinstance(msg, AssistantMessage): for block in msg.content: if hasattr(block, "thinking"): continue elif isinstance(block, TextBlock): text_parts.append(block.text) preview = block.text.replace("\n", " ").strip()[:160] if preview: print(f" [cap={label}] text: {preview}", flush=True) elif hasattr(block, "name") and hasattr(block, "input"): tool_use_count += 1 tool_input_str = json.dumps(block.input, ensure_ascii=False)[:160] print( f" [cap={label}] tool#{tool_use_count} {block.name}({tool_input_str})", flush=True, ) elif isinstance(msg, ResultMessage): is_error = msg.is_error turns_used = msg.num_turns except ClaudeSDKError as e: reason = f"SDK ERROR {type(e).__name__}: {e}" print(f" [grounding-agent] cap={label} → {reason}", flush=True) if stderr_lines: print(f" CLI stderr tail: {stderr_lines[-3:]}", flush=True) return None, turns_used, reason # Recover 策略:is_error=True 也尝试解析 JSON content = "".join(text_parts).strip() parsed = extract_json_from_response(content) if parsed is None: reason = ( "is_error=True (likely max_turns) AND no parseable JSON in output" if is_error else "JSON parse FAILED (model output didn't contain valid JSON block)" ) print(f" [grounding-agent] cap={label} → {reason}", flush=True) print(f" raw output total: {len(content)} chars", flush=True) print(f" raw output head: {content[:1000]}", flush=True) if len(content) > 1000: print(f" raw output tail: {content[-500:]}", flush=True) return None, turns_used, reason # Schema 校验 schema_err = _validate_against_schema(parsed) schema_warn = f" ⚠ schema violation: {schema_err}" if schema_err else "" matched_all = parsed.get("matched", []) or [] shi_n = sum(1 for m in matched_all if isinstance(m, dict) and m.get("category_type") == "实质") xing_n = sum(1 for m in matched_all if isinstance(m, dict) and m.get("category_type") == "形式") suggest_n = len(parsed.get("suggested_additions", []) or []) if is_error: print( f" [grounding-agent] cap={label} → ⚠ RECOVERED: turns={turns_used} tools={tool_use_count} " f"matched={len(matched_all)} (substance={shi_n} form={xing_n}) suggested={suggest_n} " f"(is_error=True, but JSON captured before truncation){schema_warn}", flush=True, ) else: print( f" [grounding-agent] cap={label} → done: turns={turns_used} tools={tool_use_count} " f"matched={len(matched_all)} (substance={shi_n} form={xing_n}) suggested={suggest_n}{schema_warn}", flush=True, ) return parsed, turns_used, None # 成功 async def ground_single_capability( capability: Dict[str, Any], template: str, model: str = DEFAULT_MODEL, max_turns: int = DEFAULT_MAX_TURNS, task_label: str = "", max_retries: int = DEFAULT_MAX_RETRIES, ) -> Tuple[Optional[Dict[str, Any]], float, int]: """ 对单个 capability 跑 agent grounding,失败时自动 retry。 Returns: (parsed_output, cost, turns_used) parsed_output 为 None 表示解析失败或 agent 出错且重试用完 cost 当前总为 0.0(OAuth 模式不计费),保留接口便于后续切换 """ try: import claude_agent_sdk # noqa: F401 — 早失败,避免 retry 循环里反复 ImportError except ImportError as e: raise RuntimeError(f"claude_agent_sdk not installed: {e}") from e cap_id = capability.get("capability_id") or "?" label = task_label or cap_id last_reason: Optional[str] = None last_turns_used = 0 for attempt in range(max_retries + 1): if attempt > 0: print( f" [grounding-agent] cap={label} → RETRY {attempt}/{max_retries} " f"(prev failure: {last_reason})", flush=True, ) parsed, turns_used, failure_reason = await _ground_single_capability_attempt( capability, template, model, max_turns, label, attempt ) last_turns_used = turns_used if failure_reason is None: # 成功(包括 recover 路径) if attempt > 0: print( f" [grounding-agent] cap={label} → ✓ recovered after retry#{attempt}", flush=True, ) return parsed, 0.0, turns_used # 失败:记录原因,进入下一次重试 last_reason = failure_reason # 所有 retry 用完 print( f" [grounding-agent] cap={label} → ✗ FAILED after {max_retries + 1} attempts: {last_reason}", flush=True, ) return None, 0.0, last_turns_used def _validate_against_schema(parsed: Dict[str, Any]) -> Optional[str]: """ 用 apply_to_grounding_agent.schema.json 校验 agent 输出。 返回 None 表示通过,否则返回错误字符串。 任何加载/调用异常都吞掉(返回 None) — schema 校验是 nice-to-have,不阻塞主流程。 """ try: from examples.process_pipeline.script.schema_manager import validate_with_schema return validate_with_schema(parsed, "apply_to_grounding_agent") except Exception as e: # schema 文件不存在 / schema_manager 异常 — 静默跳过,不影响主流程 return None async def ground_single_case( case_item: Dict[str, Any], template: str, model: str = DEFAULT_MODEL, max_turns: int = DEFAULT_MAX_TURNS, max_concurrent: int = DEFAULT_MAX_CONCURRENT, on_cap_done: Optional[Any] = None, # async callable, called after each cap completes (for incremental persist) backend: str = "agent", # "agent" | "oneshot" llm_call: Any = None, # oneshot 必需 oneshot_template: Optional[str] = None, # oneshot 必需 shi_json: Optional[str] = None, # oneshot 必需(预读,避免每个 cap 重读) xing_json: Optional[str] = None, # oneshot 必需 ) -> Tuple[Dict[str, Any], float]: """ 对单个 case 的所有 capability 跑 agent grounding。 **直接 in-place 修改 case_item**(不再创建副本)— 每个 cap 完成时,对应的 capability dict 立刻被更新 apply_to / suggest_apply_to 字段。 on_cap_done: 可选 async 回调,每个 cap 完成后调用(无参)。caller 可以在回调里 持久化 case_item 所属的 case.json,实现 cap 级增量写回。 """ total_cost = 0.0 case_index = case_item.get("index", "?") title = (case_item.get("title") or "")[:20] or "untitled" workflow_groups = case_item.get("workflow_groups") if not isinstance(workflow_groups, list) or not workflow_groups: return case_item, total_cost # 跨 group 收集所有 capability — 直接持有原 case_item 内部的 dict 引用。 # 跳过被 workflow.steps[*].phase == "非制作" 标注的 cap(meta-task,没有真实产出物)。 flat_pairs: List[tuple] = [] skipped_non_production = 0 for group_idx, group in enumerate(workflow_groups): if not isinstance(group, dict): continue workflow_id = group.get("workflow_id") or f"g{group_idx + 1}" # 构造 step_id → phase 映射 wf = group.get("workflow") or {} step_phase = { s.get("step_id"): s.get("phase") for s in (wf.get("steps") or []) if isinstance(s, dict) } capability_items = group.get("capability") if not isinstance(capability_items, list) or not capability_items: continue for cap_idx, cap in enumerate(capability_items): if not (isinstance(cap, dict) and cap.get("capability_id")): continue cap_id = cap.get("capability_id") # 通过 workflow_step_ref 反查 phase step_ref = cap.get("workflow_step_ref") or {} step_id = step_ref.get("step_id") phase = step_phase.get(step_id) if phase == "非制作": print( f" [grounding-agent] skip cap={case_index}/{workflow_id}/{cap_id} " f"(step={step_id} phase=非制作, no production output to ground)", flush=True, ) skipped_non_production += 1 continue flat_pairs.append((group_idx, cap_idx, cap, workflow_id)) if not flat_pairs: return case_item, total_cost print(f"\n[{case_index}] === case '{title}' grounding {len(flat_pairs)} capabilities ===", flush=True) sem = asyncio.Semaphore(max_concurrent) # cap 级写锁:max_concurrent>1 时同 case 内多个 cap 并发,写回 case_item + 触发 persist 需要串行 write_lock = asyncio.Lock() grounded_count = 0 async def _run_one(group_idx: int, cap_idx: int, cap: Dict, workflow_id: str): nonlocal grounded_count, total_cost async with sem: cap_id = cap.get("capability_id") or f"idx{cap_idx}" label = f"{case_index}/{workflow_id}/{cap_id}" if backend == "oneshot": parsed, cost, _ = await ground_single_capability_oneshot( cap, oneshot_template, llm_call, model, shi_json=shi_json, xing_json=xing_json, task_label=label, ) else: parsed, cost, _turns = await ground_single_capability( cap, template, model=model, max_turns=max_turns, task_label=label, ) # in-place 写回到 case_item 内的 capability dict(直接修改原对象) # write_lock 保证:多个并发 cap 完成时,更新 + 触发 persist 是串行的 async with write_lock: total_cost += cost if parsed is not None: target_cap = case_item["workflow_groups"][group_idx]["capability"][cap_idx] if isinstance(target_cap, dict): target_cap["apply_to"] = parsed.get("matched", []) or [] target_cap["suggest_apply_to"] = parsed.get("suggested_additions", []) or [] grounded_count += 1 # 通知 caller 持久化(即使 parsed=None 也通知 — failure 也是个"完成", # caller 可以选择是否在失败时也持久化部分进度) if on_cap_done is not None: try: await on_cap_done() except Exception as e: print(f" [cap={label}] persist callback failed: {type(e).__name__}: {e}", flush=True) return group_idx, cap_idx, parsed, cost tasks = [_run_one(gi, ci, cap, wid) for (gi, ci, cap, wid) in flat_pairs] await asyncio.gather(*tasks) print( f"[{case_index}] >>> case done: {grounded_count}/{len(flat_pairs)} caps grounded, " f"{skipped_non_production} skipped (phase=非制作) <<<", flush=True, ) return case_item, total_cost async def apply_grounding( case_file: Path, llm_call: Any = None, model: str = DEFAULT_MODEL, max_concurrent: int = DEFAULT_MAX_CONCURRENT, backend: Optional[str] = None, # "agent" | "oneshot",None 时读环境变量 ) -> Dict[str, Any]: """ 顶层入口:遍历 case.json,对每个 case 的 capability 跑 grounding。 backend 选择: - "agent" :用 Claude Agent SDK + OAuth/Max 订阅,按需 Read/Grep 分类库 - "oneshot" :用一次性 LLM 调用(通过 llm_call,通常是 OpenRouter),把分类库全文塞 prompt - None :从环境变量 GROUNDING_BACKEND 读,默认 "agent" 保持与旧 apply_grounding 一致的签名,便于 run_pipeline.py 平滑切换。 """ # 确定 backend if backend is None: backend = os.getenv("GROUNDING_BACKEND", "agent").strip().lower() if backend not in ("agent", "oneshot"): raise ValueError(f"Invalid backend: {backend!r}. Choose 'agent' or 'oneshot'.") if backend == "oneshot" and llm_call is None: raise ValueError("backend='oneshot' requires llm_call (e.g. OpenRouter llm_call from run_pipeline.py)") if not SHI_JSON_PATH.exists(): raise FileNotFoundError(f"实质分类库不存在: {SHI_JSON_PATH}") if not XING_JSON_PATH.exists(): raise FileNotFoundError(f"形式分类库不存在: {XING_JSON_PATH}") with open(case_file, "r", encoding="utf-8") as f: case_data = json.load(f) cases = case_data.get("cases", []) template = load_prompt_template("apply_to_grounding_agent") # Oneshot 模式:预读两个分类库 JSON(每个 cap 共享,不重复读) oneshot_template: Optional[str] = None shi_json: Optional[str] = None xing_json: Optional[str] = None if backend == "oneshot": oneshot_template = load_prompt_template("apply_to_grounding_oneshot") shi_json = SHI_JSON_PATH.read_text(encoding="utf-8") xing_json = XING_JSON_PATH.read_text(encoding="utf-8") print( f"[apply-grounding] backend=oneshot model={model} " f"shi_json={len(shi_json):,} chars xing_json={len(xing_json):,} chars", flush=True, ) else: print(f"[apply-grounding] backend=agent model={model}", flush=True) # 选出有 capability 的 case needs = [ c for c in cases if isinstance(c.get("workflow_groups"), list) and any( isinstance(g, dict) and isinstance(g.get("capability"), list) and any( isinstance(cap, dict) and cap.get("capability_id") for cap in g.get("capability", []) ) for g in c["workflow_groups"] ) ] print(f"Grounding (agent) for {len(needs)}/{len(cases)} cases (max_concurrent={max_concurrent})", flush=True) if not needs: return { "total": len(cases), "grounded": 0, "total_cost": 0.0, "output_file": str(case_file), } # 顶层 case 之间限制并发(agent 单次就贵,case 间不要再并发) # 默认每次只跑 1 个 case,case 内部按 max_concurrent 并发 cap # 注意:case_sem=1 同时保证了"case 间串行" — 增量写回 case.json 不会有并发冲突。 case_sem = asyncio.Semaphore(1) # 跑前先做一次快照(保留改动前的 case.json) try: from examples.process_pipeline.script.case_history import snapshot_case_file snap = snapshot_case_file(case_file, step="apply_grounding_agent") if snap: print(f" [snapshot] {snap.name}", flush=True) except Exception as e: print(f" [snapshot] skipped: {e}", flush=True) # 全局写锁:跨 case 共享一个 — 即使 case_sem=1 的设定将来放宽,I/O 也安全 persist_lock = asyncio.Lock() async def _persist_case_data() -> None: """异步串行 dump case_data 到 case.json(cap 级增量写回 + finally 兜底都用这个)""" async with persist_lock: with open(case_file, "w", encoding="utf-8") as f: json.dump(case_data, f, ensure_ascii=False, indent=2) async def _process_case(case_item): async with case_sem: idx = case_item.get("index", 0) cid = (case_item.get("_raw") or {}).get("case_id", "unknown") print(f" -> [{idx}] [{cid}] agent grounding starts", flush=True) # cap 级增量持久化:ground_single_case 内部每个 cap 完成时调一次 # case_item 已被 in-place 修改,case_data["cases"] 里的引用就是 case_item, # 所以直接 dump case_data 就能反映最新进度。 async def _on_cap_done(): try: await _persist_case_data() except Exception as e: print(f" [case {idx}] persist failed: {type(e).__name__}: {e}", flush=True) grounded, cost = await ground_single_case( case_item, template, model=model, max_concurrent=max_concurrent, on_cap_done=_on_cap_done, backend=backend, llm_call=llm_call, oneshot_template=oneshot_template, shi_json=shi_json, xing_json=xing_json, ) print(f" <- [{idx}] [{cid}] agent grounding done (case.json saved incrementally at each cap)", flush=True) return case_item, grounded, cost total_cost = 0.0 try: tasks = [_process_case(c) for c in needs] results = await asyncio.gather(*tasks) total_cost = sum(c for _, _, c in results) except (KeyboardInterrupt, asyncio.CancelledError) as e: # 兜底:用户 Ctrl+C 或上游 cancel 时,把已写入的 case_data 落盘 # (cap 级增量写回已经在每个 cap 完成时写过,这里再保险一次) print(f"\n[grounding-agent] interrupted ({type(e).__name__}) — flushing partial results to {case_file}", flush=True) try: await _persist_case_data() except Exception as flush_err: print(f" [grounding-agent] final flush failed: {flush_err}", flush=True) raise # 最终 dump 一次(兜底,确保最终状态一致) await _persist_case_data() return { "total": len(cases), "grounded": len(needs), "total_cost": total_cost, "output_file": str(case_file), } if __name__ == "__main__": import sys print("Please use this module through run_pipeline.py") sys.exit(1)