elksmmx пре 6 дана
родитељ
комит
c7b6337463
1 измењених фајлова са 125 додато и 36 уклоњено
  1. 125 36
      examples/process_pipeline/script/apply_to_grounding.py

+ 125 - 36
examples/process_pipeline/script/apply_to_grounding.py

@@ -1,8 +1,9 @@
 """
 Stage 2: 将 apply_to_draft 映射为正式 apply_to
 
-从 case.json 读取,对每个 case 的 workflow 和 capabilities 中的 apply_to_draft,
-调用 LLM 映射到内容树的正式节点,按 index 原位回填到 case.json
+从 case.json 读取,优先对每个 case 的 fragments 中的 apply_to_draft 做映射;
+没有 fragments 时,回退处理 workflow steps / capabilities 中的 apply_to_draft。
+调用 LLM 映射到内容树的正式节点,原位回填到 case.json
 
 改造版本:通过远程 API 获取内容树,不再依赖本地文件
 """
@@ -226,10 +227,12 @@ async def ground_single_case(
     compact_tree: str = None,
 ) -> tuple[Dict[str, Any], float]:
     """
-    对单个 case 的 workflow 和 capabilities 做 apply_to 映射
+    对单个 case 做 apply_to 映射
 
-    对于 workflow:一次性处理整个 workflow,为每个 step 生成对应的 apply_to
-    对于 capabilities:对每个有 apply_to_draft 的 capability 进行映射
+    优先级:
+    1. 如果存在 fragments,只处理 fragments[*].apply_to_draft,并回填到 fragments[*].apply_to
+    2. 没有 fragments 时,处理旧格式 workflow.steps[*].apply_to_draft
+    3. workflow 没有 draft 时,再处理 capabilities[*].apply_to_draft
 
     Args:
         case_item: 案例数据
@@ -243,25 +246,26 @@ async def ground_single_case(
     result = dict(case_item)
     title = case_item.get("title", "")[:20] or "untitled"
 
-    # 处理 fragments - 整体处理,保持上下文。只要存在 fragments,就不再读取 capabilities。
+    # 处理 fragments - 整体处理,保持上下文。只要存在 fragments,就不再读取 workflow/capabilities。
     fragments = case_item.get("fragments")
     has_fragments = isinstance(fragments, list) and bool(fragments)
     if has_fragments:
-        has_draft = any(
-            isinstance(frag, dict) and "apply_to_draft" in frag
-            for frag in fragments
-        )
+        draft_fragment_pairs = [
+            (idx, frag)
+            for idx, frag in enumerate(fragments)
+            if isinstance(frag, dict) and "apply_to_draft" in frag
+        ]
+        has_draft = bool(draft_fragment_pairs)
 
         if has_draft:
             # 收集所有 fragment 的关键词(用于 API 搜索)
             if use_api:
                 all_keywords = []
-                for frag in fragments:
-                    if isinstance(frag, dict) and "apply_to_draft" in frag:
-                        apply_to_draft = frag.get("apply_to_draft", {})
-                        for key in ["实质", "形式"]:
-                            for draft_text in apply_to_draft.get(key, []):
-                                all_keywords.extend(extract_keywords_from_draft(draft_text))
+                for _, frag in draft_fragment_pairs:
+                    apply_to_draft = frag.get("apply_to_draft", {})
+                    for key in ["实质", "形式"]:
+                        for draft_text in apply_to_draft.get(key, []):
+                            all_keywords.extend(extract_keywords_from_draft(draft_text))
                 all_keywords = list(dict.fromkeys(all_keywords))[:10]
 
                 if all_keywords:
@@ -278,7 +282,9 @@ async def ground_single_case(
                 frag_ref_paths = []
 
             # 复用 capability grounding 的 prompt/schema,只把数据源从 workflow step 换成 fragment。
-            draft = {"capabilities": fragments}
+            # 只发送带 apply_to_draft 的 fragment,再按原始下标回填,避免数组错位。
+            draft_fragments = [frag for _, frag in draft_fragment_pairs]
+            draft = {"capabilities": draft_fragments}
             prompt = render_grounding_prompt(template, "capability", draft, frag_compact_tree, frag_ref_paths)
             messages = [{"role": "user", "content": prompt}]
 
@@ -297,20 +303,98 @@ async def ground_single_case(
             # 按索引回填 apply_to。输入数组来自 fragments,输出数组使用 capability schema。
             if grounded and isinstance(grounded.get("capabilities"), list):
                 grounded_frags = grounded["capabilities"]
-                updated_fragments = []
-                for idx, frag in enumerate(fragments):
-                    updated_frag = dict(frag)
-                    if idx < len(grounded_frags) and isinstance(grounded_frags[idx], dict):
-                        apply_to = grounded_frags[idx].get("apply_to")
+                updated_fragments = [
+                    dict(frag) if isinstance(frag, dict) else frag
+                    for frag in fragments
+                ]
+                for draft_idx, (frag_idx, _) in enumerate(draft_fragment_pairs):
+                    if draft_idx < len(grounded_frags) and isinstance(grounded_frags[draft_idx], dict):
+                        apply_to = grounded_frags[draft_idx].get("apply_to")
                         if apply_to is not None:
-                            updated_frag["apply_to"] = apply_to
-                    updated_frag.pop("apply_to_draft", None)
-                    updated_fragments.append(updated_frag)
+                            updated_fragments[frag_idx]["apply_to"] = apply_to
+                            updated_fragments[frag_idx].pop("apply_to_draft", None)
                 result["fragments"] = updated_fragments
 
-    # 没有 fragments 时,才回退处理 capabilities。
+    # 没有 fragments 时,回退处理旧格式 workflow step draft。
+    workflow = case_item.get("workflow")
+    handled_workflow = False
+    if not has_fragments and isinstance(workflow, dict) and "steps" in workflow:
+        steps = workflow.get("steps", [])
+
+        has_draft = any(
+            isinstance(step, dict) and "apply_to_draft" in step
+            for step in steps
+        )
+
+        if has_draft:
+            handled_workflow = True
+            # 收集所有 step 的关键词(用于 API 搜索)
+            if use_api:
+                all_keywords = []
+                for step in steps:
+                    if isinstance(step, dict) and "apply_to_draft" in step:
+                        apply_to_draft = step.get("apply_to_draft", {})
+                        for key in ["实质", "形式"]:
+                            for draft_text in apply_to_draft.get(key, []):
+                                all_keywords.extend(extract_keywords_from_draft(draft_text))
+                all_keywords = list(dict.fromkeys(all_keywords))[:10]
+
+                if all_keywords:
+                    categories = await search_categories_by_keywords(all_keywords, top_k=5)
+                    workflow_compact_tree = build_compact_tree(categories)
+                    workflow_ref_paths = list(dict.fromkeys(
+                        c["path"] for c in categories if c.get("path")
+                    ))
+                else:
+                    workflow_compact_tree = compact_tree or "[]"
+                    workflow_ref_paths = []
+            else:
+                workflow_compact_tree = compact_tree or "[]"
+                workflow_ref_paths = []
+
+            # 整个 workflow 传给 LLM(保持上下文)
+            draft = {"strategy": workflow}
+            prompt = render_grounding_prompt(template, "strategy", draft, workflow_compact_tree, workflow_ref_paths)
+            messages = [{"role": "user", "content": prompt}]
+
+            grounded, cost = await call_llm_with_retry(
+                llm_call=llm_call,
+                messages=messages,
+                model=model,
+                temperature=0.1,
+                max_tokens=4000,
+                max_retries=3,
+                schema_name="apply_to_grounding_strategy",
+                task_name=f"Ground_W_{title}",
+            )
+            total_cost += cost
+
+            # 按 order 回填 apply_to
+            if grounded and isinstance(grounded.get("strategy"), dict):
+                grounded_steps = grounded["strategy"].get("steps", [])
+                order_to_apply_to = {}
+                for grounded_step in grounded_steps:
+                    if isinstance(grounded_step, dict):
+                        order = grounded_step.get("order")
+                        apply_to = grounded_step.get("apply_to")
+                        if order is not None and apply_to is not None:
+                            order_to_apply_to[order] = apply_to
+
+                updated_steps = []
+                for step in steps:
+                    updated_step = dict(step)
+                    order = step.get("order")
+                    if order in order_to_apply_to:
+                        updated_step["apply_to"] = order_to_apply_to[order]
+                        updated_step.pop("apply_to_draft", None)
+                    updated_steps.append(updated_step)
+
+                result["workflow"] = dict(workflow)
+                result["workflow"]["steps"] = updated_steps
+
+    # 没有 fragments 且 workflow 没处理时,才回退处理 capabilities。
     capabilities = case_item.get("capabilities")
-    if not has_fragments and isinstance(capabilities, list) and capabilities:
+    if not has_fragments and not handled_workflow and isinstance(capabilities, list) and capabilities:
         has_draft = any(
             isinstance(cap, dict) and "apply_to_draft" in cap
             for cap in capabilities
@@ -418,15 +502,20 @@ async def apply_grounding(
     needs_grounding = []
     for case in cases:
         fragments = case.get("fragments")
+        workflow = case.get("workflow")
         capabilities = case.get("capabilities")
         has_fragments = isinstance(fragments, list) and bool(fragments)
         has_frag_draft = isinstance(fragments, list) and any(
             isinstance(frag, dict) and "apply_to_draft" in frag for frag in fragments
         )
-        has_cap_draft = not has_fragments and isinstance(capabilities, list) and any(
+        has_workflow_draft = not has_fragments and isinstance(workflow, dict) and any(
+            isinstance(step, dict) and "apply_to_draft" in step
+            for step in workflow.get("steps", [])
+        )
+        has_cap_draft = not has_fragments and not has_workflow_draft and isinstance(capabilities, list) and any(
             isinstance(c, dict) and "apply_to_draft" in c for c in capabilities
         )
-        if has_frag_draft or has_cap_draft:
+        if has_frag_draft or has_workflow_draft or has_cap_draft:
             needs_grounding.append(case)
 
     print(f"Grounding apply_to for {len(needs_grounding)}/{len(cases)} cases...")
@@ -455,23 +544,23 @@ async def apply_grounding(
             )
 
             print(f"  <- [{index}] [{case_id}] grounded (cost=${cost:.4f})")
-            return grounded, cost
+            return case_item, grounded, cost
 
     tasks = [process_with_semaphore(case) for case in needs_grounding]
     results_with_costs = await asyncio.gather(*tasks)
 
-    # 用 grounded 结果替换原 case(按 index 匹配
+    # 用 grounded 结果替换原 case(按对象身份匹配,避免 index 缺失或重复时回填错 case
     grounded_map = {}
     total_cost = 0.0
-    for grounded, cost in results_with_costs:
-        grounded_map[grounded.get("index")] = grounded
+    for original_case, grounded, cost in results_with_costs:
+        grounded_map[id(original_case)] = grounded
         total_cost += cost
 
     updated_cases = []
     for case in cases:
-        idx = case.get("index")
-        if idx in grounded_map:
-            updated_cases.append(grounded_map[idx])
+        case_id = id(case)
+        if case_id in grounded_map:
+            updated_cases.append(grounded_map[case_id])
         else:
             updated_cases.append(case)