|
|
@@ -0,0 +1,145 @@
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+"""工具解构 · search_data 帖子 → 结构化工具条目 → mode_tools 表
|
|
|
+================================================================================
|
|
|
+- 帖子源:search_data 表(--query-id + --case-ids 定位)
|
|
|
+- 模型默认 google/gemini-3.1-flash-lite,可 --model 传任意 OpenRouter id
|
|
|
+- 多模态:复用 search_and_evaluate._attach_image_refs(URL 直传)
|
|
|
+- 写库:db.replace_tools(同版本幂等,跨版本保留);runs/tools/ 留调试副本
|
|
|
+
|
|
|
+用法(一般由 server.py 起子进程调):
|
|
|
+ python pipeline/tool_extract.py --query-id q0000 --case-ids xhs_abc,gzh_def
|
|
|
+ python pipeline/tool_extract.py --query-id q0000 --case-ids xhs_abc --model anthropic/claude-sonnet-4-6
|
|
|
+"""
|
|
|
+import argparse
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+import sys
|
|
|
+import time
|
|
|
+from datetime import datetime
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+PROJECT_ROOT = Path(__file__).resolve().parents[3] # …/Agent
|
|
|
+sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
+
|
|
|
+from dotenv import load_dotenv
|
|
|
+load_dotenv()
|
|
|
+
|
|
|
+from examples.process_pipeline.script.search_eval.search_and_evaluate import _attach_image_refs
|
|
|
+from examples.process_pipeline.script.llm_evaluate_sources import _format_post_for_eval, build_eval_llm_call
|
|
|
+from examples.process_pipeline.script.llm_helper import call_llm_with_retry
|
|
|
+
|
|
|
+HERE = Path(__file__).resolve().parent # pipeline/
|
|
|
+MW = HERE.parent # mode_workflow/
|
|
|
+sys.path.insert(0, str(MW))
|
|
|
+import db
|
|
|
+
|
|
|
+TOOL_SYSTEM = (MW / "prompts" / "tool_extract_system.md").read_text(encoding="utf-8")
|
|
|
+DEFAULT_MODEL_CHOICE = "gemini-flash-lite"
|
|
|
+MAX_IMAGES = 6
|
|
|
+
|
|
|
+
|
|
|
+def _row_to_source(row):
|
|
|
+ """search_data 行 → 引擎函数认的 source dict。"""
|
|
|
+ return {
|
|
|
+ "case_id": row["case_id"], "platform": row["platform"],
|
|
|
+ "channel_content_id": row["channel_content_id"],
|
|
|
+ "source_url": row["url"],
|
|
|
+ "post": {
|
|
|
+ "title": row["title"], "body_text": row["body"],
|
|
|
+ "images": row["images"] or [], "like_count": row["like_count"],
|
|
|
+ "publish_timestamp": row["publish_time"], "link": row["url"],
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def _validate_tools(data):
|
|
|
+ if not isinstance(data, dict) or "tools" not in data:
|
|
|
+ return '缺少顶层 "tools" 字段'
|
|
|
+ if not isinstance(data["tools"], list):
|
|
|
+ return '"tools" 必须是数组'
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+async def extract_one(source, llm_call, model):
|
|
|
+ """对一条 source 抽工具,返回 (tools_list, cost)。失败返回 ([], cost)。"""
|
|
|
+ user_text = "【内容】\n" + _format_post_for_eval(source)
|
|
|
+ image_urls = source.get("_image_data_urls") or None
|
|
|
+ if image_urls:
|
|
|
+ user_content = [{"type": "text", "text": user_text}]
|
|
|
+ for u in image_urls:
|
|
|
+ user_content.append({"type": "image_url", "image_url": {"url": u}})
|
|
|
+ messages = [{"role": "system", "content": TOOL_SYSTEM},
|
|
|
+ {"role": "user", "content": user_content}]
|
|
|
+ else:
|
|
|
+ messages = [{"role": "system", "content": TOOL_SYSTEM},
|
|
|
+ {"role": "user", "content": user_text}]
|
|
|
+ data, cost = await call_llm_with_retry(
|
|
|
+ llm_call=llm_call, messages=messages, model=model,
|
|
|
+ temperature=0.1, max_tokens=4000,
|
|
|
+ validate_fn=_validate_tools,
|
|
|
+ task_name=f"ToolExtract[{source.get('case_id', '?')}]",
|
|
|
+ )
|
|
|
+ if not data:
|
|
|
+ return [], cost
|
|
|
+ return data.get("tools", []), cost
|
|
|
+
|
|
|
+
|
|
|
+async def run(args):
|
|
|
+ qid = args.query_id
|
|
|
+ case_ids = [c.strip() for c in args.case_ids.split(",") if c.strip()]
|
|
|
+ sources = []
|
|
|
+ for cid in case_ids:
|
|
|
+ row = db.fetch_post(qid, cid)
|
|
|
+ if row is None:
|
|
|
+ print(f"⚠️ {qid}/{cid} 不在 search_data,跳过")
|
|
|
+ continue
|
|
|
+ sources.append(_row_to_source(row))
|
|
|
+ if not sources:
|
|
|
+ print("❌ 没有可解构的帖子"); return 1
|
|
|
+
|
|
|
+ if args.model and "/" in args.model:
|
|
|
+ from agent.llm.openrouter import create_openrouter_llm_call
|
|
|
+ llm_call, model_id = create_openrouter_llm_call(model=args.model), args.model
|
|
|
+ else:
|
|
|
+ llm_call, model_id = build_eval_llm_call(args.model or DEFAULT_MODEL_CHOICE)
|
|
|
+ version = args.version or ("v_" + datetime.now().strftime("%m%d%H%M"))
|
|
|
+ print(f"🔧 工具解构 {len(sources)} 帖 · 模型 {model_id} · 版本 {version}")
|
|
|
+
|
|
|
+ await _attach_image_refs(sources, MAX_IMAGES, max(2, args.max_concurrent * 2), "url")
|
|
|
+ sem = asyncio.Semaphore(args.max_concurrent)
|
|
|
+ out_dir = MW / "runs" / "tools"
|
|
|
+ out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ async def _work(s):
|
|
|
+ t0 = time.monotonic()
|
|
|
+ async with sem:
|
|
|
+ tools, cost = await extract_one(s, llm_call, model_id)
|
|
|
+ dur = round(time.monotonic() - t0, 1)
|
|
|
+ n = db.replace_tools(qid, s["case_id"], s.get("platform"),
|
|
|
+ (s.get("post") or {}).get("title", ""),
|
|
|
+ tools, model_id, version, cost, dur)
|
|
|
+ (out_dir / f"{s['case_id']}_{version}.json").write_text(json.dumps({
|
|
|
+ "case_id": s["case_id"], "version": version, "model": model_id,
|
|
|
+ "cost_usd": cost, "duration_s": dur, "tools": tools,
|
|
|
+ }, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
|
+ print(f" ✅ {s['case_id']} → {n} 个工具 · ${cost:.4f} · {dur}s")
|
|
|
+ return cost
|
|
|
+
|
|
|
+ costs = await asyncio.gather(*[_work(s) for s in sources])
|
|
|
+ print(f"\n📊 完成 {len(sources)} 帖 · 总成本 ${sum(costs):.4f}")
|
|
|
+ return 0
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ p = argparse.ArgumentParser(description="工具解构:search_data 帖子 → mode_tools")
|
|
|
+ p.add_argument("--query-id", required=True)
|
|
|
+ p.add_argument("--case-ids", required=True, help="逗号分隔 case_id 列表")
|
|
|
+ p.add_argument("--model", default=None, help="默认 gemini-flash-lite,可传 OpenRouter id")
|
|
|
+ p.add_argument("--max-concurrent", type=int, default=3)
|
|
|
+ p.add_argument("--version", default=None, help="默认自动 v_月日时分")
|
|
|
+ args = p.parse_args()
|
|
|
+ raise SystemExit(asyncio.run(run(args)))
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|