procedure_extract.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # -*- coding: utf-8 -*-
  2. """工序解构 · search_data 帖子 → workflow JSON → mode_process 表
  3. ================================================================================
  4. 单次大模型直出(无 agent / 无 validate 循环),prompt 见 prompts/procedure_extract_system.md。
  5. 配图下载转 base64(绕防盗链)随文本一起发。结果按工序拆行写 mode_process。
  6. 用法(一般由 server.py 起子进程调):
  7. python pipeline/procedure_extract.py --query-id q0000 --case-ids xhs_abc
  8. python pipeline/procedure_extract.py --query-id q0000 --case-ids xhs_abc --model google/gemini-3.1-flash-lite
  9. """
  10. import argparse
  11. import asyncio
  12. import base64
  13. import json
  14. import sys
  15. import time
  16. from datetime import datetime
  17. from pathlib import Path
  18. PROJECT_ROOT = Path(__file__).resolve().parents[3] # …/Agent
  19. sys.path.insert(0, str(PROJECT_ROOT))
  20. from dotenv import load_dotenv
  21. load_dotenv()
  22. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  23. HERE = Path(__file__).resolve().parent
  24. MW = HERE.parent
  25. sys.path.insert(0, str(MW))
  26. import db
  27. PROMPT_FILE = MW / "prompts" / "procedure_extract_system.md"
  28. DEFAULT_MODEL = "anthropic/claude-sonnet-4-6"
  29. MAX_IMAGES = 8
  30. # ── 以下 4 个助手原样取自 mode_procedure/mode-dsl/procedure_model_extract.py ──
  31. def _detect_image_mime(data: bytes):
  32. if not data or len(data) < 12:
  33. return None
  34. if data[:3] == b"\xff\xd8\xff":
  35. return "image/jpeg"
  36. if data[:8] == b"\x89PNG\r\n\x1a\n":
  37. return "image/png"
  38. if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
  39. return "image/webp"
  40. if data[:6] in (b"GIF87a", b"GIF89a"):
  41. return "image/gif"
  42. return None
  43. async def _fetch_data_url(url, sem):
  44. from agent.tools.builtin.file.image_cdn import _download_image
  45. async with sem:
  46. try:
  47. data = await _download_image(url)
  48. except Exception:
  49. return None
  50. mime = _detect_image_mime(data)
  51. if mime is None:
  52. return None
  53. return f"data:{mime};base64,{base64.b64encode(data).decode()}"
  54. async def _collect_images(urls, max_images, concurrency):
  55. urls = [u for u in urls if isinstance(u, str) and u][:max_images]
  56. if not urls:
  57. return []
  58. sem = asyncio.Semaphore(concurrency)
  59. results = await asyncio.gather(*[_fetch_data_url(u, sem) for u in urls])
  60. return [d for d in results if d]
  61. def _validate_wf(data):
  62. if not isinstance(data, dict):
  63. return "顶层必须是 JSON 对象"
  64. if "procedures" not in data:
  65. return '缺少 "procedures" 字段'
  66. if not isinstance(data["procedures"], list):
  67. return '"procedures" 必须是数组'
  68. return None
  69. def _sanitize_workflow(data):
  70. dropped = {"procedures": 0, "steps": 0, "io": 0}
  71. procs = data.get("procedures")
  72. if not isinstance(procs, list):
  73. return data, dropped
  74. clean_procs = []
  75. for p in procs:
  76. if not isinstance(p, dict):
  77. dropped["procedures"] += 1
  78. continue
  79. steps = p.get("steps")
  80. if isinstance(steps, list):
  81. kept = []
  82. for s in steps:
  83. if not isinstance(s, dict):
  84. dropped["steps"] += 1
  85. continue
  86. for io in ("inputs", "outputs"):
  87. if isinstance(s.get(io), list):
  88. before = len(s[io])
  89. s[io] = [x for x in s[io] if isinstance(x, dict)]
  90. dropped["io"] += before - len(s[io])
  91. kept.append(s)
  92. p["steps"] = kept
  93. if not isinstance(p.get("declarations"), dict):
  94. p.pop("declarations", None)
  95. if not isinstance(p.get("type_registry"), dict):
  96. p.pop("type_registry", None)
  97. clean_procs.append(p)
  98. data["procedures"] = clean_procs
  99. return data, dropped
  100. # ── 助手复制结束 ──────────────────────────────────────────────────────────────
  101. async def extract_one(row, system, llm_call, model, args):
  102. """单帖工序解构 → 写 mode_process。返回 cost。"""
  103. cid = row["case_id"]
  104. t0 = time.monotonic()
  105. post_text = (f"【标题】{row['title'] or ''}\n【来源】{row['url'] or ''}\n"
  106. f"【正文】\n{row['body'] or ''}")
  107. data_urls = [] if args.no_images else await _collect_images(
  108. row["images"] or [], args.max_images, args.max_concurrent)
  109. print(f"🖼️ {cid} 配图 {len(data_urls)}/{len(row['images'] or [])} 张")
  110. if data_urls:
  111. user_content = [{"type": "text", "text": post_text}]
  112. for u in data_urls:
  113. user_content.append({"type": "image_url", "image_url": {"url": u}})
  114. messages = [{"role": "system", "content": system},
  115. {"role": "user", "content": user_content}]
  116. else:
  117. messages = [{"role": "system", "content": system},
  118. {"role": "user", "content": post_text}]
  119. data, cost = await call_llm_with_retry(
  120. llm_call=llm_call, messages=messages, model=model,
  121. temperature=0.2, max_tokens=args.max_tokens,
  122. validate_fn=_validate_wf, task_name=f"ProcExtract[{cid}]",
  123. )
  124. if not data:
  125. print(f"❌ {cid} 解构失败(重试耗尽)")
  126. return cost
  127. data, dropped = _sanitize_workflow(data)
  128. if any(dropped.values()):
  129. print(f"🧹 {cid} 清洗:丢弃 procedure {dropped['procedures']} / "
  130. f"step {dropped['steps']} / io {dropped['io']}")
  131. dur = round(time.monotonic() - t0, 1)
  132. n = db.replace_process(args.query_id, cid, row["platform"], row["title"],
  133. data, model, args.version, cost, dur)
  134. out_dir = MW / "runs" / "procedures"
  135. out_dir.mkdir(parents=True, exist_ok=True)
  136. (out_dir / f"{cid}_{args.version}.json").write_text(
  137. json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
  138. print(f" ✅ {cid} → {n} 个工序 · ${cost:.4f} · {dur}s")
  139. return cost
  140. async def run(args):
  141. case_ids = [c.strip() for c in args.case_ids.split(",") if c.strip()]
  142. rows = []
  143. for cid in case_ids:
  144. row = db.fetch_post(args.query_id, cid)
  145. if row is None:
  146. print(f"⚠️ {args.query_id}/{cid} 不在 search_data,跳过")
  147. continue
  148. rows.append(row)
  149. if not rows:
  150. print("❌ 没有可解构的帖子"); return 1
  151. system = PROMPT_FILE.read_text(encoding="utf-8")
  152. from agent.llm.openrouter import create_openrouter_llm_call
  153. llm_call = create_openrouter_llm_call(model=args.model)
  154. args.version = args.version or ("v_" + datetime.now().strftime("%m%d%H%M"))
  155. print(f"🤖 工序解构 {len(rows)} 帖 · 模型 {args.model} · 版本 {args.version}")
  156. costs = []
  157. for row in rows: # 工序解构 token 重,串行跑,避免 OpenRouter 限流
  158. costs.append(await extract_one(row, system, llm_call, args.model, args))
  159. print(f"\n📊 完成 {len(rows)} 帖 · 总成本 ${sum(costs):.4f}")
  160. return 0
  161. def main():
  162. p = argparse.ArgumentParser(description="工序解构:search_data 帖子 → mode_process")
  163. p.add_argument("--query-id", required=True)
  164. p.add_argument("--case-ids", required=True, help="逗号分隔 case_id 列表")
  165. p.add_argument("--model", default=DEFAULT_MODEL)
  166. p.add_argument("--version", default=None, help="默认自动 v_月日时分")
  167. p.add_argument("--max-images", type=int, default=MAX_IMAGES)
  168. p.add_argument("--max-concurrent", type=int, default=4)
  169. p.add_argument("--max-tokens", type=int, default=8000)
  170. p.add_argument("--no-images", action="store_true")
  171. args = p.parse_args()
  172. raise SystemExit(asyncio.run(run(args)))
  173. if __name__ == "__main__":
  174. main()