| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457 |
- """
- 独立脚本:query 词多渠道搜索 + qwen/LLM 逐条评估(不经过 agent,也不写 source.json)
- 定位:一个自包含的「搜 + 评」工具。
- 输入 :一组 query 词(json 文件)
- 可选 :先用 LLM(默认 qwen)改写 / 扩展 query
- 搜索 :直接调各渠道 search_impl 拿 post 详情(绕开 agent)
- 评估 :把每条 post 详情交给 LLM,按 rubric 逐条评估
- 输出 :evaluated.json —— 每条 = post 详情 + llm_evaluation
- 与 run_pipeline / llm_evaluate_sources 的区别:这里不维护 source.json / filtered_cases,
- 不做跨轮去重与回写,只产出一份评估结果,方便单独跑、单独看。
- 评估的 rubric prompt 与单帖评估逻辑仍复用 llm_evaluate_sources(忠实 rubric)。
- queries.json 支持两种格式:
- ["query1", "query2", ...]
- {"requirement": "采集目标描述", "queries": ["query1", "query2", ...]}
- 典型用法:
- # 直接用给定 query 搜索 + qwen 评估
- python search_and_evaluate.py --queries q.json --platforms xhs,zhihu \
- --output-dir scratch/run1 --eval-model qwen
- # 先让 qwen 改写 query 再搜
- python search_and_evaluate.py --queries q.json --platforms xhs \
- --output-dir scratch/run1 --gen-query --gen-model qwen --keep-original
- # 只搜不评估
- python search_and_evaluate.py --queries q.json --platforms xhs --output-dir d --no-eval
- """
- import argparse
- import asyncio
- import json
- import logging
- import sys
- from pathlib import Path
- from typing import Any, Callable, Dict, List, Optional, Tuple
- _PROJECT_ROOT = Path(__file__).resolve().parents[3]
- if str(_PROJECT_ROOT) not in sys.path:
- sys.path.insert(0, str(_PROJECT_ROOT))
- from examples.process_pipeline.script.llm_helper import call_llm_with_retry
- logger = logging.getLogger(__name__)
- # ── queries 加载 ────────────────────────────────────────────────────────────────
- def load_queries(path: Path) -> Tuple[List[str], str]:
- """读 queries.json,返回 (queries, requirement)。requirement 可能为空串。"""
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
- if isinstance(data, list):
- return [str(q).strip() for q in data if str(q).strip()], ""
- if isinstance(data, dict):
- raw = data.get("queries") or []
- queries = [str(q).strip() for q in raw if str(q).strip()]
- return queries, str(data.get("requirement") or "").strip()
- raise ValueError(f"无法识别的 queries 文件格式: {path}(应为数组或 {{queries:[...]}})")
- # ── query 生成 / 改写 ───────────────────────────────────────────────────────────
- def _validate_gen(data: Dict[str, Any]) -> Optional[str]:
- qs = data.get("queries")
- if not isinstance(qs, list) or not qs:
- return "queries 必须是非空数组"
- if not all(isinstance(q, str) and q.strip() for q in qs):
- return "queries 每一项必须是非空字符串"
- return None
- async def generate_queries(
- base_queries: List[str],
- requirement: str,
- llm_call: Callable,
- model: str,
- target_count: int,
- ) -> Tuple[List[str], float]:
- """让 LLM 基于采集需求改写 / 扩展已有 query。返回 (新 query 列表, cost)。"""
- system = (
- "你是内容采集的搜索词优化器。基于采集需求和已有 query,产出一组更适合在"
- "社媒/内容平台搜索框直接使用的关键词:覆盖同义表达、具体工具名、典型用法场景,"
- "去掉过于宽泛或重复的词。只输出一个 JSON 对象,不要解释、不要 markdown。"
- )
- user = (
- f"【采集需求 / 目标格子】\n{requirement or '(未提供,参考已有 query 自行归纳)'}\n\n"
- f"【已有 query】\n{json.dumps(base_queries, ensure_ascii=False)}\n\n"
- f"【要求】产出约 {target_count} 个改写 / 扩展后的搜索词,输出格式:\n"
- '{"queries": ["词1", "词2", ...]}\n'
- "词应简短(适合搜索框)、彼此不同、贴合『制作做法』而非泛泛话题。只输出 JSON。"
- )
- data, cost = await call_llm_with_retry(
- llm_call=llm_call,
- messages=[{"role": "system", "content": system},
- {"role": "user", "content": user}],
- model=model, temperature=0.5, max_tokens=1500,
- validate_fn=_validate_gen, task_name="GenQuery",
- )
- if not data:
- logger.warning("query 生成失败,回退使用原始 query")
- return [], cost
- out, seen = [], set()
- for q in data["queries"]:
- q = q.strip()
- if q and q not in seen:
- seen.add(q)
- out.append(q)
- return out, cost
- # ── 渠道搜索 ────────────────────────────────────────────────────────────────────
- def _post_cid(post: Dict[str, Any]) -> Optional[str]:
- cid = post.get("channel_content_id") or post.get("video_id")
- if cid:
- return str(cid)
- link = post.get("link") or post.get("url")
- return str(link) if link else None
- async def _search_one(pdef, query: str, max_count: int, sem: asyncio.Semaphore):
- """跑一次 (platform, query) 搜索,返回 (platform, query, posts)。失败返回空 posts。"""
- async with sem:
- try:
- result = await pdef.search_impl(
- platform_id=pdef.id, keyword=query, max_count=max_count,
- cursor="", extras=None,
- )
- except Exception as e:
- logger.warning("search 失败 [%s/%s]: %s", pdef.id, query, e)
- return pdef.id, query, []
- if getattr(result, "error", None):
- logger.warning("search 返回错误 [%s/%s]: %s", pdef.id, query, result.error)
- return pdef.id, query, []
- posts = (result.metadata or {}).get("posts", []) or []
- return pdef.id, query, posts
- async def search_all(
- platforms: List[str], queries: List[str], max_count: int, max_concurrent: int,
- ) -> List[Dict[str, Any]]:
- """对所有 (platform × query) 组合并发搜索,按 (platform, cid) 去重。
- 返回 source_dict 列表,每条带:case_id / platform / channel_content_id /
- source_url / post / comments / found_by_queries(命中它的 query,用于回溯 query 质量)。
- """
- import agent.tools.builtin.content.tools # noqa: F401 触发平台自注册
- from agent.tools.builtin.content.registry import get_platform, all_platforms
- pdefs = []
- for p in platforms:
- pdef = get_platform(p)
- if not pdef:
- avail = ", ".join(x.id for x in all_platforms())
- raise ValueError(f"未知平台 '{p}'。可用: {avail}")
- if not pdef.search_impl:
- raise ValueError(f"平台 '{p}' 不支持搜索")
- pdefs.append(pdef)
- sem = asyncio.Semaphore(max_concurrent)
- tasks = [_search_one(pdef, q, max_count, sem) for pdef in pdefs for q in queries]
- print(f"🔎 搜索 {len(pdefs)} 渠道 × {len(queries)} query = {len(tasks)} 次请求 (并发 {max_concurrent})")
- results = await asyncio.gather(*tasks)
- collected: Dict[Tuple[str, str], Dict[str, Any]] = {}
- per_query_counts: Dict[str, int] = {}
- for platform, query, posts in results:
- per_query_counts[f"{platform}/{query}"] = len(posts)
- for post in posts:
- if not isinstance(post, dict):
- continue
- cid = _post_cid(post)
- if not cid:
- continue
- key = (platform, cid)
- if key in collected:
- collected[key]["found_by_queries"].append(query)
- continue
- link = post.get("link") or post.get("url") or ""
- collected[key] = {
- "case_id": f"{platform}_{cid}",
- "platform": platform,
- "channel_content_id": cid,
- "source_url": link,
- "post": post,
- "comments": post.get("author_comments", []) or [],
- "found_by_queries": [query],
- }
- print(" 每个 (渠道/query) 命中数:")
- for k, n in sorted(per_query_counts.items()):
- print(f" - {k}: {n}")
- print(f" 去重后唯一 post:{len(collected)}")
- return list(collected.values())
- # ── 图片获取(下载转 base64 data URL,喂多模态评估)──────────────────────────────
- import base64
- _MIME_BY_EXT = {".png": "image/png", ".webp": "image/webp", ".gif": "image/gif"}
- def _collect_post_image_urls(post: Dict[str, Any], max_images: int) -> List[str]:
- """从 post 收集图片 URL(复用 generate_case 的字段映射),截断到 max_images。"""
- try:
- from examples.process_pipeline.script.generate_case import _extract_raw_images
- urls = _extract_raw_images(post, post.get("channel") or "")
- except Exception:
- urls = post.get("images") or []
- urls = [u for u in urls if isinstance(u, str) and u.startswith("http")]
- return urls[:max_images]
- async def _fetch_data_url(url: str, sem: asyncio.Semaphore) -> Optional[str]:
- """下载单张图片转 base64 data URL(用项目 _download_image,带 Referer/UA 绕防盗链)。"""
- from agent.tools.builtin.file.image_cdn import _download_image
- async with sem:
- try:
- data = await _download_image(url)
- except Exception as e:
- logger.warning("图片下载失败 %s: %s", url[:60], e)
- return None
- if not data:
- return None
- ext = next((e for e in _MIME_BY_EXT if url.lower().split("?")[0].endswith(e)), "")
- mime = _MIME_BY_EXT.get(ext, "image/jpeg")
- b64 = base64.b64encode(data).decode("ascii")
- return f"data:{mime};base64,{b64}"
- async def _attach_image_data_urls(
- sources: List[Dict[str, Any]], max_images: int, max_concurrent: int,
- ) -> int:
- """为每条 source 下载配图转 data URL,挂到 s['_image_data_urls'](评估后会清掉,不写进报告)。"""
- sem = asyncio.Semaphore(max_concurrent)
- plan: List[Tuple[Dict[str, Any], List[str]]] = []
- for s in sources:
- urls = _collect_post_image_urls(s.get("post", {}) or {}, max_images)
- plan.append((s, urls))
- # 拉平并发下载
- flat = [(s, u) for s, urls in plan for u in urls]
- results = await asyncio.gather(*[_fetch_data_url(u, sem) for _, u in flat])
- bucket: Dict[int, List[str]] = {}
- for (s, _), data_url in zip(flat, results):
- if data_url:
- bucket.setdefault(id(s), []).append(data_url)
- total = 0
- for s in sources:
- s["_image_data_urls"] = bucket.get(id(s), [])
- total += len(s["_image_data_urls"])
- return total
- # ── 逐条评估(直评,不写 source.json)─────────────────────────────────────────────
- async def evaluate_posts(
- sources: List[Dict[str, Any]],
- requirement: str,
- llm_call: Callable,
- model: str,
- max_concurrent: int,
- include_images: bool = True,
- max_images: int = 4,
- ) -> Tuple[List[Dict[str, Any]], float]:
- """对每条 post 用 rubric 逐条评估,把 llm_evaluation 挂到 source 上。返回 (sources, total_cost)。
- 复用 llm_evaluate_sources 的 rubric 加载与单帖评估逻辑,保证与管线评估口径一致。
- include_images=True 时把帖子配图(下载转 base64)一并发给模型做多模态评估。
- 评估失败(重试耗尽)的条目标 error 标记并保留,不丢。
- """
- from examples.process_pipeline.script.llm_evaluate_sources import (
- load_post_rubric, load_rubric_md, _evaluate_one,
- )
- post_rubric = load_post_rubric()
- rubric_md = load_rubric_md()
- sem = asyncio.Semaphore(max_concurrent)
- if include_images:
- n_img = await _attach_image_data_urls(sources, max_images, max_concurrent * 2)
- print(f"🖼️ 下载配图 {n_img} 张(每帖≤{max_images})用于多模态评估")
- print(f"🧠 逐条评估 {len(sources)} 条 (并发 {max_concurrent}) ...")
- results = await asyncio.gather(*[
- _evaluate_one(s, post_rubric, rubric_md, requirement, llm_call, model, sem,
- image_urls=(s.get("_image_data_urls") if include_images else None))
- for s in sources
- ])
- total_cost = 0.0
- rep = dis = failed = 0
- for s, (llm_eval, cost) in zip(sources, results):
- total_cost += cost
- if llm_eval is None:
- s["llm_evaluation"] = {"decision": "report", "reason": "llm_eval_failed", "error": True}
- failed += 1
- else:
- s["llm_evaluation"] = llm_eval
- if llm_eval.get("decision") == "discard":
- dis += 1
- else:
- rep += 1
- title = (s.get("post", {}) or {}).get("title", "")[:30]
- dec = s["llm_evaluation"].get("decision")
- print(f" - [{dec:7}] {s['case_id'][:24]} {title}")
- print(f" 汇总:report={rep} discard={dis} failed={failed} cost=${total_cost:.4f}")
- return sources, total_cost
- # ── 主流程 ────────────────────────────────────────────────────────────────────
- async def run(args: argparse.Namespace) -> None:
- from examples.process_pipeline.script.llm_evaluate_sources import build_eval_llm_call
- queries, req_from_file = load_queries(Path(args.queries))
- if not queries:
- print("❌ queries 为空"); sys.exit(1)
- platforms = [p.strip() for p in args.platforms.split(",") if p.strip()]
- if not platforms:
- print("❌ 未指定 --platforms"); sys.exit(1)
- requirement = args.requirement or req_from_file or (";".join(queries))[:200]
- output_dir = Path(args.output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
- print(f"📋 需求: {requirement[:80]}")
- print(f"📡 渠道: {platforms} | 原始 query 数: {len(queries)}")
- total_cost = 0.0
- # 1. 可选:LLM 改写 / 扩展 query
- if args.gen_query:
- gen_llm, gen_model_id = build_eval_llm_call(args.gen_model)
- print(f"✍️ query 生成模型: {args.gen_model} -> {gen_model_id}")
- gen_queries, gen_cost = await generate_queries(
- queries, requirement, gen_llm, gen_model_id, args.gen_count,
- )
- total_cost += gen_cost
- if gen_queries:
- search_queries = list(dict.fromkeys(
- (queries + gen_queries) if args.keep_original else gen_queries
- ))
- (output_dir / "generated_queries.json").write_text(
- json.dumps({"requirement": requirement, "original": queries,
- "generated": gen_queries, "used": search_queries},
- ensure_ascii=False, indent=2), encoding="utf-8",
- )
- print(f" 生成 {len(gen_queries)} 个 query,实际搜索 {len(search_queries)} 个"
- f"({'含' if args.keep_original else '不含'}原始)→ generated_queries.json")
- else:
- search_queries = queries
- print(" 生成失败,回退原始 query")
- else:
- search_queries = queries
- # 2. 多渠道搜索
- sources = await search_all(platforms, search_queries, args.max_count, args.max_concurrent)
- # 3. 时间戳转可读
- try:
- from examples.process_pipeline.script.extract_sources import _convert_timestamps
- _convert_timestamps(sources)
- except Exception as e:
- logger.warning("时间戳转换跳过: %s", e)
- # 4. 可选:视频帖自动转写
- if args.transcribe and sources:
- try:
- from examples.process_pipeline.script.extract_sources import (
- _transcribe_pending_async, _merge_transcript_into_body,
- )
- updates = await _transcribe_pending_async(sources, concurrency=3)
- print(f"🎙️ 转写: {len(updates)} 条视频获得字幕")
- for s in sources:
- post = s.get("post")
- if isinstance(post, dict) and post.get("video_transcript"):
- merged = _merge_transcript_into_body(post)
- if merged is not post:
- post["body_text"] = merged.get("body_text", post.get("body_text", ""))
- except Exception as e:
- logger.warning("转写跳过: %s", e)
- # 5. 评估(除非 --no-eval)
- eval_model_id = None
- include_images = not args.no_images
- if not args.no_eval and sources:
- eval_llm, eval_model_id = build_eval_llm_call(args.eval_model)
- print(f"🧠 评估模型: {args.eval_model} -> {eval_model_id} | 多模态图片: {'开' if include_images else '关'}")
- sources, eval_cost = await evaluate_posts(
- sources, requirement, eval_llm, eval_model_id, args.max_concurrent,
- include_images=include_images, max_images=args.max_images,
- )
- total_cost += eval_cost
- # 清掉评估用的临时 base64 图(别写进报告,否则文件爆炸),只留张数留痕
- for s in sources:
- imgs = s.pop("_image_data_urls", None)
- if imgs is not None:
- s["images_sent"] = len(imgs)
- # 6. 写 evaluated.json
- out_file = output_dir / "evaluated.json"
- rep = sum(1 for s in sources if (s.get("llm_evaluation") or {}).get("decision") == "report"
- and not (s.get("llm_evaluation") or {}).get("error"))
- dis = sum(1 for s in sources if (s.get("llm_evaluation") or {}).get("decision") == "discard")
- out = {
- "requirement": requirement,
- "platforms": platforms,
- "queries_used": search_queries,
- "eval_model": eval_model_id,
- "total": len(sources),
- "report": rep,
- "discard": dis,
- "results": sources,
- }
- with open(out_file, "w", encoding="utf-8") as f:
- json.dump(out, f, ensure_ascii=False, indent=2)
- print(f"💾 evaluated.json: {len(sources)} 条 (report={rep} discard={dis}) → {out_file}")
- print(f"💰 累计成本: ${total_cost:.4f}")
- def main() -> None:
- from dotenv import load_dotenv
- load_dotenv()
- from examples.process_pipeline.script.llm_evaluate_sources import EVAL_MODELS, DEFAULT_EVAL_MODEL
- parser = argparse.ArgumentParser(description="query 词多渠道直搜 + LLM 逐条评估")
- parser.add_argument("--queries", required=True, help="query 词 json 文件(数组或 {queries:[...]})")
- parser.add_argument("--platforms", required=True, help="逗号分隔渠道,如 xhs,zhihu,gzh,douyin,sph,youtube")
- parser.add_argument("--output-dir", required=True, help="输出目录(写 evaluated.json)")
- parser.add_argument("--max-count", type=int, default=20, help="每个 (渠道,query) 返回条数上限(部分渠道如 xhs 由后端决定)")
- parser.add_argument("--requirement", default="", help="评估用采集目标描述(缺省取 queries 文件 requirement 或 query 拼接)")
- parser.add_argument("--eval-model", default=DEFAULT_EVAL_MODEL, choices=list(EVAL_MODELS),
- help=f"评估模型(默认 {DEFAULT_EVAL_MODEL})")
- parser.add_argument("--max-concurrent", type=int, default=3, help="搜索 / 评估并发上限")
- parser.add_argument("--no-eval", action="store_true", help="只搜索,跳过 LLM 评估")
- parser.add_argument("--no-images", action="store_true", help="不把帖子配图发给模型(默认发,多模态评估;纯文本模型请加此项)")
- parser.add_argument("--max-images", type=int, default=4, help="每帖最多发给模型几张配图(默认 4)")
- parser.add_argument("--transcribe", action="store_true", help="评估前对视频帖跑 Deepgram 转写")
- # query 生成模式
- parser.add_argument("--gen-query", action="store_true", help="搜索前用 LLM 改写 / 扩展 query")
- parser.add_argument("--gen-model", default=DEFAULT_EVAL_MODEL, choices=list(EVAL_MODELS),
- help=f"query 生成模型(默认 {DEFAULT_EVAL_MODEL})")
- parser.add_argument("--gen-count", type=int, default=10, help="生成 query 的目标数量")
- parser.add_argument("--keep-original", action="store_true", help="生成的 query 与原始 query 合并搜索(默认只用生成的)")
- args = parser.parse_args()
- asyncio.run(run(args))
- if __name__ == "__main__":
- main()
|