| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- # -*- coding: utf-8 -*-
- """工具解构 · search_tools 帖子 → 结构化工具条目 → mode_tools 表
- ================================================================================
- - 帖子源:search_tools 表(--query-id + --case-ids 定位)
- - 模型默认 google/gemini-3.1-flash-lite,可 --model 传任意 OpenRouter id
- - 多模态:复用 search_and_evaluate._attach_image_refs(URL 直传)
- - 写库:db.replace_tools(同版本幂等,跨版本保留);runs/mode_tools/ 留调试副本
- 用法(一般由 server.py 起子进程调):
- python pipeline/tool_extract.py --query-id q0000 --case-ids xhs_abc,gzh_def
- python pipeline/tool_extract.py --query-id q0000 --case-ids xhs_abc --model anthropic/claude-sonnet-4-6
- """
- import argparse
- import asyncio
- import json
- import sys
- import time
- from datetime import datetime
- from pathlib import Path
- PROJECT_ROOT = Path(__file__).resolve().parents[3] # …/Agent
- sys.path.insert(0, str(PROJECT_ROOT))
- from dotenv import load_dotenv
- load_dotenv()
- from examples.process_pipeline.script.search_eval.search_and_evaluate import _attach_image_refs
- from examples.process_pipeline.script.llm_evaluate_sources import _format_post_for_eval, build_eval_llm_call
- from examples.process_pipeline.script.llm_helper import call_llm_with_retry
- HERE = Path(__file__).resolve().parent # pipeline/
- MW = HERE.parent # mode_workflow/
- sys.path.insert(0, str(MW))
- import db
- TOOL_SYSTEM = (MW / "prompts" / "tool_extract_system.md").read_text(encoding="utf-8")
- DEFAULT_MODEL_CHOICE = "gemini-flash-lite"
- MAX_IMAGES = 6
- def _row_to_source(row):
- """search_tools 行 → 引擎函数认的 source dict。"""
- return {
- "case_id": row["case_id"], "platform": row["platform"],
- "channel_content_id": row["channel_content_id"],
- "source_url": row["url"],
- "post": {
- "title": row["title"], "body_text": row["body"],
- "images": row["images"] or [], "like_count": row["like_count"],
- "publish_timestamp": row["publish_time"], "link": row["url"],
- },
- }
- def _validate_tools(data):
- if not isinstance(data, dict) or "tools" not in data:
- return '缺少顶层 "tools" 字段'
- if not isinstance(data["tools"], list):
- return '"tools" 必须是数组'
- return None
- async def extract_one(source, llm_call, model):
- """对一条 source 抽工具,返回 (tools_list, cost)。失败返回 ([], cost)。"""
- user_text = "【内容】\n" + _format_post_for_eval(source)
- image_urls = source.get("_image_data_urls") or None
- if image_urls:
- user_content = [{"type": "text", "text": user_text}]
- for u in image_urls:
- user_content.append({"type": "image_url", "image_url": {"url": u}})
- messages = [{"role": "system", "content": TOOL_SYSTEM},
- {"role": "user", "content": user_content}]
- else:
- messages = [{"role": "system", "content": TOOL_SYSTEM},
- {"role": "user", "content": user_text}]
- data, cost = await call_llm_with_retry(
- llm_call=llm_call, messages=messages, model=model,
- temperature=0.1, max_tokens=4000,
- validate_fn=_validate_tools,
- task_name=f"ToolExtract[{source.get('case_id', '?')}]",
- )
- if not data:
- return [], cost
- return data.get("tools", []), cost
- async def run(args):
- qid = args.query_id
- case_ids = [c.strip() for c in args.case_ids.split(",") if c.strip()]
- # 方案A:解构前按 case 全局去重(同 procedure_extract)。已解构的不再调 LLM,
- # 跨 query 的用 link_* 复制补齐关联。--force 强制重解构。
- linked = skipped = 0
- todo = []
- for cid in dict.fromkeys(case_ids):
- if not args.force:
- ex = db.latest_real_version(cid, mode="tools")
- if ex:
- if ex["query_id"] == qid:
- print(f"♻️ {cid} 本 query 已解构(版本 {ex['version']}),跳过")
- skipped += 1
- else:
- n = db.link_process(qid, cid, mode="tools")
- print(f"♻️ {cid} 已在 {ex['query_id']} 解构(版本 {ex['version']}),"
- f"link 补齐 {n} 行 · $0")
- linked += 1
- continue
- todo.append(cid)
- sources = []
- for cid in todo:
- row = db.fetch_post(qid, cid, table="search_tools")
- if row is None:
- print(f"⚠️ {qid}/{cid} 不在 search_tools,跳过")
- continue
- sources.append(_row_to_source(row))
- if not sources:
- if linked or skipped:
- print(f"✅ 无需 LLM 解构(link 补齐 {linked} 帖 · 已存在跳过 {skipped} 帖)")
- return 0
- print("❌ 没有可解构的帖子"); return 1
- if args.model and "/" in args.model:
- from agent.llm.openrouter import create_openrouter_llm_call
- llm_call, model_id = create_openrouter_llm_call(model=args.model), args.model
- else:
- llm_call, model_id = build_eval_llm_call(args.model or DEFAULT_MODEL_CHOICE)
- version = args.version or ("v_" + datetime.now().strftime("%m%d%H%M"))
- print(f"🔧 工具解构 {len(sources)} 帖 · 模型 {model_id} · 版本 {version}")
- await _attach_image_refs(sources, MAX_IMAGES, max(2, args.max_concurrent * 2), "url")
- sem = asyncio.Semaphore(args.max_concurrent)
- out_dir = MW / "runs" / "mode_tools"
- out_dir.mkdir(parents=True, exist_ok=True)
- async def _work(s):
- t0 = time.monotonic()
- async with sem:
- tools, cost = await extract_one(s, llm_call, model_id)
- dur = round(time.monotonic() - t0, 1)
- n = db.replace_tools(qid, s["case_id"], s.get("platform"),
- (s.get("post") or {}).get("title", ""),
- tools, model_id, version, cost, dur)
- (out_dir / f"{s['case_id']}_{version}.json").write_text(json.dumps({
- "case_id": s["case_id"], "version": version, "model": model_id,
- "cost_usd": cost, "duration_s": dur, "tools": tools,
- }, ensure_ascii=False, indent=2), encoding="utf-8")
- print(f" ✅ {s['case_id']} → {n} 个工具 · ${cost:.4f} · {dur}s")
- return cost
- costs = await asyncio.gather(*[_work(s) for s in sources])
- print(f"\n📊 完成 {len(sources)} 帖 · link 补齐 {linked} 帖 · 总成本 ${sum(costs):.4f}")
- return 0
- def main():
- p = argparse.ArgumentParser(description="工具解构:search_tools 帖子 → mode_tools")
- p.add_argument("--query-id", required=True)
- p.add_argument("--case-ids", required=True, help="逗号分隔 case_id 列表")
- p.add_argument("--model", default=None, help="默认 gemini-flash-lite,可传 OpenRouter id")
- p.add_argument("--max-concurrent", type=int, default=3)
- p.add_argument("--version", default=None, help="默认自动 v_月日时分")
- p.add_argument("--force", action="store_true",
- help="跳过去重,强制重解构(换 prompt/模型做对比时用)")
- args = p.parse_args()
- raise SystemExit(asyncio.run(run(args)))
- if __name__ == "__main__":
- main()
|