| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756 |
- """
- Stage 2 (Agent 版):将 capability 对接到内容分类库
- 替换原本 step1 (extract-draft-query) + step2 (search API) + step3 (grounding LLM)
- 的三段式流程为单次 agent 调用 —— 用 Claude Agent SDK 启动一个带 Read/Grep/Glob
- 工具的 agent,让它自己读两份分类库 JSON 并在最终输出中包含 JSON 代码块;
- Python 端从 agent 的文本输出里提取 JSON 并写回 case.json。
- 输入(每个 capability):
- {body, inputs, outputs, ...}(不再依赖 apply_to_draft,agent 自己看 body)
- Agent 输出的 JSON:
- {
- "matched": [{ category_path, category_type: "实质"|"形式", action, ability_type,
- matched_elements, structured_content }, ...],
- "suggested_additions": [{ category_type: "实质", parent_path, suggested_level, ...}, ...]
- }
- 写回 case.json(直接平铺,跟现有 grounded cap 结构对齐):
- capability.apply_to = matched 数组(每项自带 category_type 字段区分实质/形式)
- capability.suggest_apply_to = suggested_additions 数组(每项 category_type 必为"实质")
- """
- import asyncio
- import json
- import os
- import re
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Tuple
- from dotenv import load_dotenv
- # Project root + .env
- PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent
- load_dotenv(PROJECT_ROOT / ".env")
- # 分类库 JSON 路径(由 inputs/ 移到 script/resource/ 下作为引用)
- RESOURCE_DIR = Path(__file__).resolve().parent / "resource"
- SHI_JSON_PATH = RESOURCE_DIR / "分类库导出_实质_20260512_132218.json"
- XING_JSON_PATH = RESOURCE_DIR / "分类库导出_形式_20260512_170623.json"
- # 并发上限(agent 调用比一次性 LLM 慢且贵,控制并发避免 OAuth 配额爆炸)
- DEFAULT_MAX_CONCURRENT = int(os.getenv("GROUNDING_AGENT_MAX_CONCURRENT", "2"))
- # 单 capability agent 最多对话轮数
- DEFAULT_MAX_TURNS = int(os.getenv("GROUNDING_AGENT_MAX_TURNS", "100"))
- # 单 cap 失败时的重试次数(每次重试启动一个新 ClaudeSDKClient,拿到 fresh context)
- DEFAULT_MAX_RETRIES = int(os.getenv("GROUNDING_AGENT_MAX_RETRIES", "2"))
- # 默认模型(走 OAuth/Max 订阅)
- DEFAULT_MODEL = os.getenv("GROUNDING_AGENT_MODEL", "claude-sonnet-4-6")
- def load_prompt_template(name: str = "apply_to_grounding_agent") -> str:
- """加载 prompts/<name>.prompt 模板(agent 或 oneshot 通用)"""
- prompt_path = Path(__file__).resolve().parent.parent / "prompts" / f"{name}.prompt"
- with open(prompt_path, "r", encoding="utf-8") as f:
- content = f.read()
- if content.startswith("---"):
- parts = content.split("---", 2)
- if len(parts) >= 3:
- content = parts[2]
- return content.strip()
- def render_oneshot_prompt(template: str, capability: Dict[str, Any], shi_json: str, xing_json: str) -> str:
- """渲染 oneshot prompt:把 body + 两个分类库全文注入"""
- body = capability.get("body") or ""
- return (
- template
- .replace("{capability的body信息}", body)
- .replace("{shi_json}", shi_json)
- .replace("{xing_json}", xing_json)
- )
- async def ground_single_capability_oneshot(
- capability: Dict[str, Any],
- template: str,
- llm_call: Any,
- model: str,
- shi_json: str,
- xing_json: str,
- task_label: str = "",
- max_retries: int = DEFAULT_MAX_RETRIES,
- ) -> Tuple[Optional[Dict[str, Any]], float, int]:
- """
- Oneshot 模式:一次性把"分类库 + capability body"塞进 prompt,调 LLM 拿 JSON。
- 跟 agent 模式共享 retry / recover / schema 校验 / 日志格式。
- Returns:
- (parsed_output, cost, attempts_used)
- """
- cap_id = capability.get("capability_id") or "?"
- label = task_label or cap_id
- full_prompt = render_oneshot_prompt(template, capability, shi_json, xing_json)
- last_reason: Optional[str] = None
- last_in_tok = 0
- last_out_tok = 0
- last_cache_read = 0
- print(
- f" [grounding-oneshot] cap={label} → starting (input ~{len(full_prompt):,} chars, model={model})",
- flush=True,
- )
- for attempt in range(max_retries + 1):
- if attempt > 0:
- print(
- f" [grounding-oneshot] cap={label} → RETRY {attempt}/{max_retries} "
- f"(prev: {last_reason})",
- flush=True,
- )
- try:
- response = await llm_call(
- messages=[{"role": "user", "content": full_prompt}],
- model=model,
- temperature=0.1,
- max_tokens=16384,
- )
- except Exception as e:
- last_reason = f"LLM call failed: {type(e).__name__}: {e}"
- print(f" [grounding-oneshot] cap={label} → {last_reason}", flush=True)
- # 服务端 body
- for attr in ("response", "body"):
- obj = getattr(e, attr, None)
- if obj is not None:
- try:
- text = obj.text if hasattr(obj, "text") else str(obj)
- print(f" server body: {text[:500]}", flush=True)
- except Exception:
- pass
- continue
- # 提取 content
- content = response.get("content", "")
- if isinstance(content, list):
- content = "".join(
- (b.get("text") or "") if isinstance(b, dict) else str(b) for b in content
- )
- elif not isinstance(content, str):
- content = str(content)
- # Usage
- usage = response.get("usage") or {}
- if hasattr(usage, "__dict__") and not isinstance(usage, dict):
- usage = {k: getattr(usage, k) for k in dir(usage)
- if not k.startswith("_") and not callable(getattr(usage, k))}
- last_in_tok = usage.get("input_tokens") or usage.get("prompt_tokens") or 0
- last_out_tok = usage.get("output_tokens") or usage.get("completion_tokens") or 0
- last_cache_read = usage.get("cache_read_input_tokens") or usage.get("cached_tokens") or 0
- # 提取 JSON
- parsed = extract_json_from_response(content)
- if parsed is None:
- last_reason = "JSON parse FAILED (no valid JSON block found)"
- print(f" [grounding-oneshot] cap={label} → {last_reason}", flush=True)
- print(f" raw output total: {len(content)} chars", flush=True)
- print(f" raw output head: {content[:1000]}", flush=True)
- if len(content) > 1000:
- print(f" raw output tail: {content[-500:]}", flush=True)
- continue
- # Schema 校验
- schema_err = _validate_against_schema(parsed)
- schema_warn = f" ⚠ schema violation: {schema_err}" if schema_err else ""
- matched_all = parsed.get("matched", []) or []
- shi_n = sum(1 for m in matched_all if isinstance(m, dict) and m.get("category_type") == "实质")
- xing_n = sum(1 for m in matched_all if isinstance(m, dict) and m.get("category_type") == "形式")
- suggest_n = len(parsed.get("suggested_additions", []) or [])
- cache_note = f" ({last_cache_read:,} cached)" if last_cache_read else ""
- print(
- f" [grounding-oneshot] cap={label} → done: in={last_in_tok:,}{cache_note} out={last_out_tok:,} "
- f"matched={len(matched_all)} (substance={shi_n} form={xing_n}) suggested={suggest_n}{schema_warn}",
- flush=True,
- )
- return parsed, 0.0, attempt + 1
- # 所有 retry 用完
- print(
- f" [grounding-oneshot] cap={label} → ✗ FAILED after {max_retries + 1} attempts: {last_reason}",
- flush=True,
- )
- return None, 0.0, max_retries + 1
- def render_prompt(template: str, capability: Dict[str, Any]) -> str:
- """渲染 prompt:把 capability.body 填进「做法描述」占位符 + 两个分类库的绝对路径。
- 新版 prompt 只用 body 字段(不再喂整个 capability JSON)。inputs/outputs 不传给 agent,
- 保持 prompt 简洁,跟用户提供的 prompt_input.txt 模板一致。
- """
- body = capability.get("body") or ""
- return (
- template
- # 新占位符(用户最新 prompt 模板用的)
- .replace("{capability的body信息}", body)
- # 旧占位符(兼容,万一未来需要换回完整 JSON)
- .replace("{capability_json}", body)
- .replace("{shi_path}", str(SHI_JSON_PATH.resolve()))
- .replace("{xing_path}", str(XING_JSON_PATH.resolve()))
- )
- def _find_balanced_json_object(text: str, start_idx: int = 0) -> Optional[str]:
- """
- 从 text[start_idx:] 找一个花括号配平的 JSON 对象,返回原始片段(含两端 {})。
- 用字符串状态机正确处理 JSON 字符串内的 { } 不算 brace。
- 找不到返回 None。
- """
- start = text.find("{", start_idx)
- if start < 0:
- return None
- depth = 0
- in_string = False
- escape_next = False
- for i in range(start, len(text)):
- c = text[i]
- if escape_next:
- escape_next = False
- continue
- if c == "\\":
- escape_next = True
- continue
- if c == '"':
- in_string = not in_string
- continue
- if in_string:
- continue
- if c == "{":
- depth += 1
- elif c == "}":
- depth -= 1
- if depth == 0:
- return text[start:i + 1]
- return None # 没配平
- def extract_json_from_response(text: str) -> Optional[Dict[str, Any]]:
- """
- 从 agent 的最终输出文本中提取 JSON。多策略,尽量宽容:
- 1. ```json ... ``` 代码块(findall 拿所有,倒序试,取第一个 valid 的 —— agent 倾向"先预览后正式")
- 2. 花括号配平 —— 不依赖 ``` 包裹,从文本任意位置找完整 {…} 对象(多个候选倒序试,取后面的优先)
- 3. 贪婪 regex 兜底 —— 取整段文本里 { 到最后一个 } 之间的内容
- """
- # 策略 1:``` 代码块(允许 ```json 或纯 ```)
- blocks = re.findall(r"```(?:json|JSON)?\s*\n?(\{[\s\S]*?\})\s*```", text)
- for raw in reversed(blocks):
- try:
- return json.loads(raw)
- except json.JSONDecodeError:
- continue
- # 策略 2:花括号配平 — 扫描所有可能的起点,倒序试(最终结果通常在文本末尾)
- candidates: List[str] = []
- cursor = 0
- while cursor < len(text):
- nxt = text.find("{", cursor)
- if nxt < 0:
- break
- obj = _find_balanced_json_object(text, nxt)
- if obj is None:
- break
- candidates.append(obj)
- cursor = nxt + len(obj)
- for raw in reversed(candidates):
- try:
- return json.loads(raw)
- except json.JSONDecodeError:
- continue
- # 策略 3:贪婪兜底(最长 { ... } 块,可能解析失败但试一下)
- greedy = re.search(r"\{[\s\S]*\}", text)
- if greedy:
- try:
- return json.loads(greedy.group(0))
- except json.JSONDecodeError:
- pass
- return None
- async def _ground_single_capability_attempt(
- capability: Dict[str, Any],
- template: str,
- model: str,
- max_turns: int,
- label: str,
- attempt_no: int, # 0=首次,>=1=retry
- ) -> Tuple[Optional[Dict[str, Any]], int, Optional[str]]:
- """
- 单次 agent 调用尝试。
- Returns:
- (parsed_output, turns_used, failure_reason)
- - parsed != None & failure_reason == None: 成功(或 recover)
- - parsed == None & failure_reason != None: 失败,可 retry
- """
- from claude_agent_sdk import (
- AssistantMessage,
- ClaudeAgentOptions,
- ClaudeSDKClient,
- ClaudeSDKError,
- ResultMessage,
- TextBlock,
- )
- # 抹掉父进程 API key,强制 SDK 子进程走 OAuth
- override_env = {
- "ANTHROPIC_API_KEY": "",
- "ANTHROPIC_BASE_URL": "",
- "ANTHROPIC_AUTH_TOKEN": "",
- }
- full_prompt = render_prompt(template, capability)
- stderr_lines: List[str] = []
- def _capture_stderr(line: str) -> None:
- if line:
- stderr_lines.append(line)
- options = ClaudeAgentOptions(
- model=model,
- allowed_tools=["Read", "Grep", "Glob"],
- max_turns=max_turns,
- cwd=str(RESOURCE_DIR),
- env=override_env,
- stderr=_capture_stderr,
- setting_sources=[],
- )
- text_parts: List[str] = []
- is_error = False
- turns_used = 0
- tool_use_count = 0
- attempt_label = f" attempt#{attempt_no+1}" if attempt_no > 0 else ""
- print(f" [grounding-agent] cap={label}{attempt_label} → starting (max_turns={max_turns})", flush=True)
- try:
- async with ClaudeSDKClient(options=options) as client:
- await client.query(full_prompt)
- async for msg in client.receive_response():
- if isinstance(msg, AssistantMessage):
- for block in msg.content:
- if hasattr(block, "thinking"):
- continue
- elif isinstance(block, TextBlock):
- text_parts.append(block.text)
- preview = block.text.replace("\n", " ").strip()[:160]
- if preview:
- print(f" [cap={label}] text: {preview}", flush=True)
- elif hasattr(block, "name") and hasattr(block, "input"):
- tool_use_count += 1
- tool_input_str = json.dumps(block.input, ensure_ascii=False)[:160]
- print(
- f" [cap={label}] tool#{tool_use_count} {block.name}({tool_input_str})",
- flush=True,
- )
- elif isinstance(msg, ResultMessage):
- is_error = msg.is_error
- turns_used = msg.num_turns
- except ClaudeSDKError as e:
- reason = f"SDK ERROR {type(e).__name__}: {e}"
- print(f" [grounding-agent] cap={label} → {reason}", flush=True)
- if stderr_lines:
- print(f" CLI stderr tail: {stderr_lines[-3:]}", flush=True)
- return None, turns_used, reason
- # Recover 策略:is_error=True 也尝试解析 JSON
- content = "".join(text_parts).strip()
- parsed = extract_json_from_response(content)
- if parsed is None:
- reason = (
- "is_error=True (likely max_turns) AND no parseable JSON in output"
- if is_error
- else "JSON parse FAILED (model output didn't contain valid JSON block)"
- )
- print(f" [grounding-agent] cap={label} → {reason}", flush=True)
- print(f" raw output total: {len(content)} chars", flush=True)
- print(f" raw output head: {content[:1000]}", flush=True)
- if len(content) > 1000:
- print(f" raw output tail: {content[-500:]}", flush=True)
- return None, turns_used, reason
- # Schema 校验
- schema_err = _validate_against_schema(parsed)
- schema_warn = f" ⚠ schema violation: {schema_err}" if schema_err else ""
- matched_all = parsed.get("matched", []) or []
- shi_n = sum(1 for m in matched_all if isinstance(m, dict) and m.get("category_type") == "实质")
- xing_n = sum(1 for m in matched_all if isinstance(m, dict) and m.get("category_type") == "形式")
- suggest_n = len(parsed.get("suggested_additions", []) or [])
- if is_error:
- print(
- f" [grounding-agent] cap={label} → ⚠ RECOVERED: turns={turns_used} tools={tool_use_count} "
- f"matched={len(matched_all)} (substance={shi_n} form={xing_n}) suggested={suggest_n} "
- f"(is_error=True, but JSON captured before truncation){schema_warn}",
- flush=True,
- )
- else:
- print(
- f" [grounding-agent] cap={label} → done: turns={turns_used} tools={tool_use_count} "
- f"matched={len(matched_all)} (substance={shi_n} form={xing_n}) suggested={suggest_n}{schema_warn}",
- flush=True,
- )
- return parsed, turns_used, None # 成功
- async def ground_single_capability(
- capability: Dict[str, Any],
- template: str,
- model: str = DEFAULT_MODEL,
- max_turns: int = DEFAULT_MAX_TURNS,
- task_label: str = "",
- max_retries: int = DEFAULT_MAX_RETRIES,
- ) -> Tuple[Optional[Dict[str, Any]], float, int]:
- """
- 对单个 capability 跑 agent grounding,失败时自动 retry。
- Returns:
- (parsed_output, cost, turns_used)
- parsed_output 为 None 表示解析失败或 agent 出错且重试用完
- cost 当前总为 0.0(OAuth 模式不计费),保留接口便于后续切换
- """
- try:
- import claude_agent_sdk # noqa: F401 — 早失败,避免 retry 循环里反复 ImportError
- except ImportError as e:
- raise RuntimeError(f"claude_agent_sdk not installed: {e}") from e
- cap_id = capability.get("capability_id") or "?"
- label = task_label or cap_id
- last_reason: Optional[str] = None
- last_turns_used = 0
- for attempt in range(max_retries + 1):
- if attempt > 0:
- print(
- f" [grounding-agent] cap={label} → RETRY {attempt}/{max_retries} "
- f"(prev failure: {last_reason})",
- flush=True,
- )
- parsed, turns_used, failure_reason = await _ground_single_capability_attempt(
- capability, template, model, max_turns, label, attempt
- )
- last_turns_used = turns_used
- if failure_reason is None:
- # 成功(包括 recover 路径)
- if attempt > 0:
- print(
- f" [grounding-agent] cap={label} → ✓ recovered after retry#{attempt}",
- flush=True,
- )
- return parsed, 0.0, turns_used
- # 失败:记录原因,进入下一次重试
- last_reason = failure_reason
- # 所有 retry 用完
- print(
- f" [grounding-agent] cap={label} → ✗ FAILED after {max_retries + 1} attempts: {last_reason}",
- flush=True,
- )
- return None, 0.0, last_turns_used
- def _validate_against_schema(parsed: Dict[str, Any]) -> Optional[str]:
- """
- 用 apply_to_grounding_agent.schema.json 校验 agent 输出。
- 返回 None 表示通过,否则返回错误字符串。
- 任何加载/调用异常都吞掉(返回 None) — schema 校验是 nice-to-have,不阻塞主流程。
- """
- try:
- from examples.process_pipeline.script.schema_manager import validate_with_schema
- return validate_with_schema(parsed, "apply_to_grounding_agent")
- except Exception as e:
- # schema 文件不存在 / schema_manager 异常 — 静默跳过,不影响主流程
- return None
- async def ground_single_case(
- case_item: Dict[str, Any],
- template: str,
- model: str = DEFAULT_MODEL,
- max_turns: int = DEFAULT_MAX_TURNS,
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
- on_cap_done: Optional[Any] = None, # async callable, called after each cap completes (for incremental persist)
- backend: str = "agent", # "agent" | "oneshot"
- llm_call: Any = None, # oneshot 必需
- oneshot_template: Optional[str] = None, # oneshot 必需
- shi_json: Optional[str] = None, # oneshot 必需(预读,避免每个 cap 重读)
- xing_json: Optional[str] = None, # oneshot 必需
- ) -> Tuple[Dict[str, Any], float]:
- """
- 对单个 case 的所有 capability 跑 agent grounding。
- **直接 in-place 修改 case_item**(不再创建副本)— 每个 cap 完成时,对应的
- capability dict 立刻被更新 apply_to / suggest_apply_to 字段。
- on_cap_done: 可选 async 回调,每个 cap 完成后调用(无参)。caller 可以在回调里
- 持久化 case_item 所属的 case.json,实现 cap 级增量写回。
- """
- total_cost = 0.0
- case_index = case_item.get("index", "?")
- title = (case_item.get("title") or "")[:20] or "untitled"
- workflow_groups = case_item.get("workflow_groups")
- if not isinstance(workflow_groups, list) or not workflow_groups:
- return case_item, total_cost
- # 跨 group 收集所有 capability — 直接持有原 case_item 内部的 dict 引用。
- # 跳过被 workflow.steps[*].phase == "非制作" 标注的 cap(meta-task,没有真实产出物)。
- flat_pairs: List[tuple] = []
- skipped_non_production = 0
- for group_idx, group in enumerate(workflow_groups):
- if not isinstance(group, dict):
- continue
- workflow_id = group.get("workflow_id") or f"g{group_idx + 1}"
- # 构造 step_id → phase 映射
- wf = group.get("workflow") or {}
- step_phase = {
- s.get("step_id"): s.get("phase")
- for s in (wf.get("steps") or []) if isinstance(s, dict)
- }
- capability_items = group.get("capability")
- if not isinstance(capability_items, list) or not capability_items:
- continue
- for cap_idx, cap in enumerate(capability_items):
- if not (isinstance(cap, dict) and cap.get("capability_id")):
- continue
- cap_id = cap.get("capability_id")
- # 通过 workflow_step_ref 反查 phase
- step_ref = cap.get("workflow_step_ref") or {}
- step_id = step_ref.get("step_id")
- phase = step_phase.get(step_id)
- if phase == "非制作":
- print(
- f" [grounding-agent] skip cap={case_index}/{workflow_id}/{cap_id} "
- f"(step={step_id} phase=非制作, no production output to ground)",
- flush=True,
- )
- skipped_non_production += 1
- continue
- flat_pairs.append((group_idx, cap_idx, cap, workflow_id))
- if not flat_pairs:
- return case_item, total_cost
- print(f"\n[{case_index}] === case '{title}' grounding {len(flat_pairs)} capabilities ===", flush=True)
- sem = asyncio.Semaphore(max_concurrent)
- # cap 级写锁:max_concurrent>1 时同 case 内多个 cap 并发,写回 case_item + 触发 persist 需要串行
- write_lock = asyncio.Lock()
- grounded_count = 0
- async def _run_one(group_idx: int, cap_idx: int, cap: Dict, workflow_id: str):
- nonlocal grounded_count, total_cost
- async with sem:
- cap_id = cap.get("capability_id") or f"idx{cap_idx}"
- label = f"{case_index}/{workflow_id}/{cap_id}"
- if backend == "oneshot":
- parsed, cost, _ = await ground_single_capability_oneshot(
- cap, oneshot_template, llm_call, model,
- shi_json=shi_json, xing_json=xing_json, task_label=label,
- )
- else:
- parsed, cost, _turns = await ground_single_capability(
- cap, template, model=model, max_turns=max_turns, task_label=label,
- )
- # in-place 写回到 case_item 内的 capability dict(直接修改原对象)
- # write_lock 保证:多个并发 cap 完成时,更新 + 触发 persist 是串行的
- async with write_lock:
- total_cost += cost
- if parsed is not None:
- target_cap = case_item["workflow_groups"][group_idx]["capability"][cap_idx]
- if isinstance(target_cap, dict):
- target_cap["apply_to"] = parsed.get("matched", []) or []
- target_cap["suggest_apply_to"] = parsed.get("suggested_additions", []) or []
- grounded_count += 1
- # 通知 caller 持久化(即使 parsed=None 也通知 — failure 也是个"完成",
- # caller 可以选择是否在失败时也持久化部分进度)
- if on_cap_done is not None:
- try:
- await on_cap_done()
- except Exception as e:
- print(f" [cap={label}] persist callback failed: {type(e).__name__}: {e}", flush=True)
- return group_idx, cap_idx, parsed, cost
- tasks = [_run_one(gi, ci, cap, wid) for (gi, ci, cap, wid) in flat_pairs]
- await asyncio.gather(*tasks)
- print(
- f"[{case_index}] >>> case done: {grounded_count}/{len(flat_pairs)} caps grounded, "
- f"{skipped_non_production} skipped (phase=非制作) <<<",
- flush=True,
- )
- return case_item, total_cost
- async def apply_grounding(
- case_file: Path,
- llm_call: Any = None,
- model: str = DEFAULT_MODEL,
- max_concurrent: int = DEFAULT_MAX_CONCURRENT,
- backend: Optional[str] = None, # "agent" | "oneshot",None 时读环境变量
- ) -> Dict[str, Any]:
- """
- 顶层入口:遍历 case.json,对每个 case 的 capability 跑 grounding。
- backend 选择:
- - "agent" :用 Claude Agent SDK + OAuth/Max 订阅,按需 Read/Grep 分类库
- - "oneshot" :用一次性 LLM 调用(通过 llm_call,通常是 OpenRouter),把分类库全文塞 prompt
- - None :从环境变量 GROUNDING_BACKEND 读,默认 "agent"
- 保持与旧 apply_grounding 一致的签名,便于 run_pipeline.py 平滑切换。
- """
- # 确定 backend
- if backend is None:
- backend = os.getenv("GROUNDING_BACKEND", "agent").strip().lower()
- if backend not in ("agent", "oneshot"):
- raise ValueError(f"Invalid backend: {backend!r}. Choose 'agent' or 'oneshot'.")
- if backend == "oneshot" and llm_call is None:
- raise ValueError("backend='oneshot' requires llm_call (e.g. OpenRouter llm_call from run_pipeline.py)")
- if not SHI_JSON_PATH.exists():
- raise FileNotFoundError(f"实质分类库不存在: {SHI_JSON_PATH}")
- if not XING_JSON_PATH.exists():
- raise FileNotFoundError(f"形式分类库不存在: {XING_JSON_PATH}")
- with open(case_file, "r", encoding="utf-8") as f:
- case_data = json.load(f)
- cases = case_data.get("cases", [])
- template = load_prompt_template("apply_to_grounding_agent")
- # Oneshot 模式:预读两个分类库 JSON(每个 cap 共享,不重复读)
- oneshot_template: Optional[str] = None
- shi_json: Optional[str] = None
- xing_json: Optional[str] = None
- if backend == "oneshot":
- oneshot_template = load_prompt_template("apply_to_grounding_oneshot")
- shi_json = SHI_JSON_PATH.read_text(encoding="utf-8")
- xing_json = XING_JSON_PATH.read_text(encoding="utf-8")
- print(
- f"[apply-grounding] backend=oneshot model={model} "
- f"shi_json={len(shi_json):,} chars xing_json={len(xing_json):,} chars",
- flush=True,
- )
- else:
- print(f"[apply-grounding] backend=agent model={model}", flush=True)
- # 选出有 capability 的 case
- needs = [
- c for c in cases
- if isinstance(c.get("workflow_groups"), list) and any(
- isinstance(g, dict) and isinstance(g.get("capability"), list) and any(
- isinstance(cap, dict) and cap.get("capability_id")
- for cap in g.get("capability", [])
- )
- for g in c["workflow_groups"]
- )
- ]
- print(f"Grounding (agent) for {len(needs)}/{len(cases)} cases (max_concurrent={max_concurrent})", flush=True)
- if not needs:
- return {
- "total": len(cases),
- "grounded": 0,
- "total_cost": 0.0,
- "output_file": str(case_file),
- }
- # 顶层 case 之间限制并发(agent 单次就贵,case 间不要再并发)
- # 默认每次只跑 1 个 case,case 内部按 max_concurrent 并发 cap
- # 注意:case_sem=1 同时保证了"case 间串行" — 增量写回 case.json 不会有并发冲突。
- case_sem = asyncio.Semaphore(1)
- # 跑前先做一次快照(保留改动前的 case.json)
- try:
- from examples.process_pipeline.script.case_history import snapshot_case_file
- snap = snapshot_case_file(case_file, step="apply_grounding_agent")
- if snap:
- print(f" [snapshot] {snap.name}", flush=True)
- except Exception as e:
- print(f" [snapshot] skipped: {e}", flush=True)
- # 全局写锁:跨 case 共享一个 — 即使 case_sem=1 的设定将来放宽,I/O 也安全
- persist_lock = asyncio.Lock()
- async def _persist_case_data() -> None:
- """异步串行 dump case_data 到 case.json(cap 级增量写回 + finally 兜底都用这个)"""
- async with persist_lock:
- with open(case_file, "w", encoding="utf-8") as f:
- json.dump(case_data, f, ensure_ascii=False, indent=2)
- async def _process_case(case_item):
- async with case_sem:
- idx = case_item.get("index", 0)
- cid = (case_item.get("_raw") or {}).get("case_id", "unknown")
- print(f" -> [{idx}] [{cid}] agent grounding starts", flush=True)
- # cap 级增量持久化:ground_single_case 内部每个 cap 完成时调一次
- # case_item 已被 in-place 修改,case_data["cases"] 里的引用就是 case_item,
- # 所以直接 dump case_data 就能反映最新进度。
- async def _on_cap_done():
- try:
- await _persist_case_data()
- except Exception as e:
- print(f" [case {idx}] persist failed: {type(e).__name__}: {e}", flush=True)
- grounded, cost = await ground_single_case(
- case_item, template, model=model, max_concurrent=max_concurrent,
- on_cap_done=_on_cap_done,
- backend=backend, llm_call=llm_call,
- oneshot_template=oneshot_template,
- shi_json=shi_json, xing_json=xing_json,
- )
- print(f" <- [{idx}] [{cid}] agent grounding done (case.json saved incrementally at each cap)", flush=True)
- return case_item, grounded, cost
- total_cost = 0.0
- try:
- tasks = [_process_case(c) for c in needs]
- results = await asyncio.gather(*tasks)
- total_cost = sum(c for _, _, c in results)
- except (KeyboardInterrupt, asyncio.CancelledError) as e:
- # 兜底:用户 Ctrl+C 或上游 cancel 时,把已写入的 case_data 落盘
- # (cap 级增量写回已经在每个 cap 完成时写过,这里再保险一次)
- print(f"\n[grounding-agent] interrupted ({type(e).__name__}) — flushing partial results to {case_file}", flush=True)
- try:
- await _persist_case_data()
- except Exception as flush_err:
- print(f" [grounding-agent] final flush failed: {flush_err}", flush=True)
- raise
- # 最终 dump 一次(兜底,确保最终状态一致)
- await _persist_case_data()
- return {
- "total": len(cases),
- "grounded": len(needs),
- "total_cost": total_cost,
- "output_file": str(case_file),
- }
- if __name__ == "__main__":
- import sys
- print("Please use this module through run_pipeline.py")
- sys.exit(1)
|