| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464 |
- """
- 三形式 query 批量搜索 + 多模态评估
- 针对 high_priority_queries.json 的前 N 条高优 query,每条 query 用三种形式搜索 + 评估:
- 形式 A(原词组合):直接用 item["q"],如 "反推 提示词 教程"
- 形式 B(句子填充):gemini flash 把词组改写成自然搜索短句,**禁止注入具体工具/品牌/示例**
- 形式 C(同义替换):按 synonym_pools 对 动作/类型/知识词 各取同义词重组
- 输出(按 query 分文件夹):
- output_dir/
- q00/ form_A.json form_B.json form_C.json
- q01/ ...
- ...
- forms_preview.json # 三形式 query 预览
- summary.json # 三形式对比汇总
- 每个 form_X.json = {query 词} ↔ {帖子源信息 + 评估结果}(一对多)。
- 搜索 / 评估 / 多模态图片逻辑复用 script/search_and_evaluate.py。
- 用法:
- python batch_3forms.py --count 10 --platforms xhs,gzh,zhihu --max-count 10 \
- --output-dir runs/3forms_001
- """
- import argparse
- import asyncio
- import json
- import random
- import sys
- from pathlib import Path
- from typing import Any, Callable, Dict, List, Optional, Tuple
- _PROJECT_ROOT = Path(__file__).resolve().parents[4]
- if str(_PROJECT_ROOT) not in sys.path:
- sys.path.insert(0, str(_PROJECT_ROOT))
- from examples.process_pipeline.script.llm_helper import call_llm_with_retry
- from examples.process_pipeline.script.search_and_evaluate import (
- search_all, evaluate_posts, transcribe_video_posts, build_query_overrides,
- )
- from examples.process_pipeline.script.llm_evaluate_sources import build_eval_llm_call
- _EVAL_DIR = _PROJECT_ROOT / "examples" / "process_pipeline" / "test_script" / "evaluation"
- _HIGH_PRIORITY = Path(__file__).resolve().parent / "high_priority_queries.json"
- _SYNONYM_POOLS = _EVAL_DIR / "synonym_pools.json"
- # ── 形式 A:原词组合 ─────────────────────────────────────────────────────────────
- def form_a(items: List[Dict[str, Any]]) -> List[str]:
- return [it["q"] for it in items]
- # ── 形式 B:gemini 句子化(禁止注入示例)─────────────────────────────────────────
- def _validate_sentences(data: Dict[str, Any], n: int) -> Optional[str]:
- qs = data.get("sentences")
- if not isinstance(qs, list):
- return "sentences 必须是数组"
- if len(qs) != n:
- return f"sentences 长度应为 {n},得到 {len(qs)}"
- if not all(isinstance(x, str) and x.strip() for x in qs):
- return "sentences 每项必须是非空字符串"
- return None
- async def form_b(items: List[Dict[str, Any]], llm_call: Callable, model: str) -> Tuple[List[str], float]:
- """把每条词组改写成自然搜索短句(一次批量调用,按序对齐)。"""
- words = [it["q"] for it in items]
- system = (
- "你是中文搜索词改写器。把每个『关键词组』改写成一句自然、口语、适合在内容平台"
- "搜索框输入的短句。严格要求:只表达词组本身的意图,"
- "**绝不添加任何具体工具名 / 品牌 / 产品 / 模型名 / 风格名 / 数字示例**"
- "(如 Midjourney、赛博朋克、SD 等都禁止出现)。只输出 JSON。"
- )
- user = (
- "把下面每个词组改写成一句自然搜索短句,顺序一一对应,输出:\n"
- '{"sentences": ["短句1", "短句2", ...]}\n\n'
- f"词组列表(共 {len(words)} 个):\n{json.dumps(words, ensure_ascii=False, indent=2)}"
- )
- data, cost = await call_llm_with_retry(
- llm_call=llm_call, messages=[{"role": "system", "content": system},
- {"role": "user", "content": user}],
- model=model, temperature=0.4, max_tokens=2000,
- validate_fn=lambda d: _validate_sentences(d, len(words)), task_name="FormB",
- )
- if not data:
- print(" ⚠️ form B 生成失败,回退用原词组")
- return list(words), cost
- return [s.strip() for s in data["sentences"]], cost
- # ── 形式 C:同义替换重组 ─────────────────────────────────────────────────────────
- class SynonymComposer:
- def __init__(self, pools: Dict[str, Any], rng: random.Random):
- self.action = pools.get("action_leaves", {})
- self.types = pools.get("types", {})
- self.knowledge = pools.get("knowledge", {})
- self.tool_type = pools.get("tool_type", {})
- self.rng = rng
- def _pick(self, pool: Any, fallback: str) -> str:
- pool = [x for x in pool if isinstance(x, str)] if isinstance(pool, list) else []
- return self.rng.choice(pool) if pool else fallback
- def compose(self, item: Dict[str, Any]) -> str:
- """按 synonym_pools._usage:[模态/工具前缀] 动作 类型 知识词。"""
- parts: List[str] = []
- c = item.get("constraint")
- if isinstance(c, dict):
- if c.get("kind") == "模态" and c.get("value"):
- parts.append(str(c["value"]))
- elif c.get("kind") == "工具类型":
- parts.append(self._pick(self.tool_type.get(c.get("value")), str(c.get("限定词") or "")))
- parts.append(self._pick(self.action.get(item.get("action", "")), item.get("action", "")))
- parts.append(self._pick(self.types.get(item.get("type", "")), item.get("type", "")))
- gx = self.knowledge.get("工序", {})
- parts.append(self._pick(gx.get("单步") if isinstance(gx, dict) else None, "教程"))
- return " ".join(p for p in parts if p)
- def form_c(items: List[Dict[str, Any]], seed: int) -> List[str]:
- pools = json.loads(_SYNONYM_POOLS.read_text(encoding="utf-8"))
- composer = SynonymComposer(pools, random.Random(seed))
- return [composer.compose(it) for it in items]
- # ── 单个 (query, form) 的搜索 + 评估 + 落盘 ──────────────────────────────────────
- async def run_one(
- qtext: str, form_key: str, original_q: str,
- args, eval_llm, eval_model_id, out_file: Path,
- query_overrides=None,
- ) -> Dict[str, Any]:
- platforms = [p.strip() for p in args.platforms.split(",") if p.strip()]
- sources = await search_all(platforms, [qtext], args.max_count, args.max_concurrent,
- query_overrides=query_overrides)
- try:
- from examples.process_pipeline.script.extract_sources import _convert_timestamps
- _convert_timestamps(sources)
- except Exception:
- pass
- # 视频帖转写:把字幕并入正文再评估(默认开)
- if not args.no_transcribe and sources:
- n = await transcribe_video_posts(sources, concurrency=args.max_concurrent)
- if n:
- print(f" 🎙️ 视频转写 {n} 条")
- cost = 0.0
- if not args.no_eval and sources:
- # 评估只看 query 词 + 帖子:把该形式的搜索词 qtext 作为检索锚点
- sources, cost = await evaluate_posts(
- sources, "", eval_llm, eval_model_id, args.max_concurrent,
- include_images=not args.no_images, max_images=args.max_images,
- image_mode=args.image_mode, query=qtext,
- )
- for s in sources:
- imgs = s.pop("_image_data_urls", None)
- if imgs is not None:
- s["images_sent"] = len(imgs)
- rep = sum(1 for s in sources
- if ((s.get("llm_evaluation") or {}).get("制作相关性") or {}).get("得分") in (2, 3, 2.0, 3.0, "2", "3"))
- dis = sum(1 for s in sources
- if ((s.get("llm_evaluation") or {}).get("制作相关性") or {}).get("得分") in (1, 1.0, "1"))
- failed = sum(1 for s in sources if (s.get("llm_evaluation") or {}).get("_error"))
- out_file.parent.mkdir(parents=True, exist_ok=True)
- out_file.write_text(json.dumps({
- "form": form_key,
- "query": qtext, # 该形式实际搜索用的词(也是评估的检索锚点)
- "original_q": original_q, # 原词组(形式 A 的基准)
- "platforms": platforms,
- "total": len(sources), "report": rep, "discard": dis, "failed": failed,
- "results": sources, # 帖子源信息 + llm_evaluation,一对多
- }, ensure_ascii=False, indent=2), encoding="utf-8")
- print(f" [{form_key}] {qtext!r} → total={len(sources)} report={rep} discard={dis} "
- f"failed={failed} cost=${cost:.4f}")
- return {"form": form_key, "total": len(sources), "report": rep,
- "discard": dis, "failed": failed, "cost": round(cost, 4)}
- async def reeval_existing(args, eval_llm, eval_model_id) -> None:
- """只重跑评估、覆盖旧评估,不重新搜索。
- 读 output_dir 下已有的 q*/form_*.json,对里面已抓到的 post 重新评估(评估锚点 = 文件里
- 记录的该形式 query 词),原地覆盖 llm_evaluation 后写回。适合改了评估 prompt / 模型后复评。
- 用 --start / --count 在 q 编号层(自然数序)切片限制范围,与主流程同语义;每个 q 文件夹下
- 的所有 form_A/B/C.json 一起复评(三种形式可比性)。
- """
- import re
- output_dir = Path(args.output_dir)
- # 按 q 编号自然数排序:避免 "q10" < "q2" 这种字符串误排(与 server.py _qnum 同口径)
- def _qnum(p):
- m = re.search(r"\d+", p.name)
- return (int(m.group()) if m else 0, p.name)
- q_dirs = sorted([d for d in output_dir.glob("q*") if d.is_dir()], key=_qnum)
- if not q_dirs:
- print(f"❌ {output_dir} 下没有 q*/ 子目录,无可复评内容"); return
- # --reeval-q 优先于 --start/--count:直接按 q 名过滤(接 "q01" 或 "q01,q05,q12" 多选)
- reeval_q = getattr(args, 'reeval_q', None)
- if reeval_q:
- wanted = {x.strip() for x in reeval_q.split(',') if x.strip()}
- sliced = [d for d in q_dirs if d.name in wanted]
- if not sliced:
- print(f"[X] 指定 q ({reeval_q}) 在 {output_dir} 下不存在"); return
- range_label = f"q={','.join(d.name for d in sliced)}"
- else:
- sliced = q_dirs[args.start : args.start + args.count]
- range_label = f"q[{args.start}:{args.start + args.count}]"
- files = [f for qd in sliced for f in sorted(qd.glob("form_*.json"))]
- if not files:
- print(f"❌ {output_dir} 切片 {range_label} 下没有 form_*.json"); return
- print(f"♻️ 复评模式:{range_label} → {len(sliced)} 个 query / "
- f"{len(files)} 个文件,模型 {eval_model_id}(不重新搜索)")
- for f in files:
- d = json.loads(f.read_text(encoding="utf-8"))
- results = d.get("results", [])
- if not results:
- print(f" - {f.relative_to(output_dir)}: 空,跳过"); continue
- # 清掉旧评估痕迹,重新评
- for s in results:
- s.pop("llm_evaluation", None)
- s.pop("images_sent", None)
- s.pop("_image_data_urls", None)
- qtext = d.get("query", "") # 该形式实际搜索词 = 评估检索锚点
- if not args.no_transcribe and results:
- await transcribe_video_posts(results, concurrency=args.max_concurrent)
- results, cost = await evaluate_posts(
- results, "", eval_llm, eval_model_id, args.max_concurrent,
- include_images=not args.no_images, max_images=args.max_images,
- image_mode=args.image_mode, query=qtext,
- )
- for s in results:
- imgs = s.pop("_image_data_urls", None)
- if imgs is not None:
- s["images_sent"] = len(imgs)
- rep = sum(1 for s in results
- if ((s.get("llm_evaluation") or {}).get("制作相关性") or {}).get("得分") in (2, 3, 2.0, 3.0, "2", "3"))
- dis = sum(1 for s in results
- if ((s.get("llm_evaluation") or {}).get("制作相关性") or {}).get("得分") in (1, 1.0, "1"))
- failed = sum(1 for s in results if (s.get("llm_evaluation") or {}).get("_error"))
- d.update({"results": results,
- "total": len(results), "report": rep, "discard": dis, "failed": failed})
- d.pop("requirement", None) # 不再用 requirement
- f.write_text(json.dumps(d, ensure_ascii=False, indent=2), encoding="utf-8")
- print(f" ✓ {f.relative_to(output_dir)}: total={len(results)} report={rep} "
- f"discard={dis} failed={failed} cost=${cost:.4f}")
- print("♻️ 复评完成(已覆盖原文件)")
- async def append_existing(args, eval_llm, eval_model_id, gen_llm, gen_model_id) -> None:
- """往已有 q*/form_*.json 追加新渠道结果,不重搜旧渠道。
- 用文件里存的 query 词、只搜 --platforms 指定的新渠道,评估后按 (平台, id) 去重合并进
- 原 results,旧渠道结果原样保留。适合先跑了中文渠道、再补 youtube/x 等。
- """
- from examples.process_pipeline.script.extract_sources import _convert_timestamps
- output_dir = Path(args.output_dir)
- files = sorted(output_dir.glob("q*/form_*.json"))
- if not files:
- print(f"❌ {output_dir} 下没有 q*/form_*.json,无可追加目标"); return
- new_plats = [p.strip() for p in args.platforms.split(",") if p.strip()]
- print(f"➕ 追加模式:{len(files)} 个文件追加渠道 {new_plats}(不重搜旧渠道)")
- # 英文平台一次性翻译所有 query
- queries = list(dict.fromkeys(json.loads(f.read_text(encoding="utf-8")).get("query", "") for f in files))
- overrides = await build_query_overrides(new_plats, queries, gen_llm, gen_model_id)
- for f in files:
- d = json.loads(f.read_text(encoding="utf-8"))
- qtext = d.get("query", "")
- existing = d.get("results", [])
- existing_keys = {(r.get("platform"), r.get("channel_content_id")) for r in existing}
- new_sources = await search_all(new_plats, [qtext], args.max_count, args.max_concurrent,
- query_overrides=overrides)
- new_sources = [s for s in new_sources
- if (s.get("platform"), s.get("channel_content_id")) not in existing_keys]
- try:
- _convert_timestamps(new_sources)
- except Exception:
- pass
- if not args.no_transcribe and new_sources:
- await transcribe_video_posts(new_sources, concurrency=args.max_concurrent)
- cost = 0.0
- if not args.no_eval and new_sources:
- new_sources, cost = await evaluate_posts(
- new_sources, "", eval_llm, eval_model_id, args.max_concurrent,
- include_images=not args.no_images, max_images=args.max_images,
- image_mode=args.image_mode, query=qtext,
- )
- for s in new_sources:
- imgs = s.pop("_image_data_urls", None)
- if imgs is not None:
- s["images_sent"] = len(imgs)
- merged = existing + new_sources
- plats_union = list(dict.fromkeys((d.get("platforms") or []) + new_plats))
- rep = sum(1 for s in merged
- if ((s.get("llm_evaluation") or {}).get("制作相关性") or {}).get("得分") in (2, 3, 2.0, 3.0, "2", "3"))
- dis = sum(1 for s in merged
- if ((s.get("llm_evaluation") or {}).get("制作相关性") or {}).get("得分") in (1, 1.0, "1"))
- failed = sum(1 for s in merged if (s.get("llm_evaluation") or {}).get("_error"))
- d.update({"platforms": plats_union, "results": merged,
- "total": len(merged), "report": rep, "discard": dis, "failed": failed})
- f.write_text(json.dumps(d, ensure_ascii=False, indent=2), encoding="utf-8")
- print(f" ✓ {f.relative_to(output_dir)}: +{len(new_sources)} 新帖 → total={len(merged)} "
- f"report={rep} discard={dis} failed={failed} cost=${cost:.4f}")
- print("➕ 追加完成(已并入原文件)")
- async def run(args) -> None:
- eval_llm0, eval_model0 = build_eval_llm_call(args.eval_model)
- if args.reeval:
- await reeval_existing(args, eval_llm0, eval_model0)
- return
- if args.append:
- gen_llm0, gen_model0 = build_eval_llm_call(args.gen_model)
- await append_existing(args, eval_llm0, eval_model0, gen_llm0, gen_model0)
- return
- queries_file = Path(args.queries_file) if getattr(args, "queries_file", None) else _HIGH_PRIORITY
- all_items = json.loads(queries_file.read_text(encoding="utf-8"))["queries"]
- print(f"📂 query 源: {queries_file.name} ({len(all_items)} 条)")
- only_q = getattr(args, "only_q", None)
- if only_q:
- # 支持 "1,5,51" 或 "q01,q05,q51";优先级高于 --start/--count
- import re as _re
- raw = [t.strip() for t in only_q.split(",") if t.strip()]
- idxs = []
- for t in raw:
- m = _re.match(r"q?(\d+)$", t)
- if not m:
- print(f"⚠️ 忽略无法解析的 q: {t!r}"); continue
- i = int(m.group(1))
- if 0 <= i < len(all_items):
- idxs.append(i)
- else:
- print(f"⚠️ idx {i} 超出范围 [0,{len(all_items)}),忽略")
- idxs = sorted(dict.fromkeys(idxs)) # 去重保序
- if not idxs:
- print("❌ --only-q 没有合法索引可用"); return
- items = [all_items[i] for i in idxs]
- print(f"📋 取 high_priority 指定 {len(idxs)} 条 query (idx={','.join(map(str, idxs))})"
- f" | 渠道 {args.platforms} | 每渠道≤{args.max_count} 帖")
- else:
- start = args.start
- items = all_items[start:start + args.count]
- idxs = list(range(start, start + len(items))) # 绝对下标,用于文件夹命名
- print(f"📋 取 high_priority 第 {start}~{start+len(items)-1} 条 query(共 {len(items)} 条)"
- f" | 渠道 {args.platforms} | 每渠道≤{args.max_count} 帖")
- eval_llm, eval_model_id = build_eval_llm_call(args.eval_model)
- gen_llm, gen_model_id = build_eval_llm_call(args.gen_model)
- print(f"🧠 评估模型 {args.eval_model}->{eval_model_id} | form B 生成 {args.gen_model}->{gen_model_id}")
- output_dir = Path(args.output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
- qa = form_a(items)
- qb, b_cost = await form_b(items, gen_llm, gen_model_id)
- qc = form_c(items, args.seed)
- # forms_preview 用绝对下标做 key,多次区间跑不会互相覆盖
- preview_path = output_dir / "forms_preview.json"
- preview = {}
- if preview_path.exists():
- try:
- loaded = json.loads(preview_path.read_text(encoding="utf-8"))
- if isinstance(loaded, dict):
- preview = loaded # 旧版本写成 list,非 dict 一律重置
- except Exception:
- preview = {}
- for j, absi in enumerate(idxs):
- preview[str(absi)] = {"idx": absi, "A": qa[j], "B": qb[j], "C": qc[j]}
- preview_path.write_text(json.dumps(preview, ensure_ascii=False, indent=2), encoding="utf-8")
- print("📝 三形式预览 → forms_preview.json")
- for j, absi in enumerate(idxs):
- print(f" [{absi}] A={qa[j]!r} B={qb[j]!r} C={qc[j]!r}")
- # 英文平台(youtube/x):对全部形式的 query 一次性翻成英文
- platforms = [p.strip() for p in args.platforms.split(",") if p.strip()]
- all_q = list(dict.fromkeys(qa + qb + qc))
- overrides = await build_query_overrides(platforms, all_q, gen_llm, gen_model_id)
- summary = []
- for j, absi in enumerate(idxs):
- qdir = output_dir / f"q{absi:02d}"
- print(f"\n▶ q{absi:02d} 原词={qa[j]!r}")
- per_form = {}
- for fk, qtext in (("A", qa[j]), ("B", qb[j]), ("C", qc[j])):
- stat = await run_one(qtext, fk, qa[j], args, eval_llm, eval_model_id,
- qdir / f"form_{fk}.json", query_overrides=overrides)
- per_form[fk] = stat
- summary.append({"idx": absi, "q": qa[j], "forms": per_form})
- (output_dir / "summary.json").write_text(json.dumps({
- "count": len(items), "platforms": args.platforms, "eval_model": eval_model_id,
- "form_b_gen_cost": round(b_cost, 4), "per_query": summary,
- }, ensure_ascii=False, indent=2), encoding="utf-8")
- # 形式聚合对比
- print(f"\n{'='*60}\n📊 三形式对比 (各形式 report/total 合计)")
- for fk in ("A", "B", "C"):
- tot = sum(s["forms"][fk]["total"] for s in summary)
- rep = sum(s["forms"][fk]["report"] for s in summary)
- dis = sum(s["forms"][fk]["discard"] for s in summary)
- print(f" 形式 {fk}: report={rep}/{tot} discard={dis}")
- print(f"→ {output_dir/'summary.json'}")
- def main() -> None:
- from dotenv import load_dotenv
- load_dotenv()
- from examples.process_pipeline.script.llm_evaluate_sources import EVAL_MODELS
- p = argparse.ArgumentParser(description="三形式 query 批量搜索 + 多模态评估")
- p.add_argument("--start", type=int, default=0, help="起始 query 下标(0-based,默认 0)")
- p.add_argument("--count", type=int, default=10, help="从 --start 起取几条 query(默认 10)")
- p.add_argument("--only-q", default=None,
- help="离散指定要跑的 q(如 '51,55,331' 或 'q51,q55,q331'),优先于 --start/--count")
- p.add_argument("--queries-file", default=None,
- help="自定义 query 源 JSON 路径(结构需含 queries[...]),默认读 high_priority_queries.json")
- p.add_argument("--platforms", default="xhs,gzh,zhihu", help="逗号分隔渠道(默认 xhs,gzh,zhihu)")
- p.add_argument("--max-count", type=int, default=10, help="每个 (渠道,query) 取几条帖子(默认 10)")
- p.add_argument("--output-dir", required=True, help="输出目录")
- p.add_argument("--eval-model", default="gemini-flash-lite", choices=list(EVAL_MODELS),
- help="评估模型(默认 gemini-flash-lite,多模态)")
- p.add_argument("--gen-model", default="gemini-flash-lite", choices=list(EVAL_MODELS),
- help="form B 句子生成模型(默认 gemini-flash-lite)")
- p.add_argument("--max-concurrent", type=int, default=3, help="搜索 / 评估并发上限")
- p.add_argument("--max-images", type=int, default=4, help="每帖最多发给模型几张配图")
- p.add_argument("--image-mode", choices=["url", "base64"], default="url",
- help="图片传输:url 直传(快,默认) / base64 下载内嵌(稳)")
- p.add_argument("--no-images", action="store_true", help="不发图(纯文本评估)")
- p.add_argument("--no-transcribe", action="store_true",
- help="不对视频帖转写(默认会转写并把字幕并入正文再评估)")
- p.add_argument("--no-eval", action="store_true", help="只搜不评估")
- p.add_argument("--reeval", action="store_true",
- help="只重跑评估、覆盖 output-dir 下已有 q*/form_*.json(不重新搜索);"
- "用 --start / --count 在 q 编号层限范围,或 --reeval-q 直接指定")
- p.add_argument("--reeval-q", default=None,
- help="仅复评指定的 q(如 'q01' 或 'q01,q05,q12'),优先于 --start/--count")
- p.add_argument("--append", action="store_true",
- help="往已有 q*/form_*.json 追加 --platforms 指定的新渠道结果(不重搜旧渠道)")
- p.add_argument("--seed", type=int, default=42, help="form C 同义替换随机种子")
- args = p.parse_args()
- asyncio.run(run(args))
- if __name__ == "__main__":
- main()
|