apply_to_grounding.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. """
  2. Stage 2: 将 apply_to_draft 映射为正式 apply_to
  3. 从 case.json 读取,只对每个 case 的 workflow_groups[*].capability 中的 apply_to_draft 做映射。
  4. 调用 LLM 映射到内容树的正式节点,原位回填到 case.json
  5. 改造版本:通过远程 API 获取内容树,不再依赖本地文件
  6. """
  7. import asyncio
  8. import json
  9. import os
  10. from pathlib import Path
  11. from typing import Any, Dict, List, Optional
  12. import httpx
  13. from dotenv import load_dotenv
  14. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  15. # 加载环境变量
  16. PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent
  17. load_dotenv(PROJECT_ROOT / ".env")
  18. # 搜索 API 配置
  19. SEARCH_API = os.getenv("SEARCH_API", "http://8.147.104.190:8001").rstrip("/")
  20. CAPABILITY_GROUNDING_BATCH_SIZE = int(os.getenv("CAPABILITY_GROUNDING_BATCH_SIZE", "8"))
  21. # 本地文件路径(作为回退方案)
  22. EXTRACT_DIR = Path(__file__).resolve().parent / "resource"
  23. CATEGORY_TREE_PATH = EXTRACT_DIR / "category_tree_56.json"
  24. def load_prompt_template(prompt_name: str) -> str:
  25. """加载 prompt 模板"""
  26. base_dir = Path(__file__).parent.parent
  27. prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
  28. with open(prompt_path, "r", encoding="utf-8") as f:
  29. content = f.read()
  30. if content.startswith("---"):
  31. parts = content.split("---", 2)
  32. if len(parts) >= 3:
  33. content = parts[2]
  34. content = content.replace("$system$", "").replace("$user$", "")
  35. return content.strip()
  36. def load_category_tree_from_local() -> List[Dict]:
  37. """从本地文件加载内容树(回退方案)"""
  38. if not CATEGORY_TREE_PATH.exists():
  39. raise FileNotFoundError(f"Category tree not found: {CATEGORY_TREE_PATH}")
  40. with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
  41. data = json.load(f)
  42. cats = data.get("categories", [])
  43. if not cats:
  44. raise RuntimeError("category_tree is empty")
  45. return cats
  46. async def load_category_tree(use_api: bool = False) -> List[Dict]:
  47. """
  48. 加载内容树(支持本地文件和远程 API)
  49. Args:
  50. use_api: 是否使用远程 API(默认 False,使用本地文件)
  51. Returns:
  52. 内容树节点列表
  53. """
  54. if use_api:
  55. try:
  56. print("Attempting to load category tree from API...")
  57. return await fetch_category_tree_from_api()
  58. except Exception as e:
  59. print(f"API failed ({e}), falling back to local file...")
  60. return load_category_tree_from_local()
  61. else:
  62. print("Loading category tree from local file...")
  63. return load_category_tree_from_local()
  64. async def search_categories_by_keywords(
  65. keywords: List[str],
  66. source_types: List[str] = None,
  67. top_k: int = 10,
  68. timeout: int = 10
  69. ) -> List[Dict]:
  70. """
  71. 根据关键词搜索相关分类节点
  72. Args:
  73. keywords: 搜索关键词列表
  74. source_types: 来源类型列表,默认 ["形式", "实质"]
  75. top_k: 每个关键词返回的结果数
  76. timeout: 请求超时时间(秒)
  77. Returns:
  78. 相关分类节点列表
  79. """
  80. if source_types is None:
  81. source_types = ["形式", "实质"]
  82. all_categories = []
  83. seen_ids = set()
  84. async with httpx.AsyncClient(timeout=timeout) as client:
  85. for keyword in keywords:
  86. for source_type in source_types:
  87. try:
  88. params = {
  89. "q": keyword,
  90. "source_type": source_type,
  91. "entity_type": "category",
  92. "top_k": top_k,
  93. "mode": "vector"
  94. }
  95. resp = await client.get(f"{SEARCH_API}/api/search", params=params)
  96. resp.raise_for_status()
  97. data = resp.json()
  98. results = data.get("results", [])
  99. # 转换为内容树格式,去重
  100. for item in results:
  101. entity_id = item.get("entity_id")
  102. if entity_id and entity_id not in seen_ids:
  103. category = {
  104. "id": entity_id,
  105. "path": item.get("path", ""),
  106. "source_type": source_type,
  107. "description": item.get("description", ""),
  108. "elements": []
  109. }
  110. all_categories.append(category)
  111. seen_ids.add(entity_id)
  112. except Exception as e:
  113. # 静默失败,继续处理其他关键词
  114. continue
  115. return all_categories
  116. def extract_keywords_from_draft(draft_text: str) -> List[str]:
  117. """
  118. 从 apply_to_draft 文本中提取关键词
  119. Args:
  120. draft_text: apply_to_draft 的文本内容
  121. Returns:
  122. 关键词列表
  123. """
  124. if not draft_text or not isinstance(draft_text, str):
  125. return []
  126. # 简单的关键词提取:分词并过滤
  127. import re
  128. # 移除标点符号,按空格和常见分隔符分词
  129. words = re.split(r'[,。、;:!?\s]+', draft_text)
  130. # 过滤短词和停用词
  131. keywords = [w.strip() for w in words if len(w.strip()) >= 2]
  132. # 去重并限制数量
  133. keywords = list(dict.fromkeys(keywords))[:5] # 最多5个关键词
  134. return keywords
  135. def build_compact_tree(cats: List[Dict]) -> str:
  136. """构建紧凑版内容树(用于注入 prompt)"""
  137. rows = []
  138. for c in cats:
  139. if c.get("source_type") not in ("实质", "形式"):
  140. continue
  141. row = {
  142. "id": c.get("id"),
  143. "path": c.get("path"),
  144. "source_type": c.get("source_type"),
  145. "description": c.get("description"),
  146. }
  147. elems = c.get("elements", [])
  148. if isinstance(elems, list) and elems:
  149. elem_names = [
  150. e.get("name") if isinstance(e, dict) else e
  151. for e in elems
  152. if e
  153. ]
  154. if elem_names:
  155. row["elements"] = elem_names
  156. rows.append(row)
  157. return json.dumps(rows, ensure_ascii=False, separators=(",", ":"))
  158. def build_valid_ids(cats: List[Dict]) -> Dict[int, Dict]:
  159. """构建 id -> node 映射"""
  160. return {c["id"]: c for c in cats if "id" in c}
  161. def iter_batches(items: List[Any], batch_size: int) -> List[List[Any]]:
  162. """按固定大小切分列表"""
  163. batch_size = max(1, batch_size)
  164. return [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
  165. def build_capability_grounding_input(capability: Dict[str, Any]) -> Dict[str, Any]:
  166. """只保留 capability grounding 需要的最小字段"""
  167. return {
  168. "capability_id": capability.get("capability_id"),
  169. "body": capability.get("body") or "",
  170. "apply_to_draft": capability.get("apply_to_draft", {}),
  171. }
  172. def body_excerpts_are_verbatim(apply_to: Any, suggest_apply_to: Any, body: str) -> bool:
  173. """确认非空 body_excerpt 都逐字来自 capability.body;空字符串允许存在。"""
  174. def _entry_is_valid(item: Any) -> bool:
  175. if not isinstance(item, dict):
  176. return False
  177. excerpt = item.get("body_excerpt")
  178. note = item.get("body_excerpt_note")
  179. excerpt_type = item.get("body_excerpt_type")
  180. if not isinstance(excerpt, str) or not isinstance(note, str) or not isinstance(excerpt_type, str):
  181. return False
  182. if not excerpt.strip() and (note.strip() or excerpt_type.strip()):
  183. return False
  184. if excerpt.strip() and excerpt.strip() not in body:
  185. return False
  186. return True
  187. if not isinstance(apply_to, dict) or not isinstance(suggest_apply_to, list) or not isinstance(body, str):
  188. return False
  189. for source_type in ("实质", "形式"):
  190. items = apply_to.get(source_type, [])
  191. if not isinstance(items, list):
  192. return False
  193. for item in items:
  194. if not _entry_is_valid(item):
  195. return False
  196. for item in suggest_apply_to:
  197. if not _entry_is_valid(item):
  198. return False
  199. return True
  200. def render_grounding_prompt(
  201. template: str,
  202. task: str,
  203. draft: Dict,
  204. compact_tree: str,
  205. reference_paths: List[str] = None,
  206. ) -> str:
  207. """渲染 Stage 2 prompt"""
  208. target = "capability 数组中的每一条 capability"
  209. paths_str = json.dumps(reference_paths or [], ensure_ascii=False)
  210. return (
  211. template
  212. .replace("{target}", target)
  213. .replace("{compact_tree}", compact_tree)
  214. .replace("{reference_paths}", paths_str)
  215. .replace("{draft_json}", json.dumps(draft, ensure_ascii=False, indent=2))
  216. )
  217. async def ground_single_case(
  218. case_item: Dict[str, Any],
  219. template: str,
  220. llm_call: Any,
  221. model: str,
  222. use_api: bool = False,
  223. compact_tree: str = None,
  224. ) -> tuple[Dict[str, Any], float]:
  225. """
  226. 对单个 case 的 workflow_groups[*].capability[*].apply_to_draft 做 apply_to 映射。
  227. """
  228. total_cost = 0.0
  229. result = dict(case_item)
  230. title = case_item.get("title", "")[:20] or "untitled"
  231. workflow_groups = case_item.get("workflow_groups")
  232. if not isinstance(workflow_groups, list) or not workflow_groups:
  233. return result, total_cost
  234. updated_groups = [
  235. dict(group) if isinstance(group, dict) else group
  236. for group in workflow_groups
  237. ]
  238. for group_idx, group in enumerate(updated_groups):
  239. if not isinstance(group, dict):
  240. continue
  241. workflow_id = group.get("workflow_id") or f"g{group_idx + 1}"
  242. capability_items = group.get("capability")
  243. if not isinstance(capability_items, list) or not capability_items:
  244. continue
  245. draft_capability_pairs = [
  246. (idx, capability)
  247. for idx, capability in enumerate(capability_items)
  248. if isinstance(capability, dict) and "apply_to_draft" in capability
  249. ]
  250. if not draft_capability_pairs:
  251. continue
  252. # 收集 capability 的关键词(用于 API 搜索)
  253. if use_api:
  254. all_keywords = []
  255. for _, capability in draft_capability_pairs:
  256. apply_to_draft = capability.get("apply_to_draft", {})
  257. for key in ["实质", "形式"]:
  258. for draft_text in apply_to_draft.get(key, []):
  259. all_keywords.extend(extract_keywords_from_draft(draft_text))
  260. all_keywords = list(dict.fromkeys(all_keywords))[:10]
  261. if all_keywords:
  262. categories = await search_categories_by_keywords(all_keywords, top_k=5)
  263. capability_compact_tree = build_compact_tree(categories)
  264. capability_ref_paths = list(dict.fromkeys(
  265. c["path"] for c in categories if c.get("path")
  266. ))
  267. else:
  268. capability_compact_tree = compact_tree or "[]"
  269. capability_ref_paths = []
  270. else:
  271. capability_compact_tree = compact_tree or "[]"
  272. capability_ref_paths = []
  273. updated_capabilities = [
  274. dict(capability) if isinstance(capability, dict) else capability
  275. for capability in capability_items
  276. ]
  277. id_to_index = {
  278. capability.get("capability_id"): idx
  279. for idx, capability in draft_capability_pairs
  280. if isinstance(capability.get("capability_id"), str)
  281. }
  282. batches = iter_batches(draft_capability_pairs, CAPABILITY_GROUNDING_BATCH_SIZE)
  283. for batch_idx, batch_pairs in enumerate(batches, start=1):
  284. draft_capabilities = [
  285. build_capability_grounding_input(capability)
  286. for _, capability in batch_pairs
  287. ]
  288. draft = {"capability": draft_capabilities}
  289. prompt = render_grounding_prompt(template, "capability", draft, capability_compact_tree, capability_ref_paths)
  290. messages = [{"role": "user", "content": prompt}]
  291. grounded, cost = await call_llm_with_retry(
  292. llm_call=llm_call,
  293. messages=messages,
  294. model=model,
  295. temperature=0.1,
  296. max_tokens=4000,
  297. max_retries=3,
  298. schema_name="apply_to_grounding_capability",
  299. task_name=f"Ground_C_{title}_{workflow_id}_B{batch_idx}/{len(batches)}",
  300. )
  301. total_cost += cost
  302. if not grounded or not isinstance(grounded.get("capability"), list):
  303. continue
  304. grounded_capabilities = grounded["capability"]
  305. used_indices = set()
  306. for output_idx, grounded_capability in enumerate(grounded_capabilities):
  307. if not isinstance(grounded_capability, dict):
  308. continue
  309. capability_idx = None
  310. capability_id = grounded_capability.get("capability_id")
  311. if isinstance(capability_id, str):
  312. capability_idx = id_to_index.get(capability_id)
  313. if capability_idx is None and output_idx < len(batch_pairs):
  314. capability_idx = batch_pairs[output_idx][0]
  315. if capability_idx is None or capability_idx in used_indices:
  316. continue
  317. apply_to = grounded_capability.get("apply_to")
  318. suggest_apply_to = grounded_capability.get("suggest_apply_to")
  319. body = updated_capabilities[capability_idx].get("body", "")
  320. if (
  321. apply_to is not None
  322. and isinstance(suggest_apply_to, list)
  323. and len(suggest_apply_to) <= 3
  324. and isinstance(updated_capabilities[capability_idx], dict)
  325. and body_excerpts_are_verbatim(apply_to, suggest_apply_to, body)
  326. ):
  327. updated_capabilities[capability_idx]["apply_to"] = apply_to
  328. updated_capabilities[capability_idx]["suggest_apply_to"] = suggest_apply_to
  329. updated_capabilities[capability_idx].pop("apply_to_draft", None)
  330. used_indices.add(capability_idx)
  331. else:
  332. print(
  333. f" ⚠️ Skip capability grounding writeback: "
  334. f"{capability_id or capability_idx} has missing/non-verbatim body_excerpt"
  335. )
  336. group["capability"] = updated_capabilities
  337. result["workflow_groups"] = updated_groups
  338. return result, total_cost
  339. async def apply_grounding(
  340. case_file: Path,
  341. llm_call: Any,
  342. model: str = "openai/gpt-5.4",
  343. max_concurrent: int = 3,
  344. ) -> Dict[str, Any]:
  345. """
  346. 按 index 遍历 case.json,将 apply_to_draft 映射为 apply_to
  347. Args:
  348. case_file: case.json 文件路径
  349. llm_call: LLM 调用函数
  350. model: 模型名称
  351. max_concurrent: 最大并发数
  352. Returns:
  353. 包含统计信息的字典
  354. """
  355. with open(case_file, "r", encoding="utf-8") as f:
  356. case_data = json.load(f)
  357. cases = case_data.get("cases", [])
  358. # 检查是否使用 API 动态搜索模式
  359. use_api = os.getenv("USE_SEARCH_API", "true").lower() == "true"
  360. # 如果不使用 API,预加载完整内容树
  361. compact_tree = None
  362. if not use_api:
  363. cats = await load_category_tree(use_api=False)
  364. compact_tree = build_compact_tree(cats)
  365. print(f"Loaded category tree: {len(cats)} nodes, compact={len(compact_tree)} chars")
  366. else:
  367. print(f"Using API dynamic search mode (searching on-demand)")
  368. # 加载 prompt 模板
  369. template = load_prompt_template("apply_to_grounding")
  370. # 过滤出需要处理的 case(只看 workflow_groups[*].capability[*].apply_to_draft)
  371. needs_grounding = []
  372. for case in cases:
  373. workflow_groups = case.get("workflow_groups")
  374. has_capability_draft = isinstance(workflow_groups, list) and any(
  375. isinstance(group, dict)
  376. and isinstance(group.get("capability"), list)
  377. and any(
  378. isinstance(capability, dict) and "apply_to_draft" in capability
  379. for capability in group.get("capability", [])
  380. )
  381. for group in workflow_groups
  382. )
  383. if has_capability_draft:
  384. needs_grounding.append(case)
  385. print(f"Grounding apply_to for {len(needs_grounding)}/{len(cases)} cases...")
  386. if not needs_grounding:
  387. print(" No cases need grounding, skipping.")
  388. return {
  389. "total": len(cases),
  390. "grounded": 0,
  391. "total_cost": 0.0,
  392. "output_file": str(case_file),
  393. }
  394. semaphore = asyncio.Semaphore(max_concurrent)
  395. async def process_with_semaphore(case_item):
  396. async with semaphore:
  397. index = case_item.get("index", 0)
  398. raw = case_item.get("_raw", {})
  399. case_id = raw.get("case_id", "unknown")
  400. print(f" -> [{index}] [{case_id}] grounding apply_to...")
  401. grounded, cost = await ground_single_case(
  402. case_item, template, llm_call, model, use_api, compact_tree
  403. )
  404. print(f" <- [{index}] [{case_id}] grounded (cost=${cost:.4f})")
  405. return case_item, grounded, cost
  406. tasks = [process_with_semaphore(case) for case in needs_grounding]
  407. results_with_costs = await asyncio.gather(*tasks)
  408. # 用 grounded 结果替换原 case(按对象身份匹配,避免 index 缺失或重复时回填错 case)
  409. grounded_map = {}
  410. total_cost = 0.0
  411. for original_case, grounded, cost in results_with_costs:
  412. grounded_map[id(original_case)] = grounded
  413. total_cost += cost
  414. updated_cases = []
  415. for case in cases:
  416. case_id = id(case)
  417. if case_id in grounded_map:
  418. updated_cases.append(grounded_map[case_id])
  419. else:
  420. updated_cases.append(case)
  421. updated_cases.sort(key=lambda x: x.get("index", 0))
  422. case_data["cases"] = updated_cases
  423. case_file.parent.mkdir(parents=True, exist_ok=True)
  424. with open(case_file, "w", encoding="utf-8") as f:
  425. json.dump(case_data, f, ensure_ascii=False, indent=2)
  426. return {
  427. "total": len(cases),
  428. "grounded": len(needs_grounding),
  429. "total_cost": total_cost,
  430. "output_file": str(case_file),
  431. }
  432. if __name__ == "__main__":
  433. import sys
  434. if len(sys.argv) < 2:
  435. print("Usage: python apply_to_grounding.py <output_dir>")
  436. sys.exit(1)
  437. print("Please use this module through run_pipeline.py")