Kaynağa Gözat

feat(mode_workflow): 新增 query_score 评分脚本(tier≥1×维度→Sonnet,原子落盘缓存)

刘文武 1 hafta önce
ebeveyn
işleme
0bb4664a90
1 değiştirilmiş dosya ile 184 ekleme ve 0 silme
  1. 184 0
      examples/mode_workflow/stages/query_score.py

+ 184 - 0
examples/mode_workflow/stages/query_score.py

@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+"""Query 正交格评分 · 对 judged_matrix 的 tier≥1 格(动作×类型)在当前维度上下文下用 Sonnet 打分,
+挑出有意义、人话、有助于内容制作知识库目的的组合。结果原子写 .cache/query_score/<sel>.json。
+
+由 server.py /api/query_score 起子进程调;也可独立跑:
+  python stages/query_score.py --tool-type AI --modality 图片 --suffix 怎么做 \
+      --substance-path 表象,实体 --form-path 呈现,视觉 --sel adhoc --dry-run
+"""
+import argparse
+import asyncio
+import json
+import re
+import sys
+from pathlib import Path
+
+PROJECT_ROOT = Path(__file__).resolve().parents[3]   # …/Agent
+sys.path.insert(0, str(PROJECT_ROOT))
+
+from dotenv import load_dotenv
+load_dotenv()
+
+HERE = Path(__file__).resolve().parent
+MW = HERE.parent
+MATRIX_FILE = MW / "reference" / "judged_matrix.json"
+PROMPT_FILE = MW / "prompts" / "query_score_system.md"
+CACHE_DIR = MW / ".cache" / "query_score"
+DEFAULT_MODEL = "anthropic/claude-sonnet-4-6"
+BATCH = 40
+CONCURRENCY = 5
+
+
+def _build_cells(matrix, tool_type, modality, suffix):
+    """筛 tier≥1 格,产出 [{a_idx,t_idx,action,type,tier,query}]。
+    query = [工具类型] 动作叶 类型叶 [模态] [后缀],"无"/空跳过。"""
+    actions, types, grid = matrix["actions"], matrix["types"], matrix["matrix"]
+    pre, mod, suf = (tool_type or "").strip(), (modality or "").strip(), (suffix or "").strip()
+    cells = []
+    for ai, arow in enumerate(grid):
+        action = actions[ai]["name"]
+        for ti, cell in enumerate(arow):
+            if cell.get("tier", 0) < 1:
+                continue
+            typ = types[ti]["name"]
+            parts = [p for p in (pre, action, typ, mod, suf) if p and p != "无"]
+            cells.append({"a_idx": ai, "t_idx": ti, "action": action,
+                          "type": typ, "tier": cell["tier"], "query": " ".join(parts)})
+    return cells
+
+
+def _build_user(batch, ctx):
+    lines = [f'{i}. "{c["query"]}"   (动作={c["action"]} 类型={c["type"]} 内容树tier={c["tier"]})'
+             for i, c in enumerate(batch)]
+    sub = (ctx["substance"] or "无").replace(",", "›")
+    form = (ctx["form"] or "无").replace(",", "›")
+    return (f"【固定上下文(本批共享)】\n"
+            f"工具类型: {ctx['tool_type'] or '无'}   模态: {ctx['modality'] or '无'}   后缀: {ctx['suffix'] or '无'}\n"
+            f"(实质/形式不参与拼词,仅供理解领域定位: 实质路径={sub}  形式路径={form})\n\n"
+            f"【候选列表】每条 = 动作 + 类型 + 上下文词拼成的 query:\n" + "\n".join(lines))
+
+
+async def _call_with_retry(llm_call, messages, model, task_name, max_retries=3):
+    """直接调 llm_call 并解析 JSON 数组(call_llm_with_retry 的正则只捕获 {…},不适用数组)。"""
+    total_cost = 0.0
+    last_err = None
+    cur_messages = list(messages)
+    for attempt in range(max_retries):
+        if attempt > 0 and last_err:
+            cur_messages = list(messages) + [
+                {"role": "user",
+                 "content": f"上次输出未通过校验:{last_err}\n请重新输出完整 JSON 数组,不含其他内容。"}]
+            print(f"   [{task_name}] Retry {attempt}/{max_retries - 1}: {last_err[:80]}...")
+        try:
+            resp = await llm_call(messages=cur_messages, model=model,
+                                  temperature=0.1, max_tokens=4000)
+            cost = resp.get("cost") or 0.0
+            total_cost += cost
+            content = resp.get("content", "")
+            if isinstance(content, list):
+                first = content[0] if content else ""
+                content = first.get("text", "") if isinstance(first, dict) else str(first)
+            # 提取 JSON 数组(支持裸数组和 markdown 围栏包裹)
+            arr_match = re.search(r"\[[\s\S]*\]", content)
+            if not arr_match:
+                last_err = "LLM 输出中未找到 JSON 数组"
+                continue
+            try:
+                data = json.loads(arr_match.group())
+            except json.JSONDecodeError as e:
+                last_err = f"JSON 解析失败: {e}"
+                continue
+            if not isinstance(data, list):
+                last_err = "需 JSON 数组"
+                continue
+            return data, total_cost
+        except Exception as e:
+            last_err = f"LLM 调用异常: {type(e).__name__}: {e}"
+            print(f"   [{task_name}] Error: {last_err}")
+    print(f"   [{task_name}] All {max_retries} attempts failed. Last error: {last_err}")
+    return None, total_cost
+
+
+async def _score_batch(batch, ctx, system, llm_call, model, sem):
+    messages = [{"role": "system", "content": system},
+                {"role": "user", "content": _build_user(batch, ctx)}]
+    task_name = f"QueryScore[{batch[0]['query'][:12]}]"
+    async with sem:
+        data, cost = await _call_with_retry(llm_call, messages, model, task_name)
+    out = {}
+    for v in (data or []):
+        if not isinstance(v, dict):
+            continue
+        i = v.get("idx")
+        if not isinstance(i, int) or not (0 <= i < len(batch)):
+            continue
+        c = batch[i]
+        try:
+            score = round(float(v.get("natural")) * 0.4 + float(v.get("findable")) * 0.3
+                          + float(v.get("useful")) * 0.3, 1)
+        except (TypeError, ValueError):
+            score = None
+        out[f"{c['a_idx']}_{c['t_idx']}"] = {
+            "query": c["query"], "natural": v.get("natural"), "findable": v.get("findable"),
+            "useful": v.get("useful"), "keep": bool(v.get("keep")),
+            "rewrite": (v.get("rewrite") or c["query"]), "reason": v.get("reason", ""),
+            "score": score}
+    return out, cost
+
+
+async def run(args):
+    matrix = json.loads(MATRIX_FILE.read_text(encoding="utf-8"))
+    ctx = {"tool_type": args.tool_type, "modality": args.modality, "suffix": args.suffix,
+           "substance": args.substance_path, "form": args.form_path}
+    cells = _build_cells(matrix, args.tool_type, args.modality, args.suffix)
+    if args.limit:
+        cells = cells[:args.limit]
+    print(f"📋 tier≥1 候选 {len(cells)} 格" + (f" (--limit {args.limit})" if args.limit else ""))
+    if args.dry_run:
+        for c in cells[:10]:
+            print(f"  [{c['tier']}] {c['query']}")
+        print(f"…共 {len(cells)} 格(dry-run,未调 LLM)")
+        return 0
+
+    system = PROMPT_FILE.read_text(encoding="utf-8")
+    from agent.llm.openrouter import create_openrouter_llm_call
+    llm_call = create_openrouter_llm_call(model=args.model)
+    sem = asyncio.Semaphore(CONCURRENCY)
+    batches = [cells[i:i + BATCH] for i in range(0, len(cells), BATCH)]
+    print(f"🤖 {len(batches)} 批 × ≤{BATCH} 格 · 并发 {CONCURRENCY} · 模型 {args.model}")
+    results = await asyncio.gather(*[
+        _score_batch(b, ctx, system, llm_call, args.model, sem) for b in batches])
+    merged, cost = {}, 0.0
+    for cmap, c in results:
+        merged.update(cmap)
+        cost += c
+    kept = sum(1 for v in merged.values() if v.get("keep"))
+    out = {"sel": ctx, "model": args.model, "kept": kept, "total": len(merged),
+           "cost_usd": round(cost, 4), "cells": merged}
+    CACHE_DIR.mkdir(parents=True, exist_ok=True)
+    dest = CACHE_DIR / f"{args.sel}.json"
+    tmp = dest.with_suffix(".tmp")
+    tmp.write_text(json.dumps(out, ensure_ascii=False, indent=2), encoding="utf-8")
+    tmp.replace(dest)   # 原子落盘,避免前端读到半截
+    print(f"✅ 评分完成 {len(merged)} 格 · keep {kept} · ${cost:.4f} → {dest.name}")
+    return 0
+
+
+def main():
+    p = argparse.ArgumentParser(description="Query 正交格评分(tier≥1 × 当前维度 → Sonnet 打分)")
+    p.add_argument("--tool-type", default="")
+    p.add_argument("--modality", default="")
+    p.add_argument("--suffix", default="")
+    p.add_argument("--substance-path", default="", help="实质祖先路径,逗号分隔(仅作上下文)")
+    p.add_argument("--form-path", default="", help="形式祖先路径,逗号分隔(仅作上下文)")
+    p.add_argument("--model", default=DEFAULT_MODEL)
+    p.add_argument("--sel", default="adhoc", help="缓存文件名(server 传 sel_hash)")
+    p.add_argument("--limit", type=int, default=None, help="只评前 N 格(调试)")
+    p.add_argument("--dry-run", action="store_true", help="只拼词打印,不调 LLM、不落盘")
+    p.add_argument("--force", action="store_true", help="(占位)缓存短路在 server 侧,本脚本恒重算")
+    args = p.parse_args()
+    raise SystemExit(asyncio.run(run(args)))
+
+
+if __name__ == "__main__":
+    main()