batch_extract_groups.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. #!/usr/bin/env python3
  2. """
  3. 按 method-groups.json 的分组,把每组的所有成员帖(含图片)打包喂给 LLM,
  4. 按 docs/prompts/batch.md 模板做跨帖批次提炼,每组产出一份 JSON。
  5. - model: claude-opus-4-7
  6. - effort: xhigh(开启扩展思考)
  7. - 模型可用 Read 工具读取每个帖子配套的图片
  8. - stream-json 实时输出,便于看到中间过程
  9. - 已产出的 group 默认跳过,--force 可重跑
  10. """
  11. from __future__ import annotations
  12. import argparse
  13. import json
  14. import subprocess
  15. import sys
  16. import time
  17. from pathlib import Path
  18. ROOT = Path(__file__).resolve().parent
  19. PROMPT_PATH = Path("/Users/sunlit/Code/Agent/knowhub/docs/prompts/batch.md")
  20. RESULT_JSON = ROOT / "result.json"
  21. GROUPS_JSON = ROOT / "method-groups.json"
  22. IMAGES_DIR = ROOT / "images"
  23. OUT_DIR = ROOT / "batch_extracted"
  24. MODEL = "claude-opus-4-7"
  25. EFFORT = "xhigh"
  26. # batch.md 里要被整体替换掉的占位段落(包含示例的 3 个 post + 「(按需续)」)
  27. PROMPT_PLACEHOLDER_BLOCK = """[POST id=p1]
  28. {post_1_content}
  29. [POST id=p2]
  30. {post_2_content}
  31. [POST id=p3]
  32. {post_3_content}
  33. (按需续)"""
  34. def build_post_block(post: dict, post_id: str) -> str:
  35. """把一条帖子组装成喂给 LLM 的 [POST id=pN] 区块。"""
  36. lines: list[str] = []
  37. lines.append(f"[POST id={post_id}]")
  38. lines.append(f"标题:{post.get('title') or ''}")
  39. if post.get("author"):
  40. lines.append(f"作者:{post['author']}")
  41. if post.get("category"):
  42. lines.append(f"分类:{post['category']}")
  43. if post.get("description"):
  44. lines.append(f"摘要:{post['description']}")
  45. if post.get("method"):
  46. lines.append(f"方法概述:{post['method']}")
  47. if post.get("url"):
  48. lines.append(f"URL:{post['url']}")
  49. body = (post.get("body") or "").strip()
  50. lines.append("")
  51. lines.append("正文:")
  52. lines.append(body if body else "(无正文)")
  53. return "\n".join(lines)
  54. def render_prompt(template: str, group: dict, posts_by_index: dict[int, dict]) -> tuple[str, list[str]]:
  55. """渲染 batch prompt:替换示例 post 块,并附加图片读取指引。
  56. 返回 (prompt_text, image_paths)
  57. """
  58. member_ids: list[str] = group.get("member_post_ids") or []
  59. blocks: list[str] = []
  60. image_paths: list[str] = []
  61. for pid in member_ids:
  62. # "p1" -> index 1
  63. if not (pid.startswith("p") and pid[1:].isdigit()):
  64. raise ValueError(f"非法 post id:{pid}")
  65. idx = int(pid[1:])
  66. post = posts_by_index.get(idx)
  67. if post is None:
  68. raise ValueError(f"在 result.json 中未找到 index={idx}({pid})")
  69. blocks.append(build_post_block(post, pid))
  70. for p in post.get("images") or []:
  71. image_paths.append(str(IMAGES_DIR / Path(p).name))
  72. posts_section = "\n\n".join(blocks)
  73. if PROMPT_PLACEHOLDER_BLOCK not in template:
  74. raise RuntimeError(
  75. "batch.md 模板里找不到预期的占位段落([POST id=p1] ... (按需续)),"
  76. "请检查模板是否被修改。"
  77. )
  78. prompt = template.replace(PROMPT_PLACEHOLDER_BLOCK, posts_section)
  79. if image_paths:
  80. prompt += "\n\n# 配套参考图片(请先逐张用 Read 工具查看,再综合判断)\n"
  81. prompt += "图片是各帖的成品 / 步骤示意 / 排版示意。请把视觉特征也纳入 capability/strategy 的判断。\n\n"
  82. # 按 post 分组列出,便于模型对应
  83. for pid in member_ids:
  84. idx = int(pid[1:])
  85. post = posts_by_index[idx]
  86. imgs = post.get("images") or []
  87. if not imgs:
  88. continue
  89. prompt += f"[{pid}] 图片:\n"
  90. for ip in imgs:
  91. prompt += f"- {IMAGES_DIR / Path(ip).name}\n"
  92. prompt += "\n"
  93. prompt += (
  94. "# 输出约束\n"
  95. "- 只输出最终的严格 JSON,不要包裹 ```json 代码块。\n"
  96. "- 不要任何前言、寒暄、解释、Markdown 标题。\n"
  97. "- 字段以提示词定义为准。\n"
  98. "- source_post_ids 只用本次输入里出现过的 id(如 p1/p2/...),不要编造。\n"
  99. "- **JSON 字符串值内严禁使用 ASCII 双引号 \" 强调或引述**,需要引号一律用中文角引号 「」。"
  100. "如必须出现 ASCII 双引号则务必转义为 \\\"。\n"
  101. )
  102. return prompt, image_paths
  103. def stream_claude(prompt: str, log_prefix: str) -> tuple[str | None, dict]:
  104. """调 claude -p,stream-json 实时打印中间过程,返回 (final_text, stats)。"""
  105. cmd = [
  106. "claude",
  107. "-p",
  108. "--model", MODEL,
  109. "--effort", EFFORT,
  110. "--output-format", "stream-json",
  111. "--verbose",
  112. "--add-dir", str(IMAGES_DIR),
  113. "--tools", "Read",
  114. "--permission-mode", "bypassPermissions",
  115. "--no-session-persistence",
  116. prompt,
  117. ]
  118. print(f"{log_prefix} ▶ launching claude (model={MODEL}, effort={EFFORT})", flush=True)
  119. proc = subprocess.Popen(
  120. cmd,
  121. stdout=subprocess.PIPE,
  122. stderr=subprocess.PIPE,
  123. text=True,
  124. bufsize=1,
  125. )
  126. final_text: str | None = None
  127. last_assistant_text = ""
  128. stats: dict = {"tool_calls": [], "thinking_chunks": 0, "raw_events": 0}
  129. assert proc.stdout is not None
  130. for line in proc.stdout:
  131. line = line.rstrip("\n")
  132. if not line.strip():
  133. continue
  134. stats["raw_events"] += 1
  135. try:
  136. ev = json.loads(line)
  137. except json.JSONDecodeError:
  138. print(f"{log_prefix} ! non-json line: {line[:200]}", flush=True)
  139. continue
  140. et = ev.get("type")
  141. if et == "system":
  142. sub = ev.get("subtype")
  143. if sub == "init":
  144. print(f"{log_prefix} · session={ev.get('session_id','?')[:8]}", flush=True)
  145. elif et == "assistant":
  146. msg = ev.get("message", {}) or {}
  147. for block in msg.get("content", []) or []:
  148. btype = block.get("type")
  149. if btype == "thinking":
  150. stats["thinking_chunks"] += 1
  151. text = (block.get("thinking") or "").strip()
  152. if text:
  153. first = text.splitlines()[0][:140]
  154. print(f"{log_prefix} 🧠 {first}", flush=True)
  155. elif btype == "tool_use":
  156. name = block.get("name")
  157. inp = block.get("input") or {}
  158. desc = inp.get("file_path") or inp.get("path") or inp.get("command") or ""
  159. stats["tool_calls"].append({"name": name, "input": inp})
  160. print(f"{log_prefix} 🔧 {name}({str(desc)[:140]})", flush=True)
  161. elif btype == "text":
  162. txt = block.get("text") or ""
  163. if txt.strip():
  164. last_assistant_text = txt
  165. snippet = txt.strip().splitlines()[0][:140]
  166. print(f"{log_prefix} ✏️ {snippet}", flush=True)
  167. elif et == "user":
  168. msg = ev.get("message", {}) or {}
  169. for block in msg.get("content", []) or []:
  170. if block.get("type") == "tool_result":
  171. is_err = block.get("is_error")
  172. if is_err:
  173. c = block.get("content")
  174. text = c if isinstance(c, str) else json.dumps(c, ensure_ascii=False)
  175. print(f"{log_prefix} ⚠️ tool_result error: {text[:200]}", flush=True)
  176. elif et == "result":
  177. final_text = ev.get("result") or last_assistant_text or None
  178. stats["duration_ms"] = ev.get("duration_ms")
  179. stats["total_cost_usd"] = ev.get("total_cost_usd")
  180. stats["num_turns"] = ev.get("num_turns")
  181. stats["is_error"] = ev.get("is_error")
  182. stats["stop_reason"] = ev.get("stop_reason")
  183. rc = proc.wait()
  184. err = (proc.stderr.read() if proc.stderr else "") or ""
  185. if rc != 0:
  186. print(f"{log_prefix} ! claude exited rc={rc}, stderr={err[:400]}", flush=True)
  187. stats["exit_code"] = rc
  188. stats["stderr"] = err
  189. return None, stats
  190. if final_text is None:
  191. final_text = last_assistant_text or None
  192. return final_text, stats
  193. def _repair_inner_quotes(text: str) -> str:
  194. """把 JSON 字符串值里裸用 ASCII 双引号包裹的中文短语 "xxx" 改成 「xxx」。
  195. 模型偶尔会写 \"避免出现"塑料感"和"AI 感"\",导致 JSON 截断。
  196. 用启发式:一个 ASCII 双引号紧挨 CJK 字符开头、再有一个紧挨 CJK 或中文标点的 ASCII 双引号,
  197. 很可能是内嵌引述而不是结构边界。替换为 「」。
  198. """
  199. import re
  200. pattern = re.compile(
  201. r'(?<=[一-鿿,。、;:!?「」()"])"([^"\n]{1,60}?)"(?=[一-鿿,。、;:!?「」()"])'
  202. )
  203. prev = None
  204. cur = text
  205. # 反复跑直到稳定(一次替换可能让相邻 pattern 暴露出来)
  206. while prev != cur:
  207. prev = cur
  208. cur = pattern.sub(r"「\1」", cur)
  209. return cur
  210. def parse_extraction(text: str) -> dict | None:
  211. """模型应直接产出 JSON;做多层容错(去 ``` 包裹、找首末花括号、修复内嵌引号)。"""
  212. t = text.strip()
  213. if t.startswith("```"):
  214. t = t.strip("`")
  215. if t.lower().startswith("json"):
  216. t = t[4:]
  217. t = t.strip()
  218. if not t.startswith("{"):
  219. l, r = t.find("{"), t.rfind("}")
  220. if l != -1 and r != -1 and r > l:
  221. t = t[l : r + 1]
  222. try:
  223. return json.loads(t)
  224. except json.JSONDecodeError:
  225. pass
  226. # 兜底:尝试修复中文短语裸双引号
  227. repaired = _repair_inner_quotes(t)
  228. if repaired != t:
  229. try:
  230. return json.loads(repaired)
  231. except json.JSONDecodeError:
  232. return None
  233. return None
  234. def process_group(template: str, group: dict, posts_by_index: dict[int, dict], force: bool) -> bool:
  235. gid = group.get("group_id") or "g?"
  236. label = (group.get("method_label") or "").strip()
  237. members: list[str] = group.get("member_post_ids") or []
  238. log_prefix = f"[{gid:>4}]"
  239. out_path = OUT_DIR / f"{gid}.json"
  240. if out_path.exists() and not force:
  241. print(f"{log_prefix} ⏭ 已存在 {out_path.name},跳过(用 --force 重跑)", flush=True)
  242. return True
  243. print(f"\n{log_prefix} ── {label[:80]}", flush=True)
  244. print(f"{log_prefix} members={members}", flush=True)
  245. prompt, image_paths = render_prompt(template, group, posts_by_index)
  246. print(f"{log_prefix} posts={len(members)} images={len(image_paths)} prompt_chars={len(prompt)}", flush=True)
  247. t0 = time.time()
  248. final_text, stats = stream_claude(prompt, log_prefix)
  249. dt = time.time() - t0
  250. if final_text is None:
  251. print(f"{log_prefix} ✗ 调用失败", flush=True)
  252. return False
  253. parsed = parse_extraction(final_text)
  254. out_payload: dict = {
  255. "group_id": gid,
  256. "method_label": label,
  257. "member_post_ids": members,
  258. "elapsed_sec": round(dt, 1),
  259. "stats": stats,
  260. "extraction": parsed,
  261. "extraction_raw": None if parsed is not None else final_text,
  262. }
  263. OUT_DIR.mkdir(parents=True, exist_ok=True)
  264. out_path.write_text(json.dumps(out_payload, ensure_ascii=False, indent=2), encoding="utf-8")
  265. if parsed is None:
  266. print(f"{log_prefix} ⚠️ JSON 解析失败,原文已保存到 {out_path}", flush=True)
  267. return False
  268. strat_n = len(parsed.get("strategies") or [])
  269. cap_n = len(parsed.get("capabilities") or [])
  270. skipped_n = len(parsed.get("skipped_posts") or [])
  271. cost = stats.get("total_cost_usd")
  272. cost_s = f"${cost:.3f}" if isinstance(cost, (int, float)) else "?"
  273. print(
  274. f"{log_prefix} ✓ strategies={strat_n} capabilities={cap_n} skipped_posts={skipped_n} "
  275. f"time={dt:.0f}s cost={cost_s}",
  276. flush=True,
  277. )
  278. return True
  279. def main() -> int:
  280. ap = argparse.ArgumentParser()
  281. ap.add_argument("--only", nargs="+", help="只处理这些 group_id(如 g1 g6)")
  282. ap.add_argument("--limit", type=int, help="最多处理 N 组")
  283. ap.add_argument("--force", action="store_true", help="已有输出也重跑")
  284. ap.add_argument("--dry-run", action="store_true", help="只打印将要处理的 group")
  285. args = ap.parse_args()
  286. for p in (PROMPT_PATH, RESULT_JSON, GROUPS_JSON):
  287. if not p.exists():
  288. print(f"文件不存在:{p}", file=sys.stderr)
  289. return 2
  290. template = PROMPT_PATH.read_text(encoding="utf-8")
  291. posts = json.loads(RESULT_JSON.read_text(encoding="utf-8"))
  292. posts_by_index: dict[int, dict] = {p["index"]: p for p in posts if "index" in p}
  293. groups_doc = json.loads(GROUPS_JSON.read_text(encoding="utf-8"))
  294. groups: list[dict] = groups_doc.get("groups") or []
  295. selected = groups
  296. if args.only:
  297. wanted = set(args.only)
  298. selected = [g for g in selected if g.get("group_id") in wanted]
  299. if args.limit:
  300. selected = selected[: args.limit]
  301. print(f"将处理 {len(selected)} 组(共 {len(groups)} 组)→ 输出目录 {OUT_DIR}")
  302. OUT_DIR.mkdir(parents=True, exist_ok=True)
  303. if args.dry_run:
  304. for g in selected:
  305. label = (g.get("method_label") or "")[:60]
  306. print(f" - {g.get('group_id')} members={g.get('member_post_ids')} {label}")
  307. return 0
  308. ok, fail = 0, 0
  309. t_start = time.time()
  310. for g in selected:
  311. try:
  312. if process_group(template, g, posts_by_index, force=args.force):
  313. ok += 1
  314. else:
  315. fail += 1
  316. except KeyboardInterrupt:
  317. print("\n中断", flush=True)
  318. break
  319. except Exception as e:
  320. fail += 1
  321. print(f"[{g.get('group_id')}] ✗ 异常:{e!r}", flush=True)
  322. elapsed = time.time() - t_start
  323. print(f"\n完成:成功 {ok},失败 {fail},总耗时 {elapsed:.0f}s")
  324. return 0 if fail == 0 else 1
  325. if __name__ == "__main__":
  326. sys.exit(main())