apply_to_grounding.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. """
  2. Stage 2: 将 apply_to_draft 映射为正式 apply_to
  3. 从 case.json 读取,只对每个 case 的 workflow_groups[*].capability 中的 apply_to_draft 做映射。
  4. 调用 LLM 映射到内容树的正式节点,原位回填到 case.json
  5. 改造版本:通过远程 API 获取内容树,不再依赖本地文件
  6. """
  7. import asyncio
  8. import json
  9. import os
  10. from pathlib import Path
  11. from typing import Any, Dict, List, Optional
  12. import httpx
  13. from dotenv import load_dotenv
  14. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  15. # 加载环境变量
  16. PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent
  17. load_dotenv(PROJECT_ROOT / ".env")
  18. # 搜索 API 配置
  19. SEARCH_API = os.getenv("SEARCH_API", "http://8.147.104.190:8001").rstrip("/")
  20. CAPABILITY_GROUNDING_BATCH_SIZE = int(os.getenv("CAPABILITY_GROUNDING_BATCH_SIZE", "8"))
  21. # 本地文件路径(作为回退方案)
  22. EXTRACT_DIR = Path(__file__).resolve().parent / "resource"
  23. CATEGORY_TREE_PATH = EXTRACT_DIR / "category_tree_56.json"
  24. def load_prompt_template(prompt_name: str) -> str:
  25. """加载 prompt 模板"""
  26. base_dir = Path(__file__).parent.parent
  27. prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
  28. with open(prompt_path, "r", encoding="utf-8") as f:
  29. content = f.read()
  30. if content.startswith("---"):
  31. parts = content.split("---", 2)
  32. if len(parts) >= 3:
  33. content = parts[2]
  34. content = content.replace("$system$", "").replace("$user$", "")
  35. return content.strip()
  36. def load_category_tree_from_local() -> List[Dict]:
  37. """从本地文件加载内容树(回退方案)"""
  38. if not CATEGORY_TREE_PATH.exists():
  39. raise FileNotFoundError(f"Category tree not found: {CATEGORY_TREE_PATH}")
  40. with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
  41. data = json.load(f)
  42. cats = data.get("categories", [])
  43. if not cats:
  44. raise RuntimeError("category_tree is empty")
  45. return cats
  46. async def load_category_tree(use_api: bool = False) -> List[Dict]:
  47. """
  48. 加载内容树(支持本地文件和远程 API)
  49. Args:
  50. use_api: 是否使用远程 API(默认 False,使用本地文件)
  51. Returns:
  52. 内容树节点列表
  53. """
  54. if use_api:
  55. try:
  56. print("Attempting to load category tree from API...")
  57. return await fetch_category_tree_from_api()
  58. except Exception as e:
  59. print(f"API failed ({e}), falling back to local file...")
  60. return load_category_tree_from_local()
  61. else:
  62. print("Loading category tree from local file...")
  63. return load_category_tree_from_local()
  64. async def search_categories_by_keywords(
  65. keywords: List[str],
  66. source_types: List[str] = None,
  67. top_k: int = 10,
  68. timeout: int = 10
  69. ) -> List[Dict]:
  70. """
  71. 根据关键词搜索相关分类节点
  72. Args:
  73. keywords: 搜索关键词列表
  74. source_types: 来源类型列表,默认 ["形式", "实质"]
  75. top_k: 每个关键词返回的结果数
  76. timeout: 请求超时时间(秒)
  77. Returns:
  78. 相关分类节点列表
  79. """
  80. if source_types is None:
  81. source_types = ["形式", "实质"]
  82. all_categories = []
  83. seen_ids = set()
  84. async with httpx.AsyncClient(timeout=timeout) as client:
  85. for keyword in keywords:
  86. for source_type in source_types:
  87. try:
  88. params = {
  89. "q": keyword,
  90. "source_type": source_type,
  91. "entity_type": "category",
  92. "top_k": top_k,
  93. "mode": "vector"
  94. }
  95. resp = await client.get(f"{SEARCH_API}/api/search", params=params)
  96. resp.raise_for_status()
  97. data = resp.json()
  98. results = data.get("results", [])
  99. # 转换为内容树格式,去重
  100. for item in results:
  101. entity_id = item.get("entity_id")
  102. if entity_id and entity_id not in seen_ids:
  103. category = {
  104. "id": entity_id,
  105. "path": item.get("path", ""),
  106. "source_type": source_type,
  107. "description": item.get("description", ""),
  108. "elements": []
  109. }
  110. all_categories.append(category)
  111. seen_ids.add(entity_id)
  112. except Exception as e:
  113. # 静默失败,继续处理其他关键词
  114. continue
  115. return all_categories
  116. def extract_keywords_from_draft(draft_text: str) -> List[str]:
  117. """
  118. 从 apply_to_draft 文本中提取关键词
  119. Args:
  120. draft_text: apply_to_draft 的文本内容
  121. Returns:
  122. 关键词列表
  123. """
  124. if not draft_text or not isinstance(draft_text, str):
  125. return []
  126. # 简单的关键词提取:分词并过滤
  127. import re
  128. # 移除标点符号,按空格和常见分隔符分词
  129. words = re.split(r'[,。、;:!?\s]+', draft_text)
  130. # 过滤短词和停用词
  131. keywords = [w.strip() for w in words if len(w.strip()) >= 2]
  132. # 去重并限制数量
  133. keywords = list(dict.fromkeys(keywords))[:5] # 最多5个关键词
  134. return keywords
  135. def build_compact_tree(cats: List[Dict]) -> str:
  136. """构建紧凑版内容树(用于注入 prompt)"""
  137. rows = []
  138. for c in cats:
  139. if c.get("source_type") not in ("实质", "形式"):
  140. continue
  141. row = {
  142. "id": c.get("id"),
  143. "path": c.get("path"),
  144. "source_type": c.get("source_type"),
  145. "description": c.get("description"),
  146. }
  147. elems = c.get("elements", [])
  148. if isinstance(elems, list) and elems:
  149. elem_names = [
  150. e.get("name") if isinstance(e, dict) else e
  151. for e in elems
  152. if e
  153. ]
  154. if elem_names:
  155. row["elements"] = elem_names
  156. rows.append(row)
  157. return json.dumps(rows, ensure_ascii=False, separators=(",", ":"))
  158. def build_valid_ids(cats: List[Dict]) -> Dict[int, Dict]:
  159. """构建 id -> node 映射"""
  160. return {c["id"]: c for c in cats if "id" in c}
  161. def iter_batches(items: List[Any], batch_size: int) -> List[List[Any]]:
  162. """按固定大小切分列表"""
  163. batch_size = max(1, batch_size)
  164. return [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
  165. def build_capability_grounding_input(capability: Dict[str, Any]) -> Dict[str, Any]:
  166. """只保留 capability grounding 需要的最小字段"""
  167. return {
  168. "capability_id": capability.get("capability_id"),
  169. "body": capability.get("body") or "",
  170. "apply_to_draft": capability.get("apply_to_draft", {}),
  171. }
  172. def apply_to_body_excerpts_are_verbatim(apply_to: Any, body: str) -> bool:
  173. """确认每个 apply_to 条目的 body_excerpt 都逐字来自 capability.body。"""
  174. if not isinstance(apply_to, dict) or not isinstance(body, str) or not body.strip():
  175. return False
  176. for source_type in ("实质", "形式"):
  177. items = apply_to.get(source_type, [])
  178. if not isinstance(items, list):
  179. return False
  180. for item in items:
  181. if not isinstance(item, dict):
  182. return False
  183. excerpt = item.get("body_excerpt")
  184. if not isinstance(excerpt, str) or not excerpt.strip():
  185. return False
  186. if excerpt.strip() not in body:
  187. return False
  188. return True
  189. def render_grounding_prompt(
  190. template: str,
  191. task: str,
  192. draft: Dict,
  193. compact_tree: str,
  194. reference_paths: List[str] = None,
  195. ) -> str:
  196. """渲染 Stage 2 prompt"""
  197. target = "capability 数组中的每一条 capability"
  198. paths_str = json.dumps(reference_paths or [], ensure_ascii=False)
  199. return (
  200. template
  201. .replace("{target}", target)
  202. .replace("{compact_tree}", compact_tree)
  203. .replace("{reference_paths}", paths_str)
  204. .replace("{draft_json}", json.dumps(draft, ensure_ascii=False, indent=2))
  205. )
  206. async def ground_single_case(
  207. case_item: Dict[str, Any],
  208. template: str,
  209. llm_call: Any,
  210. model: str,
  211. use_api: bool = False,
  212. compact_tree: str = None,
  213. ) -> tuple[Dict[str, Any], float]:
  214. """
  215. 对单个 case 的 workflow_groups[*].capability[*].apply_to_draft 做 apply_to 映射。
  216. """
  217. total_cost = 0.0
  218. result = dict(case_item)
  219. title = case_item.get("title", "")[:20] or "untitled"
  220. workflow_groups = case_item.get("workflow_groups")
  221. if not isinstance(workflow_groups, list) or not workflow_groups:
  222. return result, total_cost
  223. updated_groups = [
  224. dict(group) if isinstance(group, dict) else group
  225. for group in workflow_groups
  226. ]
  227. for group_idx, group in enumerate(updated_groups):
  228. if not isinstance(group, dict):
  229. continue
  230. workflow_id = group.get("workflow_id") or f"g{group_idx + 1}"
  231. capability_items = group.get("capability")
  232. if not isinstance(capability_items, list) or not capability_items:
  233. continue
  234. draft_capability_pairs = [
  235. (idx, capability)
  236. for idx, capability in enumerate(capability_items)
  237. if isinstance(capability, dict) and "apply_to_draft" in capability
  238. ]
  239. if not draft_capability_pairs:
  240. continue
  241. # 收集 capability 的关键词(用于 API 搜索)
  242. if use_api:
  243. all_keywords = []
  244. for _, capability in draft_capability_pairs:
  245. apply_to_draft = capability.get("apply_to_draft", {})
  246. for key in ["实质", "形式"]:
  247. for draft_text in apply_to_draft.get(key, []):
  248. all_keywords.extend(extract_keywords_from_draft(draft_text))
  249. all_keywords = list(dict.fromkeys(all_keywords))[:10]
  250. if all_keywords:
  251. categories = await search_categories_by_keywords(all_keywords, top_k=5)
  252. capability_compact_tree = build_compact_tree(categories)
  253. capability_ref_paths = list(dict.fromkeys(
  254. c["path"] for c in categories if c.get("path")
  255. ))
  256. else:
  257. capability_compact_tree = compact_tree or "[]"
  258. capability_ref_paths = []
  259. else:
  260. capability_compact_tree = compact_tree or "[]"
  261. capability_ref_paths = []
  262. updated_capabilities = [
  263. dict(capability) if isinstance(capability, dict) else capability
  264. for capability in capability_items
  265. ]
  266. id_to_index = {
  267. capability.get("capability_id"): idx
  268. for idx, capability in draft_capability_pairs
  269. if isinstance(capability.get("capability_id"), str)
  270. }
  271. batches = iter_batches(draft_capability_pairs, CAPABILITY_GROUNDING_BATCH_SIZE)
  272. for batch_idx, batch_pairs in enumerate(batches, start=1):
  273. draft_capabilities = [
  274. build_capability_grounding_input(capability)
  275. for _, capability in batch_pairs
  276. ]
  277. draft = {"capability": draft_capabilities}
  278. prompt = render_grounding_prompt(template, "capability", draft, capability_compact_tree, capability_ref_paths)
  279. messages = [{"role": "user", "content": prompt}]
  280. grounded, cost = await call_llm_with_retry(
  281. llm_call=llm_call,
  282. messages=messages,
  283. model=model,
  284. temperature=0.1,
  285. max_tokens=4000,
  286. max_retries=3,
  287. schema_name="apply_to_grounding_capability",
  288. task_name=f"Ground_C_{title}_{workflow_id}_B{batch_idx}/{len(batches)}",
  289. )
  290. total_cost += cost
  291. if not grounded or not isinstance(grounded.get("capability"), list):
  292. continue
  293. grounded_capabilities = grounded["capability"]
  294. used_indices = set()
  295. for output_idx, grounded_capability in enumerate(grounded_capabilities):
  296. if not isinstance(grounded_capability, dict):
  297. continue
  298. capability_idx = None
  299. capability_id = grounded_capability.get("capability_id")
  300. if isinstance(capability_id, str):
  301. capability_idx = id_to_index.get(capability_id)
  302. if capability_idx is None and output_idx < len(batch_pairs):
  303. capability_idx = batch_pairs[output_idx][0]
  304. if capability_idx is None or capability_idx in used_indices:
  305. continue
  306. apply_to = grounded_capability.get("apply_to")
  307. suggest_apply_to = grounded_capability.get("suggest_apply_to")
  308. body = updated_capabilities[capability_idx].get("body", "")
  309. if (
  310. apply_to is not None
  311. and isinstance(suggest_apply_to, str)
  312. and suggest_apply_to.strip()
  313. and isinstance(updated_capabilities[capability_idx], dict)
  314. and apply_to_body_excerpts_are_verbatim(apply_to, body)
  315. ):
  316. updated_capabilities[capability_idx]["apply_to"] = apply_to
  317. updated_capabilities[capability_idx]["suggest_apply_to"] = suggest_apply_to.strip()
  318. updated_capabilities[capability_idx].pop("apply_to_draft", None)
  319. used_indices.add(capability_idx)
  320. else:
  321. print(
  322. f" ⚠️ Skip capability grounding writeback: "
  323. f"{capability_id or capability_idx} has missing/non-verbatim body_excerpt"
  324. )
  325. group["capability"] = updated_capabilities
  326. result["workflow_groups"] = updated_groups
  327. return result, total_cost
  328. async def apply_grounding(
  329. case_file: Path,
  330. llm_call: Any,
  331. model: str = "openai/gpt-5.4",
  332. max_concurrent: int = 3,
  333. ) -> Dict[str, Any]:
  334. """
  335. 按 index 遍历 case.json,将 apply_to_draft 映射为 apply_to
  336. Args:
  337. case_file: case.json 文件路径
  338. llm_call: LLM 调用函数
  339. model: 模型名称
  340. max_concurrent: 最大并发数
  341. Returns:
  342. 包含统计信息的字典
  343. """
  344. with open(case_file, "r", encoding="utf-8") as f:
  345. case_data = json.load(f)
  346. cases = case_data.get("cases", [])
  347. # 检查是否使用 API 动态搜索模式
  348. use_api = os.getenv("USE_SEARCH_API", "true").lower() == "true"
  349. # 如果不使用 API,预加载完整内容树
  350. compact_tree = None
  351. if not use_api:
  352. cats = await load_category_tree(use_api=False)
  353. compact_tree = build_compact_tree(cats)
  354. print(f"Loaded category tree: {len(cats)} nodes, compact={len(compact_tree)} chars")
  355. else:
  356. print(f"Using API dynamic search mode (searching on-demand)")
  357. # 加载 prompt 模板
  358. template = load_prompt_template("apply_to_grounding")
  359. # 过滤出需要处理的 case(只看 workflow_groups[*].capability[*].apply_to_draft)
  360. needs_grounding = []
  361. for case in cases:
  362. workflow_groups = case.get("workflow_groups")
  363. has_capability_draft = isinstance(workflow_groups, list) and any(
  364. isinstance(group, dict)
  365. and isinstance(group.get("capability"), list)
  366. and any(
  367. isinstance(capability, dict) and "apply_to_draft" in capability
  368. for capability in group.get("capability", [])
  369. )
  370. for group in workflow_groups
  371. )
  372. if has_capability_draft:
  373. needs_grounding.append(case)
  374. print(f"Grounding apply_to for {len(needs_grounding)}/{len(cases)} cases...")
  375. if not needs_grounding:
  376. print(" No cases need grounding, skipping.")
  377. return {
  378. "total": len(cases),
  379. "grounded": 0,
  380. "total_cost": 0.0,
  381. "output_file": str(case_file),
  382. }
  383. semaphore = asyncio.Semaphore(max_concurrent)
  384. async def process_with_semaphore(case_item):
  385. async with semaphore:
  386. index = case_item.get("index", 0)
  387. raw = case_item.get("_raw", {})
  388. case_id = raw.get("case_id", "unknown")
  389. print(f" -> [{index}] [{case_id}] grounding apply_to...")
  390. grounded, cost = await ground_single_case(
  391. case_item, template, llm_call, model, use_api, compact_tree
  392. )
  393. print(f" <- [{index}] [{case_id}] grounded (cost=${cost:.4f})")
  394. return case_item, grounded, cost
  395. tasks = [process_with_semaphore(case) for case in needs_grounding]
  396. results_with_costs = await asyncio.gather(*tasks)
  397. # 用 grounded 结果替换原 case(按对象身份匹配,避免 index 缺失或重复时回填错 case)
  398. grounded_map = {}
  399. total_cost = 0.0
  400. for original_case, grounded, cost in results_with_costs:
  401. grounded_map[id(original_case)] = grounded
  402. total_cost += cost
  403. updated_cases = []
  404. for case in cases:
  405. case_id = id(case)
  406. if case_id in grounded_map:
  407. updated_cases.append(grounded_map[case_id])
  408. else:
  409. updated_cases.append(case)
  410. updated_cases.sort(key=lambda x: x.get("index", 0))
  411. case_data["cases"] = updated_cases
  412. case_file.parent.mkdir(parents=True, exist_ok=True)
  413. with open(case_file, "w", encoding="utf-8") as f:
  414. json.dump(case_data, f, ensure_ascii=False, indent=2)
  415. return {
  416. "total": len(cases),
  417. "grounded": len(needs_grounding),
  418. "total_cost": total_cost,
  419. "output_file": str(case_file),
  420. }
  421. if __name__ == "__main__":
  422. import sys
  423. if len(sys.argv) < 2:
  424. print("Usage: python apply_to_grounding.py <output_dir>")
  425. sys.exit(1)
  426. print("Please use this module through run_pipeline.py")