search_and_evaluate.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. """
  2. 独立脚本:query 词多渠道搜索 + qwen/LLM 逐条评估(不经过 agent,也不写 source.json)
  3. 定位:一个自包含的「搜 + 评」工具。
  4. 输入 :一组 query 词(json 文件)
  5. 可选 :先用 LLM(默认 qwen)改写 / 扩展 query
  6. 搜索 :直接调各渠道 search_impl 拿 post 详情(绕开 agent)
  7. 评估 :把每条 post 详情交给 LLM,按 rubric 逐条评估
  8. 输出 :evaluated.json —— 每条 = post 详情 + llm_evaluation
  9. 与 run_pipeline / llm_evaluate_sources 的区别:这里不维护 source.json / filtered_cases,
  10. 不做跨轮去重与回写,只产出一份评估结果,方便单独跑、单独看。
  11. 评估的 rubric prompt 与单帖评估逻辑仍复用 llm_evaluate_sources(忠实 rubric)。
  12. queries.json 支持两种格式:
  13. ["query1", "query2", ...]
  14. {"requirement": "采集目标描述", "queries": ["query1", "query2", ...]}
  15. 典型用法:
  16. # 直接用给定 query 搜索 + qwen 评估
  17. python search_and_evaluate.py --queries q.json --platforms xhs,zhihu \
  18. --output-dir scratch/run1 --eval-model qwen
  19. # 先让 qwen 改写 query 再搜
  20. python search_and_evaluate.py --queries q.json --platforms xhs \
  21. --output-dir scratch/run1 --gen-query --gen-model qwen --keep-original
  22. # 只搜不评估
  23. python search_and_evaluate.py --queries q.json --platforms xhs --output-dir d --no-eval
  24. """
  25. import argparse
  26. import asyncio
  27. import json
  28. import logging
  29. import sys
  30. from pathlib import Path
  31. from typing import Any, Callable, Dict, List, Optional, Tuple
  32. _PROJECT_ROOT = Path(__file__).resolve().parents[3]
  33. if str(_PROJECT_ROOT) not in sys.path:
  34. sys.path.insert(0, str(_PROJECT_ROOT))
  35. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  36. logger = logging.getLogger(__name__)
  37. # ── queries 加载 ────────────────────────────────────────────────────────────────
  38. def load_queries(path: Path) -> Tuple[List[str], str]:
  39. """读 queries.json,返回 (queries, requirement)。requirement 可能为空串。"""
  40. with open(path, "r", encoding="utf-8") as f:
  41. data = json.load(f)
  42. if isinstance(data, list):
  43. return [str(q).strip() for q in data if str(q).strip()], ""
  44. if isinstance(data, dict):
  45. raw = data.get("queries") or []
  46. queries = [str(q).strip() for q in raw if str(q).strip()]
  47. return queries, str(data.get("requirement") or "").strip()
  48. raise ValueError(f"无法识别的 queries 文件格式: {path}(应为数组或 {{queries:[...]}})")
  49. # ── query 生成 / 改写 ───────────────────────────────────────────────────────────
  50. def _validate_gen(data: Dict[str, Any]) -> Optional[str]:
  51. qs = data.get("queries")
  52. if not isinstance(qs, list) or not qs:
  53. return "queries 必须是非空数组"
  54. if not all(isinstance(q, str) and q.strip() for q in qs):
  55. return "queries 每一项必须是非空字符串"
  56. return None
  57. async def generate_queries(
  58. base_queries: List[str],
  59. requirement: str,
  60. llm_call: Callable,
  61. model: str,
  62. target_count: int,
  63. ) -> Tuple[List[str], float]:
  64. """让 LLM 基于采集需求改写 / 扩展已有 query。返回 (新 query 列表, cost)。"""
  65. system = (
  66. "你是内容采集的搜索词优化器。基于采集需求和已有 query,产出一组更适合在"
  67. "社媒/内容平台搜索框直接使用的关键词:覆盖同义表达、具体工具名、典型用法场景,"
  68. "去掉过于宽泛或重复的词。只输出一个 JSON 对象,不要解释、不要 markdown。"
  69. )
  70. user = (
  71. f"【采集需求 / 目标格子】\n{requirement or '(未提供,参考已有 query 自行归纳)'}\n\n"
  72. f"【已有 query】\n{json.dumps(base_queries, ensure_ascii=False)}\n\n"
  73. f"【要求】产出约 {target_count} 个改写 / 扩展后的搜索词,输出格式:\n"
  74. '{"queries": ["词1", "词2", ...]}\n'
  75. "词应简短(适合搜索框)、彼此不同、贴合『制作做法』而非泛泛话题。只输出 JSON。"
  76. )
  77. data, cost = await call_llm_with_retry(
  78. llm_call=llm_call,
  79. messages=[{"role": "system", "content": system},
  80. {"role": "user", "content": user}],
  81. model=model, temperature=0.5, max_tokens=1500,
  82. validate_fn=_validate_gen, task_name="GenQuery",
  83. )
  84. if not data:
  85. logger.warning("query 生成失败,回退使用原始 query")
  86. return [], cost
  87. out, seen = [], set()
  88. for q in data["queries"]:
  89. q = q.strip()
  90. if q and q not in seen:
  91. seen.add(q)
  92. out.append(q)
  93. return out, cost
  94. # ── 渠道搜索 ────────────────────────────────────────────────────────────────────
  95. def _post_cid(post: Dict[str, Any]) -> Optional[str]:
  96. cid = post.get("channel_content_id") or post.get("video_id")
  97. if cid:
  98. return str(cid)
  99. link = post.get("link") or post.get("url")
  100. return str(link) if link else None
  101. async def _search_one(pdef, query: str, max_count: int, sem: asyncio.Semaphore):
  102. """跑一次 (platform, query) 搜索,返回 (platform, query, posts)。失败返回空 posts。"""
  103. async with sem:
  104. try:
  105. result = await pdef.search_impl(
  106. platform_id=pdef.id, keyword=query, max_count=max_count,
  107. cursor="", extras=None,
  108. )
  109. except Exception as e:
  110. logger.warning("search 失败 [%s/%s]: %s", pdef.id, query, e)
  111. return pdef.id, query, []
  112. if getattr(result, "error", None):
  113. logger.warning("search 返回错误 [%s/%s]: %s", pdef.id, query, result.error)
  114. return pdef.id, query, []
  115. posts = (result.metadata or {}).get("posts", []) or []
  116. return pdef.id, query, posts
  117. async def search_all(
  118. platforms: List[str], queries: List[str], max_count: int, max_concurrent: int,
  119. ) -> List[Dict[str, Any]]:
  120. """对所有 (platform × query) 组合并发搜索,按 (platform, cid) 去重。
  121. 返回 source_dict 列表,每条带:case_id / platform / channel_content_id /
  122. source_url / post / comments / found_by_queries(命中它的 query,用于回溯 query 质量)。
  123. """
  124. import agent.tools.builtin.content.tools # noqa: F401 触发平台自注册
  125. from agent.tools.builtin.content.registry import get_platform, all_platforms
  126. pdefs = []
  127. for p in platforms:
  128. pdef = get_platform(p)
  129. if not pdef:
  130. avail = ", ".join(x.id for x in all_platforms())
  131. raise ValueError(f"未知平台 '{p}'。可用: {avail}")
  132. if not pdef.search_impl:
  133. raise ValueError(f"平台 '{p}' 不支持搜索")
  134. pdefs.append(pdef)
  135. sem = asyncio.Semaphore(max_concurrent)
  136. tasks = [_search_one(pdef, q, max_count, sem) for pdef in pdefs for q in queries]
  137. print(f"🔎 搜索 {len(pdefs)} 渠道 × {len(queries)} query = {len(tasks)} 次请求 (并发 {max_concurrent})")
  138. results = await asyncio.gather(*tasks)
  139. collected: Dict[Tuple[str, str], Dict[str, Any]] = {}
  140. per_query_counts: Dict[str, int] = {}
  141. for platform, query, posts in results:
  142. per_query_counts[f"{platform}/{query}"] = len(posts)
  143. for post in posts:
  144. if not isinstance(post, dict):
  145. continue
  146. cid = _post_cid(post)
  147. if not cid:
  148. continue
  149. key = (platform, cid)
  150. if key in collected:
  151. collected[key]["found_by_queries"].append(query)
  152. continue
  153. link = post.get("link") or post.get("url") or ""
  154. collected[key] = {
  155. "case_id": f"{platform}_{cid}",
  156. "platform": platform,
  157. "channel_content_id": cid,
  158. "source_url": link,
  159. "post": post,
  160. "comments": post.get("author_comments", []) or [],
  161. "found_by_queries": [query],
  162. }
  163. print(" 每个 (渠道/query) 命中数:")
  164. for k, n in sorted(per_query_counts.items()):
  165. print(f" - {k}: {n}")
  166. print(f" 去重后唯一 post:{len(collected)}")
  167. return list(collected.values())
  168. # ── 图片获取(下载转 base64 data URL,喂多模态评估)──────────────────────────────
  169. import base64
  170. _MIME_BY_EXT = {".png": "image/png", ".webp": "image/webp", ".gif": "image/gif"}
  171. def _collect_post_image_urls(post: Dict[str, Any], max_images: int) -> List[str]:
  172. """从 post 收集图片 URL(复用 generate_case 的字段映射),截断到 max_images。"""
  173. try:
  174. from examples.process_pipeline.script.generate_case import _extract_raw_images
  175. urls = _extract_raw_images(post, post.get("channel") or "")
  176. except Exception:
  177. urls = post.get("images") or []
  178. urls = [u for u in urls if isinstance(u, str) and u.startswith("http")]
  179. return urls[:max_images]
  180. async def _fetch_data_url(url: str, sem: asyncio.Semaphore) -> Optional[str]:
  181. """下载单张图片转 base64 data URL(用项目 _download_image,带 Referer/UA 绕防盗链)。"""
  182. from agent.tools.builtin.file.image_cdn import _download_image
  183. async with sem:
  184. try:
  185. data = await _download_image(url)
  186. except Exception as e:
  187. logger.warning("图片下载失败 %s: %s", url[:60], e)
  188. return None
  189. if not data:
  190. return None
  191. ext = next((e for e in _MIME_BY_EXT if url.lower().split("?")[0].endswith(e)), "")
  192. mime = _MIME_BY_EXT.get(ext, "image/jpeg")
  193. b64 = base64.b64encode(data).decode("ascii")
  194. return f"data:{mime};base64,{b64}"
  195. async def _attach_image_data_urls(
  196. sources: List[Dict[str, Any]], max_images: int, max_concurrent: int,
  197. ) -> int:
  198. """为每条 source 下载配图转 data URL,挂到 s['_image_data_urls'](评估后会清掉,不写进报告)。"""
  199. sem = asyncio.Semaphore(max_concurrent)
  200. plan: List[Tuple[Dict[str, Any], List[str]]] = []
  201. for s in sources:
  202. urls = _collect_post_image_urls(s.get("post", {}) or {}, max_images)
  203. plan.append((s, urls))
  204. # 拉平并发下载
  205. flat = [(s, u) for s, urls in plan for u in urls]
  206. results = await asyncio.gather(*[_fetch_data_url(u, sem) for _, u in flat])
  207. bucket: Dict[int, List[str]] = {}
  208. for (s, _), data_url in zip(flat, results):
  209. if data_url:
  210. bucket.setdefault(id(s), []).append(data_url)
  211. total = 0
  212. for s in sources:
  213. s["_image_data_urls"] = bucket.get(id(s), [])
  214. total += len(s["_image_data_urls"])
  215. return total
  216. # ── 逐条评估(直评,不写 source.json)─────────────────────────────────────────────
  217. async def evaluate_posts(
  218. sources: List[Dict[str, Any]],
  219. requirement: str,
  220. llm_call: Callable,
  221. model: str,
  222. max_concurrent: int,
  223. include_images: bool = True,
  224. max_images: int = 4,
  225. ) -> Tuple[List[Dict[str, Any]], float]:
  226. """对每条 post 用 rubric 逐条评估,把 llm_evaluation 挂到 source 上。返回 (sources, total_cost)。
  227. 复用 llm_evaluate_sources 的 rubric 加载与单帖评估逻辑,保证与管线评估口径一致。
  228. include_images=True 时把帖子配图(下载转 base64)一并发给模型做多模态评估。
  229. 评估失败(重试耗尽)的条目标 error 标记并保留,不丢。
  230. """
  231. from examples.process_pipeline.script.llm_evaluate_sources import (
  232. load_post_rubric, load_rubric_md, _evaluate_one,
  233. )
  234. post_rubric = load_post_rubric()
  235. rubric_md = load_rubric_md()
  236. sem = asyncio.Semaphore(max_concurrent)
  237. if include_images:
  238. n_img = await _attach_image_data_urls(sources, max_images, max_concurrent * 2)
  239. print(f"🖼️ 下载配图 {n_img} 张(每帖≤{max_images})用于多模态评估")
  240. print(f"🧠 逐条评估 {len(sources)} 条 (并发 {max_concurrent}) ...")
  241. results = await asyncio.gather(*[
  242. _evaluate_one(s, post_rubric, rubric_md, requirement, llm_call, model, sem,
  243. image_urls=(s.get("_image_data_urls") if include_images else None))
  244. for s in sources
  245. ])
  246. total_cost = 0.0
  247. rep = dis = failed = 0
  248. for s, (llm_eval, cost) in zip(sources, results):
  249. total_cost += cost
  250. if llm_eval is None:
  251. s["llm_evaluation"] = {"decision": "report", "reason": "llm_eval_failed", "error": True}
  252. failed += 1
  253. else:
  254. s["llm_evaluation"] = llm_eval
  255. if llm_eval.get("decision") == "discard":
  256. dis += 1
  257. else:
  258. rep += 1
  259. title = (s.get("post", {}) or {}).get("title", "")[:30]
  260. dec = s["llm_evaluation"].get("decision")
  261. print(f" - [{dec:7}] {s['case_id'][:24]} {title}")
  262. print(f" 汇总:report={rep} discard={dis} failed={failed} cost=${total_cost:.4f}")
  263. return sources, total_cost
  264. # ── 主流程 ────────────────────────────────────────────────────────────────────
  265. async def run(args: argparse.Namespace) -> None:
  266. from examples.process_pipeline.script.llm_evaluate_sources import build_eval_llm_call
  267. queries, req_from_file = load_queries(Path(args.queries))
  268. if not queries:
  269. print("❌ queries 为空"); sys.exit(1)
  270. platforms = [p.strip() for p in args.platforms.split(",") if p.strip()]
  271. if not platforms:
  272. print("❌ 未指定 --platforms"); sys.exit(1)
  273. requirement = args.requirement or req_from_file or (";".join(queries))[:200]
  274. output_dir = Path(args.output_dir)
  275. output_dir.mkdir(parents=True, exist_ok=True)
  276. print(f"📋 需求: {requirement[:80]}")
  277. print(f"📡 渠道: {platforms} | 原始 query 数: {len(queries)}")
  278. total_cost = 0.0
  279. # 1. 可选:LLM 改写 / 扩展 query
  280. if args.gen_query:
  281. gen_llm, gen_model_id = build_eval_llm_call(args.gen_model)
  282. print(f"✍️ query 生成模型: {args.gen_model} -> {gen_model_id}")
  283. gen_queries, gen_cost = await generate_queries(
  284. queries, requirement, gen_llm, gen_model_id, args.gen_count,
  285. )
  286. total_cost += gen_cost
  287. if gen_queries:
  288. search_queries = list(dict.fromkeys(
  289. (queries + gen_queries) if args.keep_original else gen_queries
  290. ))
  291. (output_dir / "generated_queries.json").write_text(
  292. json.dumps({"requirement": requirement, "original": queries,
  293. "generated": gen_queries, "used": search_queries},
  294. ensure_ascii=False, indent=2), encoding="utf-8",
  295. )
  296. print(f" 生成 {len(gen_queries)} 个 query,实际搜索 {len(search_queries)} 个"
  297. f"({'含' if args.keep_original else '不含'}原始)→ generated_queries.json")
  298. else:
  299. search_queries = queries
  300. print(" 生成失败,回退原始 query")
  301. else:
  302. search_queries = queries
  303. # 2. 多渠道搜索
  304. sources = await search_all(platforms, search_queries, args.max_count, args.max_concurrent)
  305. # 3. 时间戳转可读
  306. try:
  307. from examples.process_pipeline.script.extract_sources import _convert_timestamps
  308. _convert_timestamps(sources)
  309. except Exception as e:
  310. logger.warning("时间戳转换跳过: %s", e)
  311. # 4. 可选:视频帖自动转写
  312. if args.transcribe and sources:
  313. try:
  314. from examples.process_pipeline.script.extract_sources import (
  315. _transcribe_pending_async, _merge_transcript_into_body,
  316. )
  317. updates = await _transcribe_pending_async(sources, concurrency=3)
  318. print(f"🎙️ 转写: {len(updates)} 条视频获得字幕")
  319. for s in sources:
  320. post = s.get("post")
  321. if isinstance(post, dict) and post.get("video_transcript"):
  322. merged = _merge_transcript_into_body(post)
  323. if merged is not post:
  324. post["body_text"] = merged.get("body_text", post.get("body_text", ""))
  325. except Exception as e:
  326. logger.warning("转写跳过: %s", e)
  327. # 5. 评估(除非 --no-eval)
  328. eval_model_id = None
  329. include_images = not args.no_images
  330. if not args.no_eval and sources:
  331. eval_llm, eval_model_id = build_eval_llm_call(args.eval_model)
  332. print(f"🧠 评估模型: {args.eval_model} -> {eval_model_id} | 多模态图片: {'开' if include_images else '关'}")
  333. sources, eval_cost = await evaluate_posts(
  334. sources, requirement, eval_llm, eval_model_id, args.max_concurrent,
  335. include_images=include_images, max_images=args.max_images,
  336. )
  337. total_cost += eval_cost
  338. # 清掉评估用的临时 base64 图(别写进报告,否则文件爆炸),只留张数留痕
  339. for s in sources:
  340. imgs = s.pop("_image_data_urls", None)
  341. if imgs is not None:
  342. s["images_sent"] = len(imgs)
  343. # 6. 写 evaluated.json
  344. out_file = output_dir / "evaluated.json"
  345. rep = sum(1 for s in sources if (s.get("llm_evaluation") or {}).get("decision") == "report"
  346. and not (s.get("llm_evaluation") or {}).get("error"))
  347. dis = sum(1 for s in sources if (s.get("llm_evaluation") or {}).get("decision") == "discard")
  348. out = {
  349. "requirement": requirement,
  350. "platforms": platforms,
  351. "queries_used": search_queries,
  352. "eval_model": eval_model_id,
  353. "total": len(sources),
  354. "report": rep,
  355. "discard": dis,
  356. "results": sources,
  357. }
  358. with open(out_file, "w", encoding="utf-8") as f:
  359. json.dump(out, f, ensure_ascii=False, indent=2)
  360. print(f"💾 evaluated.json: {len(sources)} 条 (report={rep} discard={dis}) → {out_file}")
  361. print(f"💰 累计成本: ${total_cost:.4f}")
  362. def main() -> None:
  363. from dotenv import load_dotenv
  364. load_dotenv()
  365. from examples.process_pipeline.script.llm_evaluate_sources import EVAL_MODELS, DEFAULT_EVAL_MODEL
  366. parser = argparse.ArgumentParser(description="query 词多渠道直搜 + LLM 逐条评估")
  367. parser.add_argument("--queries", required=True, help="query 词 json 文件(数组或 {queries:[...]})")
  368. parser.add_argument("--platforms", required=True, help="逗号分隔渠道,如 xhs,zhihu,gzh,douyin,sph,youtube")
  369. parser.add_argument("--output-dir", required=True, help="输出目录(写 evaluated.json)")
  370. parser.add_argument("--max-count", type=int, default=20, help="每个 (渠道,query) 返回条数上限(部分渠道如 xhs 由后端决定)")
  371. parser.add_argument("--requirement", default="", help="评估用采集目标描述(缺省取 queries 文件 requirement 或 query 拼接)")
  372. parser.add_argument("--eval-model", default=DEFAULT_EVAL_MODEL, choices=list(EVAL_MODELS),
  373. help=f"评估模型(默认 {DEFAULT_EVAL_MODEL})")
  374. parser.add_argument("--max-concurrent", type=int, default=3, help="搜索 / 评估并发上限")
  375. parser.add_argument("--no-eval", action="store_true", help="只搜索,跳过 LLM 评估")
  376. parser.add_argument("--no-images", action="store_true", help="不把帖子配图发给模型(默认发,多模态评估;纯文本模型请加此项)")
  377. parser.add_argument("--max-images", type=int, default=4, help="每帖最多发给模型几张配图(默认 4)")
  378. parser.add_argument("--transcribe", action="store_true", help="评估前对视频帖跑 Deepgram 转写")
  379. # query 生成模式
  380. parser.add_argument("--gen-query", action="store_true", help="搜索前用 LLM 改写 / 扩展 query")
  381. parser.add_argument("--gen-model", default=DEFAULT_EVAL_MODEL, choices=list(EVAL_MODELS),
  382. help=f"query 生成模型(默认 {DEFAULT_EVAL_MODEL})")
  383. parser.add_argument("--gen-count", type=int, default=10, help="生成 query 的目标数量")
  384. parser.add_argument("--keep-original", action="store_true", help="生成的 query 与原始 query 合并搜索(默认只用生成的)")
  385. args = parser.parse_args()
  386. asyncio.run(run(args))
  387. if __name__ == "__main__":
  388. main()