apply_to_grounding.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. """
  2. Stage 2: 将 apply_to_draft 映射为正式 apply_to
  3. 从 case.json 读取,对每个 case 的 workflow 和 capabilities 中的 apply_to_draft,
  4. 调用 LLM 映射到内容树的正式节点,按 index 原位回填到 case.json
  5. 改造版本:通过远程 API 获取内容树,不再依赖本地文件
  6. """
  7. import asyncio
  8. import json
  9. import os
  10. from pathlib import Path
  11. from typing import Any, Dict, List, Optional
  12. import httpx
  13. from dotenv import load_dotenv
  14. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  15. # 加载环境变量
  16. PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent.parent
  17. load_dotenv(PROJECT_ROOT / ".env")
  18. # 搜索 API 配置
  19. SEARCH_API = os.getenv("SEARCH_API", "http://8.147.104.190:8001").rstrip("/")
  20. # 本地文件路径(作为回退方案)
  21. EXTRACT_DIR = Path(__file__).resolve().parent / "resource"
  22. CATEGORY_TREE_PATH = EXTRACT_DIR / "category_tree_56.json"
  23. def load_prompt_template(prompt_name: str) -> str:
  24. """加载 prompt 模板"""
  25. base_dir = Path(__file__).parent.parent
  26. prompt_path = base_dir / "prompts" / f"{prompt_name}.prompt"
  27. with open(prompt_path, "r", encoding="utf-8") as f:
  28. content = f.read()
  29. if content.startswith("---"):
  30. parts = content.split("---", 2)
  31. if len(parts) >= 3:
  32. content = parts[2]
  33. content = content.replace("$system$", "").replace("$user$", "")
  34. return content.strip()
  35. def load_category_tree_from_local() -> List[Dict]:
  36. """从本地文件加载内容树(回退方案)"""
  37. if not CATEGORY_TREE_PATH.exists():
  38. raise FileNotFoundError(f"Category tree not found: {CATEGORY_TREE_PATH}")
  39. with open(CATEGORY_TREE_PATH, "r", encoding="utf-8") as f:
  40. data = json.load(f)
  41. cats = data.get("categories", [])
  42. if not cats:
  43. raise RuntimeError("category_tree is empty")
  44. return cats
  45. async def load_category_tree(use_api: bool = False) -> List[Dict]:
  46. """
  47. 加载内容树(支持本地文件和远程 API)
  48. Args:
  49. use_api: 是否使用远程 API(默认 False,使用本地文件)
  50. Returns:
  51. 内容树节点列表
  52. """
  53. if use_api:
  54. try:
  55. print("Attempting to load category tree from API...")
  56. return await fetch_category_tree_from_api()
  57. except Exception as e:
  58. print(f"API failed ({e}), falling back to local file...")
  59. return load_category_tree_from_local()
  60. else:
  61. print("Loading category tree from local file...")
  62. return load_category_tree_from_local()
  63. async def search_categories_by_keywords(
  64. keywords: List[str],
  65. source_types: List[str] = None,
  66. top_k: int = 10,
  67. timeout: int = 10
  68. ) -> List[Dict]:
  69. """
  70. 根据关键词搜索相关分类节点
  71. Args:
  72. keywords: 搜索关键词列表
  73. source_types: 来源类型列表,默认 ["形式", "实质"]
  74. top_k: 每个关键词返回的结果数
  75. timeout: 请求超时时间(秒)
  76. Returns:
  77. 相关分类节点列表
  78. """
  79. if source_types is None:
  80. source_types = ["形式", "实质"]
  81. all_categories = []
  82. seen_ids = set()
  83. async with httpx.AsyncClient(timeout=timeout) as client:
  84. for keyword in keywords:
  85. for source_type in source_types:
  86. try:
  87. params = {
  88. "q": keyword,
  89. "source_type": source_type,
  90. "entity_type": "category",
  91. "top_k": top_k,
  92. "mode": "vector"
  93. }
  94. resp = await client.get(f"{SEARCH_API}/api/search", params=params)
  95. resp.raise_for_status()
  96. data = resp.json()
  97. results = data.get("results", [])
  98. # 转换为内容树格式,去重
  99. for item in results:
  100. entity_id = item.get("entity_id")
  101. if entity_id and entity_id not in seen_ids:
  102. category = {
  103. "id": entity_id,
  104. "path": item.get("path", ""),
  105. "source_type": source_type,
  106. "description": item.get("description", ""),
  107. "elements": []
  108. }
  109. all_categories.append(category)
  110. seen_ids.add(entity_id)
  111. except Exception as e:
  112. # 静默失败,继续处理其他关键词
  113. continue
  114. return all_categories
  115. def extract_keywords_from_draft(draft_text: str) -> List[str]:
  116. """
  117. 从 apply_to_draft 文本中提取关键词
  118. Args:
  119. draft_text: apply_to_draft 的文本内容
  120. Returns:
  121. 关键词列表
  122. """
  123. if not draft_text or not isinstance(draft_text, str):
  124. return []
  125. # 简单的关键词提取:分词并过滤
  126. import re
  127. # 移除标点符号,按空格和常见分隔符分词
  128. words = re.split(r'[,。、;:!?\s]+', draft_text)
  129. # 过滤短词和停用词
  130. keywords = [w.strip() for w in words if len(w.strip()) >= 2]
  131. # 去重并限制数量
  132. keywords = list(dict.fromkeys(keywords))[:5] # 最多5个关键词
  133. return keywords
  134. def build_compact_tree(cats: List[Dict]) -> str:
  135. """构建紧凑版内容树(用于注入 prompt)"""
  136. rows = []
  137. for c in cats:
  138. if c.get("source_type") not in ("实质", "形式"):
  139. continue
  140. row = {
  141. "id": c.get("id"),
  142. "path": c.get("path"),
  143. "source_type": c.get("source_type"),
  144. "description": c.get("description"),
  145. }
  146. elems = c.get("elements", [])
  147. if isinstance(elems, list) and elems:
  148. elem_names = [
  149. e.get("name") if isinstance(e, dict) else e
  150. for e in elems
  151. if e
  152. ]
  153. if elem_names:
  154. row["elements"] = elem_names
  155. rows.append(row)
  156. return json.dumps(rows, ensure_ascii=False, separators=(",", ":"))
  157. def build_valid_ids(cats: List[Dict]) -> Dict[int, Dict]:
  158. """构建 id -> node 映射"""
  159. return {c["id"]: c for c in cats if "id" in c}
  160. def render_grounding_prompt(
  161. template: str,
  162. task: str,
  163. draft: Dict,
  164. compact_tree: str,
  165. ) -> str:
  166. """渲染 Stage 2 prompt"""
  167. if task == "capability":
  168. target = "capabilities 数组中的每一条 capability"
  169. else:
  170. target = "strategy;如果 strategy 为 null,则原样返回"
  171. return (
  172. template
  173. .replace("{target}", target)
  174. .replace("{compact_tree}", compact_tree)
  175. .replace("{draft_json}", json.dumps(draft, ensure_ascii=False, indent=2))
  176. )
  177. async def ground_single_case(
  178. case_item: Dict[str, Any],
  179. template: str,
  180. llm_call: Any,
  181. model: str,
  182. use_api: bool = False,
  183. compact_tree: str = None,
  184. ) -> tuple[Dict[str, Any], float]:
  185. """
  186. 对单个 case 的 workflow 和 capabilities 做 apply_to 映射
  187. 对于 workflow:一次性处理整个 workflow,为每个 step 生成对应的 apply_to
  188. 对于 capabilities:对每个有 apply_to_draft 的 capability 进行映射
  189. Args:
  190. case_item: 案例数据
  191. template: prompt 模板
  192. llm_call: LLM 调用函数
  193. model: 模型名称
  194. use_api: 是否使用 API 动态搜索
  195. compact_tree: 预加载的完整内容树(use_api=False 时使用)
  196. """
  197. total_cost = 0.0
  198. result = dict(case_item)
  199. title = case_item.get("title", "")[:20] or "untitled"
  200. # 处理 workflow (strategy) - 整体处理,保持上下文
  201. workflow = case_item.get("workflow")
  202. if isinstance(workflow, dict) and "steps" in workflow:
  203. steps = workflow.get("steps", [])
  204. # 检查是否有任何 step 包含 apply_to_draft
  205. has_draft = any(
  206. isinstance(step, dict) and "apply_to_draft" in step
  207. for step in steps
  208. )
  209. if has_draft:
  210. # 收集所有 step 的关键词(用于 API 搜索)
  211. if use_api:
  212. all_keywords = []
  213. for step in steps:
  214. if isinstance(step, dict) and "apply_to_draft" in step:
  215. apply_to_draft = step.get("apply_to_draft", {})
  216. for key in ["实质", "形式"]:
  217. for draft_text in apply_to_draft.get(key, []):
  218. all_keywords.extend(extract_keywords_from_draft(draft_text))
  219. all_keywords = list(dict.fromkeys(all_keywords))[:10]
  220. if all_keywords:
  221. categories = await search_categories_by_keywords(all_keywords, top_k=5)
  222. workflow_compact_tree = build_compact_tree(categories)
  223. else:
  224. workflow_compact_tree = compact_tree or "[]"
  225. else:
  226. workflow_compact_tree = compact_tree or "[]"
  227. # 整个 workflow 传给 LLM(保持上下文)
  228. draft = {"strategy": workflow}
  229. prompt = render_grounding_prompt(template, "strategy", draft, workflow_compact_tree)
  230. messages = [{"role": "user", "content": prompt}]
  231. grounded, cost = await call_llm_with_retry(
  232. llm_call=llm_call,
  233. messages=messages,
  234. model=model,
  235. temperature=0.1,
  236. max_tokens=4000,
  237. max_retries=3,
  238. schema_name="apply_to_grounding_strategy",
  239. task_name=f"Ground_W_{title}",
  240. )
  241. total_cost += cost
  242. # 按 order 回填 apply_to
  243. if grounded and isinstance(grounded.get("strategy"), dict):
  244. grounded_steps = grounded["strategy"].get("steps", [])
  245. # 建立 order -> apply_to 的映射
  246. order_to_apply_to = {}
  247. for grounded_step in grounded_steps:
  248. if isinstance(grounded_step, dict):
  249. order = grounded_step.get("order")
  250. apply_to = grounded_step.get("apply_to")
  251. if order is not None and apply_to is not None:
  252. order_to_apply_to[order] = apply_to
  253. # 回填到原 steps
  254. updated_steps = []
  255. for step in steps:
  256. updated_step = dict(step)
  257. order = step.get("order")
  258. if order in order_to_apply_to:
  259. updated_step["apply_to"] = order_to_apply_to[order]
  260. updated_step.pop("apply_to_draft", None)
  261. updated_steps.append(updated_step)
  262. result["workflow"] = dict(workflow)
  263. result["workflow"]["steps"] = updated_steps
  264. # 处理 capabilities - 整体处理,保持上下文
  265. capabilities = case_item.get("capabilities")
  266. if isinstance(capabilities, list) and capabilities:
  267. has_draft = any(
  268. isinstance(cap, dict) and "apply_to_draft" in cap
  269. for cap in capabilities
  270. )
  271. if has_draft:
  272. # 收集所有 capability 的关键词
  273. if use_api:
  274. all_keywords = []
  275. for cap in capabilities:
  276. if isinstance(cap, dict) and "apply_to_draft" in cap:
  277. apply_to_draft = cap.get("apply_to_draft", {})
  278. for key in ["实质", "形式"]:
  279. for draft_text in apply_to_draft.get(key, []):
  280. all_keywords.extend(extract_keywords_from_draft(draft_text))
  281. all_keywords = list(dict.fromkeys(all_keywords))[:10]
  282. if all_keywords:
  283. categories = await search_categories_by_keywords(all_keywords, top_k=5)
  284. cap_compact_tree = build_compact_tree(categories)
  285. else:
  286. cap_compact_tree = compact_tree or "[]"
  287. else:
  288. cap_compact_tree = compact_tree or "[]"
  289. # 整个 capabilities 传给 LLM(保持上下文)
  290. draft = {"capabilities": capabilities}
  291. prompt = render_grounding_prompt(template, "capability", draft, cap_compact_tree)
  292. messages = [{"role": "user", "content": prompt}]
  293. grounded, cost = await call_llm_with_retry(
  294. llm_call=llm_call,
  295. messages=messages,
  296. model=model,
  297. temperature=0.1,
  298. max_tokens=4000,
  299. max_retries=3,
  300. schema_name="apply_to_grounding_capability",
  301. task_name=f"Ground_C_{title}",
  302. )
  303. total_cost += cost
  304. # 回填 apply_to(按索引匹配)
  305. if grounded and isinstance(grounded.get("capabilities"), list):
  306. grounded_caps = grounded["capabilities"]
  307. updated_capabilities = []
  308. for idx, cap in enumerate(capabilities):
  309. updated_cap = dict(cap)
  310. # 如果有对应的 grounded capability,提取 apply_to
  311. if idx < len(grounded_caps) and isinstance(grounded_caps[idx], dict):
  312. apply_to = grounded_caps[idx].get("apply_to")
  313. if apply_to is not None:
  314. updated_cap["apply_to"] = apply_to
  315. updated_cap.pop("apply_to_draft", None)
  316. updated_capabilities.append(updated_cap)
  317. result["capabilities"] = updated_capabilities
  318. return result, total_cost
  319. async def apply_grounding(
  320. case_file: Path,
  321. llm_call: Any,
  322. model: str = "openai/gpt-5.4",
  323. max_concurrent: int = 3,
  324. ) -> Dict[str, Any]:
  325. """
  326. 按 index 遍历 case.json,将 apply_to_draft 映射为 apply_to
  327. Args:
  328. case_file: case.json 文件路径
  329. llm_call: LLM 调用函数
  330. model: 模型名称
  331. max_concurrent: 最大并发数
  332. Returns:
  333. 包含统计信息的字典
  334. """
  335. with open(case_file, "r", encoding="utf-8") as f:
  336. case_data = json.load(f)
  337. cases = case_data.get("cases", [])
  338. # 检查是否使用 API 动态搜索模式
  339. use_api = os.getenv("USE_SEARCH_API", "false").lower() == "true"
  340. # 如果不使用 API,预加载完整内容树
  341. compact_tree = None
  342. if not use_api:
  343. cats = await load_category_tree(use_api=False)
  344. compact_tree = build_compact_tree(cats)
  345. print(f"Loaded category tree: {len(cats)} nodes, compact={len(compact_tree)} chars")
  346. else:
  347. print(f"Using API dynamic search mode (searching on-demand)")
  348. # 加载 prompt 模板
  349. template = load_prompt_template("apply_to_grounding")
  350. # 过滤出需要处理的 case(有 apply_to_draft 的)
  351. needs_grounding = []
  352. for case in cases:
  353. workflow = case.get("workflow")
  354. capabilities = case.get("capabilities")
  355. # 检查 step 级别的 apply_to_draft
  356. has_workflow_draft = (
  357. isinstance(workflow, dict) and
  358. any(
  359. isinstance(step, dict) and "apply_to_draft" in step
  360. for step in workflow.get("steps", [])
  361. )
  362. )
  363. has_cap_draft = isinstance(capabilities, list) and any(
  364. isinstance(c, dict) and "apply_to_draft" in c for c in capabilities
  365. )
  366. if has_workflow_draft or has_cap_draft:
  367. needs_grounding.append(case)
  368. print(f"Grounding apply_to for {len(needs_grounding)}/{len(cases)} cases...")
  369. if not needs_grounding:
  370. print(" No cases need grounding, skipping.")
  371. return {
  372. "total": len(cases),
  373. "grounded": 0,
  374. "total_cost": 0.0,
  375. "output_file": str(case_file),
  376. }
  377. semaphore = asyncio.Semaphore(max_concurrent)
  378. async def process_with_semaphore(case_item):
  379. async with semaphore:
  380. index = case_item.get("index", 0)
  381. raw = case_item.get("_raw", {})
  382. case_id = raw.get("case_id", "unknown")
  383. print(f" -> [{index}] [{case_id}] grounding apply_to...")
  384. grounded, cost = await ground_single_case(
  385. case_item, template, llm_call, model, use_api, compact_tree
  386. )
  387. print(f" <- [{index}] [{case_id}] grounded (cost=${cost:.4f})")
  388. return grounded, cost
  389. tasks = [process_with_semaphore(case) for case in needs_grounding]
  390. results_with_costs = await asyncio.gather(*tasks)
  391. # 用 grounded 结果替换原 case(按 index 匹配)
  392. grounded_map = {}
  393. total_cost = 0.0
  394. for grounded, cost in results_with_costs:
  395. grounded_map[grounded.get("index")] = grounded
  396. total_cost += cost
  397. updated_cases = []
  398. for case in cases:
  399. idx = case.get("index")
  400. if idx in grounded_map:
  401. updated_cases.append(grounded_map[idx])
  402. else:
  403. updated_cases.append(case)
  404. updated_cases.sort(key=lambda x: x.get("index", 0))
  405. case_data["cases"] = updated_cases
  406. case_file.parent.mkdir(parents=True, exist_ok=True)
  407. with open(case_file, "w", encoding="utf-8") as f:
  408. json.dump(case_data, f, ensure_ascii=False, indent=2)
  409. return {
  410. "total": len(cases),
  411. "grounded": len(needs_grounding),
  412. "total_cost": total_cost,
  413. "output_file": str(case_file),
  414. }
  415. if __name__ == "__main__":
  416. import sys
  417. if len(sys.argv) < 2:
  418. print("Usage: python apply_to_grounding.py <output_dir>")
  419. sys.exit(1)
  420. print("Please use this module through run_pipeline.py")