Просмотр исходного кода

feat(mode_workflow): 工序解构 pipeline(读库→LLM→mode_process)

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
刘文武 5 дней назад
Родитель
Сommit
b2e2ad5878
1 измененных файлов с 206 добавлено и 0 удалено
  1. 206 0
      examples/mode_workflow/pipeline/procedure_extract.py

+ 206 - 0
examples/mode_workflow/pipeline/procedure_extract.py

@@ -0,0 +1,206 @@
+# -*- coding: utf-8 -*-
+"""工序解构 · search_data 帖子 → workflow JSON → mode_process 表
+================================================================================
+单次大模型直出(无 agent / 无 validate 循环),prompt 见 prompts/procedure_extract_system.md。
+配图下载转 base64(绕防盗链)随文本一起发。结果按工序拆行写 mode_process。
+
+用法(一般由 server.py 起子进程调):
+  python pipeline/procedure_extract.py --query-id q0000 --case-ids xhs_abc
+  python pipeline/procedure_extract.py --query-id q0000 --case-ids xhs_abc --model google/gemini-3.1-flash-lite
+"""
+import argparse
+import asyncio
+import base64
+import json
+import sys
+import time
+from datetime import datetime
+from pathlib import Path
+
+PROJECT_ROOT = Path(__file__).resolve().parents[3]   # …/Agent
+sys.path.insert(0, str(PROJECT_ROOT))
+
+from dotenv import load_dotenv
+load_dotenv()
+
+from examples.process_pipeline.script.llm_helper import call_llm_with_retry
+
+HERE = Path(__file__).resolve().parent
+MW = HERE.parent
+sys.path.insert(0, str(MW))
+import db
+
+PROMPT_FILE = MW / "prompts" / "procedure_extract_system.md"
+DEFAULT_MODEL = "anthropic/claude-sonnet-4-6"
+MAX_IMAGES = 8
+
+
+# ── 以下 4 个助手原样取自 mode_procedure/mode-dsl/procedure_model_extract.py ──
+
+def _detect_image_mime(data: bytes):
+    if not data or len(data) < 12:
+        return None
+    if data[:3] == b"\xff\xd8\xff":
+        return "image/jpeg"
+    if data[:8] == b"\x89PNG\r\n\x1a\n":
+        return "image/png"
+    if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
+        return "image/webp"
+    if data[:6] in (b"GIF87a", b"GIF89a"):
+        return "image/gif"
+    return None
+
+
+async def _fetch_data_url(url, sem):
+    from agent.tools.builtin.file.image_cdn import _download_image
+    async with sem:
+        try:
+            data = await _download_image(url)
+        except Exception:
+            return None
+    mime = _detect_image_mime(data)
+    if mime is None:
+        return None
+    return f"data:{mime};base64,{base64.b64encode(data).decode()}"
+
+
+async def _collect_images(urls, max_images, concurrency):
+    urls = [u for u in urls if isinstance(u, str) and u][:max_images]
+    if not urls:
+        return []
+    sem = asyncio.Semaphore(concurrency)
+    results = await asyncio.gather(*[_fetch_data_url(u, sem) for u in urls])
+    return [d for d in results if d]
+
+
+def _validate_wf(data):
+    if not isinstance(data, dict):
+        return "顶层必须是 JSON 对象"
+    if "procedures" not in data:
+        return '缺少 "procedures" 字段'
+    if not isinstance(data["procedures"], list):
+        return '"procedures" 必须是数组'
+    return None
+
+
+def _sanitize_workflow(data):
+    dropped = {"procedures": 0, "steps": 0, "io": 0}
+    procs = data.get("procedures")
+    if not isinstance(procs, list):
+        return data, dropped
+    clean_procs = []
+    for p in procs:
+        if not isinstance(p, dict):
+            dropped["procedures"] += 1
+            continue
+        steps = p.get("steps")
+        if isinstance(steps, list):
+            kept = []
+            for s in steps:
+                if not isinstance(s, dict):
+                    dropped["steps"] += 1
+                    continue
+                for io in ("inputs", "outputs"):
+                    if isinstance(s.get(io), list):
+                        before = len(s[io])
+                        s[io] = [x for x in s[io] if isinstance(x, dict)]
+                        dropped["io"] += before - len(s[io])
+                kept.append(s)
+            p["steps"] = kept
+        if not isinstance(p.get("declarations"), dict):
+            p.pop("declarations", None)
+        if not isinstance(p.get("type_registry"), dict):
+            p.pop("type_registry", None)
+        clean_procs.append(p)
+    data["procedures"] = clean_procs
+    return data, dropped
+
+# ── 助手复制结束 ──────────────────────────────────────────────────────────────
+
+
+async def extract_one(row, system, llm_call, model, args):
+    """单帖工序解构 → 写 mode_process。返回 cost。"""
+    cid = row["case_id"]
+    t0 = time.monotonic()
+    post_text = (f"【标题】{row['title'] or ''}\n【来源】{row['url'] or ''}\n"
+                 f"【正文】\n{row['body'] or ''}")
+    data_urls = [] if args.no_images else await _collect_images(
+        row["images"] or [], args.max_images, args.max_concurrent)
+    print(f"🖼️  {cid} 配图 {len(data_urls)}/{len(row['images'] or [])} 张")
+
+    if data_urls:
+        user_content = [{"type": "text", "text": post_text}]
+        for u in data_urls:
+            user_content.append({"type": "image_url", "image_url": {"url": u}})
+        messages = [{"role": "system", "content": system},
+                    {"role": "user", "content": user_content}]
+    else:
+        messages = [{"role": "system", "content": system},
+                    {"role": "user", "content": post_text}]
+
+    data, cost = await call_llm_with_retry(
+        llm_call=llm_call, messages=messages, model=model,
+        temperature=0.2, max_tokens=args.max_tokens,
+        validate_fn=_validate_wf, task_name=f"ProcExtract[{cid}]",
+    )
+    if not data:
+        print(f"❌ {cid} 解构失败(重试耗尽)")
+        return cost
+
+    data, dropped = _sanitize_workflow(data)
+    if any(dropped.values()):
+        print(f"🧹 {cid} 清洗:丢弃 procedure {dropped['procedures']} / "
+              f"step {dropped['steps']} / io {dropped['io']}")
+
+    dur = round(time.monotonic() - t0, 1)
+    n = db.replace_process(args.query_id, cid, row["platform"], row["title"],
+                           data, model, args.version, cost, dur)
+    out_dir = MW / "runs" / "procedures"
+    out_dir.mkdir(parents=True, exist_ok=True)
+    (out_dir / f"{cid}_{args.version}.json").write_text(
+        json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
+    print(f"   ✅ {cid} → {n} 个工序 · ${cost:.4f} · {dur}s")
+    return cost
+
+
+async def run(args):
+    case_ids = [c.strip() for c in args.case_ids.split(",") if c.strip()]
+    rows = []
+    for cid in case_ids:
+        row = db.fetch_post(args.query_id, cid)
+        if row is None:
+            print(f"⚠️ {args.query_id}/{cid} 不在 search_data,跳过")
+            continue
+        rows.append(row)
+    if not rows:
+        print("❌ 没有可解构的帖子"); return 1
+
+    system = PROMPT_FILE.read_text(encoding="utf-8")
+    from agent.llm.openrouter import create_openrouter_llm_call
+    llm_call = create_openrouter_llm_call(model=args.model)
+    args.version = args.version or ("v_" + datetime.now().strftime("%m%d%H%M"))
+    print(f"🤖 工序解构 {len(rows)} 帖 · 模型 {args.model} · 版本 {args.version}")
+
+    costs = []
+    for row in rows:   # 工序解构 token 重,串行跑,避免 OpenRouter 限流
+        costs.append(await extract_one(row, system, llm_call, args.model, args))
+    print(f"\n📊 完成 {len(rows)} 帖 · 总成本 ${sum(costs):.4f}")
+    return 0
+
+
+def main():
+    p = argparse.ArgumentParser(description="工序解构:search_data 帖子 → mode_process")
+    p.add_argument("--query-id", required=True)
+    p.add_argument("--case-ids", required=True, help="逗号分隔 case_id 列表")
+    p.add_argument("--model", default=DEFAULT_MODEL)
+    p.add_argument("--version", default=None, help="默认自动 v_月日时分")
+    p.add_argument("--max-images", type=int, default=MAX_IMAGES)
+    p.add_argument("--max-concurrent", type=int, default=4)
+    p.add_argument("--max-tokens", type=int, default=8000)
+    p.add_argument("--no-images", action="store_true")
+    args = p.parse_args()
+    raise SystemExit(asyncio.run(run(args)))
+
+
+if __name__ == "__main__":
+    main()