#!/usr/bin/env python3 """ 按 method-groups.json 的分组,把每组的所有成员帖(含图片)打包喂给 LLM, 按 docs/prompts/batch.md 模板做跨帖批次提炼,每组产出一份 JSON。 - model: claude-opus-4-7 - effort: xhigh(开启扩展思考) - 模型可用 Read 工具读取每个帖子配套的图片 - stream-json 实时输出,便于看到中间过程 - 已产出的 group 默认跳过,--force 可重跑 """ from __future__ import annotations import argparse import json import subprocess import sys import time from pathlib import Path ROOT = Path(__file__).resolve().parent PROMPT_PATH = Path("/Users/sunlit/Code/Agent/knowhub/docs/prompts/batch.md") RESULT_JSON = ROOT / "result.json" GROUPS_JSON = ROOT / "method-groups.json" IMAGES_DIR = ROOT / "images" OUT_DIR = ROOT / "batch_extracted" MODEL = "claude-opus-4-7" EFFORT = "xhigh" # batch.md 里要被整体替换掉的占位段落(包含示例的 3 个 post + 「(按需续)」) PROMPT_PLACEHOLDER_BLOCK = """[POST id=p1] {post_1_content} [POST id=p2] {post_2_content} [POST id=p3] {post_3_content} (按需续)""" def build_post_block(post: dict, post_id: str) -> str: """把一条帖子组装成喂给 LLM 的 [POST id=pN] 区块。""" lines: list[str] = [] lines.append(f"[POST id={post_id}]") lines.append(f"标题:{post.get('title') or ''}") if post.get("author"): lines.append(f"作者:{post['author']}") if post.get("category"): lines.append(f"分类:{post['category']}") if post.get("description"): lines.append(f"摘要:{post['description']}") if post.get("method"): lines.append(f"方法概述:{post['method']}") if post.get("url"): lines.append(f"URL:{post['url']}") body = (post.get("body") or "").strip() lines.append("") lines.append("正文:") lines.append(body if body else "(无正文)") return "\n".join(lines) def render_prompt(template: str, group: dict, posts_by_index: dict[int, dict]) -> tuple[str, list[str]]: """渲染 batch prompt:替换示例 post 块,并附加图片读取指引。 返回 (prompt_text, image_paths) """ member_ids: list[str] = group.get("member_post_ids") or [] blocks: list[str] = [] image_paths: list[str] = [] for pid in member_ids: # "p1" -> index 1 if not (pid.startswith("p") and pid[1:].isdigit()): raise ValueError(f"非法 post id:{pid}") idx = int(pid[1:]) post = posts_by_index.get(idx) if post is None: raise ValueError(f"在 result.json 中未找到 index={idx}({pid})") blocks.append(build_post_block(post, pid)) for p in post.get("images") or []: image_paths.append(str(IMAGES_DIR / Path(p).name)) posts_section = "\n\n".join(blocks) if PROMPT_PLACEHOLDER_BLOCK not in template: raise RuntimeError( "batch.md 模板里找不到预期的占位段落([POST id=p1] ... (按需续))," "请检查模板是否被修改。" ) prompt = template.replace(PROMPT_PLACEHOLDER_BLOCK, posts_section) if image_paths: prompt += "\n\n# 配套参考图片(请先逐张用 Read 工具查看,再综合判断)\n" prompt += "图片是各帖的成品 / 步骤示意 / 排版示意。请把视觉特征也纳入 capability/strategy 的判断。\n\n" # 按 post 分组列出,便于模型对应 for pid in member_ids: idx = int(pid[1:]) post = posts_by_index[idx] imgs = post.get("images") or [] if not imgs: continue prompt += f"[{pid}] 图片:\n" for ip in imgs: prompt += f"- {IMAGES_DIR / Path(ip).name}\n" prompt += "\n" prompt += ( "# 输出约束\n" "- 只输出最终的严格 JSON,不要包裹 ```json 代码块。\n" "- 不要任何前言、寒暄、解释、Markdown 标题。\n" "- 字段以提示词定义为准。\n" "- source_post_ids 只用本次输入里出现过的 id(如 p1/p2/...),不要编造。\n" "- **JSON 字符串值内严禁使用 ASCII 双引号 \" 强调或引述**,需要引号一律用中文角引号 「」。" "如必须出现 ASCII 双引号则务必转义为 \\\"。\n" ) return prompt, image_paths def stream_claude(prompt: str, log_prefix: str) -> tuple[str | None, dict]: """调 claude -p,stream-json 实时打印中间过程,返回 (final_text, stats)。""" cmd = [ "claude", "-p", "--model", MODEL, "--effort", EFFORT, "--output-format", "stream-json", "--verbose", "--add-dir", str(IMAGES_DIR), "--tools", "Read", "--permission-mode", "bypassPermissions", "--no-session-persistence", prompt, ] print(f"{log_prefix} ▶ launching claude (model={MODEL}, effort={EFFORT})", flush=True) proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, ) final_text: str | None = None last_assistant_text = "" stats: dict = {"tool_calls": [], "thinking_chunks": 0, "raw_events": 0} assert proc.stdout is not None for line in proc.stdout: line = line.rstrip("\n") if not line.strip(): continue stats["raw_events"] += 1 try: ev = json.loads(line) except json.JSONDecodeError: print(f"{log_prefix} ! non-json line: {line[:200]}", flush=True) continue et = ev.get("type") if et == "system": sub = ev.get("subtype") if sub == "init": print(f"{log_prefix} · session={ev.get('session_id','?')[:8]}", flush=True) elif et == "assistant": msg = ev.get("message", {}) or {} for block in msg.get("content", []) or []: btype = block.get("type") if btype == "thinking": stats["thinking_chunks"] += 1 text = (block.get("thinking") or "").strip() if text: first = text.splitlines()[0][:140] print(f"{log_prefix} 🧠 {first}", flush=True) elif btype == "tool_use": name = block.get("name") inp = block.get("input") or {} desc = inp.get("file_path") or inp.get("path") or inp.get("command") or "" stats["tool_calls"].append({"name": name, "input": inp}) print(f"{log_prefix} 🔧 {name}({str(desc)[:140]})", flush=True) elif btype == "text": txt = block.get("text") or "" if txt.strip(): last_assistant_text = txt snippet = txt.strip().splitlines()[0][:140] print(f"{log_prefix} ✏️ {snippet}", flush=True) elif et == "user": msg = ev.get("message", {}) or {} for block in msg.get("content", []) or []: if block.get("type") == "tool_result": is_err = block.get("is_error") if is_err: c = block.get("content") text = c if isinstance(c, str) else json.dumps(c, ensure_ascii=False) print(f"{log_prefix} ⚠️ tool_result error: {text[:200]}", flush=True) elif et == "result": final_text = ev.get("result") or last_assistant_text or None stats["duration_ms"] = ev.get("duration_ms") stats["total_cost_usd"] = ev.get("total_cost_usd") stats["num_turns"] = ev.get("num_turns") stats["is_error"] = ev.get("is_error") stats["stop_reason"] = ev.get("stop_reason") rc = proc.wait() err = (proc.stderr.read() if proc.stderr else "") or "" if rc != 0: print(f"{log_prefix} ! claude exited rc={rc}, stderr={err[:400]}", flush=True) stats["exit_code"] = rc stats["stderr"] = err return None, stats if final_text is None: final_text = last_assistant_text or None return final_text, stats def _repair_inner_quotes(text: str) -> str: """把 JSON 字符串值里裸用 ASCII 双引号包裹的中文短语 "xxx" 改成 「xxx」。 模型偶尔会写 \"避免出现"塑料感"和"AI 感"\",导致 JSON 截断。 用启发式:一个 ASCII 双引号紧挨 CJK 字符开头、再有一个紧挨 CJK 或中文标点的 ASCII 双引号, 很可能是内嵌引述而不是结构边界。替换为 「」。 """ import re pattern = re.compile( r'(?<=[一-鿿,。、;:!?「」()"])"([^"\n]{1,60}?)"(?=[一-鿿,。、;:!?「」()"])' ) prev = None cur = text # 反复跑直到稳定(一次替换可能让相邻 pattern 暴露出来) while prev != cur: prev = cur cur = pattern.sub(r"「\1」", cur) return cur def parse_extraction(text: str) -> dict | None: """模型应直接产出 JSON;做多层容错(去 ``` 包裹、找首末花括号、修复内嵌引号)。""" t = text.strip() if t.startswith("```"): t = t.strip("`") if t.lower().startswith("json"): t = t[4:] t = t.strip() if not t.startswith("{"): l, r = t.find("{"), t.rfind("}") if l != -1 and r != -1 and r > l: t = t[l : r + 1] try: return json.loads(t) except json.JSONDecodeError: pass # 兜底:尝试修复中文短语裸双引号 repaired = _repair_inner_quotes(t) if repaired != t: try: return json.loads(repaired) except json.JSONDecodeError: return None return None def process_group(template: str, group: dict, posts_by_index: dict[int, dict], force: bool) -> bool: gid = group.get("group_id") or "g?" label = (group.get("method_label") or "").strip() members: list[str] = group.get("member_post_ids") or [] log_prefix = f"[{gid:>4}]" out_path = OUT_DIR / f"{gid}.json" if out_path.exists() and not force: print(f"{log_prefix} ⏭ 已存在 {out_path.name},跳过(用 --force 重跑)", flush=True) return True print(f"\n{log_prefix} ── {label[:80]}", flush=True) print(f"{log_prefix} members={members}", flush=True) prompt, image_paths = render_prompt(template, group, posts_by_index) print(f"{log_prefix} posts={len(members)} images={len(image_paths)} prompt_chars={len(prompt)}", flush=True) t0 = time.time() final_text, stats = stream_claude(prompt, log_prefix) dt = time.time() - t0 if final_text is None: print(f"{log_prefix} ✗ 调用失败", flush=True) return False parsed = parse_extraction(final_text) out_payload: dict = { "group_id": gid, "method_label": label, "member_post_ids": members, "elapsed_sec": round(dt, 1), "stats": stats, "extraction": parsed, "extraction_raw": None if parsed is not None else final_text, } OUT_DIR.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(out_payload, ensure_ascii=False, indent=2), encoding="utf-8") if parsed is None: print(f"{log_prefix} ⚠️ JSON 解析失败,原文已保存到 {out_path}", flush=True) return False strat_n = len(parsed.get("strategies") or []) cap_n = len(parsed.get("capabilities") or []) skipped_n = len(parsed.get("skipped_posts") or []) cost = stats.get("total_cost_usd") cost_s = f"${cost:.3f}" if isinstance(cost, (int, float)) else "?" print( f"{log_prefix} ✓ strategies={strat_n} capabilities={cap_n} skipped_posts={skipped_n} " f"time={dt:.0f}s cost={cost_s}", flush=True, ) return True def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--only", nargs="+", help="只处理这些 group_id(如 g1 g6)") ap.add_argument("--limit", type=int, help="最多处理 N 组") ap.add_argument("--force", action="store_true", help="已有输出也重跑") ap.add_argument("--dry-run", action="store_true", help="只打印将要处理的 group") args = ap.parse_args() for p in (PROMPT_PATH, RESULT_JSON, GROUPS_JSON): if not p.exists(): print(f"文件不存在:{p}", file=sys.stderr) return 2 template = PROMPT_PATH.read_text(encoding="utf-8") posts = json.loads(RESULT_JSON.read_text(encoding="utf-8")) posts_by_index: dict[int, dict] = {p["index"]: p for p in posts if "index" in p} groups_doc = json.loads(GROUPS_JSON.read_text(encoding="utf-8")) groups: list[dict] = groups_doc.get("groups") or [] selected = groups if args.only: wanted = set(args.only) selected = [g for g in selected if g.get("group_id") in wanted] if args.limit: selected = selected[: args.limit] print(f"将处理 {len(selected)} 组(共 {len(groups)} 组)→ 输出目录 {OUT_DIR}") OUT_DIR.mkdir(parents=True, exist_ok=True) if args.dry_run: for g in selected: label = (g.get("method_label") or "")[:60] print(f" - {g.get('group_id')} members={g.get('member_post_ids')} {label}") return 0 ok, fail = 0, 0 t_start = time.time() for g in selected: try: if process_group(template, g, posts_by_index, force=args.force): ok += 1 else: fail += 1 except KeyboardInterrupt: print("\n中断", flush=True) break except Exception as e: fail += 1 print(f"[{g.get('group_id')}] ✗ 异常:{e!r}", flush=True) elapsed = time.time() - t_start print(f"\n完成:成功 {ok},失败 {fail},总耗时 {elapsed:.0f}s") return 0 if fail == 0 else 1 if __name__ == "__main__": sys.exit(main())