apply_to_grounding.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. """
  2. Stage 2: 将 apply_to_draft 映射为正式 apply_to
  3. 从 case.json 读取,只对每个 case 的 workflow_groups[*].capability 中的 apply_to_draft 做映射。
  4. 调用 LLM 映射到内容树的正式节点,原位回填到 case.json
  5. 改造版本:通过远程 API 获取内容树,不再依赖本地文件
  6. """
  7. import asyncio
  8. import json
  9. import os
  10. from pathlib import Path
  11. from typing import Any, Dict, List, Optional
  12. import httpx
  13. from dotenv import load_dotenv
  14. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  15. # 加载环境变量
  16. PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent
  17. load_dotenv(PROJECT_ROOT / ".env")
  18. # 搜索 API 配置
  19. SEARCH_API = os.getenv("SEARCH_API", "http://8.147.104.190:8001").rstrip("/")
  20. CAPABILITY_GROUNDING_BATCH_SIZE = int(os.getenv("CAPABILITY_GROUNDING_BATCH_SIZE", "8"))
  21. # 本地文件路径(作为回退方案)
  22. EXTRACT_DIR = Path(__file__).resolve().parent / "resource"
  23. CATEGORY_TREE_PATH = EXTRACT_DIR / "category_tree_56.json"
  24. def load_prompt_template(prompt_name: str) -> str:
  25. """加载 prompt 模板"""
  26. base_dir = Path(__file__).parent.parent
  27. prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
  28. with open(prompt_path, "r", encoding="utf-8") as f:
  29. content = f.read()
  30. if content.startswith("---"):
  31. parts = content.split("---", 2)
  32. if len(parts) >= 3:
  33. content = parts[2]
  34. content = content.replace("$system$", "").replace("$user$", "")
  35. return content.strip()
  36. def load_category_tree_from_local() -> List[Dict]:
  37. """从本地文件加载内容树(回退方案)"""
  38. if not CATEGORY_TREE_PATH.exists():
  39. raise FileNotFoundError(f"Category tree not found: {CATEGORY_TREE_PATH}")
  40. with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
  41. data = json.load(f)
  42. cats = data.get("categories", [])
  43. if not cats:
  44. raise RuntimeError("category_tree is empty")
  45. return cats
  46. async def load_category_tree(use_api: bool = False) -> List[Dict]:
  47. """
  48. 加载内容树(支持本地文件和远程 API)
  49. Args:
  50. use_api: 是否使用远程 API(默认 False,使用本地文件)
  51. Returns:
  52. 内容树节点列表
  53. """
  54. if use_api:
  55. try:
  56. print("Attempting to load category tree from API...")
  57. return await fetch_category_tree_from_api()
  58. except Exception as e:
  59. print(f"API failed ({e}), falling back to local file...")
  60. return load_category_tree_from_local()
  61. else:
  62. print("Loading category tree from local file...")
  63. return load_category_tree_from_local()
  64. async def search_categories_by_keywords(
  65. keywords: List[str],
  66. source_types: List[str] = None,
  67. top_k: int = 10,
  68. timeout: int = 10
  69. ) -> List[Dict]:
  70. """
  71. 根据关键词搜索相关分类节点
  72. Args:
  73. keywords: 搜索关键词列表
  74. source_types: 来源类型列表,默认 ["形式", "实质"]
  75. top_k: 每个关键词返回的结果数
  76. timeout: 请求超时时间(秒)
  77. Returns:
  78. 相关分类节点列表
  79. """
  80. if source_types is None:
  81. source_types = ["形式", "实质"]
  82. all_categories = []
  83. seen_ids = set()
  84. async with httpx.AsyncClient(timeout=timeout) as client:
  85. for keyword in keywords:
  86. for source_type in source_types:
  87. try:
  88. params = {
  89. "q": keyword,
  90. "source_type": source_type,
  91. "entity_type": "category",
  92. "top_k": top_k,
  93. "mode": "vector"
  94. }
  95. resp = await client.get(f"{SEARCH_API}/api/search", params=params)
  96. resp.raise_for_status()
  97. data = resp.json()
  98. results = data.get("results", [])
  99. # 转换为内容树格式,去重
  100. for item in results:
  101. entity_id = item.get("entity_id")
  102. if entity_id and entity_id not in seen_ids:
  103. category = {
  104. "id": entity_id,
  105. "path": item.get("path", ""),
  106. "source_type": source_type,
  107. "description": item.get("description", ""),
  108. "elements": []
  109. }
  110. all_categories.append(category)
  111. seen_ids.add(entity_id)
  112. except Exception as e:
  113. # 静默失败,继续处理其他关键词
  114. continue
  115. return all_categories
  116. def extract_keywords_from_draft(draft_text: str) -> List[str]:
  117. """
  118. 从 apply_to_draft 文本中提取关键词
  119. Args:
  120. draft_text: apply_to_draft 的文本内容
  121. Returns:
  122. 关键词列表
  123. """
  124. if not draft_text or not isinstance(draft_text, str):
  125. return []
  126. # 简单的关键词提取:分词并过滤
  127. import re
  128. # 移除标点符号,按空格和常见分隔符分词
  129. words = re.split(r'[,。、;:!?\s]+', draft_text)
  130. # 过滤短词和停用词
  131. keywords = [w.strip() for w in words if len(w.strip()) >= 2]
  132. # 去重并限制数量
  133. keywords = list(dict.fromkeys(keywords))[:5] # 最多5个关键词
  134. return keywords
  135. def build_compact_tree(cats: List[Dict]) -> str:
  136. """构建紧凑版内容树(用于注入 prompt)"""
  137. rows = []
  138. for c in cats:
  139. if c.get("source_type") not in ("实质", "形式"):
  140. continue
  141. row = {
  142. "id": c.get("id"),
  143. "path": c.get("path"),
  144. "source_type": c.get("source_type"),
  145. "description": c.get("description"),
  146. }
  147. elems = c.get("elements", [])
  148. if isinstance(elems, list) and elems:
  149. elem_names = [
  150. e.get("name") if isinstance(e, dict) else e
  151. for e in elems
  152. if e
  153. ]
  154. if elem_names:
  155. row["elements"] = elem_names
  156. rows.append(row)
  157. return json.dumps(rows, ensure_ascii=False, separators=(",", ":"))
  158. def build_valid_ids(cats: List[Dict]) -> Dict[int, Dict]:
  159. """构建 id -> node 映射"""
  160. return {c["id"]: c for c in cats if "id" in c}
  161. def iter_batches(items: List[Any], batch_size: int) -> List[List[Any]]:
  162. """按固定大小切分列表"""
  163. batch_size = max(1, batch_size)
  164. return [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
  165. def build_capability_grounding_input(capability: Dict[str, Any]) -> Dict[str, Any]:
  166. """只保留 capability grounding 需要的最小字段"""
  167. return {
  168. "capability_id": capability.get("capability_id"),
  169. "body": capability.get("body") or "",
  170. "apply_to_draft": capability.get("apply_to_draft", {}),
  171. }
  172. def body_excerpts_are_verbatim(apply_to: Any, suggest_apply_to: Any, body: str) -> bool:
  173. """确认非空 body_excerpt 都逐字来自 capability.body;空字符串允许存在。"""
  174. def _entry_is_valid(item: Any) -> bool:
  175. if not isinstance(item, dict):
  176. return False
  177. excerpt = item.get("body_excerpt")
  178. note = item.get("body_excerpt_note")
  179. if not isinstance(excerpt, str) or not isinstance(note, str):
  180. return False
  181. if not excerpt.strip() and note.strip():
  182. return False
  183. if excerpt.strip() and excerpt.strip() not in body:
  184. return False
  185. return True
  186. if not isinstance(apply_to, dict) or not isinstance(suggest_apply_to, list) or not isinstance(body, str):
  187. return False
  188. for source_type in ("实质", "形式"):
  189. items = apply_to.get(source_type, [])
  190. if not isinstance(items, list):
  191. return False
  192. for item in items:
  193. if not _entry_is_valid(item):
  194. return False
  195. for item in suggest_apply_to:
  196. if not _entry_is_valid(item):
  197. return False
  198. return True
  199. def render_grounding_prompt(
  200. template: str,
  201. task: str,
  202. draft: Dict,
  203. compact_tree: str,
  204. reference_paths: List[str] = None,
  205. ) -> str:
  206. """渲染 Stage 2 prompt"""
  207. target = "capability 数组中的每一条 capability"
  208. paths_str = json.dumps(reference_paths or [], ensure_ascii=False)
  209. return (
  210. template
  211. .replace("{target}", target)
  212. .replace("{compact_tree}", compact_tree)
  213. .replace("{reference_paths}", paths_str)
  214. .replace("{draft_json}", json.dumps(draft, ensure_ascii=False, indent=2))
  215. )
  216. async def ground_single_case(
  217. case_item: Dict[str, Any],
  218. template: str,
  219. llm_call: Any,
  220. model: str,
  221. use_api: bool = False,
  222. compact_tree: str = None,
  223. ) -> tuple[Dict[str, Any], float]:
  224. """
  225. 对单个 case 的 workflow_groups[*].capability[*].apply_to_draft 做 apply_to 映射。
  226. """
  227. total_cost = 0.0
  228. result = dict(case_item)
  229. title = case_item.get("title", "")[:20] or "untitled"
  230. workflow_groups = case_item.get("workflow_groups")
  231. if not isinstance(workflow_groups, list) or not workflow_groups:
  232. return result, total_cost
  233. updated_groups = [
  234. dict(group) if isinstance(group, dict) else group
  235. for group in workflow_groups
  236. ]
  237. for group_idx, group in enumerate(updated_groups):
  238. if not isinstance(group, dict):
  239. continue
  240. workflow_id = group.get("workflow_id") or f"g{group_idx + 1}"
  241. capability_items = group.get("capability")
  242. if not isinstance(capability_items, list) or not capability_items:
  243. continue
  244. draft_capability_pairs = [
  245. (idx, capability)
  246. for idx, capability in enumerate(capability_items)
  247. if isinstance(capability, dict) and "apply_to_draft" in capability
  248. ]
  249. if not draft_capability_pairs:
  250. continue
  251. # 收集 capability 的关键词(用于 API 搜索)
  252. if use_api:
  253. all_keywords = []
  254. for _, capability in draft_capability_pairs:
  255. apply_to_draft = capability.get("apply_to_draft", {})
  256. for key in ["实质", "形式"]:
  257. for draft_text in apply_to_draft.get(key, []):
  258. all_keywords.extend(extract_keywords_from_draft(draft_text))
  259. all_keywords = list(dict.fromkeys(all_keywords))[:10]
  260. if all_keywords:
  261. categories = await search_categories_by_keywords(all_keywords, top_k=5)
  262. capability_compact_tree = build_compact_tree(categories)
  263. capability_ref_paths = list(dict.fromkeys(
  264. c["path"] for c in categories if c.get("path")
  265. ))
  266. else:
  267. capability_compact_tree = compact_tree or "[]"
  268. capability_ref_paths = []
  269. else:
  270. capability_compact_tree = compact_tree or "[]"
  271. capability_ref_paths = []
  272. updated_capabilities = [
  273. dict(capability) if isinstance(capability, dict) else capability
  274. for capability in capability_items
  275. ]
  276. id_to_index = {
  277. capability.get("capability_id"): idx
  278. for idx, capability in draft_capability_pairs
  279. if isinstance(capability.get("capability_id"), str)
  280. }
  281. batches = iter_batches(draft_capability_pairs, CAPABILITY_GROUNDING_BATCH_SIZE)
  282. for batch_idx, batch_pairs in enumerate(batches, start=1):
  283. draft_capabilities = [
  284. build_capability_grounding_input(capability)
  285. for _, capability in batch_pairs
  286. ]
  287. draft = {"capability": draft_capabilities}
  288. prompt = render_grounding_prompt(template, "capability", draft, capability_compact_tree, capability_ref_paths)
  289. messages = [{"role": "user", "content": prompt}]
  290. grounded, cost = await call_llm_with_retry(
  291. llm_call=llm_call,
  292. messages=messages,
  293. model=model,
  294. temperature=0.1,
  295. max_tokens=4000,
  296. max_retries=3,
  297. schema_name="apply_to_grounding_capability",
  298. task_name=f"Ground_C_{title}_{workflow_id}_B{batch_idx}/{len(batches)}",
  299. )
  300. total_cost += cost
  301. if not grounded or not isinstance(grounded.get("capability"), list):
  302. continue
  303. grounded_capabilities = grounded["capability"]
  304. used_indices = set()
  305. for output_idx, grounded_capability in enumerate(grounded_capabilities):
  306. if not isinstance(grounded_capability, dict):
  307. continue
  308. capability_idx = None
  309. capability_id = grounded_capability.get("capability_id")
  310. if isinstance(capability_id, str):
  311. capability_idx = id_to_index.get(capability_id)
  312. if capability_idx is None and output_idx < len(batch_pairs):
  313. capability_idx = batch_pairs[output_idx][0]
  314. if capability_idx is None or capability_idx in used_indices:
  315. continue
  316. apply_to = grounded_capability.get("apply_to")
  317. suggest_apply_to = grounded_capability.get("suggest_apply_to")
  318. body = updated_capabilities[capability_idx].get("body", "")
  319. if (
  320. apply_to is not None
  321. and isinstance(suggest_apply_to, list)
  322. and len(suggest_apply_to) <= 3
  323. and isinstance(updated_capabilities[capability_idx], dict)
  324. and body_excerpts_are_verbatim(apply_to, suggest_apply_to, body)
  325. ):
  326. updated_capabilities[capability_idx]["apply_to"] = apply_to
  327. updated_capabilities[capability_idx]["suggest_apply_to"] = suggest_apply_to
  328. updated_capabilities[capability_idx].pop("apply_to_draft", None)
  329. used_indices.add(capability_idx)
  330. else:
  331. print(
  332. f" ⚠️ Skip capability grounding writeback: "
  333. f"{capability_id or capability_idx} has missing/non-verbatim body_excerpt"
  334. )
  335. group["capability"] = updated_capabilities
  336. result["workflow_groups"] = updated_groups
  337. return result, total_cost
  338. async def apply_grounding(
  339. case_file: Path,
  340. llm_call: Any,
  341. model: str = "openai/gpt-5.4",
  342. max_concurrent: int = 3,
  343. ) -> Dict[str, Any]:
  344. """
  345. 按 index 遍历 case.json,将 apply_to_draft 映射为 apply_to
  346. Args:
  347. case_file: case.json 文件路径
  348. llm_call: LLM 调用函数
  349. model: 模型名称
  350. max_concurrent: 最大并发数
  351. Returns:
  352. 包含统计信息的字典
  353. """
  354. with open(case_file, "r", encoding="utf-8") as f:
  355. case_data = json.load(f)
  356. cases = case_data.get("cases", [])
  357. # 检查是否使用 API 动态搜索模式
  358. use_api = os.getenv("USE_SEARCH_API", "true").lower() == "true"
  359. # 如果不使用 API,预加载完整内容树
  360. compact_tree = None
  361. if not use_api:
  362. cats = await load_category_tree(use_api=False)
  363. compact_tree = build_compact_tree(cats)
  364. print(f"Loaded category tree: {len(cats)} nodes, compact={len(compact_tree)} chars")
  365. else:
  366. print(f"Using API dynamic search mode (searching on-demand)")
  367. # 加载 prompt 模板
  368. template = load_prompt_template("apply_to_grounding")
  369. # 过滤出需要处理的 case(只看 workflow_groups[*].capability[*].apply_to_draft)
  370. needs_grounding = []
  371. for case in cases:
  372. workflow_groups = case.get("workflow_groups")
  373. has_capability_draft = isinstance(workflow_groups, list) and any(
  374. isinstance(group, dict)
  375. and isinstance(group.get("capability"), list)
  376. and any(
  377. isinstance(capability, dict) and "apply_to_draft" in capability
  378. for capability in group.get("capability", [])
  379. )
  380. for group in workflow_groups
  381. )
  382. if has_capability_draft:
  383. needs_grounding.append(case)
  384. print(f"Grounding apply_to for {len(needs_grounding)}/{len(cases)} cases...")
  385. if not needs_grounding:
  386. print(" No cases need grounding, skipping.")
  387. return {
  388. "total": len(cases),
  389. "grounded": 0,
  390. "total_cost": 0.0,
  391. "output_file": str(case_file),
  392. }
  393. semaphore = asyncio.Semaphore(max_concurrent)
  394. async def process_with_semaphore(case_item):
  395. async with semaphore:
  396. index = case_item.get("index", 0)
  397. raw = case_item.get("_raw", {})
  398. case_id = raw.get("case_id", "unknown")
  399. print(f" -> [{index}] [{case_id}] grounding apply_to...")
  400. grounded, cost = await ground_single_case(
  401. case_item, template, llm_call, model, use_api, compact_tree
  402. )
  403. print(f" <- [{index}] [{case_id}] grounded (cost=${cost:.4f})")
  404. return case_item, grounded, cost
  405. tasks = [process_with_semaphore(case) for case in needs_grounding]
  406. results_with_costs = await asyncio.gather(*tasks)
  407. # 用 grounded 结果替换原 case(按对象身份匹配,避免 index 缺失或重复时回填错 case)
  408. grounded_map = {}
  409. total_cost = 0.0
  410. for original_case, grounded, cost in results_with_costs:
  411. grounded_map[id(original_case)] = grounded
  412. total_cost += cost
  413. updated_cases = []
  414. for case in cases:
  415. case_id = id(case)
  416. if case_id in grounded_map:
  417. updated_cases.append(grounded_map[case_id])
  418. else:
  419. updated_cases.append(case)
  420. updated_cases.sort(key=lambda x: x.get("index", 0))
  421. case_data["cases"] = updated_cases
  422. case_file.parent.mkdir(parents=True, exist_ok=True)
  423. with open(case_file, "w", encoding="utf-8") as f:
  424. json.dump(case_data, f, ensure_ascii=False, indent=2)
  425. return {
  426. "total": len(cases),
  427. "grounded": len(needs_grounding),
  428. "total_cost": total_cost,
  429. "output_file": str(case_file),
  430. }
  431. if __name__ == "__main__":
  432. import sys
  433. if len(sys.argv) < 2:
  434. print("Usage: python apply_to_grounding.py <output_dir>")
  435. sys.exit(1)
  436. print("Please use this module through run_pipeline.py")