| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510 |
- """
- Stage 2: 将 apply_to_draft 映射为正式 apply_to
- 从 case.json 读取,只对每个 case 的 workflow_groups[*].capability 中的 apply_to_draft 做映射。
- 调用 LLM 映射到内容树的正式节点,原位回填到 case.json
- 改造版本:通过远程 API 获取内容树,不再依赖本地文件
- """
- import asyncio
- import json
- import os
- from pathlib import Path
- from typing import Any, Dict, List, Optional
- import httpx
- from dotenv import load_dotenv
- from examples.process_pipeline.script.llm_helper import call_llm_with_retry
- # 加载环境变量
- PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent
- load_dotenv(PROJECT_ROOT / ".env")
- # 搜索 API 配置
- SEARCH_API = os.getenv("SEARCH_API", "http://8.147.104.190:8001").rstrip("/")
- CAPABILITY_GROUNDING_BATCH_SIZE = int(os.getenv("CAPABILITY_GROUNDING_BATCH_SIZE", "8"))
- # 本地文件路径(作为回退方案)
- EXTRACT_DIR = Path(__file__).resolve().parent / "resource"
- CATEGORY_TREE_PATH = EXTRACT_DIR / "category_tree_56.json"
- def load_prompt_template(prompt_name: str) -> str:
- """加载 prompt 模板"""
- base_dir = Path(__file__).parent.parent
- prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
- with open(prompt_path, "r", encoding="utf-8") as f:
- content = f.read()
- if content.startswith("---"):
- parts = content.split("---", 2)
- if len(parts) >= 3:
- content = parts[2]
- content = content.replace("$system$", "").replace("$user$", "")
- return content.strip()
- def load_category_tree_from_local() -> List[Dict]:
- """从本地文件加载内容树(回退方案)"""
- if not CATEGORY_TREE_PATH.exists():
- raise FileNotFoundError(f"Category tree not found: {CATEGORY_TREE_PATH}")
- with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
- data = json.load(f)
- cats = data.get("categories", [])
- if not cats:
- raise RuntimeError("category_tree is empty")
- return cats
- async def load_category_tree(use_api: bool = False) -> List[Dict]:
- """
- 加载内容树(支持本地文件和远程 API)
- Args:
- use_api: 是否使用远程 API(默认 False,使用本地文件)
- Returns:
- 内容树节点列表
- """
- if use_api:
- try:
- print("Attempting to load category tree from API...")
- return await fetch_category_tree_from_api()
- except Exception as e:
- print(f"API failed ({e}), falling back to local file...")
- return load_category_tree_from_local()
- else:
- print("Loading category tree from local file...")
- return load_category_tree_from_local()
- async def search_categories_by_keywords(
- keywords: List[str],
- source_types: List[str] = None,
- top_k: int = 10,
- timeout: int = 10
- ) -> List[Dict]:
- """
- 根据关键词搜索相关分类节点
- Args:
- keywords: 搜索关键词列表
- source_types: 来源类型列表,默认 ["形式", "实质"]
- top_k: 每个关键词返回的结果数
- timeout: 请求超时时间(秒)
- Returns:
- 相关分类节点列表
- """
- if source_types is None:
- source_types = ["形式", "实质"]
- all_categories = []
- seen_ids = set()
- async with httpx.AsyncClient(timeout=timeout) as client:
- for keyword in keywords:
- for source_type in source_types:
- try:
- params = {
- "q": keyword,
- "source_type": source_type,
- "entity_type": "category",
- "top_k": top_k,
- "mode": "vector"
- }
- resp = await client.get(f"{SEARCH_API}/api/search", params=params)
- resp.raise_for_status()
- data = resp.json()
- results = data.get("results", [])
- # 转换为内容树格式,去重
- for item in results:
- entity_id = item.get("entity_id")
- if entity_id and entity_id not in seen_ids:
- category = {
- "id": entity_id,
- "path": item.get("path", ""),
- "source_type": source_type,
- "description": item.get("description", ""),
- "elements": []
- }
- all_categories.append(category)
- seen_ids.add(entity_id)
- except Exception as e:
- # 静默失败,继续处理其他关键词
- continue
- return all_categories
- def extract_keywords_from_draft(draft_text: str) -> List[str]:
- """
- 从 apply_to_draft 文本中提取关键词
- Args:
- draft_text: apply_to_draft 的文本内容
- Returns:
- 关键词列表
- """
- if not draft_text or not isinstance(draft_text, str):
- return []
- # 简单的关键词提取:分词并过滤
- import re
- # 移除标点符号,按空格和常见分隔符分词
- words = re.split(r'[,。、;:!?\s]+', draft_text)
- # 过滤短词和停用词
- keywords = [w.strip() for w in words if len(w.strip()) >= 2]
- # 去重并限制数量
- keywords = list(dict.fromkeys(keywords))[:5] # 最多5个关键词
- return keywords
- def build_compact_tree(cats: List[Dict]) -> str:
- """构建紧凑版内容树(用于注入 prompt)"""
- rows = []
- for c in cats:
- if c.get("source_type") not in ("实质", "形式"):
- continue
- row = {
- "id": c.get("id"),
- "path": c.get("path"),
- "source_type": c.get("source_type"),
- "description": c.get("description"),
- }
- elems = c.get("elements", [])
- if isinstance(elems, list) and elems:
- elem_names = [
- e.get("name") if isinstance(e, dict) else e
- for e in elems
- if e
- ]
- if elem_names:
- row["elements"] = elem_names
- rows.append(row)
- return json.dumps(rows, ensure_ascii=False, separators=(",", ":"))
- def build_valid_ids(cats: List[Dict]) -> Dict[int, Dict]:
- """构建 id -> node 映射"""
- return {c["id"]: c for c in cats if "id" in c}
- def iter_batches(items: List[Any], batch_size: int) -> List[List[Any]]:
- """按固定大小切分列表"""
- batch_size = max(1, batch_size)
- return [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
- def build_capability_grounding_input(capability: Dict[str, Any]) -> Dict[str, Any]:
- """只保留 capability grounding 需要的最小字段"""
- return {
- "capability_id": capability.get("capability_id"),
- "body": capability.get("body") or "",
- "apply_to_draft": capability.get("apply_to_draft", {}),
- }
- def apply_to_body_excerpts_are_verbatim(apply_to: Any, body: str) -> bool:
- """确认每个 apply_to 条目的 body_excerpt 都逐字来自 capability.body。"""
- if not isinstance(apply_to, dict) or not isinstance(body, str) or not body.strip():
- return False
- for source_type in ("实质", "形式"):
- items = apply_to.get(source_type, [])
- if not isinstance(items, list):
- return False
- for item in items:
- if not isinstance(item, dict):
- return False
- excerpt = item.get("body_excerpt")
- if not isinstance(excerpt, str) or not excerpt.strip():
- return False
- if excerpt.strip() not in body:
- return False
- return True
- def render_grounding_prompt(
- template: str,
- task: str,
- draft: Dict,
- compact_tree: str,
- reference_paths: List[str] = None,
- ) -> str:
- """渲染 Stage 2 prompt"""
- target = "capability 数组中的每一条 capability"
- paths_str = json.dumps(reference_paths or [], ensure_ascii=False)
- return (
- template
- .replace("{target}", target)
- .replace("{compact_tree}", compact_tree)
- .replace("{reference_paths}", paths_str)
- .replace("{draft_json}", json.dumps(draft, ensure_ascii=False, indent=2))
- )
- async def ground_single_case(
- case_item: Dict[str, Any],
- template: str,
- llm_call: Any,
- model: str,
- use_api: bool = False,
- compact_tree: str = None,
- ) -> tuple[Dict[str, Any], float]:
- """
- 对单个 case 的 workflow_groups[*].capability[*].apply_to_draft 做 apply_to 映射。
- """
- total_cost = 0.0
- result = dict(case_item)
- title = case_item.get("title", "")[:20] or "untitled"
- workflow_groups = case_item.get("workflow_groups")
- if not isinstance(workflow_groups, list) or not workflow_groups:
- return result, total_cost
- updated_groups = [
- dict(group) if isinstance(group, dict) else group
- for group in workflow_groups
- ]
- for group_idx, group in enumerate(updated_groups):
- if not isinstance(group, dict):
- continue
- workflow_id = group.get("workflow_id") or f"g{group_idx + 1}"
- capability_items = group.get("capability")
- if not isinstance(capability_items, list) or not capability_items:
- continue
- draft_capability_pairs = [
- (idx, capability)
- for idx, capability in enumerate(capability_items)
- if isinstance(capability, dict) and "apply_to_draft" in capability
- ]
- if not draft_capability_pairs:
- continue
- # 收集 capability 的关键词(用于 API 搜索)
- if use_api:
- all_keywords = []
- for _, capability in draft_capability_pairs:
- apply_to_draft = capability.get("apply_to_draft", {})
- for key in ["实质", "形式"]:
- for draft_text in apply_to_draft.get(key, []):
- all_keywords.extend(extract_keywords_from_draft(draft_text))
- all_keywords = list(dict.fromkeys(all_keywords))[:10]
- if all_keywords:
- categories = await search_categories_by_keywords(all_keywords, top_k=5)
- capability_compact_tree = build_compact_tree(categories)
- capability_ref_paths = list(dict.fromkeys(
- c["path"] for c in categories if c.get("path")
- ))
- else:
- capability_compact_tree = compact_tree or "[]"
- capability_ref_paths = []
- else:
- capability_compact_tree = compact_tree or "[]"
- capability_ref_paths = []
- updated_capabilities = [
- dict(capability) if isinstance(capability, dict) else capability
- for capability in capability_items
- ]
- id_to_index = {
- capability.get("capability_id"): idx
- for idx, capability in draft_capability_pairs
- if isinstance(capability.get("capability_id"), str)
- }
- batches = iter_batches(draft_capability_pairs, CAPABILITY_GROUNDING_BATCH_SIZE)
- for batch_idx, batch_pairs in enumerate(batches, start=1):
- draft_capabilities = [
- build_capability_grounding_input(capability)
- for _, capability in batch_pairs
- ]
- draft = {"capability": draft_capabilities}
- prompt = render_grounding_prompt(template, "capability", draft, capability_compact_tree, capability_ref_paths)
- messages = [{"role": "user", "content": prompt}]
- grounded, cost = await call_llm_with_retry(
- llm_call=llm_call,
- messages=messages,
- model=model,
- temperature=0.1,
- max_tokens=4000,
- max_retries=3,
- schema_name="apply_to_grounding_capability",
- task_name=f"Ground_C_{title}_{workflow_id}_B{batch_idx}/{len(batches)}",
- )
- total_cost += cost
- if not grounded or not isinstance(grounded.get("capability"), list):
- continue
- grounded_capabilities = grounded["capability"]
- used_indices = set()
- for output_idx, grounded_capability in enumerate(grounded_capabilities):
- if not isinstance(grounded_capability, dict):
- continue
- capability_idx = None
- capability_id = grounded_capability.get("capability_id")
- if isinstance(capability_id, str):
- capability_idx = id_to_index.get(capability_id)
- if capability_idx is None and output_idx < len(batch_pairs):
- capability_idx = batch_pairs[output_idx][0]
- if capability_idx is None or capability_idx in used_indices:
- continue
- apply_to = grounded_capability.get("apply_to")
- suggest_apply_to = grounded_capability.get("suggest_apply_to")
- body = updated_capabilities[capability_idx].get("body", "")
- if (
- apply_to is not None
- and isinstance(suggest_apply_to, str)
- and suggest_apply_to.strip()
- and isinstance(updated_capabilities[capability_idx], dict)
- and apply_to_body_excerpts_are_verbatim(apply_to, body)
- ):
- updated_capabilities[capability_idx]["apply_to"] = apply_to
- updated_capabilities[capability_idx]["suggest_apply_to"] = suggest_apply_to.strip()
- updated_capabilities[capability_idx].pop("apply_to_draft", None)
- used_indices.add(capability_idx)
- else:
- print(
- f" ⚠️ Skip capability grounding writeback: "
- f"{capability_id or capability_idx} has missing/non-verbatim body_excerpt"
- )
- group["capability"] = updated_capabilities
- result["workflow_groups"] = updated_groups
- return result, total_cost
- async def apply_grounding(
- case_file: Path,
- llm_call: Any,
- model: str = "openai/gpt-5.4",
- max_concurrent: int = 3,
- ) -> Dict[str, Any]:
- """
- 按 index 遍历 case.json,将 apply_to_draft 映射为 apply_to
- Args:
- case_file: case.json 文件路径
- llm_call: LLM 调用函数
- model: 模型名称
- max_concurrent: 最大并发数
- Returns:
- 包含统计信息的字典
- """
- with open(case_file, "r", encoding="utf-8") as f:
- case_data = json.load(f)
- cases = case_data.get("cases", [])
- # 检查是否使用 API 动态搜索模式
- use_api = os.getenv("USE_SEARCH_API", "true").lower() == "true"
- # 如果不使用 API,预加载完整内容树
- compact_tree = None
- if not use_api:
- cats = await load_category_tree(use_api=False)
- compact_tree = build_compact_tree(cats)
- print(f"Loaded category tree: {len(cats)} nodes, compact={len(compact_tree)} chars")
- else:
- print(f"Using API dynamic search mode (searching on-demand)")
- # 加载 prompt 模板
- template = load_prompt_template("apply_to_grounding")
- # 过滤出需要处理的 case(只看 workflow_groups[*].capability[*].apply_to_draft)
- needs_grounding = []
- for case in cases:
- workflow_groups = case.get("workflow_groups")
- has_capability_draft = isinstance(workflow_groups, list) and any(
- isinstance(group, dict)
- and isinstance(group.get("capability"), list)
- and any(
- isinstance(capability, dict) and "apply_to_draft" in capability
- for capability in group.get("capability", [])
- )
- for group in workflow_groups
- )
- if has_capability_draft:
- needs_grounding.append(case)
- print(f"Grounding apply_to for {len(needs_grounding)}/{len(cases)} cases...")
- if not needs_grounding:
- print(" No cases need grounding, skipping.")
- return {
- "total": len(cases),
- "grounded": 0,
- "total_cost": 0.0,
- "output_file": str(case_file),
- }
- semaphore = asyncio.Semaphore(max_concurrent)
- async def process_with_semaphore(case_item):
- async with semaphore:
- index = case_item.get("index", 0)
- raw = case_item.get("_raw", {})
- case_id = raw.get("case_id", "unknown")
- print(f" -> [{index}] [{case_id}] grounding apply_to...")
- grounded, cost = await ground_single_case(
- case_item, template, llm_call, model, use_api, compact_tree
- )
- print(f" <- [{index}] [{case_id}] grounded (cost=${cost:.4f})")
- return case_item, grounded, cost
- tasks = [process_with_semaphore(case) for case in needs_grounding]
- results_with_costs = await asyncio.gather(*tasks)
- # 用 grounded 结果替换原 case(按对象身份匹配,避免 index 缺失或重复时回填错 case)
- grounded_map = {}
- total_cost = 0.0
- for original_case, grounded, cost in results_with_costs:
- grounded_map[id(original_case)] = grounded
- total_cost += cost
- updated_cases = []
- for case in cases:
- case_id = id(case)
- if case_id in grounded_map:
- updated_cases.append(grounded_map[case_id])
- else:
- updated_cases.append(case)
- updated_cases.sort(key=lambda x: x.get("index", 0))
- case_data["cases"] = updated_cases
- case_file.parent.mkdir(parents=True, exist_ok=True)
- with open(case_file, "w", encoding="utf-8") as f:
- json.dump(case_data, f, ensure_ascii=False, indent=2)
- return {
- "total": len(cases),
- "grounded": len(needs_grounding),
- "total_cost": total_cost,
- "output_file": str(case_file),
- }
- if __name__ == "__main__":
- import sys
- if len(sys.argv) < 2:
- print("Usage: python apply_to_grounding.py <output_dir>")
- sys.exit(1)
- print("Please use this module through run_pipeline.py")
|