| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516 |
- """
- Stage 2: 将 apply_to_draft 映射为正式 apply_to
- 从 case.json 读取,对每个 case 的 workflow 和 capabilities 中的 apply_to_draft,
- 调用 LLM 映射到内容树的正式节点,按 index 原位回填到 case.json
- 改造版本:通过远程 API 获取内容树,不再依赖本地文件
- """
- import asyncio
- import json
- import os
- from pathlib import Path
- from typing import Any, Dict, List, Optional
- import httpx
- from dotenv import load_dotenv
- from examples.process_pipeline.script.llm_helper import call_llm_with_retry
- # 加载环境变量
- PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent
- load_dotenv(PROJECT_ROOT / ".env")
- # 搜索 API 配置
- SEARCH_API = os.getenv("SEARCH_API", "http://8.147.104.190:8001").rstrip("/")
- # 本地文件路径(作为回退方案)
- EXTRACT_DIR = Path(__file__).resolve().parent / "resource"
- CATEGORY_TREE_PATH = EXTRACT_DIR / "category_tree_56.json"
- def load_prompt_template(prompt_name: str) -> str:
- """加载 prompt 模板"""
- base_dir = Path(__file__).parent.parent
- prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
- with open(prompt_path, "r", encoding="utf-8") as f:
- content = f.read()
- if content.startswith("---"):
- parts = content.split("---", 2)
- if len(parts) >= 3:
- content = parts[2]
- content = content.replace("$system$", "").replace("$user$", "")
- return content.strip()
- def load_category_tree_from_local() -> List[Dict]:
- """从本地文件加载内容树(回退方案)"""
- if not CATEGORY_TREE_PATH.exists():
- raise FileNotFoundError(f"Category tree not found: {CATEGORY_TREE_PATH}")
- with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
- data = json.load(f)
- cats = data.get("categories", [])
- if not cats:
- raise RuntimeError("category_tree is empty")
- return cats
- async def load_category_tree(use_api: bool = False) -> List[Dict]:
- """
- 加载内容树(支持本地文件和远程 API)
- Args:
- use_api: 是否使用远程 API(默认 False,使用本地文件)
- Returns:
- 内容树节点列表
- """
- if use_api:
- try:
- print("Attempting to load category tree from API...")
- return await fetch_category_tree_from_api()
- except Exception as e:
- print(f"API failed ({e}), falling back to local file...")
- return load_category_tree_from_local()
- else:
- print("Loading category tree from local file...")
- return load_category_tree_from_local()
- async def search_categories_by_keywords(
- keywords: List[str],
- source_types: List[str] = None,
- top_k: int = 10,
- timeout: int = 10
- ) -> List[Dict]:
- """
- 根据关键词搜索相关分类节点
- Args:
- keywords: 搜索关键词列表
- source_types: 来源类型列表,默认 ["形式", "实质"]
- top_k: 每个关键词返回的结果数
- timeout: 请求超时时间(秒)
- Returns:
- 相关分类节点列表
- """
- if source_types is None:
- source_types = ["形式", "实质"]
- all_categories = []
- seen_ids = set()
- async with httpx.AsyncClient(timeout=timeout) as client:
- for keyword in keywords:
- for source_type in source_types:
- try:
- params = {
- "q": keyword,
- "source_type": source_type,
- "entity_type": "category",
- "top_k": top_k,
- "mode": "vector"
- }
- resp = await client.get(f"{SEARCH_API}/api/search", params=params)
- resp.raise_for_status()
- data = resp.json()
- results = data.get("results", [])
- # 转换为内容树格式,去重
- for item in results:
- entity_id = item.get("entity_id")
- if entity_id and entity_id not in seen_ids:
- category = {
- "id": entity_id,
- "path": item.get("path", ""),
- "source_type": source_type,
- "description": item.get("description", ""),
- "elements": []
- }
- all_categories.append(category)
- seen_ids.add(entity_id)
- except Exception as e:
- # 静默失败,继续处理其他关键词
- continue
- return all_categories
- def extract_keywords_from_draft(draft_text: str) -> List[str]:
- """
- 从 apply_to_draft 文本中提取关键词
- Args:
- draft_text: apply_to_draft 的文本内容
- Returns:
- 关键词列表
- """
- if not draft_text or not isinstance(draft_text, str):
- return []
- # 简单的关键词提取:分词并过滤
- import re
- # 移除标点符号,按空格和常见分隔符分词
- words = re.split(r'[,。、;:!?\s]+', draft_text)
- # 过滤短词和停用词
- keywords = [w.strip() for w in words if len(w.strip()) >= 2]
- # 去重并限制数量
- keywords = list(dict.fromkeys(keywords))[:5] # 最多5个关键词
- return keywords
- def build_compact_tree(cats: List[Dict]) -> str:
- """构建紧凑版内容树(用于注入 prompt)"""
- rows = []
- for c in cats:
- if c.get("source_type") not in ("实质", "形式"):
- continue
- row = {
- "id": c.get("id"),
- "path": c.get("path"),
- "source_type": c.get("source_type"),
- "description": c.get("description"),
- }
- elems = c.get("elements", [])
- if isinstance(elems, list) and elems:
- elem_names = [
- e.get("name") if isinstance(e, dict) else e
- for e in elems
- if e
- ]
- if elem_names:
- row["elements"] = elem_names
- rows.append(row)
- return json.dumps(rows, ensure_ascii=False, separators=(",", ":"))
- def build_valid_ids(cats: List[Dict]) -> Dict[int, Dict]:
- """构建 id -> node 映射"""
- return {c["id"]: c for c in cats if "id" in c}
- def render_grounding_prompt(
- template: str,
- task: str,
- draft: Dict,
- compact_tree: str,
- reference_paths: List[str] = None,
- ) -> str:
- """渲染 Stage 2 prompt"""
- if task == "capability":
- target = "capabilities 数组中的每一条 capability"
- else:
- target = "strategy;如果 strategy 为 null,则原样返回"
- paths_str = json.dumps(reference_paths or [], ensure_ascii=False)
- return (
- template
- .replace("{target}", target)
- .replace("{compact_tree}", compact_tree)
- .replace("{reference_paths}", paths_str)
- .replace("{draft_json}", json.dumps(draft, ensure_ascii=False, indent=2))
- )
- async def ground_single_case(
- case_item: Dict[str, Any],
- template: str,
- llm_call: Any,
- model: str,
- use_api: bool = False,
- compact_tree: str = None,
- ) -> tuple[Dict[str, Any], float]:
- """
- 对单个 case 的 workflow 和 capabilities 做 apply_to 映射
- 对于 workflow:一次性处理整个 workflow,为每个 step 生成对应的 apply_to
- 对于 capabilities:对每个有 apply_to_draft 的 capability 进行映射
- Args:
- case_item: 案例数据
- template: prompt 模板
- llm_call: LLM 调用函数
- model: 模型名称
- use_api: 是否使用 API 动态搜索
- compact_tree: 预加载的完整内容树(use_api=False 时使用)
- """
- total_cost = 0.0
- result = dict(case_item)
- title = case_item.get("title", "")[:20] or "untitled"
- # 处理 workflow (strategy) - 整体处理,保持上下文
- workflow = case_item.get("workflow")
- if isinstance(workflow, dict) and "steps" in workflow:
- steps = workflow.get("steps", [])
- # 检查是否有任何 step 包含 apply_to_draft
- has_draft = any(
- isinstance(step, dict) and "apply_to_draft" in step
- for step in steps
- )
- if has_draft:
- # 收集所有 step 的关键词(用于 API 搜索)
- if use_api:
- all_keywords = []
- for step in steps:
- if isinstance(step, dict) and "apply_to_draft" in step:
- apply_to_draft = step.get("apply_to_draft", {})
- for key in ["实质", "形式"]:
- for draft_text in apply_to_draft.get(key, []):
- all_keywords.extend(extract_keywords_from_draft(draft_text))
- all_keywords = list(dict.fromkeys(all_keywords))[:10]
- if all_keywords:
- categories = await search_categories_by_keywords(all_keywords, top_k=5)
- workflow_compact_tree = build_compact_tree(categories)
- workflow_ref_paths = list(dict.fromkeys(
- c["path"] for c in categories if c.get("path")
- ))
- else:
- workflow_compact_tree = compact_tree or "[]"
- workflow_ref_paths = []
- else:
- workflow_compact_tree = compact_tree or "[]"
- workflow_ref_paths = []
- # 整个 workflow 传给 LLM(保持上下文)
- draft = {"strategy": workflow}
- prompt = render_grounding_prompt(template, "strategy", draft, workflow_compact_tree, workflow_ref_paths)
- messages = [{"role": "user", "content": prompt}]
- grounded, cost = await call_llm_with_retry(
- llm_call=llm_call,
- messages=messages,
- model=model,
- temperature=0.1,
- max_tokens=4000,
- max_retries=3,
- schema_name="apply_to_grounding_strategy",
- task_name=f"Ground_W_{title}",
- )
- total_cost += cost
- # 按 order 回填 apply_to
- if grounded and isinstance(grounded.get("strategy"), dict):
- grounded_steps = grounded["strategy"].get("steps", [])
- # 建立 order -> apply_to 的映射
- order_to_apply_to = {}
- for grounded_step in grounded_steps:
- if isinstance(grounded_step, dict):
- order = grounded_step.get("order")
- apply_to = grounded_step.get("apply_to")
- if order is not None and apply_to is not None:
- order_to_apply_to[order] = apply_to
- # 回填到原 steps
- updated_steps = []
- for step in steps:
- updated_step = dict(step)
- order = step.get("order")
- if order in order_to_apply_to:
- updated_step["apply_to"] = order_to_apply_to[order]
- updated_step.pop("apply_to_draft", None)
- updated_steps.append(updated_step)
- result["workflow"] = dict(workflow)
- result["workflow"]["steps"] = updated_steps
- # 处理 capabilities - 整体处理,保持上下文
- capabilities = case_item.get("capabilities")
- if isinstance(capabilities, list) and capabilities:
- has_draft = any(
- isinstance(cap, dict) and "apply_to_draft" in cap
- for cap in capabilities
- )
- if has_draft:
- # 收集所有 capability 的关键词
- if use_api:
- all_keywords = []
- for cap in capabilities:
- if isinstance(cap, dict) and "apply_to_draft" in cap:
- apply_to_draft = cap.get("apply_to_draft", {})
- for key in ["实质", "形式"]:
- for draft_text in apply_to_draft.get(key, []):
- all_keywords.extend(extract_keywords_from_draft(draft_text))
- all_keywords = list(dict.fromkeys(all_keywords))[:10]
- if all_keywords:
- categories = await search_categories_by_keywords(all_keywords, top_k=5)
- cap_compact_tree = build_compact_tree(categories)
- cap_ref_paths = list(dict.fromkeys(
- c["path"] for c in categories if c.get("path")
- ))
- else:
- cap_compact_tree = compact_tree or "[]"
- cap_ref_paths = []
- else:
- cap_compact_tree = compact_tree or "[]"
- cap_ref_paths = []
- # 整个 capabilities 传给 LLM(保持上下文)
- draft = {"capabilities": capabilities}
- prompt = render_grounding_prompt(template, "capability", draft, cap_compact_tree, cap_ref_paths)
- messages = [{"role": "user", "content": prompt}]
- grounded, cost = await call_llm_with_retry(
- llm_call=llm_call,
- messages=messages,
- model=model,
- temperature=0.1,
- max_tokens=4000,
- max_retries=3,
- schema_name="apply_to_grounding_capability",
- task_name=f"Ground_C_{title}",
- )
- total_cost += cost
- # 回填 apply_to(按索引匹配)
- if grounded and isinstance(grounded.get("capabilities"), list):
- grounded_caps = grounded["capabilities"]
- updated_capabilities = []
- for idx, cap in enumerate(capabilities):
- updated_cap = dict(cap)
- # 如果有对应的 grounded capability,提取 apply_to
- if idx < len(grounded_caps) and isinstance(grounded_caps[idx], dict):
- apply_to = grounded_caps[idx].get("apply_to")
- if apply_to is not None:
- updated_cap["apply_to"] = apply_to
- updated_cap.pop("apply_to_draft", None)
- updated_capabilities.append(updated_cap)
- result["capabilities"] = updated_capabilities
- return result, total_cost
- async def apply_grounding(
- case_file: Path,
- llm_call: Any,
- model: str = "openai/gpt-5.4",
- max_concurrent: int = 3,
- ) -> Dict[str, Any]:
- """
- 按 index 遍历 case.json,将 apply_to_draft 映射为 apply_to
- Args:
- case_file: case.json 文件路径
- llm_call: LLM 调用函数
- model: 模型名称
- max_concurrent: 最大并发数
- Returns:
- 包含统计信息的字典
- """
- with open(case_file, "r", encoding="utf-8") as f:
- case_data = json.load(f)
- cases = case_data.get("cases", [])
- # 检查是否使用 API 动态搜索模式
- use_api = os.getenv("USE_SEARCH_API", "true").lower() == "true"
- # 如果不使用 API,预加载完整内容树
- compact_tree = None
- if not use_api:
- cats = await load_category_tree(use_api=False)
- compact_tree = build_compact_tree(cats)
- print(f"Loaded category tree: {len(cats)} nodes, compact={len(compact_tree)} chars")
- else:
- print(f"Using API dynamic search mode (searching on-demand)")
- # 加载 prompt 模板
- template = load_prompt_template("apply_to_grounding")
- # 过滤出需要处理的 case(有 apply_to_draft 的)
- needs_grounding = []
- for case in cases:
- workflow = case.get("workflow")
- capabilities = case.get("capabilities")
- # 检查 step 级别的 apply_to_draft
- has_workflow_draft = (
- isinstance(workflow, dict) and
- any(
- isinstance(step, dict) and "apply_to_draft" in step
- for step in workflow.get("steps", [])
- )
- )
- has_cap_draft = isinstance(capabilities, list) and any(
- isinstance(c, dict) and "apply_to_draft" in c for c in capabilities
- )
- if has_workflow_draft or has_cap_draft:
- needs_grounding.append(case)
- print(f"Grounding apply_to for {len(needs_grounding)}/{len(cases)} cases...")
- if not needs_grounding:
- print(" No cases need grounding, skipping.")
- return {
- "total": len(cases),
- "grounded": 0,
- "total_cost": 0.0,
- "output_file": str(case_file),
- }
- semaphore = asyncio.Semaphore(max_concurrent)
- async def process_with_semaphore(case_item):
- async with semaphore:
- index = case_item.get("index", 0)
- raw = case_item.get("_raw", {})
- case_id = raw.get("case_id", "unknown")
- print(f" -> [{index}] [{case_id}] grounding apply_to...")
- grounded, cost = await ground_single_case(
- case_item, template, llm_call, model, use_api, compact_tree
- )
- print(f" <- [{index}] [{case_id}] grounded (cost=${cost:.4f})")
- return grounded, cost
- tasks = [process_with_semaphore(case) for case in needs_grounding]
- results_with_costs = await asyncio.gather(*tasks)
- # 用 grounded 结果替换原 case(按 index 匹配)
- grounded_map = {}
- total_cost = 0.0
- for grounded, cost in results_with_costs:
- grounded_map[grounded.get("index")] = grounded
- total_cost += cost
- updated_cases = []
- for case in cases:
- idx = case.get("index")
- if idx in grounded_map:
- updated_cases.append(grounded_map[idx])
- else:
- updated_cases.append(case)
- updated_cases.sort(key=lambda x: x.get("index", 0))
- case_data["cases"] = updated_cases
- case_file.parent.mkdir(parents=True, exist_ok=True)
- with open(case_file, "w", encoding="utf-8") as f:
- json.dump(case_data, f, ensure_ascii=False, indent=2)
- return {
- "total": len(cases),
- "grounded": len(needs_grounding),
- "total_cost": total_cost,
- "output_file": str(case_file),
- }
- if __name__ == "__main__":
- import sys
- if len(sys.argv) < 2:
- print("Usage: python apply_to_grounding.py <output_dir>")
- sys.exit(1)
- print("Please use this module through run_pipeline.py")
|