Przeglądaj źródła

feat(mode_workflow): 工具解构 pipeline(读库→LLM→mode_tools)

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
刘文武 5 dni temu
rodzic
commit
3cdc8ef3ac
1 zmienionych plików z 145 dodań i 0 usunięć
  1. 145 0
      examples/mode_workflow/pipeline/tool_extract.py

+ 145 - 0
examples/mode_workflow/pipeline/tool_extract.py

@@ -0,0 +1,145 @@
+# -*- coding: utf-8 -*-
+"""工具解构 · search_data 帖子 → 结构化工具条目 → mode_tools 表
+================================================================================
+- 帖子源:search_data 表(--query-id + --case-ids 定位)
+- 模型默认 google/gemini-3.1-flash-lite,可 --model 传任意 OpenRouter id
+- 多模态:复用 search_and_evaluate._attach_image_refs(URL 直传)
+- 写库:db.replace_tools(同版本幂等,跨版本保留);runs/tools/ 留调试副本
+
+用法(一般由 server.py 起子进程调):
+  python pipeline/tool_extract.py --query-id q0000 --case-ids xhs_abc,gzh_def
+  python pipeline/tool_extract.py --query-id q0000 --case-ids xhs_abc --model anthropic/claude-sonnet-4-6
+"""
+import argparse
+import asyncio
+import json
+import sys
+import time
+from datetime import datetime
+from pathlib import Path
+
+PROJECT_ROOT = Path(__file__).resolve().parents[3]   # …/Agent
+sys.path.insert(0, str(PROJECT_ROOT))
+
+from dotenv import load_dotenv
+load_dotenv()
+
+from examples.process_pipeline.script.search_eval.search_and_evaluate import _attach_image_refs
+from examples.process_pipeline.script.llm_evaluate_sources import _format_post_for_eval, build_eval_llm_call
+from examples.process_pipeline.script.llm_helper import call_llm_with_retry
+
+HERE = Path(__file__).resolve().parent          # pipeline/
+MW = HERE.parent                                 # mode_workflow/
+sys.path.insert(0, str(MW))
+import db
+
+TOOL_SYSTEM = (MW / "prompts" / "tool_extract_system.md").read_text(encoding="utf-8")
+DEFAULT_MODEL_CHOICE = "gemini-flash-lite"
+MAX_IMAGES = 6
+
+
+def _row_to_source(row):
+    """search_data 行 → 引擎函数认的 source dict。"""
+    return {
+        "case_id": row["case_id"], "platform": row["platform"],
+        "channel_content_id": row["channel_content_id"],
+        "source_url": row["url"],
+        "post": {
+            "title": row["title"], "body_text": row["body"],
+            "images": row["images"] or [], "like_count": row["like_count"],
+            "publish_timestamp": row["publish_time"], "link": row["url"],
+        },
+    }
+
+
+def _validate_tools(data):
+    if not isinstance(data, dict) or "tools" not in data:
+        return '缺少顶层 "tools" 字段'
+    if not isinstance(data["tools"], list):
+        return '"tools" 必须是数组'
+    return None
+
+
+async def extract_one(source, llm_call, model):
+    """对一条 source 抽工具,返回 (tools_list, cost)。失败返回 ([], cost)。"""
+    user_text = "【内容】\n" + _format_post_for_eval(source)
+    image_urls = source.get("_image_data_urls") or None
+    if image_urls:
+        user_content = [{"type": "text", "text": user_text}]
+        for u in image_urls:
+            user_content.append({"type": "image_url", "image_url": {"url": u}})
+        messages = [{"role": "system", "content": TOOL_SYSTEM},
+                    {"role": "user", "content": user_content}]
+    else:
+        messages = [{"role": "system", "content": TOOL_SYSTEM},
+                    {"role": "user", "content": user_text}]
+    data, cost = await call_llm_with_retry(
+        llm_call=llm_call, messages=messages, model=model,
+        temperature=0.1, max_tokens=4000,
+        validate_fn=_validate_tools,
+        task_name=f"ToolExtract[{source.get('case_id', '?')}]",
+    )
+    if not data:
+        return [], cost
+    return data.get("tools", []), cost
+
+
+async def run(args):
+    qid = args.query_id
+    case_ids = [c.strip() for c in args.case_ids.split(",") if c.strip()]
+    sources = []
+    for cid in case_ids:
+        row = db.fetch_post(qid, cid)
+        if row is None:
+            print(f"⚠️ {qid}/{cid} 不在 search_data,跳过")
+            continue
+        sources.append(_row_to_source(row))
+    if not sources:
+        print("❌ 没有可解构的帖子"); return 1
+
+    if args.model and "/" in args.model:
+        from agent.llm.openrouter import create_openrouter_llm_call
+        llm_call, model_id = create_openrouter_llm_call(model=args.model), args.model
+    else:
+        llm_call, model_id = build_eval_llm_call(args.model or DEFAULT_MODEL_CHOICE)
+    version = args.version or ("v_" + datetime.now().strftime("%m%d%H%M"))
+    print(f"🔧 工具解构 {len(sources)} 帖 · 模型 {model_id} · 版本 {version}")
+
+    await _attach_image_refs(sources, MAX_IMAGES, max(2, args.max_concurrent * 2), "url")
+    sem = asyncio.Semaphore(args.max_concurrent)
+    out_dir = MW / "runs" / "tools"
+    out_dir.mkdir(parents=True, exist_ok=True)
+
+    async def _work(s):
+        t0 = time.monotonic()
+        async with sem:
+            tools, cost = await extract_one(s, llm_call, model_id)
+        dur = round(time.monotonic() - t0, 1)
+        n = db.replace_tools(qid, s["case_id"], s.get("platform"),
+                             (s.get("post") or {}).get("title", ""),
+                             tools, model_id, version, cost, dur)
+        (out_dir / f"{s['case_id']}_{version}.json").write_text(json.dumps({
+            "case_id": s["case_id"], "version": version, "model": model_id,
+            "cost_usd": cost, "duration_s": dur, "tools": tools,
+        }, ensure_ascii=False, indent=2), encoding="utf-8")
+        print(f"   ✅ {s['case_id']} → {n} 个工具 · ${cost:.4f} · {dur}s")
+        return cost
+
+    costs = await asyncio.gather(*[_work(s) for s in sources])
+    print(f"\n📊 完成 {len(sources)} 帖 · 总成本 ${sum(costs):.4f}")
+    return 0
+
+
+def main():
+    p = argparse.ArgumentParser(description="工具解构:search_data 帖子 → mode_tools")
+    p.add_argument("--query-id", required=True)
+    p.add_argument("--case-ids", required=True, help="逗号分隔 case_id 列表")
+    p.add_argument("--model", default=None, help="默认 gemini-flash-lite,可传 OpenRouter id")
+    p.add_argument("--max-concurrent", type=int, default=3)
+    p.add_argument("--version", default=None, help="默认自动 v_月日时分")
+    args = p.parse_args()
+    raise SystemExit(asyncio.run(run(args)))
+
+
+if __name__ == "__main__":
+    main()