""" Stage 2: 将 apply_to_draft 映射为正式 apply_to 从 case.json 读取,只对每个 case 的 workflow_groups[*].capability 中的 apply_to_draft 做映射。 调用 LLM 映射到内容树的正式节点,原位回填到 case.json 改造版本:通过远程 API 获取内容树,不再依赖本地文件 """ import asyncio import json import os from pathlib import Path from typing import Any, Dict, List, Optional import httpx from dotenv import load_dotenv from examples.process_pipeline.script.llm_helper import call_llm_with_retry # 加载环境变量 PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent load_dotenv(PROJECT_ROOT / ".env") # 搜索 API 配置 SEARCH_API = os.getenv("SEARCH_API", "http://8.147.104.190:8001").rstrip("/") CAPABILITY_GROUNDING_BATCH_SIZE = int(os.getenv("CAPABILITY_GROUNDING_BATCH_SIZE", "8")) # 本地文件路径(作为回退方案) EXTRACT_DIR = Path(__file__).resolve().parent / "resource" CATEGORY_TREE_PATH = EXTRACT_DIR / "category_tree_56.json" def load_prompt_template(prompt_name: str) -> str: """加载 prompt 模板""" base_dir = Path(__file__).parent.parent prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt" with open(prompt_path, "r", encoding="utf-8") as f: content = f.read() if content.startswith("---"): parts = content.split("---", 2) if len(parts) >= 3: content = parts[2] content = content.replace("$system$", "").replace("$user$", "") return content.strip() def load_category_tree_from_local() -> List[Dict]: """从本地文件加载内容树(回退方案)""" if not CATEGORY_TREE_PATH.exists(): raise FileNotFoundError(f"Category tree not found: {CATEGORY_TREE_PATH}") with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f: data = json.load(f) cats = data.get("categories", []) if not cats: raise RuntimeError("category_tree is empty") return cats async def load_category_tree(use_api: bool = False) -> List[Dict]: """ 加载内容树(支持本地文件和远程 API) Args: use_api: 是否使用远程 API(默认 False,使用本地文件) Returns: 内容树节点列表 """ if use_api: try: print("Attempting to load category tree from API...") return await fetch_category_tree_from_api() except Exception as e: print(f"API failed ({e}), falling back to local file...") return load_category_tree_from_local() else: print("Loading category tree from local file...") return load_category_tree_from_local() async def search_categories_by_keywords( keywords: List[str], source_types: List[str] = None, top_k: int = 10, timeout: int = 10 ) -> List[Dict]: """ 根据关键词搜索相关分类节点 Args: keywords: 搜索关键词列表 source_types: 来源类型列表,默认 ["形式", "实质"] top_k: 每个关键词返回的结果数 timeout: 请求超时时间(秒) Returns: 相关分类节点列表 """ if source_types is None: source_types = ["形式", "实质"] all_categories = [] seen_ids = set() async with httpx.AsyncClient(timeout=timeout) as client: for keyword in keywords: for source_type in source_types: try: params = { "q": keyword, "source_type": source_type, "entity_type": "category", "top_k": top_k, "mode": "vector" } resp = await client.get(f"{SEARCH_API}/api/search", params=params) resp.raise_for_status() data = resp.json() results = data.get("results", []) # 转换为内容树格式,去重 for item in results: entity_id = item.get("entity_id") if entity_id and entity_id not in seen_ids: category = { "id": entity_id, "path": item.get("path", ""), "source_type": source_type, "description": item.get("description", ""), "elements": [] } all_categories.append(category) seen_ids.add(entity_id) except Exception as e: # 静默失败,继续处理其他关键词 continue return all_categories def extract_keywords_from_draft(draft_text: str) -> List[str]: """ 从 apply_to_draft 文本中提取关键词 Args: draft_text: apply_to_draft 的文本内容 Returns: 关键词列表 """ if not draft_text or not isinstance(draft_text, str): return [] # 简单的关键词提取:分词并过滤 import re # 移除标点符号,按空格和常见分隔符分词 words = re.split(r'[,。、;:!?\s]+', draft_text) # 过滤短词和停用词 keywords = [w.strip() for w in words if len(w.strip()) >= 2] # 去重并限制数量 keywords = list(dict.fromkeys(keywords))[:5] # 最多5个关键词 return keywords def build_compact_tree(cats: List[Dict]) -> str: """构建紧凑版内容树(用于注入 prompt)""" rows = [] for c in cats: if c.get("source_type") not in ("实质", "形式"): continue row = { "id": c.get("id"), "path": c.get("path"), "source_type": c.get("source_type"), "description": c.get("description"), } elems = c.get("elements", []) if isinstance(elems, list) and elems: elem_names = [ e.get("name") if isinstance(e, dict) else e for e in elems if e ] if elem_names: row["elements"] = elem_names rows.append(row) return json.dumps(rows, ensure_ascii=False, separators=(",", ":")) def build_valid_ids(cats: List[Dict]) -> Dict[int, Dict]: """构建 id -> node 映射""" return {c["id"]: c for c in cats if "id" in c} def iter_batches(items: List[Any], batch_size: int) -> List[List[Any]]: """按固定大小切分列表""" batch_size = max(1, batch_size) return [items[i:i + batch_size] for i in range(0, len(items), batch_size)] def build_capability_grounding_input(capability: Dict[str, Any]) -> Dict[str, Any]: """只保留 capability grounding 需要的最小字段""" return { "capability_id": capability.get("capability_id"), "body": capability.get("body") or "", "apply_to_draft": capability.get("apply_to_draft", {}), } def body_excerpts_are_verbatim(apply_to: Any, suggest_apply_to: Any, body: str) -> bool: """确认非空 body_excerpt 都逐字来自 capability.body;空字符串允许存在。""" def _entry_is_valid(item: Any) -> bool: if not isinstance(item, dict): return False excerpt = item.get("body_excerpt") note = item.get("body_excerpt_note") if not isinstance(excerpt, str) or not isinstance(note, str): return False if not excerpt.strip() and note.strip(): return False if excerpt.strip() and excerpt.strip() not in body: return False return True if not isinstance(apply_to, dict) or not isinstance(suggest_apply_to, list) or not isinstance(body, str): return False for source_type in ("实质", "形式"): items = apply_to.get(source_type, []) if not isinstance(items, list): return False for item in items: if not _entry_is_valid(item): return False for item in suggest_apply_to: if not _entry_is_valid(item): return False return True def render_grounding_prompt( template: str, task: str, draft: Dict, compact_tree: str, reference_paths: List[str] = None, ) -> str: """渲染 Stage 2 prompt""" target = "capability 数组中的每一条 capability" paths_str = json.dumps(reference_paths or [], ensure_ascii=False) return ( template .replace("{target}", target) .replace("{compact_tree}", compact_tree) .replace("{reference_paths}", paths_str) .replace("{draft_json}", json.dumps(draft, ensure_ascii=False, indent=2)) ) async def ground_single_case( case_item: Dict[str, Any], template: str, llm_call: Any, model: str, use_api: bool = False, compact_tree: str = None, ) -> tuple[Dict[str, Any], float]: """ 对单个 case 的 workflow_groups[*].capability[*].apply_to_draft 做 apply_to 映射。 """ total_cost = 0.0 result = dict(case_item) title = case_item.get("title", "")[:20] or "untitled" workflow_groups = case_item.get("workflow_groups") if not isinstance(workflow_groups, list) or not workflow_groups: return result, total_cost updated_groups = [ dict(group) if isinstance(group, dict) else group for group in workflow_groups ] for group_idx, group in enumerate(updated_groups): if not isinstance(group, dict): continue workflow_id = group.get("workflow_id") or f"g{group_idx + 1}" capability_items = group.get("capability") if not isinstance(capability_items, list) or not capability_items: continue draft_capability_pairs = [ (idx, capability) for idx, capability in enumerate(capability_items) if isinstance(capability, dict) and "apply_to_draft" in capability ] if not draft_capability_pairs: continue # 收集 capability 的关键词(用于 API 搜索) if use_api: all_keywords = [] for _, capability in draft_capability_pairs: apply_to_draft = capability.get("apply_to_draft", {}) for key in ["实质", "形式"]: for draft_text in apply_to_draft.get(key, []): all_keywords.extend(extract_keywords_from_draft(draft_text)) all_keywords = list(dict.fromkeys(all_keywords))[:10] if all_keywords: categories = await search_categories_by_keywords(all_keywords, top_k=5) capability_compact_tree = build_compact_tree(categories) capability_ref_paths = list(dict.fromkeys( c["path"] for c in categories if c.get("path") )) else: capability_compact_tree = compact_tree or "[]" capability_ref_paths = [] else: capability_compact_tree = compact_tree or "[]" capability_ref_paths = [] updated_capabilities = [ dict(capability) if isinstance(capability, dict) else capability for capability in capability_items ] id_to_index = { capability.get("capability_id"): idx for idx, capability in draft_capability_pairs if isinstance(capability.get("capability_id"), str) } batches = iter_batches(draft_capability_pairs, CAPABILITY_GROUNDING_BATCH_SIZE) for batch_idx, batch_pairs in enumerate(batches, start=1): draft_capabilities = [ build_capability_grounding_input(capability) for _, capability in batch_pairs ] draft = {"capability": draft_capabilities} prompt = render_grounding_prompt(template, "capability", draft, capability_compact_tree, capability_ref_paths) messages = [{"role": "user", "content": prompt}] grounded, cost = await call_llm_with_retry( llm_call=llm_call, messages=messages, model=model, temperature=0.1, max_tokens=4000, max_retries=3, schema_name="apply_to_grounding_capability", task_name=f"Ground_C_{title}_{workflow_id}_B{batch_idx}/{len(batches)}", ) total_cost += cost if not grounded or not isinstance(grounded.get("capability"), list): continue grounded_capabilities = grounded["capability"] used_indices = set() for output_idx, grounded_capability in enumerate(grounded_capabilities): if not isinstance(grounded_capability, dict): continue capability_idx = None capability_id = grounded_capability.get("capability_id") if isinstance(capability_id, str): capability_idx = id_to_index.get(capability_id) if capability_idx is None and output_idx < len(batch_pairs): capability_idx = batch_pairs[output_idx][0] if capability_idx is None or capability_idx in used_indices: continue apply_to = grounded_capability.get("apply_to") suggest_apply_to = grounded_capability.get("suggest_apply_to") body = updated_capabilities[capability_idx].get("body", "") if ( apply_to is not None and isinstance(suggest_apply_to, list) and len(suggest_apply_to) <= 3 and isinstance(updated_capabilities[capability_idx], dict) and body_excerpts_are_verbatim(apply_to, suggest_apply_to, body) ): updated_capabilities[capability_idx]["apply_to"] = apply_to updated_capabilities[capability_idx]["suggest_apply_to"] = suggest_apply_to updated_capabilities[capability_idx].pop("apply_to_draft", None) used_indices.add(capability_idx) else: print( f" ⚠️ Skip capability grounding writeback: " f"{capability_id or capability_idx} has missing/non-verbatim body_excerpt" ) group["capability"] = updated_capabilities result["workflow_groups"] = updated_groups return result, total_cost async def apply_grounding( case_file: Path, llm_call: Any, model: str = "openai/gpt-5.4", max_concurrent: int = 3, ) -> Dict[str, Any]: """ 按 index 遍历 case.json,将 apply_to_draft 映射为 apply_to Args: case_file: case.json 文件路径 llm_call: LLM 调用函数 model: 模型名称 max_concurrent: 最大并发数 Returns: 包含统计信息的字典 """ with open(case_file, "r", encoding="utf-8") as f: case_data = json.load(f) cases = case_data.get("cases", []) # 检查是否使用 API 动态搜索模式 use_api = os.getenv("USE_SEARCH_API", "true").lower() == "true" # 如果不使用 API,预加载完整内容树 compact_tree = None if not use_api: cats = await load_category_tree(use_api=False) compact_tree = build_compact_tree(cats) print(f"Loaded category tree: {len(cats)} nodes, compact={len(compact_tree)} chars") else: print(f"Using API dynamic search mode (searching on-demand)") # 加载 prompt 模板 template = load_prompt_template("apply_to_grounding") # 过滤出需要处理的 case(只看 workflow_groups[*].capability[*].apply_to_draft) needs_grounding = [] for case in cases: workflow_groups = case.get("workflow_groups") has_capability_draft = isinstance(workflow_groups, list) and any( isinstance(group, dict) and isinstance(group.get("capability"), list) and any( isinstance(capability, dict) and "apply_to_draft" in capability for capability in group.get("capability", []) ) for group in workflow_groups ) if has_capability_draft: needs_grounding.append(case) print(f"Grounding apply_to for {len(needs_grounding)}/{len(cases)} cases...") if not needs_grounding: print(" No cases need grounding, skipping.") return { "total": len(cases), "grounded": 0, "total_cost": 0.0, "output_file": str(case_file), } semaphore = asyncio.Semaphore(max_concurrent) async def process_with_semaphore(case_item): async with semaphore: index = case_item.get("index", 0) raw = case_item.get("_raw", {}) case_id = raw.get("case_id", "unknown") print(f" -> [{index}] [{case_id}] grounding apply_to...") grounded, cost = await ground_single_case( case_item, template, llm_call, model, use_api, compact_tree ) print(f" <- [{index}] [{case_id}] grounded (cost=${cost:.4f})") return case_item, grounded, cost tasks = [process_with_semaphore(case) for case in needs_grounding] results_with_costs = await asyncio.gather(*tasks) # 用 grounded 结果替换原 case(按对象身份匹配,避免 index 缺失或重复时回填错 case) grounded_map = {} total_cost = 0.0 for original_case, grounded, cost in results_with_costs: grounded_map[id(original_case)] = grounded total_cost += cost updated_cases = [] for case in cases: case_id = id(case) if case_id in grounded_map: updated_cases.append(grounded_map[case_id]) else: updated_cases.append(case) updated_cases.sort(key=lambda x: x.get("index", 0)) case_data["cases"] = updated_cases case_file.parent.mkdir(parents=True, exist_ok=True) with open(case_file, "w", encoding="utf-8") as f: json.dump(case_data, f, ensure_ascii=False, indent=2) return { "total": len(cases), "grounded": len(needs_grounding), "total_cost": total_cost, "output_file": str(case_file), } if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python apply_to_grounding.py ") sys.exit(1) print("Please use this module through run_pipeline.py")