|
|
@@ -0,0 +1,206 @@
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""工序解构 · search_data 帖子 → workflow JSON → mode_process 表
|
|
|
+================================================================================
|
|
|
+单次大模型直出(无 agent / 无 validate 循环),prompt 见 prompts/procedure_extract_system.md。
|
|
|
+配图下载转 base64(绕防盗链)随文本一起发。结果按工序拆行写 mode_process。
|
|
|
+
|
|
|
+用法(一般由 server.py 起子进程调):
|
|
|
+ python pipeline/procedure_extract.py --query-id q0000 --case-ids xhs_abc
|
|
|
+ python pipeline/procedure_extract.py --query-id q0000 --case-ids xhs_abc --model google/gemini-3.1-flash-lite
|
|
|
+"""
|
|
|
+import argparse
|
|
|
+import asyncio
|
|
|
+import base64
|
|
|
+import json
|
|
|
+import sys
|
|
|
+import time
|
|
|
+from datetime import datetime
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+PROJECT_ROOT = Path(__file__).resolve().parents[3] # …/Agent
|
|
|
+sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
+
|
|
|
+from dotenv import load_dotenv
|
|
|
+load_dotenv()
|
|
|
+
|
|
|
+from examples.process_pipeline.script.llm_helper import call_llm_with_retry
|
|
|
+
|
|
|
+HERE = Path(__file__).resolve().parent
|
|
|
+MW = HERE.parent
|
|
|
+sys.path.insert(0, str(MW))
|
|
|
+import db
|
|
|
+
|
|
|
+PROMPT_FILE = MW / "prompts" / "procedure_extract_system.md"
|
|
|
+DEFAULT_MODEL = "anthropic/claude-sonnet-4-6"
|
|
|
+MAX_IMAGES = 8
|
|
|
+
|
|
|
+
|
|
|
+# ── 以下 4 个助手原样取自 mode_procedure/mode-dsl/procedure_model_extract.py ──
|
|
|
+
|
|
|
+def _detect_image_mime(data: bytes):
|
|
|
+ if not data or len(data) < 12:
|
|
|
+ return None
|
|
|
+ if data[:3] == b"\xff\xd8\xff":
|
|
|
+ return "image/jpeg"
|
|
|
+ if data[:8] == b"\x89PNG\r\n\x1a\n":
|
|
|
+ return "image/png"
|
|
|
+ if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
|
|
|
+ return "image/webp"
|
|
|
+ if data[:6] in (b"GIF87a", b"GIF89a"):
|
|
|
+ return "image/gif"
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+async def _fetch_data_url(url, sem):
|
|
|
+ from agent.tools.builtin.file.image_cdn import _download_image
|
|
|
+ async with sem:
|
|
|
+ try:
|
|
|
+ data = await _download_image(url)
|
|
|
+ except Exception:
|
|
|
+ return None
|
|
|
+ mime = _detect_image_mime(data)
|
|
|
+ if mime is None:
|
|
|
+ return None
|
|
|
+ return f"data:{mime};base64,{base64.b64encode(data).decode()}"
|
|
|
+
|
|
|
+
|
|
|
+async def _collect_images(urls, max_images, concurrency):
|
|
|
+ urls = [u for u in urls if isinstance(u, str) and u][:max_images]
|
|
|
+ if not urls:
|
|
|
+ return []
|
|
|
+ sem = asyncio.Semaphore(concurrency)
|
|
|
+ results = await asyncio.gather(*[_fetch_data_url(u, sem) for u in urls])
|
|
|
+ return [d for d in results if d]
|
|
|
+
|
|
|
+
|
|
|
+def _validate_wf(data):
|
|
|
+ if not isinstance(data, dict):
|
|
|
+ return "顶层必须是 JSON 对象"
|
|
|
+ if "procedures" not in data:
|
|
|
+ return '缺少 "procedures" 字段'
|
|
|
+ if not isinstance(data["procedures"], list):
|
|
|
+ return '"procedures" 必须是数组'
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def _sanitize_workflow(data):
|
|
|
+ dropped = {"procedures": 0, "steps": 0, "io": 0}
|
|
|
+ procs = data.get("procedures")
|
|
|
+ if not isinstance(procs, list):
|
|
|
+ return data, dropped
|
|
|
+ clean_procs = []
|
|
|
+ for p in procs:
|
|
|
+ if not isinstance(p, dict):
|
|
|
+ dropped["procedures"] += 1
|
|
|
+ continue
|
|
|
+ steps = p.get("steps")
|
|
|
+ if isinstance(steps, list):
|
|
|
+ kept = []
|
|
|
+ for s in steps:
|
|
|
+ if not isinstance(s, dict):
|
|
|
+ dropped["steps"] += 1
|
|
|
+ continue
|
|
|
+ for io in ("inputs", "outputs"):
|
|
|
+ if isinstance(s.get(io), list):
|
|
|
+ before = len(s[io])
|
|
|
+ s[io] = [x for x in s[io] if isinstance(x, dict)]
|
|
|
+ dropped["io"] += before - len(s[io])
|
|
|
+ kept.append(s)
|
|
|
+ p["steps"] = kept
|
|
|
+ if not isinstance(p.get("declarations"), dict):
|
|
|
+ p.pop("declarations", None)
|
|
|
+ if not isinstance(p.get("type_registry"), dict):
|
|
|
+ p.pop("type_registry", None)
|
|
|
+ clean_procs.append(p)
|
|
|
+ data["procedures"] = clean_procs
|
|
|
+ return data, dropped
|
|
|
+
|
|
|
+# ── 助手复制结束 ──────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+
|
|
|
+async def extract_one(row, system, llm_call, model, args):
|
|
|
+ """单帖工序解构 → 写 mode_process。返回 cost。"""
|
|
|
+ cid = row["case_id"]
|
|
|
+ t0 = time.monotonic()
|
|
|
+ post_text = (f"【标题】{row['title'] or ''}\n【来源】{row['url'] or ''}\n"
|
|
|
+ f"【正文】\n{row['body'] or ''}")
|
|
|
+ data_urls = [] if args.no_images else await _collect_images(
|
|
|
+ row["images"] or [], args.max_images, args.max_concurrent)
|
|
|
+ print(f"🖼️ {cid} 配图 {len(data_urls)}/{len(row['images'] or [])} 张")
|
|
|
+
|
|
|
+ if data_urls:
|
|
|
+ user_content = [{"type": "text", "text": post_text}]
|
|
|
+ for u in data_urls:
|
|
|
+ user_content.append({"type": "image_url", "image_url": {"url": u}})
|
|
|
+ messages = [{"role": "system", "content": system},
|
|
|
+ {"role": "user", "content": user_content}]
|
|
|
+ else:
|
|
|
+ messages = [{"role": "system", "content": system},
|
|
|
+ {"role": "user", "content": post_text}]
|
|
|
+
|
|
|
+ data, cost = await call_llm_with_retry(
|
|
|
+ llm_call=llm_call, messages=messages, model=model,
|
|
|
+ temperature=0.2, max_tokens=args.max_tokens,
|
|
|
+ validate_fn=_validate_wf, task_name=f"ProcExtract[{cid}]",
|
|
|
+ )
|
|
|
+ if not data:
|
|
|
+ print(f"❌ {cid} 解构失败(重试耗尽)")
|
|
|
+ return cost
|
|
|
+
|
|
|
+ data, dropped = _sanitize_workflow(data)
|
|
|
+ if any(dropped.values()):
|
|
|
+ print(f"🧹 {cid} 清洗:丢弃 procedure {dropped['procedures']} / "
|
|
|
+ f"step {dropped['steps']} / io {dropped['io']}")
|
|
|
+
|
|
|
+ dur = round(time.monotonic() - t0, 1)
|
|
|
+ n = db.replace_process(args.query_id, cid, row["platform"], row["title"],
|
|
|
+ data, model, args.version, cost, dur)
|
|
|
+ out_dir = MW / "runs" / "procedures"
|
|
|
+ out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ (out_dir / f"{cid}_{args.version}.json").write_text(
|
|
|
+ json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
|
+ print(f" ✅ {cid} → {n} 个工序 · ${cost:.4f} · {dur}s")
|
|
|
+ return cost
|
|
|
+
|
|
|
+
|
|
|
+async def run(args):
|
|
|
+ case_ids = [c.strip() for c in args.case_ids.split(",") if c.strip()]
|
|
|
+ rows = []
|
|
|
+ for cid in case_ids:
|
|
|
+ row = db.fetch_post(args.query_id, cid)
|
|
|
+ if row is None:
|
|
|
+ print(f"⚠️ {args.query_id}/{cid} 不在 search_data,跳过")
|
|
|
+ continue
|
|
|
+ rows.append(row)
|
|
|
+ if not rows:
|
|
|
+ print("❌ 没有可解构的帖子"); return 1
|
|
|
+
|
|
|
+ system = PROMPT_FILE.read_text(encoding="utf-8")
|
|
|
+ from agent.llm.openrouter import create_openrouter_llm_call
|
|
|
+ llm_call = create_openrouter_llm_call(model=args.model)
|
|
|
+ args.version = args.version or ("v_" + datetime.now().strftime("%m%d%H%M"))
|
|
|
+ print(f"🤖 工序解构 {len(rows)} 帖 · 模型 {args.model} · 版本 {args.version}")
|
|
|
+
|
|
|
+ costs = []
|
|
|
+ for row in rows: # 工序解构 token 重,串行跑,避免 OpenRouter 限流
|
|
|
+ costs.append(await extract_one(row, system, llm_call, args.model, args))
|
|
|
+ print(f"\n📊 完成 {len(rows)} 帖 · 总成本 ${sum(costs):.4f}")
|
|
|
+ return 0
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ p = argparse.ArgumentParser(description="工序解构:search_data 帖子 → mode_process")
|
|
|
+ p.add_argument("--query-id", required=True)
|
|
|
+ p.add_argument("--case-ids", required=True, help="逗号分隔 case_id 列表")
|
|
|
+ p.add_argument("--model", default=DEFAULT_MODEL)
|
|
|
+ p.add_argument("--version", default=None, help="默认自动 v_月日时分")
|
|
|
+ p.add_argument("--max-images", type=int, default=MAX_IMAGES)
|
|
|
+ p.add_argument("--max-concurrent", type=int, default=4)
|
|
|
+ p.add_argument("--max-tokens", type=int, default=8000)
|
|
|
+ p.add_argument("--no-images", action="store_true")
|
|
|
+ args = p.parse_args()
|
|
|
+ raise SystemExit(asyncio.run(run(args)))
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|