apply_to_grounding.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. """
  2. Stage 2: 将 apply_to_draft 映射为正式 apply_to
  3. 从 case.json 读取,只对每个 case 的 capability 中的 apply_to_draft 做映射。
  4. 调用 LLM 映射到内容树的正式节点,原位回填到 case.json
  5. 改造版本:通过远程 API 获取内容树,不再依赖本地文件
  6. """
  7. import asyncio
  8. import json
  9. import os
  10. from pathlib import Path
  11. from typing import Any, Dict, List, Optional
  12. import httpx
  13. from dotenv import load_dotenv
  14. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  15. # 加载环境变量
  16. PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent
  17. load_dotenv(PROJECT_ROOT / ".env")
  18. # 搜索 API 配置
  19. SEARCH_API = os.getenv("SEARCH_API", "http://8.147.104.190:8001").rstrip("/")
  20. CAPABILITY_GROUNDING_BATCH_SIZE = int(os.getenv("CAPABILITY_GROUNDING_BATCH_SIZE", "8"))
  21. # 本地文件路径(作为回退方案)
  22. EXTRACT_DIR = Path(__file__).resolve().parent / "resource"
  23. CATEGORY_TREE_PATH = EXTRACT_DIR / "category_tree_56.json"
  24. def load_prompt_template(prompt_name: str) -> str:
  25. """加载 prompt 模板"""
  26. base_dir = Path(__file__).parent.parent
  27. prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
  28. with open(prompt_path, "r", encoding="utf-8") as f:
  29. content = f.read()
  30. if content.startswith("---"):
  31. parts = content.split("---", 2)
  32. if len(parts) >= 3:
  33. content = parts[2]
  34. content = content.replace("$system$", "").replace("$user$", "")
  35. return content.strip()
  36. def load_category_tree_from_local() -> List[Dict]:
  37. """从本地文件加载内容树(回退方案)"""
  38. if not CATEGORY_TREE_PATH.exists():
  39. raise FileNotFoundError(f"Category tree not found: {CATEGORY_TREE_PATH}")
  40. with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
  41. data = json.load(f)
  42. cats = data.get("categories", [])
  43. if not cats:
  44. raise RuntimeError("category_tree is empty")
  45. return cats
  46. async def load_category_tree(use_api: bool = False) -> List[Dict]:
  47. """
  48. 加载内容树(支持本地文件和远程 API)
  49. Args:
  50. use_api: 是否使用远程 API(默认 False,使用本地文件)
  51. Returns:
  52. 内容树节点列表
  53. """
  54. if use_api:
  55. try:
  56. print("Attempting to load category tree from API...")
  57. return await fetch_category_tree_from_api()
  58. except Exception as e:
  59. print(f"API failed ({e}), falling back to local file...")
  60. return load_category_tree_from_local()
  61. else:
  62. print("Loading category tree from local file...")
  63. return load_category_tree_from_local()
  64. async def search_categories_by_keywords(
  65. keywords: List[str],
  66. source_types: List[str] = None,
  67. top_k: int = 10,
  68. timeout: int = 10
  69. ) -> List[Dict]:
  70. """
  71. 根据关键词搜索相关分类节点
  72. Args:
  73. keywords: 搜索关键词列表
  74. source_types: 来源类型列表,默认 ["形式", "实质"]
  75. top_k: 每个关键词返回的结果数
  76. timeout: 请求超时时间(秒)
  77. Returns:
  78. 相关分类节点列表
  79. """
  80. if source_types is None:
  81. source_types = ["形式", "实质"]
  82. all_categories = []
  83. seen_ids = set()
  84. async with httpx.AsyncClient(timeout=timeout) as client:
  85. for keyword in keywords:
  86. for source_type in source_types:
  87. try:
  88. params = {
  89. "q": keyword,
  90. "source_type": source_type,
  91. "entity_type": "category",
  92. "top_k": top_k,
  93. "mode": "vector"
  94. }
  95. resp = await client.get(f"{SEARCH_API}/api/search", params=params)
  96. resp.raise_for_status()
  97. data = resp.json()
  98. results = data.get("results", [])
  99. # 转换为内容树格式,去重
  100. for item in results:
  101. entity_id = item.get("entity_id")
  102. if entity_id and entity_id not in seen_ids:
  103. category = {
  104. "id": entity_id,
  105. "path": item.get("path", ""),
  106. "source_type": source_type,
  107. "description": item.get("description", ""),
  108. "elements": []
  109. }
  110. all_categories.append(category)
  111. seen_ids.add(entity_id)
  112. except Exception as e:
  113. # 静默失败,继续处理其他关键词
  114. continue
  115. return all_categories
  116. def extract_keywords_from_draft(draft_text: str) -> List[str]:
  117. """
  118. 从 apply_to_draft 文本中提取关键词
  119. Args:
  120. draft_text: apply_to_draft 的文本内容
  121. Returns:
  122. 关键词列表
  123. """
  124. if not draft_text or not isinstance(draft_text, str):
  125. return []
  126. # 简单的关键词提取:分词并过滤
  127. import re
  128. # 移除标点符号,按空格和常见分隔符分词
  129. words = re.split(r'[,。、;:!?\s]+', draft_text)
  130. # 过滤短词和停用词
  131. keywords = [w.strip() for w in words if len(w.strip()) >= 2]
  132. # 去重并限制数量
  133. keywords = list(dict.fromkeys(keywords))[:5] # 最多5个关键词
  134. return keywords
  135. def build_compact_tree(cats: List[Dict]) -> str:
  136. """构建紧凑版内容树(用于注入 prompt)"""
  137. rows = []
  138. for c in cats:
  139. if c.get("source_type") not in ("实质", "形式"):
  140. continue
  141. row = {
  142. "id": c.get("id"),
  143. "path": c.get("path"),
  144. "source_type": c.get("source_type"),
  145. "description": c.get("description"),
  146. }
  147. elems = c.get("elements", [])
  148. if isinstance(elems, list) and elems:
  149. elem_names = [
  150. e.get("name") if isinstance(e, dict) else e
  151. for e in elems
  152. if e
  153. ]
  154. if elem_names:
  155. row["elements"] = elem_names
  156. rows.append(row)
  157. return json.dumps(rows, ensure_ascii=False, separators=(",", ":"))
  158. def build_valid_ids(cats: List[Dict]) -> Dict[int, Dict]:
  159. """构建 id -> node 映射"""
  160. return {c["id"]: c for c in cats if "id" in c}
  161. def iter_batches(items: List[Any], batch_size: int) -> List[List[Any]]:
  162. """按固定大小切分列表"""
  163. batch_size = max(1, batch_size)
  164. return [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
  165. def build_capability_grounding_input(capability: Dict[str, Any]) -> Dict[str, Any]:
  166. """只保留 capability grounding 需要的最小字段"""
  167. return {
  168. "capability_id": capability.get("capability_id"),
  169. "apply_to_draft": capability.get("apply_to_draft", {}),
  170. }
  171. def render_grounding_prompt(
  172. template: str,
  173. task: str,
  174. draft: Dict,
  175. compact_tree: str,
  176. reference_paths: List[str] = None,
  177. ) -> str:
  178. """渲染 Stage 2 prompt"""
  179. target = "capability 数组中的每一条 capability"
  180. paths_str = json.dumps(reference_paths or [], ensure_ascii=False)
  181. return (
  182. template
  183. .replace("{target}", target)
  184. .replace("{compact_tree}", compact_tree)
  185. .replace("{reference_paths}", paths_str)
  186. .replace("{draft_json}", json.dumps(draft, ensure_ascii=False, indent=2))
  187. )
  188. async def ground_single_case(
  189. case_item: Dict[str, Any],
  190. template: str,
  191. llm_call: Any,
  192. model: str,
  193. use_api: bool = False,
  194. compact_tree: str = None,
  195. ) -> tuple[Dict[str, Any], float]:
  196. """
  197. 对单个 case 的 capability[*].apply_to_draft 做 apply_to 映射。
  198. 只处理 capability。
  199. """
  200. total_cost = 0.0
  201. result = dict(case_item)
  202. title = case_item.get("title", "")[:20] or "untitled"
  203. capability_items = case_item.get("capability")
  204. if not isinstance(capability_items, list) or not capability_items:
  205. return result, total_cost
  206. draft_capability_pairs = [
  207. (idx, capability)
  208. for idx, capability in enumerate(capability_items)
  209. if isinstance(capability, dict) and "apply_to_draft" in capability
  210. ]
  211. if not draft_capability_pairs:
  212. return result, total_cost
  213. # 收集 capability 的关键词(用于 API 搜索)
  214. if use_api:
  215. all_keywords = []
  216. for _, capability in draft_capability_pairs:
  217. apply_to_draft = capability.get("apply_to_draft", {})
  218. for key in ["实质", "形式"]:
  219. for draft_text in apply_to_draft.get(key, []):
  220. all_keywords.extend(extract_keywords_from_draft(draft_text))
  221. all_keywords = list(dict.fromkeys(all_keywords))[:10]
  222. if all_keywords:
  223. categories = await search_categories_by_keywords(all_keywords, top_k=5)
  224. capability_compact_tree = build_compact_tree(categories)
  225. capability_ref_paths = list(dict.fromkeys(
  226. c["path"] for c in categories if c.get("path")
  227. ))
  228. else:
  229. capability_compact_tree = compact_tree or "[]"
  230. capability_ref_paths = []
  231. else:
  232. capability_compact_tree = compact_tree or "[]"
  233. capability_ref_paths = []
  234. updated_capabilities = [
  235. dict(capability) if isinstance(capability, dict) else capability
  236. for capability in capability_items
  237. ]
  238. id_to_index = {
  239. capability.get("capability_id"): idx
  240. for idx, capability in draft_capability_pairs
  241. if isinstance(capability.get("capability_id"), str)
  242. }
  243. batches = iter_batches(draft_capability_pairs, CAPABILITY_GROUNDING_BATCH_SIZE)
  244. for batch_idx, batch_pairs in enumerate(batches, start=1):
  245. draft_capabilities = [
  246. build_capability_grounding_input(capability)
  247. for _, capability in batch_pairs
  248. ]
  249. draft = {"capability": draft_capabilities}
  250. prompt = render_grounding_prompt(template, "capability", draft, capability_compact_tree, capability_ref_paths)
  251. messages = [{"role": "user", "content": prompt}]
  252. grounded, cost = await call_llm_with_retry(
  253. llm_call=llm_call,
  254. messages=messages,
  255. model=model,
  256. temperature=0.1,
  257. max_tokens=4000,
  258. max_retries=3,
  259. schema_name="apply_to_grounding_capability",
  260. task_name=f"Ground_C_{title}_B{batch_idx}/{len(batches)}",
  261. )
  262. total_cost += cost
  263. if not grounded or not isinstance(grounded.get("capability"), list):
  264. continue
  265. grounded_capabilities = grounded["capability"]
  266. used_indices = set()
  267. for output_idx, grounded_capability in enumerate(grounded_capabilities):
  268. if not isinstance(grounded_capability, dict):
  269. continue
  270. capability_idx = None
  271. capability_id = grounded_capability.get("capability_id")
  272. if isinstance(capability_id, str):
  273. capability_idx = id_to_index.get(capability_id)
  274. if capability_idx is None and output_idx < len(batch_pairs):
  275. capability_idx = batch_pairs[output_idx][0]
  276. if capability_idx is None or capability_idx in used_indices:
  277. continue
  278. apply_to = grounded_capability.get("apply_to")
  279. suggest_apply_to = grounded_capability.get("suggest_apply_to")
  280. if (
  281. apply_to is not None
  282. and isinstance(suggest_apply_to, str)
  283. and suggest_apply_to.strip()
  284. and isinstance(updated_capabilities[capability_idx], dict)
  285. ):
  286. updated_capabilities[capability_idx]["apply_to"] = apply_to
  287. updated_capabilities[capability_idx]["suggest_apply_to"] = suggest_apply_to.strip()
  288. updated_capabilities[capability_idx].pop("apply_to_draft", None)
  289. used_indices.add(capability_idx)
  290. result["capability"] = updated_capabilities
  291. return result, total_cost
  292. async def apply_grounding(
  293. case_file: Path,
  294. llm_call: Any,
  295. model: str = "openai/gpt-5.4",
  296. max_concurrent: int = 3,
  297. ) -> Dict[str, Any]:
  298. """
  299. 按 index 遍历 case.json,将 apply_to_draft 映射为 apply_to
  300. Args:
  301. case_file: case.json 文件路径
  302. llm_call: LLM 调用函数
  303. model: 模型名称
  304. max_concurrent: 最大并发数
  305. Returns:
  306. 包含统计信息的字典
  307. """
  308. with open(case_file, "r", encoding="utf-8") as f:
  309. case_data = json.load(f)
  310. cases = case_data.get("cases", [])
  311. # 检查是否使用 API 动态搜索模式
  312. use_api = os.getenv("USE_SEARCH_API", "true").lower() == "true"
  313. # 如果不使用 API,预加载完整内容树
  314. compact_tree = None
  315. if not use_api:
  316. cats = await load_category_tree(use_api=False)
  317. compact_tree = build_compact_tree(cats)
  318. print(f"Loaded category tree: {len(cats)} nodes, compact={len(compact_tree)} chars")
  319. else:
  320. print(f"Using API dynamic search mode (searching on-demand)")
  321. # 加载 prompt 模板
  322. template = load_prompt_template("apply_to_grounding")
  323. # 过滤出需要处理的 case(只看 capability[*].apply_to_draft)
  324. needs_grounding = []
  325. for case in cases:
  326. capability_items = case.get("capability")
  327. has_capability_draft = isinstance(capability_items, list) and any(
  328. isinstance(capability, dict) and "apply_to_draft" in capability for capability in capability_items
  329. )
  330. if has_capability_draft:
  331. needs_grounding.append(case)
  332. print(f"Grounding apply_to for {len(needs_grounding)}/{len(cases)} cases...")
  333. if not needs_grounding:
  334. print(" No cases need grounding, skipping.")
  335. return {
  336. "total": len(cases),
  337. "grounded": 0,
  338. "total_cost": 0.0,
  339. "output_file": str(case_file),
  340. }
  341. semaphore = asyncio.Semaphore(max_concurrent)
  342. async def process_with_semaphore(case_item):
  343. async with semaphore:
  344. index = case_item.get("index", 0)
  345. raw = case_item.get("_raw", {})
  346. case_id = raw.get("case_id", "unknown")
  347. print(f" -> [{index}] [{case_id}] grounding apply_to...")
  348. grounded, cost = await ground_single_case(
  349. case_item, template, llm_call, model, use_api, compact_tree
  350. )
  351. print(f" <- [{index}] [{case_id}] grounded (cost=${cost:.4f})")
  352. return case_item, grounded, cost
  353. tasks = [process_with_semaphore(case) for case in needs_grounding]
  354. results_with_costs = await asyncio.gather(*tasks)
  355. # 用 grounded 结果替换原 case(按对象身份匹配,避免 index 缺失或重复时回填错 case)
  356. grounded_map = {}
  357. total_cost = 0.0
  358. for original_case, grounded, cost in results_with_costs:
  359. grounded_map[id(original_case)] = grounded
  360. total_cost += cost
  361. updated_cases = []
  362. for case in cases:
  363. case_id = id(case)
  364. if case_id in grounded_map:
  365. updated_cases.append(grounded_map[case_id])
  366. else:
  367. updated_cases.append(case)
  368. updated_cases.sort(key=lambda x: x.get("index", 0))
  369. case_data["cases"] = updated_cases
  370. case_file.parent.mkdir(parents=True, exist_ok=True)
  371. with open(case_file, "w", encoding="utf-8") as f:
  372. json.dump(case_data, f, ensure_ascii=False, indent=2)
  373. return {
  374. "total": len(cases),
  375. "grounded": len(needs_grounding),
  376. "total_cost": total_cost,
  377. "output_file": str(case_file),
  378. }
  379. if __name__ == "__main__":
  380. import sys
  381. if len(sys.argv) < 2:
  382. print("Usage: python apply_to_grounding.py <output_dir>")
  383. sys.exit(1)
  384. print("Please use this module through run_pipeline.py")