Преглед на файлове

feat(mode_workflow): server(API+任务管理+Dashboard聚合)

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
刘文武 преди 4 дни
родител
ревизия
07a8e08946
променени са 1 файла, в които са добавени 301 реда и са изтрити 0 реда
  1. 301 0
      examples/mode_workflow/server.py

+ 301 - 0
examples/mode_workflow/server.py

@@ -0,0 +1,301 @@
+# -*- coding: utf-8 -*-
+"""mode_workflow server · 页面 + API + 解构任务管理
+================================================================================
+单服务(默认 8772):
+  - GET  /                    index.html
+  - GET  /api/dashboard       Dashboard 全部聚合指标(含内容树覆盖)
+  - GET  /api/queries|posts|process|tools(+_versions)   Dataset 数据
+  - POST /api/run_search|extract_process|extract_tools  起子进程跑 pipeline
+  - GET  /api/task_status     轮询任务状态(读日志尾部)
+
+用法:python server.py [port]
+"""
+import json
+import subprocess
+import sys
+import threading
+from collections import Counter
+from datetime import datetime
+from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
+from pathlib import Path
+from urllib.parse import urlparse, parse_qs
+
+try:
+    sys.stdout.reconfigure(encoding="utf-8")
+except Exception:
+    pass
+
+HERE = Path(__file__).resolve().parent
+sys.path.insert(0, str(HERE))
+import db
+
+PORT = int(sys.argv[1]) if len(sys.argv) > 1 else 8772
+MATRIX_FILE = HERE / "reference" / "judged_matrix.json"
+LOG_DIR = HERE / "runs" / "logs"
+
+# ── 任务管理:task_id → {proc, log, status} ──────────────────────────────────
+TASKS = {}
+_TASK_LOCK = threading.Lock()
+
+
+def _spawn_task(kind, cmd):
+    LOG_DIR.mkdir(parents=True, exist_ok=True)
+    task_id = f"{kind}_{datetime.now().strftime('%m%d%H%M%S%f')}"
+    log_path = LOG_DIR / f"{task_id}.log"
+    f = open(log_path, "w", encoding="utf-8")
+    proc = subprocess.Popen(cmd, stdout=f, stderr=subprocess.STDOUT,
+                            cwd=str(HERE), text=True)
+    with _TASK_LOCK:
+        TASKS[task_id] = {"proc": proc, "log": log_path, "status": "running"}
+
+    def _wait():
+        rc = proc.wait()
+        f.close()
+        with _TASK_LOCK:
+            TASKS[task_id]["status"] = "done" if rc == 0 else "failed"
+
+    threading.Thread(target=_wait, daemon=True).start()
+    return task_id
+
+
+def _task_status(task_id):
+    with _TASK_LOCK:
+        t = TASKS.get(task_id)
+    if not t:
+        return None
+    tail = ""
+    try:
+        text = t["log"].read_text(encoding="utf-8", errors="replace")
+        tail = text[-3000:]
+    except Exception:
+        pass
+    return {"status": t["status"], "log_tail": tail}
+
+
+def _next_query_id():
+    qs = [q["query_id"] for q in db.fetch_queries()]
+    nums = [int(q[1:]) for q in qs if q.startswith("q") and q[1:].isdigit()]
+    return f"q{(max(nums) + 1 if nums else 0):04d}"
+
+
+# ── Dashboard 聚合 ────────────────────────────────────────────────────────────
+
+def _split_values(v):
+    """substance/form 字段:数组直接用;字符串按 、,/ 分割;None 丢弃。"""
+    out = []
+    items = v if isinstance(v, list) else [v]
+    for it in items:
+        if not it or not isinstance(it, str):
+            continue
+        for piece in it.replace(",", "、").replace("/", "、").split("、"):
+            piece = piece.strip()
+            if piece:
+                out.append(piece)
+    return out
+
+
+def _dashboard():
+    posts, procs, tools = db.fetch_dashboard_rows()
+
+    # 最新版本行集(覆盖度/Top10 用最新版,成本/耗时按全部版本累计)
+    def latest(rows):
+        best = {}
+        for r in rows:
+            cid = r["case_id"]
+            if cid not in best or (r["version"] or "") > (best[cid] or ""):
+                best[cid] = r["version"]
+        return [r for r in rows if r["version"] == best[r["case_id"]]]
+
+    latest_procs = latest(procs)
+    latest_tools = latest(tools)
+
+    # 内容树覆盖:steps 的 (action 叶子 × 输入/输出 type) ∩ 有效节点(tier≥1)
+    jm = json.loads(MATRIX_FILE.read_text(encoding="utf-8"))
+    a_idx = {a["name"]: i for i, a in enumerate(jm["actions"])}
+    t_idx = {t["name"]: i for i, t in enumerate(jm["types"])}
+    valid = set()
+    for ai, row in enumerate(jm["matrix"]):
+        for ti, cell in enumerate(row):
+            if isinstance(cell, dict) and cell.get("tier", 0) >= 1:
+                valid.add((ai, ti))
+    covered = set()
+    via_counter = Counter()
+    substance_counter = Counter()
+    form_counter = Counter()
+    for r in latest_procs:
+        for s in r["steps"]:
+            if not isinstance(s, dict):
+                continue
+            leaf = (s.get("action") or "").split("/")[-1].strip()
+            types = []
+            for io in ("inputs", "outputs"):
+                for x in s.get(io) or []:
+                    if isinstance(x, dict) and x.get("type"):
+                        types.append(str(x["type"]).strip())
+            if leaf in a_idx:
+                for tp in types:
+                    if tp in t_idx and (a_idx[leaf], t_idx[tp]) in valid:
+                        covered.add((a_idx[leaf], t_idx[tp]))
+            via = (s.get("via") or "").strip()
+            if via:
+                via_counter[via] += 1
+            for v in _split_values(s.get("substance")):
+                substance_counter[v] += 1
+            for v in _split_values(s.get("form")):
+                form_counter[v] += 1
+    for r in latest_tools:
+        for v in _split_values(r["substance_scope"]):
+            substance_counter[v] += 1
+        for v in _split_values(r["form_scope"]):
+            form_counter[v] += 1
+
+    # 成本/耗时:同一 (case_id, version) 只计一次(各行重复存同一次调用的值)
+    def cost_groups(rows):
+        g = {}
+        for r in rows:
+            key = (r["case_id"], r["version"])
+            if key not in g and r["cost_usd"] is not None:
+                g[key] = (r["cost_usd"], r["duration_s"] or 0.0, r["created_at"])
+        return list(g.values())
+
+    runs = cost_groups(procs) + cost_groups(tools)
+    total_cost = round(sum(c for c, _, _ in runs), 4)
+    total_dur = round(sum(d for _, d, _ in runs), 1)
+    # 按日成本趋势
+    daily = Counter()
+    for c, _, ts in runs:
+        if ts:
+            daily[ts[:10]] += c
+    cost_trend = [{"date": d, "cost": round(v, 4)} for d, v in sorted(daily.items())]
+
+    # 进度:分母 = knowledge_type 含对应类型的帖子(distinct case)
+    proc_targets = {p["case_id"] for p in posts if "工序" in (p["knowledge_type"] or [])}
+    tool_targets = {p["case_id"] for p in posts if "工具" in (p["knowledge_type"] or [])}
+    proc_done = {r["case_id"] for r in procs}
+    tool_done = {r["case_id"] for r in tools}
+
+    return {
+        "result": {
+            "matrix_covered": len(covered), "matrix_valid": len(valid),
+            "matrix_cells": sorted([ai, ti] for ai, ti in covered),
+            "matrix_actions": [a["name"] for a in jm["actions"]],
+            "matrix_types": [t["name"] for t in jm["types"]],
+            "substance_count": len(substance_counter),
+            "substance_top": substance_counter.most_common(15),
+            "form_count": len(form_counter),
+            "form_top": form_counter.most_common(15),
+            "post_count": len(posts),
+            "extracted_post_count": len(proc_done | tool_done),
+            "tool_count": len({r["tool_name"] for r in latest_tools if r["tool_name"]}),
+            "via_top10": via_counter.most_common(10),
+        },
+        "process_data": {
+            "run_count": len(runs),
+            "avg_cost": round(total_cost / len(runs), 4) if runs else 0,
+            "total_cost": total_cost,
+            "avg_duration": round(total_dur / len(runs), 1) if runs else 0,
+            "total_duration": total_dur,
+            "cost_trend": cost_trend,
+            "process_progress": {"done": len(proc_done), "total": len(proc_targets)},
+            "tools_progress": {"done": len(tool_done), "total": len(tool_targets)},
+        },
+    }
+
+
+# ── HTTP handler ─────────────────────────────────────────────────────────────
+
+class Handler(BaseHTTPRequestHandler):
+
+    def _json(self, data, code=200):
+        body = json.dumps(data, ensure_ascii=False, default=str).encode("utf-8")
+        self.send_response(code)
+        self.send_header("Content-Type", "application/json; charset=utf-8")
+        self.send_header("Content-Length", str(len(body)))
+        self.end_headers()
+        self.wfile.write(body)
+
+    def _err(self, msg, code=400):
+        self._json({"error": msg}, code)
+
+    def do_GET(self):
+        u = urlparse(self.path)
+        qs = {k: v[0] for k, v in parse_qs(u.query).items()}
+        try:
+            if u.path == "/" or u.path == "/index.html":
+                body = (HERE / "index.html").read_bytes()
+                self.send_response(200)
+                self.send_header("Content-Type", "text/html; charset=utf-8")
+                self.send_header("Content-Length", str(len(body)))
+                self.end_headers()
+                self.wfile.write(body)
+            elif u.path == "/api/dashboard":
+                self._json(_dashboard())
+            elif u.path == "/api/queries":
+                self._json(db.fetch_queries())
+            elif u.path == "/api/posts":
+                self._json(db.fetch_posts(qs.get("query_id", "")))
+            elif u.path == "/api/process_versions":
+                self._json(db.fetch_process_versions(qs.get("case_id", "")))
+            elif u.path == "/api/process":
+                r = db.fetch_process(qs.get("case_id", ""), qs.get("version"))
+                self._json(r) if r else self._err("无解构记录", 404)
+            elif u.path == "/api/tools_versions":
+                self._json(db.fetch_tools_versions(qs.get("case_id", "")))
+            elif u.path == "/api/tools":
+                r = db.fetch_tools(qs.get("case_id", ""), qs.get("version"))
+                self._json(r) if r else self._err("无解构记录", 404)
+            elif u.path == "/api/task_status":
+                r = _task_status(qs.get("task_id", ""))
+                self._json(r) if r else self._err("未知 task_id", 404)
+            else:
+                self._err("not found", 404)
+        except Exception as e:
+            self._err(f"{type(e).__name__}: {e}", 500)
+
+    def do_POST(self):
+        u = urlparse(self.path)
+        try:
+            n = int(self.headers.get("Content-Length") or 0)
+            payload = json.loads(self.rfile.read(n) or b"{}")
+        except Exception:
+            return self._err("body 必须是 JSON")
+        try:
+            if u.path in ("/api/extract_process", "/api/extract_tools"):
+                qid = payload.get("query_id")
+                cids = payload.get("case_ids") or []
+                if not qid or not cids:
+                    return self._err("缺 query_id / case_ids")
+                script = ("pipeline/procedure_extract.py" if u.path.endswith("process")
+                          else "pipeline/tool_extract.py")
+                cmd = [sys.executable, script, "--query-id", qid,
+                       "--case-ids", ",".join(cids)]
+                if payload.get("model"):
+                    cmd += ["--model", payload["model"]]
+                kind = "proc" if u.path.endswith("process") else "tool"
+                self._json({"task_id": _spawn_task(kind, cmd)})
+            elif u.path == "/api/run_search":
+                query = (payload.get("query") or "").strip()
+                if not query:
+                    return self._err("缺 query")
+                qid = payload.get("query_id") or _next_query_id()
+                cmd = [sys.executable, "pipeline/search_eval.py",
+                       "--query-id", qid, "--query", query]
+                if payload.get("synonyms"):
+                    cmd += ["--synonyms", payload["synonyms"]]
+                if payload.get("platforms"):
+                    cmd += ["--platforms", payload["platforms"]]
+                if payload.get("max_count"):
+                    cmd += ["--max-count", str(payload["max_count"])]
+                self._json({"task_id": _spawn_task("search", cmd), "query_id": qid})
+            else:
+                self._err("not found", 404)
+        except Exception as e:
+            self._err(f"{type(e).__name__}: {e}", 500)
+
+    def log_message(self, fmt, *a):
+        pass   # 静默访问日志
+
+
+if __name__ == "__main__":
+    print(f"🚀 mode_workflow server → http://0.0.0.0:{PORT}")
+    ThreadingHTTPServer(("0.0.0.0", PORT), Handler).serve_forever()