procedure_extract.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # -*- coding: utf-8 -*-
  2. """工序解构 · search_process 帖子 → workflow JSON → mode_process 表
  3. ================================================================================
  4. 单次大模型直出(无 agent / 无 validate 循环),prompt 见 prompts/procedure_extract_system.md。
  5. 配图下载转 base64(绕防盗链)随文本一起发。结果按工序拆行写 mode_process。
  6. 用法(一般由 server.py 起子进程调):
  7. python pipeline/procedure_extract.py --query-id q0000 --case-ids xhs_abc
  8. python pipeline/procedure_extract.py --query-id q0000 --case-ids xhs_abc --model google/gemini-3.1-flash-lite
  9. """
  10. import argparse
  11. import asyncio
  12. import base64
  13. import json
  14. import sys
  15. import time
  16. from datetime import datetime
  17. from pathlib import Path
  18. PROJECT_ROOT = Path(__file__).resolve().parents[3] # …/Agent
  19. sys.path.insert(0, str(PROJECT_ROOT))
  20. from dotenv import load_dotenv
  21. load_dotenv()
  22. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  23. HERE = Path(__file__).resolve().parent
  24. MW = HERE.parent
  25. sys.path.insert(0, str(MW))
  26. import db
  27. PROMPT_FILE = MW / "prompts" / "procedure_extract_system.md"
  28. DEFAULT_MODEL = "anthropic/claude-sonnet-4-6"
  29. MAX_IMAGES = 8
  30. MAX_IMG_DIM = 8000 # Anthropic 单维像素上限,超过整请求 400(长图必踩)
  31. # ── 以下 4 个助手原样取自 mode_procedure/mode-dsl/procedure_model_extract.py ──
  32. def _detect_image_mime(data: bytes):
  33. if not data or len(data) < 12:
  34. return None
  35. if data[:3] == b"\xff\xd8\xff":
  36. return "image/jpeg"
  37. if data[:8] == b"\x89PNG\r\n\x1a\n":
  38. return "image/png"
  39. if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
  40. return "image/webp"
  41. if data[:6] in (b"GIF87a", b"GIF89a"):
  42. return "image/gif"
  43. return None
  44. def _downscale_if_oversized(data, mime):
  45. """任一维 >MAX_IMG_DIM 则等比缩到最长边 ≤MAX_IMG_DIM,返回 (data, mime)。
  46. Anthropic 拒收任一维 >8000px 的图(整请求 400),公众号「长图」常踩此坑。
  47. 缩图失败则原样返回(至多退回原 400,不致崩)。"""
  48. try:
  49. from io import BytesIO
  50. from PIL import Image
  51. im = Image.open(BytesIO(data))
  52. w, h = im.size
  53. if max(w, h) <= MAX_IMG_DIM:
  54. return data, mime
  55. scale = (MAX_IMG_DIM - 100) / max(w, h)
  56. im = im.convert("RGB").resize((max(1, int(w * scale)), max(1, int(h * scale))))
  57. fmt, out_mime = ("JPEG", "image/jpeg") if mime == "image/jpeg" else ("PNG", "image/png")
  58. buf = BytesIO()
  59. im.save(buf, format=fmt)
  60. print(f" 🔧 超限图 {w}x{h} → {im.width}x{im.height}(避 Anthropic 8000px 上限)")
  61. return buf.getvalue(), out_mime
  62. except Exception:
  63. return data, mime
  64. async def _fetch_data_url(url, sem):
  65. from agent.tools.builtin.file.image_cdn import _download_image
  66. async with sem:
  67. try:
  68. data = await _download_image(url)
  69. except Exception:
  70. return None
  71. mime = _detect_image_mime(data)
  72. if mime is None:
  73. return None
  74. data, mime = _downscale_if_oversized(data, mime)
  75. return f"data:{mime};base64,{base64.b64encode(data).decode()}"
  76. async def _collect_images(urls, max_images, concurrency):
  77. urls = [u for u in urls if isinstance(u, str) and u][:max_images]
  78. if not urls:
  79. return []
  80. sem = asyncio.Semaphore(concurrency)
  81. results = await asyncio.gather(*[_fetch_data_url(u, sem) for u in urls])
  82. return [d for d in results if d]
  83. def _validate_wf(data):
  84. if not isinstance(data, dict):
  85. return "顶层必须是 JSON 对象"
  86. if "procedures" not in data:
  87. return '缺少 "procedures" 字段'
  88. if not isinstance(data["procedures"], list):
  89. return '"procedures" 必须是数组'
  90. return None
  91. def _sanitize_workflow(data):
  92. dropped = {"procedures": 0, "steps": 0, "io": 0}
  93. procs = data.get("procedures")
  94. if not isinstance(procs, list):
  95. return data, dropped
  96. clean_procs = []
  97. for p in procs:
  98. if not isinstance(p, dict):
  99. dropped["procedures"] += 1
  100. continue
  101. steps = p.get("steps")
  102. if isinstance(steps, list):
  103. kept = []
  104. for s in steps:
  105. if not isinstance(s, dict):
  106. dropped["steps"] += 1
  107. continue
  108. for io in ("inputs", "outputs"):
  109. if isinstance(s.get(io), list):
  110. before = len(s[io])
  111. s[io] = [x for x in s[io] if isinstance(x, dict)]
  112. dropped["io"] += before - len(s[io])
  113. kept.append(s)
  114. p["steps"] = kept
  115. if not isinstance(p.get("declarations"), dict):
  116. p.pop("declarations", None)
  117. if not isinstance(p.get("type_registry"), dict):
  118. p.pop("type_registry", None)
  119. clean_procs.append(p)
  120. data["procedures"] = clean_procs
  121. return data, dropped
  122. # ── 助手复制结束 ──────────────────────────────────────────────────────────────
  123. async def extract_one(row, system, llm_call, model, args):
  124. """单帖工序解构 → 写 mode_process。返回 cost。"""
  125. cid = row["case_id"]
  126. t0 = time.monotonic()
  127. post_text = (f"【标题】{row['title'] or ''}\n【来源】{row['url'] or ''}\n"
  128. f"【正文】\n{row['body'] or ''}")
  129. data_urls = [] if args.no_images else await _collect_images(
  130. row["images"] or [], args.max_images, args.max_concurrent)
  131. print(f"🖼️ {cid} 配图 {len(data_urls)}/{len(row['images'] or [])} 张")
  132. if data_urls:
  133. user_content = [{"type": "text", "text": post_text}]
  134. for u in data_urls:
  135. user_content.append({"type": "image_url", "image_url": {"url": u}})
  136. messages = [{"role": "system", "content": system},
  137. {"role": "user", "content": user_content}]
  138. else:
  139. messages = [{"role": "system", "content": system},
  140. {"role": "user", "content": post_text}]
  141. data, cost = await call_llm_with_retry(
  142. llm_call=llm_call, messages=messages, model=model,
  143. temperature=0.2, max_tokens=args.max_tokens,
  144. validate_fn=_validate_wf, task_name=f"ProcExtract[{cid}]",
  145. )
  146. if not data:
  147. print(f"❌ {cid} 解构失败(重试耗尽)")
  148. return cost
  149. data, dropped = _sanitize_workflow(data)
  150. if any(dropped.values()):
  151. print(f"🧹 {cid} 清洗:丢弃 procedure {dropped['procedures']} / "
  152. f"step {dropped['steps']} / io {dropped['io']}")
  153. dur = round(time.monotonic() - t0, 1)
  154. n = db.replace_process(args.query_id, cid, row["platform"], row["title"],
  155. data, model, args.version, cost, dur)
  156. out_dir = MW / "runs" / "mode_process" / args.query_id # 按 query 分组存放
  157. out_dir.mkdir(parents=True, exist_ok=True)
  158. (out_dir / f"{cid}.json").write_text(
  159. json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")
  160. print(f" ✅ {cid} → {n} 个工序 · ${cost:.4f} · {dur}s")
  161. return cost
  162. async def run(args):
  163. case_ids = [c.strip() for c in args.case_ids.split(",") if c.strip()]
  164. # 方案A:解构前先按 case 全局去重。已真实解构过的帖不再调 LLM(省钱),
  165. # 跨 query 的用 link_* 复制行补齐关联(cost=0)。--force 跳过去重强制重解构。
  166. linked = skipped = 0
  167. todo = []
  168. for cid in dict.fromkeys(case_ids): # 顺手去掉同批重复 case
  169. if not args.force:
  170. ex = db.latest_real_version(cid, mode="process")
  171. if ex:
  172. if ex["query_id"] == args.query_id:
  173. print(f"♻️ {cid} 本 query 已解构(版本 {ex['version']}),跳过")
  174. skipped += 1
  175. else:
  176. n = db.link_process(args.query_id, cid, mode="process")
  177. print(f"♻️ {cid} 已在 {ex['query_id']} 解构(版本 {ex['version']}),"
  178. f"link 补齐 {n} 行 · $0")
  179. linked += 1
  180. continue
  181. todo.append(cid)
  182. rows = []
  183. for cid in todo:
  184. row = db.fetch_post(args.query_id, cid, table="search_process")
  185. if row is None:
  186. print(f"⚠️ {args.query_id}/{cid} 不在 search_process,跳过")
  187. continue
  188. rows.append(row)
  189. if not rows:
  190. if linked or skipped:
  191. print(f"✅ 无需 LLM 解构(link 补齐 {linked} 帖 · 已存在跳过 {skipped} 帖)")
  192. return 0
  193. print("❌ 没有可解构的帖子"); return 1
  194. prompt_file = Path(args.prompt_file) if getattr(args, "prompt_file", None) else PROMPT_FILE
  195. system = prompt_file.read_text(encoding="utf-8")
  196. from agent.llm.openrouter import create_openrouter_llm_call
  197. llm_call = create_openrouter_llm_call(model=args.model)
  198. args.version = args.version or ("v_" + datetime.now().strftime("%m%d%H%M"))
  199. print(f"🤖 工序解构 {len(rows)} 帖 · 模型 {args.model} · 版本 {args.version}")
  200. costs = []
  201. for row in rows: # 工序解构 token 重,串行跑,避免 OpenRouter 限流
  202. costs.append(await extract_one(row, system, llm_call, args.model, args))
  203. print(f"\n📊 完成 {len(rows)} 帖 · link 补齐 {linked} 帖 · 总成本 ${sum(costs):.4f}")
  204. return 0
  205. def main():
  206. p = argparse.ArgumentParser(description="工序解构:search_process 帖子 → mode_process")
  207. p.add_argument("--query-id", required=True)
  208. p.add_argument("--case-ids", required=True, help="逗号分隔 case_id 列表")
  209. p.add_argument("--model", default=DEFAULT_MODEL)
  210. p.add_argument("--version", default=None, help="默认自动 v_月日时分")
  211. p.add_argument("--max-images", type=int, default=MAX_IMAGES)
  212. p.add_argument("--max-concurrent", type=int, default=4)
  213. p.add_argument("--max-tokens", type=int, default=8000)
  214. p.add_argument("--no-images", action="store_true")
  215. p.add_argument("--force", action="store_true",
  216. help="跳过去重,强制重解构(换 prompt/模型做对比时用)")
  217. p.add_argument("--prompt-file", default=None,
  218. help="覆盖默认解构 prompt(临时,仅本次;不改 prompts/*.md)")
  219. args = p.parse_args()
  220. raise SystemExit(asyncio.run(run(args)))
  221. if __name__ == "__main__":
  222. main()