| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377 |
- #!/usr/bin/env python3
- """
- 按 method-groups.json 的分组,把每组的所有成员帖(含图片)打包喂给 LLM,
- 按 docs/prompts/batch.md 模板做跨帖批次提炼,每组产出一份 JSON。
- - model: claude-opus-4-7
- - effort: xhigh(开启扩展思考)
- - 模型可用 Read 工具读取每个帖子配套的图片
- - stream-json 实时输出,便于看到中间过程
- - 已产出的 group 默认跳过,--force 可重跑
- """
- from __future__ import annotations
- import argparse
- import json
- import subprocess
- import sys
- import time
- from pathlib import Path
- ROOT = Path(__file__).resolve().parent
- PROMPT_PATH = Path("/Users/sunlit/Code/Agent/knowhub/docs/prompts/batch.md")
- RESULT_JSON = ROOT / "result.json"
- GROUPS_JSON = ROOT / "method-groups.json"
- IMAGES_DIR = ROOT / "images"
- OUT_DIR = ROOT / "batch_extracted"
- MODEL = "claude-opus-4-7"
- EFFORT = "xhigh"
- # batch.md 里要被整体替换掉的占位段落(包含示例的 3 个 post + 「(按需续)」)
- PROMPT_PLACEHOLDER_BLOCK = """[POST id=p1]
- {post_1_content}
- [POST id=p2]
- {post_2_content}
- [POST id=p3]
- {post_3_content}
- (按需续)"""
- def build_post_block(post: dict, post_id: str) -> str:
- """把一条帖子组装成喂给 LLM 的 [POST id=pN] 区块。"""
- lines: list[str] = []
- lines.append(f"[POST id={post_id}]")
- lines.append(f"标题:{post.get('title') or ''}")
- if post.get("author"):
- lines.append(f"作者:{post['author']}")
- if post.get("category"):
- lines.append(f"分类:{post['category']}")
- if post.get("description"):
- lines.append(f"摘要:{post['description']}")
- if post.get("method"):
- lines.append(f"方法概述:{post['method']}")
- if post.get("url"):
- lines.append(f"URL:{post['url']}")
- body = (post.get("body") or "").strip()
- lines.append("")
- lines.append("正文:")
- lines.append(body if body else "(无正文)")
- return "\n".join(lines)
- def render_prompt(template: str, group: dict, posts_by_index: dict[int, dict]) -> tuple[str, list[str]]:
- """渲染 batch prompt:替换示例 post 块,并附加图片读取指引。
- 返回 (prompt_text, image_paths)
- """
- member_ids: list[str] = group.get("member_post_ids") or []
- blocks: list[str] = []
- image_paths: list[str] = []
- for pid in member_ids:
- # "p1" -> index 1
- if not (pid.startswith("p") and pid[1:].isdigit()):
- raise ValueError(f"非法 post id:{pid}")
- idx = int(pid[1:])
- post = posts_by_index.get(idx)
- if post is None:
- raise ValueError(f"在 result.json 中未找到 index={idx}({pid})")
- blocks.append(build_post_block(post, pid))
- for p in post.get("images") or []:
- image_paths.append(str(IMAGES_DIR / Path(p).name))
- posts_section = "\n\n".join(blocks)
- if PROMPT_PLACEHOLDER_BLOCK not in template:
- raise RuntimeError(
- "batch.md 模板里找不到预期的占位段落([POST id=p1] ... (按需续)),"
- "请检查模板是否被修改。"
- )
- prompt = template.replace(PROMPT_PLACEHOLDER_BLOCK, posts_section)
- if image_paths:
- prompt += "\n\n# 配套参考图片(请先逐张用 Read 工具查看,再综合判断)\n"
- prompt += "图片是各帖的成品 / 步骤示意 / 排版示意。请把视觉特征也纳入 capability/strategy 的判断。\n\n"
- # 按 post 分组列出,便于模型对应
- for pid in member_ids:
- idx = int(pid[1:])
- post = posts_by_index[idx]
- imgs = post.get("images") or []
- if not imgs:
- continue
- prompt += f"[{pid}] 图片:\n"
- for ip in imgs:
- prompt += f"- {IMAGES_DIR / Path(ip).name}\n"
- prompt += "\n"
- prompt += (
- "# 输出约束\n"
- "- 只输出最终的严格 JSON,不要包裹 ```json 代码块。\n"
- "- 不要任何前言、寒暄、解释、Markdown 标题。\n"
- "- 字段以提示词定义为准。\n"
- "- source_post_ids 只用本次输入里出现过的 id(如 p1/p2/...),不要编造。\n"
- "- **JSON 字符串值内严禁使用 ASCII 双引号 \" 强调或引述**,需要引号一律用中文角引号 「」。"
- "如必须出现 ASCII 双引号则务必转义为 \\\"。\n"
- )
- return prompt, image_paths
- def stream_claude(prompt: str, log_prefix: str) -> tuple[str | None, dict]:
- """调 claude -p,stream-json 实时打印中间过程,返回 (final_text, stats)。"""
- cmd = [
- "claude",
- "-p",
- "--model", MODEL,
- "--effort", EFFORT,
- "--output-format", "stream-json",
- "--verbose",
- "--add-dir", str(IMAGES_DIR),
- "--tools", "Read",
- "--permission-mode", "bypassPermissions",
- "--no-session-persistence",
- prompt,
- ]
- print(f"{log_prefix} ▶ launching claude (model={MODEL}, effort={EFFORT})", flush=True)
- proc = subprocess.Popen(
- cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- text=True,
- bufsize=1,
- )
- final_text: str | None = None
- last_assistant_text = ""
- stats: dict = {"tool_calls": [], "thinking_chunks": 0, "raw_events": 0}
- assert proc.stdout is not None
- for line in proc.stdout:
- line = line.rstrip("\n")
- if not line.strip():
- continue
- stats["raw_events"] += 1
- try:
- ev = json.loads(line)
- except json.JSONDecodeError:
- print(f"{log_prefix} ! non-json line: {line[:200]}", flush=True)
- continue
- et = ev.get("type")
- if et == "system":
- sub = ev.get("subtype")
- if sub == "init":
- print(f"{log_prefix} · session={ev.get('session_id','?')[:8]}", flush=True)
- elif et == "assistant":
- msg = ev.get("message", {}) or {}
- for block in msg.get("content", []) or []:
- btype = block.get("type")
- if btype == "thinking":
- stats["thinking_chunks"] += 1
- text = (block.get("thinking") or "").strip()
- if text:
- first = text.splitlines()[0][:140]
- print(f"{log_prefix} 🧠 {first}", flush=True)
- elif btype == "tool_use":
- name = block.get("name")
- inp = block.get("input") or {}
- desc = inp.get("file_path") or inp.get("path") or inp.get("command") or ""
- stats["tool_calls"].append({"name": name, "input": inp})
- print(f"{log_prefix} 🔧 {name}({str(desc)[:140]})", flush=True)
- elif btype == "text":
- txt = block.get("text") or ""
- if txt.strip():
- last_assistant_text = txt
- snippet = txt.strip().splitlines()[0][:140]
- print(f"{log_prefix} ✏️ {snippet}", flush=True)
- elif et == "user":
- msg = ev.get("message", {}) or {}
- for block in msg.get("content", []) or []:
- if block.get("type") == "tool_result":
- is_err = block.get("is_error")
- if is_err:
- c = block.get("content")
- text = c if isinstance(c, str) else json.dumps(c, ensure_ascii=False)
- print(f"{log_prefix} ⚠️ tool_result error: {text[:200]}", flush=True)
- elif et == "result":
- final_text = ev.get("result") or last_assistant_text or None
- stats["duration_ms"] = ev.get("duration_ms")
- stats["total_cost_usd"] = ev.get("total_cost_usd")
- stats["num_turns"] = ev.get("num_turns")
- stats["is_error"] = ev.get("is_error")
- stats["stop_reason"] = ev.get("stop_reason")
- rc = proc.wait()
- err = (proc.stderr.read() if proc.stderr else "") or ""
- if rc != 0:
- print(f"{log_prefix} ! claude exited rc={rc}, stderr={err[:400]}", flush=True)
- stats["exit_code"] = rc
- stats["stderr"] = err
- return None, stats
- if final_text is None:
- final_text = last_assistant_text or None
- return final_text, stats
- def _repair_inner_quotes(text: str) -> str:
- """把 JSON 字符串值里裸用 ASCII 双引号包裹的中文短语 "xxx" 改成 「xxx」。
- 模型偶尔会写 \"避免出现"塑料感"和"AI 感"\",导致 JSON 截断。
- 用启发式:一个 ASCII 双引号紧挨 CJK 字符开头、再有一个紧挨 CJK 或中文标点的 ASCII 双引号,
- 很可能是内嵌引述而不是结构边界。替换为 「」。
- """
- import re
- pattern = re.compile(
- r'(?<=[一-鿿,。、;:!?「」()"])"([^"\n]{1,60}?)"(?=[一-鿿,。、;:!?「」()"])'
- )
- prev = None
- cur = text
- # 反复跑直到稳定(一次替换可能让相邻 pattern 暴露出来)
- while prev != cur:
- prev = cur
- cur = pattern.sub(r"「\1」", cur)
- return cur
- def parse_extraction(text: str) -> dict | None:
- """模型应直接产出 JSON;做多层容错(去 ``` 包裹、找首末花括号、修复内嵌引号)。"""
- t = text.strip()
- if t.startswith("```"):
- t = t.strip("`")
- if t.lower().startswith("json"):
- t = t[4:]
- t = t.strip()
- if not t.startswith("{"):
- l, r = t.find("{"), t.rfind("}")
- if l != -1 and r != -1 and r > l:
- t = t[l : r + 1]
- try:
- return json.loads(t)
- except json.JSONDecodeError:
- pass
- # 兜底:尝试修复中文短语裸双引号
- repaired = _repair_inner_quotes(t)
- if repaired != t:
- try:
- return json.loads(repaired)
- except json.JSONDecodeError:
- return None
- return None
- def process_group(template: str, group: dict, posts_by_index: dict[int, dict], force: bool) -> bool:
- gid = group.get("group_id") or "g?"
- label = (group.get("method_label") or "").strip()
- members: list[str] = group.get("member_post_ids") or []
- log_prefix = f"[{gid:>4}]"
- out_path = OUT_DIR / f"{gid}.json"
- if out_path.exists() and not force:
- print(f"{log_prefix} ⏭ 已存在 {out_path.name},跳过(用 --force 重跑)", flush=True)
- return True
- print(f"\n{log_prefix} ── {label[:80]}", flush=True)
- print(f"{log_prefix} members={members}", flush=True)
- prompt, image_paths = render_prompt(template, group, posts_by_index)
- print(f"{log_prefix} posts={len(members)} images={len(image_paths)} prompt_chars={len(prompt)}", flush=True)
- t0 = time.time()
- final_text, stats = stream_claude(prompt, log_prefix)
- dt = time.time() - t0
- if final_text is None:
- print(f"{log_prefix} ✗ 调用失败", flush=True)
- return False
- parsed = parse_extraction(final_text)
- out_payload: dict = {
- "group_id": gid,
- "method_label": label,
- "member_post_ids": members,
- "elapsed_sec": round(dt, 1),
- "stats": stats,
- "extraction": parsed,
- "extraction_raw": None if parsed is not None else final_text,
- }
- OUT_DIR.mkdir(parents=True, exist_ok=True)
- out_path.write_text(json.dumps(out_payload, ensure_ascii=False, indent=2), encoding="utf-8")
- if parsed is None:
- print(f"{log_prefix} ⚠️ JSON 解析失败,原文已保存到 {out_path}", flush=True)
- return False
- strat_n = len(parsed.get("strategies") or [])
- cap_n = len(parsed.get("capabilities") or [])
- skipped_n = len(parsed.get("skipped_posts") or [])
- cost = stats.get("total_cost_usd")
- cost_s = f"${cost:.3f}" if isinstance(cost, (int, float)) else "?"
- print(
- f"{log_prefix} ✓ strategies={strat_n} capabilities={cap_n} skipped_posts={skipped_n} "
- f"time={dt:.0f}s cost={cost_s}",
- flush=True,
- )
- return True
- def main() -> int:
- ap = argparse.ArgumentParser()
- ap.add_argument("--only", nargs="+", help="只处理这些 group_id(如 g1 g6)")
- ap.add_argument("--limit", type=int, help="最多处理 N 组")
- ap.add_argument("--force", action="store_true", help="已有输出也重跑")
- ap.add_argument("--dry-run", action="store_true", help="只打印将要处理的 group")
- args = ap.parse_args()
- for p in (PROMPT_PATH, RESULT_JSON, GROUPS_JSON):
- if not p.exists():
- print(f"文件不存在:{p}", file=sys.stderr)
- return 2
- template = PROMPT_PATH.read_text(encoding="utf-8")
- posts = json.loads(RESULT_JSON.read_text(encoding="utf-8"))
- posts_by_index: dict[int, dict] = {p["index"]: p for p in posts if "index" in p}
- groups_doc = json.loads(GROUPS_JSON.read_text(encoding="utf-8"))
- groups: list[dict] = groups_doc.get("groups") or []
- selected = groups
- if args.only:
- wanted = set(args.only)
- selected = [g for g in selected if g.get("group_id") in wanted]
- if args.limit:
- selected = selected[: args.limit]
- print(f"将处理 {len(selected)} 组(共 {len(groups)} 组)→ 输出目录 {OUT_DIR}")
- OUT_DIR.mkdir(parents=True, exist_ok=True)
- if args.dry_run:
- for g in selected:
- label = (g.get("method_label") or "")[:60]
- print(f" - {g.get('group_id')} members={g.get('member_post_ids')} {label}")
- return 0
- ok, fail = 0, 0
- t_start = time.time()
- for g in selected:
- try:
- if process_group(template, g, posts_by_index, force=args.force):
- ok += 1
- else:
- fail += 1
- except KeyboardInterrupt:
- print("\n中断", flush=True)
- break
- except Exception as e:
- fail += 1
- print(f"[{g.get('group_id')}] ✗ 异常:{e!r}", flush=True)
- elapsed = time.time() - t_start
- print(f"\n完成:成功 {ok},失败 {fail},总耗时 {elapsed:.0f}s")
- return 0 if fail == 0 else 1
- if __name__ == "__main__":
- sys.exit(main())
|