tool_extract.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # -*- coding: utf-8 -*-
  2. """工具解构 · search_tools 帖子 → 结构化工具条目 → mode_tools 表
  3. ================================================================================
  4. - 帖子源:search_tools 表(--query-id + --case-ids 定位)
  5. - 模型默认 google/gemini-3.1-flash-lite,可 --model 传任意 OpenRouter id
  6. - 多模态:复用 search_and_evaluate._attach_image_refs(URL 直传)
  7. - 写库:db.replace_tools(同版本幂等,跨版本保留);runs/mode_tools/ 留调试副本
  8. 用法(一般由 server.py 起子进程调):
  9. python pipeline/tool_extract.py --query-id q0000 --case-ids xhs_abc,gzh_def
  10. python pipeline/tool_extract.py --query-id q0000 --case-ids xhs_abc --model anthropic/claude-sonnet-4-6
  11. """
  12. import argparse
  13. import asyncio
  14. import json
  15. import sys
  16. import time
  17. from datetime import datetime
  18. from pathlib import Path
  19. PROJECT_ROOT = Path(__file__).resolve().parents[3] # …/Agent
  20. sys.path.insert(0, str(PROJECT_ROOT))
  21. from dotenv import load_dotenv
  22. load_dotenv()
  23. from examples.process_pipeline.script.search_eval.search_and_evaluate import _attach_image_refs
  24. from examples.process_pipeline.script.llm_evaluate_sources import _format_post_for_eval, build_eval_llm_call
  25. from examples.process_pipeline.script.llm_helper import call_llm_with_retry
  26. HERE = Path(__file__).resolve().parent # pipeline/
  27. MW = HERE.parent # mode_workflow/
  28. sys.path.insert(0, str(MW))
  29. import db
  30. TOOL_SYSTEM = (MW / "prompts" / "tool_extract_system.md").read_text(encoding="utf-8")
  31. DEFAULT_MODEL_CHOICE = "gemini-flash-lite"
  32. MAX_IMAGES = 6
  33. def _row_to_source(row):
  34. """search_tools 行 → 引擎函数认的 source dict。"""
  35. return {
  36. "case_id": row["case_id"], "platform": row["platform"],
  37. "channel_content_id": row["channel_content_id"],
  38. "source_url": row["url"],
  39. "post": {
  40. "title": row["title"], "body_text": row["body"],
  41. "images": row["images"] or [], "like_count": row["like_count"],
  42. "publish_timestamp": row["publish_time"], "link": row["url"],
  43. },
  44. }
  45. def _validate_tools(data):
  46. if not isinstance(data, dict) or "tools" not in data:
  47. return '缺少顶层 "tools" 字段'
  48. if not isinstance(data["tools"], list):
  49. return '"tools" 必须是数组'
  50. return None
  51. async def extract_one(source, llm_call, model):
  52. """对一条 source 抽工具,返回 (tools_list, cost)。失败返回 ([], cost)。"""
  53. user_text = "【内容】\n" + _format_post_for_eval(source)
  54. image_urls = source.get("_image_data_urls") or None
  55. if image_urls:
  56. user_content = [{"type": "text", "text": user_text}]
  57. for u in image_urls:
  58. user_content.append({"type": "image_url", "image_url": {"url": u}})
  59. messages = [{"role": "system", "content": TOOL_SYSTEM},
  60. {"role": "user", "content": user_content}]
  61. else:
  62. messages = [{"role": "system", "content": TOOL_SYSTEM},
  63. {"role": "user", "content": user_text}]
  64. data, cost = await call_llm_with_retry(
  65. llm_call=llm_call, messages=messages, model=model,
  66. temperature=0.1, max_tokens=4000,
  67. validate_fn=_validate_tools,
  68. task_name=f"ToolExtract[{source.get('case_id', '?')}]",
  69. )
  70. if not data:
  71. return [], cost
  72. return data.get("tools", []), cost
  73. async def run(args):
  74. qid = args.query_id
  75. case_ids = [c.strip() for c in args.case_ids.split(",") if c.strip()]
  76. # 方案A:解构前按 case 全局去重(同 procedure_extract)。已解构的不再调 LLM,
  77. # 跨 query 的用 link_* 复制补齐关联。--force 强制重解构。
  78. linked = skipped = 0
  79. todo = []
  80. for cid in dict.fromkeys(case_ids):
  81. if not args.force:
  82. ex = db.latest_real_version(cid, mode="tools")
  83. if ex:
  84. if ex["query_id"] == qid:
  85. print(f"♻️ {cid} 本 query 已解构(版本 {ex['version']}),跳过")
  86. skipped += 1
  87. else:
  88. n = db.link_process(qid, cid, mode="tools")
  89. print(f"♻️ {cid} 已在 {ex['query_id']} 解构(版本 {ex['version']}),"
  90. f"link 补齐 {n} 行 · $0")
  91. linked += 1
  92. continue
  93. todo.append(cid)
  94. sources = []
  95. for cid in todo:
  96. row = db.fetch_post(qid, cid, table="search_tools")
  97. if row is None:
  98. print(f"⚠️ {qid}/{cid} 不在 search_tools,跳过")
  99. continue
  100. sources.append(_row_to_source(row))
  101. if not sources:
  102. if linked or skipped:
  103. print(f"✅ 无需 LLM 解构(link 补齐 {linked} 帖 · 已存在跳过 {skipped} 帖)")
  104. return 0
  105. print("❌ 没有可解构的帖子"); return 1
  106. if args.model and "/" in args.model:
  107. from agent.llm.openrouter import create_openrouter_llm_call
  108. llm_call, model_id = create_openrouter_llm_call(model=args.model), args.model
  109. else:
  110. llm_call, model_id = build_eval_llm_call(args.model or DEFAULT_MODEL_CHOICE)
  111. version = args.version or ("v_" + datetime.now().strftime("%m%d%H%M"))
  112. print(f"🔧 工具解构 {len(sources)} 帖 · 模型 {model_id} · 版本 {version}")
  113. await _attach_image_refs(sources, MAX_IMAGES, max(2, args.max_concurrent * 2), "url")
  114. sem = asyncio.Semaphore(args.max_concurrent)
  115. out_dir = MW / "runs" / "mode_tools"
  116. out_dir.mkdir(parents=True, exist_ok=True)
  117. async def _work(s):
  118. t0 = time.monotonic()
  119. async with sem:
  120. tools, cost = await extract_one(s, llm_call, model_id)
  121. dur = round(time.monotonic() - t0, 1)
  122. n = db.replace_tools(qid, s["case_id"], s.get("platform"),
  123. (s.get("post") or {}).get("title", ""),
  124. tools, model_id, version, cost, dur)
  125. (out_dir / f"{s['case_id']}_{version}.json").write_text(json.dumps({
  126. "case_id": s["case_id"], "version": version, "model": model_id,
  127. "cost_usd": cost, "duration_s": dur, "tools": tools,
  128. }, ensure_ascii=False, indent=2), encoding="utf-8")
  129. print(f" ✅ {s['case_id']} → {n} 个工具 · ${cost:.4f} · {dur}s")
  130. return cost
  131. costs = await asyncio.gather(*[_work(s) for s in sources])
  132. print(f"\n📊 完成 {len(sources)} 帖 · link 补齐 {linked} 帖 · 总成本 ${sum(costs):.4f}")
  133. return 0
  134. def main():
  135. p = argparse.ArgumentParser(description="工具解构:search_tools 帖子 → mode_tools")
  136. p.add_argument("--query-id", required=True)
  137. p.add_argument("--case-ids", required=True, help="逗号分隔 case_id 列表")
  138. p.add_argument("--model", default=None, help="默认 gemini-flash-lite,可传 OpenRouter id")
  139. p.add_argument("--max-concurrent", type=int, default=3)
  140. p.add_argument("--version", default=None, help="默认自动 v_月日时分")
  141. p.add_argument("--force", action="store_true",
  142. help="跳过去重,强制重解构(换 prompt/模型做对比时用)")
  143. p.add_argument("--prompt-file", default=None,
  144. help="覆盖默认解构 prompt(临时,仅本次;不改 prompts/*.md)")
  145. args = p.parse_args()
  146. if args.prompt_file:
  147. global TOOL_SYSTEM
  148. TOOL_SYSTEM = Path(args.prompt_file).read_text(encoding="utf-8")
  149. raise SystemExit(asyncio.run(run(args)))
  150. if __name__ == "__main__":
  151. main()