|
@@ -0,0 +1,184 @@
|
|
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
|
|
+"""Query 正交格评分 · 对 judged_matrix 的 tier≥1 格(动作×类型)在当前维度上下文下用 Sonnet 打分,
|
|
|
|
|
+挑出有意义、人话、有助于内容制作知识库目的的组合。结果原子写 .cache/query_score/<sel>.json。
|
|
|
|
|
+
|
|
|
|
|
+由 server.py /api/query_score 起子进程调;也可独立跑:
|
|
|
|
|
+ python stages/query_score.py --tool-type AI --modality 图片 --suffix 怎么做 \
|
|
|
|
|
+ --substance-path 表象,实体 --form-path 呈现,视觉 --sel adhoc --dry-run
|
|
|
|
|
+"""
|
|
|
|
|
+import argparse
|
|
|
|
|
+import asyncio
|
|
|
|
|
+import json
|
|
|
|
|
+import re
|
|
|
|
|
+import sys
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+
|
|
|
|
|
+PROJECT_ROOT = Path(__file__).resolve().parents[3] # …/Agent
|
|
|
|
|
+sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
|
+
|
|
|
|
|
+from dotenv import load_dotenv
|
|
|
|
|
+load_dotenv()
|
|
|
|
|
+
|
|
|
|
|
+HERE = Path(__file__).resolve().parent
|
|
|
|
|
+MW = HERE.parent
|
|
|
|
|
+MATRIX_FILE = MW / "reference" / "judged_matrix.json"
|
|
|
|
|
+PROMPT_FILE = MW / "prompts" / "query_score_system.md"
|
|
|
|
|
+CACHE_DIR = MW / ".cache" / "query_score"
|
|
|
|
|
+DEFAULT_MODEL = "anthropic/claude-sonnet-4-6"
|
|
|
|
|
+BATCH = 40
|
|
|
|
|
+CONCURRENCY = 5
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _build_cells(matrix, tool_type, modality, suffix):
|
|
|
|
|
+ """筛 tier≥1 格,产出 [{a_idx,t_idx,action,type,tier,query}]。
|
|
|
|
|
+ query = [工具类型] 动作叶 类型叶 [模态] [后缀],"无"/空跳过。"""
|
|
|
|
|
+ actions, types, grid = matrix["actions"], matrix["types"], matrix["matrix"]
|
|
|
|
|
+ pre, mod, suf = (tool_type or "").strip(), (modality or "").strip(), (suffix or "").strip()
|
|
|
|
|
+ cells = []
|
|
|
|
|
+ for ai, arow in enumerate(grid):
|
|
|
|
|
+ action = actions[ai]["name"]
|
|
|
|
|
+ for ti, cell in enumerate(arow):
|
|
|
|
|
+ if cell.get("tier", 0) < 1:
|
|
|
|
|
+ continue
|
|
|
|
|
+ typ = types[ti]["name"]
|
|
|
|
|
+ parts = [p for p in (pre, action, typ, mod, suf) if p and p != "无"]
|
|
|
|
|
+ cells.append({"a_idx": ai, "t_idx": ti, "action": action,
|
|
|
|
|
+ "type": typ, "tier": cell["tier"], "query": " ".join(parts)})
|
|
|
|
|
+ return cells
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _build_user(batch, ctx):
|
|
|
|
|
+ lines = [f'{i}. "{c["query"]}" (动作={c["action"]} 类型={c["type"]} 内容树tier={c["tier"]})'
|
|
|
|
|
+ for i, c in enumerate(batch)]
|
|
|
|
|
+ sub = (ctx["substance"] or "无").replace(",", "›")
|
|
|
|
|
+ form = (ctx["form"] or "无").replace(",", "›")
|
|
|
|
|
+ return (f"【固定上下文(本批共享)】\n"
|
|
|
|
|
+ f"工具类型: {ctx['tool_type'] or '无'} 模态: {ctx['modality'] or '无'} 后缀: {ctx['suffix'] or '无'}\n"
|
|
|
|
|
+ f"(实质/形式不参与拼词,仅供理解领域定位: 实质路径={sub} 形式路径={form})\n\n"
|
|
|
|
|
+ f"【候选列表】每条 = 动作 + 类型 + 上下文词拼成的 query:\n" + "\n".join(lines))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def _call_with_retry(llm_call, messages, model, task_name, max_retries=3):
|
|
|
|
|
+ """直接调 llm_call 并解析 JSON 数组(call_llm_with_retry 的正则只捕获 {…},不适用数组)。"""
|
|
|
|
|
+ total_cost = 0.0
|
|
|
|
|
+ last_err = None
|
|
|
|
|
+ cur_messages = list(messages)
|
|
|
|
|
+ for attempt in range(max_retries):
|
|
|
|
|
+ if attempt > 0 and last_err:
|
|
|
|
|
+ cur_messages = list(messages) + [
|
|
|
|
|
+ {"role": "user",
|
|
|
|
|
+ "content": f"上次输出未通过校验:{last_err}\n请重新输出完整 JSON 数组,不含其他内容。"}]
|
|
|
|
|
+ print(f" [{task_name}] Retry {attempt}/{max_retries - 1}: {last_err[:80]}...")
|
|
|
|
|
+ try:
|
|
|
|
|
+ resp = await llm_call(messages=cur_messages, model=model,
|
|
|
|
|
+ temperature=0.1, max_tokens=4000)
|
|
|
|
|
+ cost = resp.get("cost") or 0.0
|
|
|
|
|
+ total_cost += cost
|
|
|
|
|
+ content = resp.get("content", "")
|
|
|
|
|
+ if isinstance(content, list):
|
|
|
|
|
+ first = content[0] if content else ""
|
|
|
|
|
+ content = first.get("text", "") if isinstance(first, dict) else str(first)
|
|
|
|
|
+ # 提取 JSON 数组(支持裸数组和 markdown 围栏包裹)
|
|
|
|
|
+ arr_match = re.search(r"\[[\s\S]*\]", content)
|
|
|
|
|
+ if not arr_match:
|
|
|
|
|
+ last_err = "LLM 输出中未找到 JSON 数组"
|
|
|
|
|
+ continue
|
|
|
|
|
+ try:
|
|
|
|
|
+ data = json.loads(arr_match.group())
|
|
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
|
|
+ last_err = f"JSON 解析失败: {e}"
|
|
|
|
|
+ continue
|
|
|
|
|
+ if not isinstance(data, list):
|
|
|
|
|
+ last_err = "需 JSON 数组"
|
|
|
|
|
+ continue
|
|
|
|
|
+ return data, total_cost
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ last_err = f"LLM 调用异常: {type(e).__name__}: {e}"
|
|
|
|
|
+ print(f" [{task_name}] Error: {last_err}")
|
|
|
|
|
+ print(f" [{task_name}] All {max_retries} attempts failed. Last error: {last_err}")
|
|
|
|
|
+ return None, total_cost
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def _score_batch(batch, ctx, system, llm_call, model, sem):
|
|
|
|
|
+ messages = [{"role": "system", "content": system},
|
|
|
|
|
+ {"role": "user", "content": _build_user(batch, ctx)}]
|
|
|
|
|
+ task_name = f"QueryScore[{batch[0]['query'][:12]}]"
|
|
|
|
|
+ async with sem:
|
|
|
|
|
+ data, cost = await _call_with_retry(llm_call, messages, model, task_name)
|
|
|
|
|
+ out = {}
|
|
|
|
|
+ for v in (data or []):
|
|
|
|
|
+ if not isinstance(v, dict):
|
|
|
|
|
+ continue
|
|
|
|
|
+ i = v.get("idx")
|
|
|
|
|
+ if not isinstance(i, int) or not (0 <= i < len(batch)):
|
|
|
|
|
+ continue
|
|
|
|
|
+ c = batch[i]
|
|
|
|
|
+ try:
|
|
|
|
|
+ score = round(float(v.get("natural")) * 0.4 + float(v.get("findable")) * 0.3
|
|
|
|
|
+ + float(v.get("useful")) * 0.3, 1)
|
|
|
|
|
+ except (TypeError, ValueError):
|
|
|
|
|
+ score = None
|
|
|
|
|
+ out[f"{c['a_idx']}_{c['t_idx']}"] = {
|
|
|
|
|
+ "query": c["query"], "natural": v.get("natural"), "findable": v.get("findable"),
|
|
|
|
|
+ "useful": v.get("useful"), "keep": bool(v.get("keep")),
|
|
|
|
|
+ "rewrite": (v.get("rewrite") or c["query"]), "reason": v.get("reason", ""),
|
|
|
|
|
+ "score": score}
|
|
|
|
|
+ return out, cost
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def run(args):
|
|
|
|
|
+ matrix = json.loads(MATRIX_FILE.read_text(encoding="utf-8"))
|
|
|
|
|
+ ctx = {"tool_type": args.tool_type, "modality": args.modality, "suffix": args.suffix,
|
|
|
|
|
+ "substance": args.substance_path, "form": args.form_path}
|
|
|
|
|
+ cells = _build_cells(matrix, args.tool_type, args.modality, args.suffix)
|
|
|
|
|
+ if args.limit:
|
|
|
|
|
+ cells = cells[:args.limit]
|
|
|
|
|
+ print(f"📋 tier≥1 候选 {len(cells)} 格" + (f" (--limit {args.limit})" if args.limit else ""))
|
|
|
|
|
+ if args.dry_run:
|
|
|
|
|
+ for c in cells[:10]:
|
|
|
|
|
+ print(f" [{c['tier']}] {c['query']}")
|
|
|
|
|
+ print(f"…共 {len(cells)} 格(dry-run,未调 LLM)")
|
|
|
|
|
+ return 0
|
|
|
|
|
+
|
|
|
|
|
+ system = PROMPT_FILE.read_text(encoding="utf-8")
|
|
|
|
|
+ from agent.llm.openrouter import create_openrouter_llm_call
|
|
|
|
|
+ llm_call = create_openrouter_llm_call(model=args.model)
|
|
|
|
|
+ sem = asyncio.Semaphore(CONCURRENCY)
|
|
|
|
|
+ batches = [cells[i:i + BATCH] for i in range(0, len(cells), BATCH)]
|
|
|
|
|
+ print(f"🤖 {len(batches)} 批 × ≤{BATCH} 格 · 并发 {CONCURRENCY} · 模型 {args.model}")
|
|
|
|
|
+ results = await asyncio.gather(*[
|
|
|
|
|
+ _score_batch(b, ctx, system, llm_call, args.model, sem) for b in batches])
|
|
|
|
|
+ merged, cost = {}, 0.0
|
|
|
|
|
+ for cmap, c in results:
|
|
|
|
|
+ merged.update(cmap)
|
|
|
|
|
+ cost += c
|
|
|
|
|
+ kept = sum(1 for v in merged.values() if v.get("keep"))
|
|
|
|
|
+ out = {"sel": ctx, "model": args.model, "kept": kept, "total": len(merged),
|
|
|
|
|
+ "cost_usd": round(cost, 4), "cells": merged}
|
|
|
|
|
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ dest = CACHE_DIR / f"{args.sel}.json"
|
|
|
|
|
+ tmp = dest.with_suffix(".tmp")
|
|
|
|
|
+ tmp.write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
|
|
|
+ tmp.replace(dest) # 原子落盘,避免前端读到半截
|
|
|
|
|
+ print(f"✅ 评分完成 {len(merged)} 格 · keep {kept} · ${cost:.4f} → {dest.name}")
|
|
|
|
|
+ return 0
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def main():
|
|
|
|
|
+ p = argparse.ArgumentParser(description="Query 正交格评分(tier≥1 × 当前维度 → Sonnet 打分)")
|
|
|
|
|
+ p.add_argument("--tool-type", default="")
|
|
|
|
|
+ p.add_argument("--modality", default="")
|
|
|
|
|
+ p.add_argument("--suffix", default="")
|
|
|
|
|
+ p.add_argument("--substance-path", default="", help="实质祖先路径,逗号分隔(仅作上下文)")
|
|
|
|
|
+ p.add_argument("--form-path", default="", help="形式祖先路径,逗号分隔(仅作上下文)")
|
|
|
|
|
+ p.add_argument("--model", default=DEFAULT_MODEL)
|
|
|
|
|
+ p.add_argument("--sel", default="adhoc", help="缓存文件名(server 传 sel_hash)")
|
|
|
|
|
+ p.add_argument("--limit", type=int, default=None, help="只评前 N 格(调试)")
|
|
|
|
|
+ p.add_argument("--dry-run", action="store_true", help="只拼词打印,不调 LLM、不落盘")
|
|
|
|
|
+ p.add_argument("--force", action="store_true", help="(占位)缓存短路在 server 侧,本脚本恒重算")
|
|
|
|
|
+ args = p.parse_args()
|
|
|
|
|
+ raise SystemExit(asyncio.run(run(args)))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ main()
|