server.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # -*- coding: utf-8 -*-
  2. """mode_workflow server · 页面 + API + 解构任务管理
  3. ================================================================================
  4. 单服务(默认 8772):
  5. - GET / index.html
  6. - GET /api/dashboard Dashboard 全部聚合指标(含内容树覆盖)
  7. - GET /api/queries|posts|process|tools(+_versions) Dataset 数据
  8. - POST /api/run_search|extract_process|extract_tools 起子进程跑 pipeline
  9. - GET /api/task_status 轮询任务状态(读日志尾部)
  10. 用法:python server.py [port]
  11. """
  12. import json
  13. import subprocess
  14. import sys
  15. import threading
  16. from collections import Counter
  17. from datetime import datetime
  18. from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
  19. from pathlib import Path
  20. from urllib.parse import urlparse, parse_qs
  21. try:
  22. sys.stdout.reconfigure(encoding="utf-8")
  23. except Exception:
  24. pass
  25. HERE = Path(__file__).resolve().parent
  26. sys.path.insert(0, str(HERE))
  27. import db
  28. PORT = int(sys.argv[1]) if len(sys.argv) > 1 else 8772
  29. MATRIX_FILE = HERE / "reference" / "judged_matrix.json"
  30. LOG_DIR = HERE / "runs" / "logs"
  31. # ── 任务管理:task_id → {proc, log, status} ──────────────────────────────────
  32. TASKS = {}
  33. _TASK_LOCK = threading.Lock()
  34. def _spawn_task(kind, cmd):
  35. LOG_DIR.mkdir(parents=True, exist_ok=True)
  36. task_id = f"{kind}_{datetime.now().strftime('%m%d%H%M%S%f')}"
  37. log_path = LOG_DIR / f"{task_id}.log"
  38. f = open(log_path, "w", encoding="utf-8")
  39. proc = subprocess.Popen(cmd, stdout=f, stderr=subprocess.STDOUT,
  40. cwd=str(HERE), text=True)
  41. with _TASK_LOCK:
  42. TASKS[task_id] = {"proc": proc, "log": log_path, "status": "running"}
  43. def _wait():
  44. rc = proc.wait()
  45. f.close()
  46. with _TASK_LOCK:
  47. TASKS[task_id]["status"] = "done" if rc == 0 else "failed"
  48. threading.Thread(target=_wait, daemon=True).start()
  49. return task_id
  50. def _task_status(task_id):
  51. with _TASK_LOCK:
  52. t = TASKS.get(task_id)
  53. if not t:
  54. return None
  55. tail = ""
  56. try:
  57. text = t["log"].read_text(encoding="utf-8", errors="replace")
  58. tail = text[-3000:]
  59. except Exception:
  60. pass
  61. return {"status": t["status"], "log_tail": tail}
  62. def _next_query_id():
  63. qs = [q["query_id"] for q in db.fetch_queries()]
  64. nums = [int(q[1:]) for q in qs if q.startswith("q") and q[1:].isdigit()]
  65. return f"q{(max(nums) + 1 if nums else 0):04d}"
  66. # ── Dashboard 聚合 ────────────────────────────────────────────────────────────
  67. def _split_values(v):
  68. """substance/form 字段:数组直接用;字符串按 、,/ 分割;None 丢弃。"""
  69. out = []
  70. items = v if isinstance(v, list) else [v]
  71. for it in items:
  72. if not it or not isinstance(it, str):
  73. continue
  74. for piece in it.replace(",", "、").replace("/", "、").split("、"):
  75. piece = piece.strip()
  76. if piece:
  77. out.append(piece)
  78. return out
  79. def _dashboard():
  80. posts, procs, tools = db.fetch_dashboard_rows()
  81. # 最新版本行集(覆盖度/Top10 用最新版,成本/耗时按全部版本累计)
  82. def latest(rows):
  83. best = {}
  84. for r in rows:
  85. cid = r["case_id"]
  86. if cid not in best or (r["version"] or "") > (best[cid] or ""):
  87. best[cid] = r["version"]
  88. return [r for r in rows if r["version"] == best[r["case_id"]]]
  89. latest_procs = latest(procs)
  90. latest_tools = latest(tools)
  91. # 内容树覆盖:steps 的 (action 叶子 × 输入/输出 type) ∩ 有效节点(tier≥1)
  92. jm = json.loads(MATRIX_FILE.read_text(encoding="utf-8"))
  93. a_idx = {a["name"]: i for i, a in enumerate(jm["actions"])}
  94. t_idx = {t["name"]: i for i, t in enumerate(jm["types"])}
  95. valid = set()
  96. for ai, row in enumerate(jm["matrix"]):
  97. for ti, cell in enumerate(row):
  98. if isinstance(cell, dict) and cell.get("tier", 0) >= 1:
  99. valid.add((ai, ti))
  100. covered = set()
  101. via_counter = Counter()
  102. substance_counter = Counter()
  103. form_counter = Counter()
  104. for r in latest_procs:
  105. for s in r["steps"]:
  106. if not isinstance(s, dict):
  107. continue
  108. leaf = (s.get("action") or "").split("/")[-1].strip()
  109. types = []
  110. for io in ("inputs", "outputs"):
  111. for x in s.get(io) or []:
  112. if isinstance(x, dict) and x.get("type"):
  113. types.append(str(x["type"]).strip())
  114. if leaf in a_idx:
  115. for tp in types:
  116. if tp in t_idx and (a_idx[leaf], t_idx[tp]) in valid:
  117. covered.add((a_idx[leaf], t_idx[tp]))
  118. via = (s.get("via") or "").strip()
  119. if via:
  120. via_counter[via] += 1
  121. for v in _split_values(s.get("substance")):
  122. substance_counter[v] += 1
  123. for v in _split_values(s.get("form")):
  124. form_counter[v] += 1
  125. for r in latest_tools:
  126. for v in _split_values(r["substance_scope"]):
  127. substance_counter[v] += 1
  128. for v in _split_values(r["form_scope"]):
  129. form_counter[v] += 1
  130. # 成本/耗时:同一 (case_id, version) 只计一次(各行重复存同一次调用的值)
  131. def cost_groups(rows):
  132. g = {}
  133. for r in rows:
  134. key = (r["case_id"], r["version"])
  135. if key not in g and r["cost_usd"] is not None:
  136. g[key] = (r["cost_usd"], r["duration_s"] or 0.0, r["created_at"])
  137. return list(g.values())
  138. runs = cost_groups(procs) + cost_groups(tools)
  139. total_cost = round(sum(c for c, _, _ in runs), 4)
  140. total_dur = round(sum(d for _, d, _ in runs), 1)
  141. # 按日成本趋势
  142. daily = Counter()
  143. for c, _, ts in runs:
  144. if ts:
  145. daily[ts[:10]] += c
  146. cost_trend = [{"date": d, "cost": round(v, 4)} for d, v in sorted(daily.items())]
  147. # 进度:分母 = knowledge_type 含对应类型的帖子(distinct case)
  148. proc_targets = {p["case_id"] for p in posts if "工序" in (p["knowledge_type"] or [])}
  149. tool_targets = {p["case_id"] for p in posts if "工具" in (p["knowledge_type"] or [])}
  150. proc_done = {r["case_id"] for r in procs}
  151. tool_done = {r["case_id"] for r in tools}
  152. return {
  153. "result": {
  154. "matrix_covered": len(covered), "matrix_valid": len(valid),
  155. "matrix_cells": sorted([ai, ti] for ai, ti in covered),
  156. "matrix_actions": [a["name"] for a in jm["actions"]],
  157. "matrix_types": [t["name"] for t in jm["types"]],
  158. "substance_count": len(substance_counter),
  159. "substance_top": substance_counter.most_common(15),
  160. "form_count": len(form_counter),
  161. "form_top": form_counter.most_common(15),
  162. "post_count": len(posts),
  163. "extracted_post_count": len(proc_done | tool_done),
  164. "tool_count": len({r["tool_name"] for r in latest_tools if r["tool_name"]}),
  165. "via_top10": via_counter.most_common(10),
  166. },
  167. "process_data": {
  168. "run_count": len(runs),
  169. "avg_cost": round(total_cost / len(runs), 4) if runs else 0,
  170. "total_cost": total_cost,
  171. "avg_duration": round(total_dur / len(runs), 1) if runs else 0,
  172. "total_duration": total_dur,
  173. "cost_trend": cost_trend,
  174. "process_progress": {"done": len(proc_done), "total": len(proc_targets)},
  175. "tools_progress": {"done": len(tool_done), "total": len(tool_targets)},
  176. },
  177. }
  178. # ── HTTP handler ─────────────────────────────────────────────────────────────
  179. class Handler(BaseHTTPRequestHandler):
  180. def _json(self, data, code=200):
  181. body = json.dumps(data, ensure_ascii=False, default=str).encode("utf-8")
  182. self.send_response(code)
  183. self.send_header("Content-Type", "application/json; charset=utf-8")
  184. self.send_header("Content-Length", str(len(body)))
  185. self.end_headers()
  186. self.wfile.write(body)
  187. def _err(self, msg, code=400):
  188. self._json({"error": msg}, code)
  189. def do_GET(self):
  190. u = urlparse(self.path)
  191. qs = {k: v[0] for k, v in parse_qs(u.query).items()}
  192. try:
  193. if u.path == "/" or u.path == "/index.html":
  194. body = (HERE / "index.html").read_bytes()
  195. self.send_response(200)
  196. self.send_header("Content-Type", "text/html; charset=utf-8")
  197. self.send_header("Content-Length", str(len(body)))
  198. self.end_headers()
  199. self.wfile.write(body)
  200. elif u.path == "/api/dashboard":
  201. self._json(_dashboard())
  202. elif u.path == "/api/queries":
  203. self._json(db.fetch_queries())
  204. elif u.path == "/api/posts":
  205. self._json(db.fetch_posts(qs.get("query_id", "")))
  206. elif u.path == "/api/process_versions":
  207. self._json(db.fetch_process_versions(qs.get("case_id", "")))
  208. elif u.path == "/api/process":
  209. r = db.fetch_process(qs.get("case_id", ""), qs.get("version"))
  210. self._json(r) if r else self._err("无解构记录", 404)
  211. elif u.path == "/api/tools_versions":
  212. self._json(db.fetch_tools_versions(qs.get("case_id", "")))
  213. elif u.path == "/api/tools":
  214. r = db.fetch_tools(qs.get("case_id", ""), qs.get("version"))
  215. self._json(r) if r else self._err("无解构记录", 404)
  216. elif u.path == "/api/task_status":
  217. r = _task_status(qs.get("task_id", ""))
  218. self._json(r) if r else self._err("未知 task_id", 404)
  219. else:
  220. self._err("not found", 404)
  221. except Exception as e:
  222. self._err(f"{type(e).__name__}: {e}", 500)
  223. def do_POST(self):
  224. u = urlparse(self.path)
  225. try:
  226. n = int(self.headers.get("Content-Length") or 0)
  227. payload = json.loads(self.rfile.read(n) or b"{}")
  228. except Exception:
  229. return self._err("body 必须是 JSON")
  230. try:
  231. if u.path in ("/api/extract_process", "/api/extract_tools"):
  232. qid = payload.get("query_id")
  233. cids = payload.get("case_ids") or []
  234. if not qid or not cids:
  235. return self._err("缺 query_id / case_ids")
  236. script = ("pipeline/procedure_extract.py" if u.path.endswith("process")
  237. else "pipeline/tool_extract.py")
  238. cmd = [sys.executable, script, "--query-id", qid,
  239. "--case-ids", ",".join(cids)]
  240. if payload.get("model"):
  241. cmd += ["--model", payload["model"]]
  242. kind = "proc" if u.path.endswith("process") else "tool"
  243. self._json({"task_id": _spawn_task(kind, cmd)})
  244. elif u.path == "/api/run_search":
  245. query = (payload.get("query") or "").strip()
  246. if not query:
  247. return self._err("缺 query")
  248. qid = payload.get("query_id") or _next_query_id()
  249. cmd = [sys.executable, "pipeline/search_eval.py",
  250. "--query-id", qid, "--query", query]
  251. if payload.get("synonyms"):
  252. cmd += ["--synonyms", payload["synonyms"]]
  253. if payload.get("mode_type") in ("工序", "工具"):
  254. cmd += ["--mode-type", payload["mode_type"]]
  255. if payload.get("platforms"):
  256. cmd += ["--platforms", payload["platforms"]]
  257. if payload.get("max_count"):
  258. cmd += ["--max-count", str(payload["max_count"])]
  259. self._json({"task_id": _spawn_task("search", cmd), "query_id": qid})
  260. else:
  261. self._err("not found", 404)
  262. except Exception as e:
  263. self._err(f"{type(e).__name__}: {e}", 500)
  264. def log_message(self, fmt, *a):
  265. pass # 静默访问日志
  266. if __name__ == "__main__":
  267. print(f"🚀 mode_workflow server → http://0.0.0.0:{PORT}")
  268. ThreadingHTTPServer(("0.0.0.0", PORT), Handler).serve_forever()