server.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # -*- coding: utf-8 -*-
  2. """mode_workflow server · 页面 + API + 解构任务管理
  3. ================================================================================
  4. 单服务(默认 8772):
  5. - GET / index.html
  6. - GET /search.html 知识检索页(聚类库 tab 内嵌;API 域名由 .env 注入)
  7. - GET /api/dashboard Dashboard 全部聚合指标(含内容树覆盖)
  8. - GET /api/queries|posts|process|tools(+_versions) Dataset 数据
  9. - POST /api/run_search|extract_process|extract_tools 起子进程跑 pipeline
  10. - GET /api/task_status 轮询任务状态(读日志尾部)
  11. 用法:python server.py [port]
  12. """
  13. import json
  14. import os
  15. import subprocess
  16. import sys
  17. import threading
  18. from collections import Counter
  19. from datetime import datetime
  20. from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
  21. from pathlib import Path
  22. from urllib.parse import urlparse, parse_qs
  23. try:
  24. sys.stdout.reconfigure(encoding="utf-8")
  25. except Exception:
  26. pass
  27. HERE = Path(__file__).resolve().parent
  28. sys.path.insert(0, str(HERE))
  29. import db
  30. PORT = int(sys.argv[1]) if len(sys.argv) > 1 else 8772
  31. MATRIX_FILE = HERE / "reference" / "judged_matrix.json"
  32. LOG_DIR = HERE / "runs" / "logs"
  33. # 知识检索 API 域名:从 .env 的 KNOWLEDGE_API_BASE 读取(db.py 已 load_dotenv)。
  34. # search.html 内写的是相对路径 '/api/v1/knowledge';serve 时把它替换成绝对地址,
  35. # 这样页面由本服务托管、接口仍打到独立的知识检索后端。
  36. KNOWLEDGE_API_BASE = os.getenv("KNOWLEDGE_API_BASE", "").rstrip("/")
  37. def _render_search_html():
  38. html = (HERE / "search.html").read_text(encoding="utf-8")
  39. if KNOWLEDGE_API_BASE:
  40. html = html.replace(
  41. "const API = '/api/v1/knowledge';",
  42. f"const API = '{KNOWLEDGE_API_BASE}/api/v1/knowledge';",
  43. )
  44. return html.encode("utf-8")
  45. # ── 任务管理:task_id → {proc, log, status} ──────────────────────────────────
  46. TASKS = {}
  47. _TASK_LOCK = threading.Lock()
  48. def _spawn_task(kind, cmd):
  49. LOG_DIR.mkdir(parents=True, exist_ok=True)
  50. task_id = f"{kind}_{datetime.now().strftime('%m%d%H%M%S%f')}"
  51. log_path = LOG_DIR / f"{task_id}.log"
  52. f = open(log_path, "w", encoding="utf-8")
  53. proc = subprocess.Popen(cmd, stdout=f, stderr=subprocess.STDOUT,
  54. cwd=str(HERE), text=True)
  55. with _TASK_LOCK:
  56. TASKS[task_id] = {"proc": proc, "log": log_path, "status": "running"}
  57. def _wait():
  58. rc = proc.wait()
  59. f.close()
  60. with _TASK_LOCK:
  61. TASKS[task_id]["status"] = "done" if rc == 0 else "failed"
  62. threading.Thread(target=_wait, daemon=True).start()
  63. return task_id
  64. def _task_status(task_id):
  65. with _TASK_LOCK:
  66. t = TASKS.get(task_id)
  67. if not t:
  68. return None
  69. tail = ""
  70. try:
  71. text = t["log"].read_text(encoding="utf-8", errors="replace")
  72. tail = text[-3000:]
  73. except Exception:
  74. pass
  75. return {"status": t["status"], "log_tail": tail}
  76. def _next_query_id():
  77. """两张搜索表统一编号,避免跨方向撞 ID。"""
  78. qs = [q["query_id"] for m in ("process", "tools") for q in db.fetch_queries(m)]
  79. nums = [int(q[1:]) for q in qs if q.startswith("q") and q[1:].isdigit()]
  80. return f"q{(max(nums) + 1 if nums else 0):04d}"
  81. # ── Dashboard 聚合 ────────────────────────────────────────────────────────────
  82. def _split_values(v):
  83. """substance/form 字段:数组直接用;字符串按 、,/ 分割;None 丢弃。"""
  84. out = []
  85. items = v if isinstance(v, list) else [v]
  86. for it in items:
  87. if not it or not isinstance(it, str):
  88. continue
  89. for piece in it.replace(",", "、").replace("/", "、").split("、"):
  90. piece = piece.strip()
  91. if piece:
  92. out.append(piece)
  93. return out
  94. def _dashboard():
  95. posts, procs, tools = db.fetch_dashboard_rows()
  96. # 最新版本行集(覆盖度/Top10 用最新版,成本/耗时按全部版本累计)
  97. def latest(rows):
  98. best = {}
  99. for r in rows:
  100. cid = r["case_id"]
  101. if cid not in best or (r["version"] or "") > (best[cid] or ""):
  102. best[cid] = r["version"]
  103. return [r for r in rows if r["version"] == best[r["case_id"]]]
  104. latest_procs = latest(procs)
  105. latest_tools = latest(tools)
  106. # 内容树覆盖:steps 的 (action 叶子 × 输入/输出 type) ∩ 有效节点(tier≥1)
  107. jm = json.loads(MATRIX_FILE.read_text(encoding="utf-8"))
  108. a_idx = {a["name"]: i for i, a in enumerate(jm["actions"])}
  109. t_idx = {t["name"]: i for i, t in enumerate(jm["types"])}
  110. valid = set()
  111. for ai, row in enumerate(jm["matrix"]):
  112. for ti, cell in enumerate(row):
  113. if isinstance(cell, dict) and cell.get("tier", 0) >= 1:
  114. valid.add((ai, ti))
  115. covered = set()
  116. via_counter = Counter()
  117. substance_counter = Counter()
  118. form_counter = Counter()
  119. for r in latest_procs:
  120. for s in r["steps"]:
  121. if not isinstance(s, dict):
  122. continue
  123. leaf = (s.get("action") or "").split("/")[-1].strip()
  124. types = []
  125. for io in ("inputs", "outputs"):
  126. for x in s.get(io) or []:
  127. if isinstance(x, dict) and x.get("type"):
  128. types.append(str(x["type"]).strip())
  129. if leaf in a_idx:
  130. for tp in types:
  131. if tp in t_idx and (a_idx[leaf], t_idx[tp]) in valid:
  132. covered.add((a_idx[leaf], t_idx[tp]))
  133. via = (s.get("via") or "").strip()
  134. if via:
  135. via_counter[via] += 1
  136. for v in _split_values(s.get("substance")):
  137. substance_counter[v] += 1
  138. for v in _split_values(s.get("form")):
  139. form_counter[v] += 1
  140. for r in latest_tools:
  141. for v in _split_values(r["substance_scope"]):
  142. substance_counter[v] += 1
  143. for v in _split_values(r["form_scope"]):
  144. form_counter[v] += 1
  145. # 成本/耗时:同一 (case_id, version) 只计一次(各行重复存同一次调用的值)
  146. def cost_groups(rows):
  147. g = {}
  148. for r in rows:
  149. key = (r["case_id"], r["version"])
  150. if key not in g and r["cost_usd"] is not None:
  151. g[key] = (r["cost_usd"], r["duration_s"] or 0.0, r["created_at"])
  152. return list(g.values())
  153. runs = cost_groups(procs) + cost_groups(tools)
  154. total_cost = round(sum(c for c, _, _ in runs), 4)
  155. total_dur = round(sum(d for _, d, _ in runs), 1)
  156. # 按日成本趋势
  157. daily = Counter()
  158. for c, _, ts in runs:
  159. if ts:
  160. daily[ts[:10]] += c
  161. cost_trend = [{"date": d, "cost": round(v, 4)} for d, v in sorted(daily.items())]
  162. # 进度:分子分母同口径,都走「采纳」。分母 = 该方向 search 表里采纳的帖(distinct case),
  163. # 即「需解构」;分子 = 采纳帖里已解构的(∩ 保证 ≤ 分母,杜绝越界/虚高)。
  164. # 方向由 p["mode"] 区分(process=search_process,tools=search_tools),不再看 knowledge_type。
  165. proc_targets = {p["case_id"] for p in posts if p["mode"] == "process" and p["adopted"]}
  166. tool_targets = {p["case_id"] for p in posts if p["mode"] == "tools" and p["adopted"]}
  167. proc_extracted = {r["case_id"] for r in procs}
  168. tool_extracted = {r["case_id"] for r in tools}
  169. proc_done = proc_extracted & proc_targets
  170. tool_done = tool_extracted & tool_targets
  171. # 渠道分项/解构总数:按实际解构过的 distinct case(不限采纳),平台由 case 内禀。
  172. extracted_all = proc_extracted | tool_extracted
  173. case_plat = {p["case_id"]: (p["platform"] or "other") for p in posts}
  174. collected_by_plat = Counter((p["platform"] or "other") for p in posts)
  175. extracted_by_plat = Counter(case_plat.get(c, "other") for c in extracted_all)
  176. return {
  177. "result": {
  178. "collected_by_platform": collected_by_plat.most_common(),
  179. "extracted_by_platform": extracted_by_plat.most_common(),
  180. "matrix_covered": len(covered), "matrix_valid": len(valid),
  181. "matrix_cells": sorted([ai, ti] for ai, ti in covered),
  182. "matrix_actions": [a["name"] for a in jm["actions"]],
  183. "matrix_types": [t["name"] for t in jm["types"]],
  184. "substance_count": len(substance_counter),
  185. "substance_top": substance_counter.most_common(15),
  186. "form_count": len(form_counter),
  187. "form_top": form_counter.most_common(15),
  188. "post_count": len(posts),
  189. "extracted_post_count": len(extracted_all),
  190. "tool_count": len({r["tool_name"] for r in latest_tools if r["tool_name"]}),
  191. "via_top10": via_counter.most_common(10),
  192. },
  193. "process_data": {
  194. "run_count": len(runs),
  195. "avg_cost": round(total_cost / len(runs), 4) if runs else 0,
  196. "total_cost": total_cost,
  197. "avg_duration": round(total_dur / len(runs), 1) if runs else 0,
  198. "total_duration": total_dur,
  199. "cost_trend": cost_trend,
  200. "process_progress": {"done": len(proc_done), "total": len(proc_targets)},
  201. "tools_progress": {"done": len(tool_done), "total": len(tool_targets)},
  202. },
  203. }
  204. # ── HTTP handler ─────────────────────────────────────────────────────────────
  205. class Handler(BaseHTTPRequestHandler):
  206. def _json(self, data, code=200):
  207. body = json.dumps(data, ensure_ascii=False, default=str).encode("utf-8")
  208. self.send_response(code)
  209. self.send_header("Content-Type", "application/json; charset=utf-8")
  210. self.send_header("Content-Length", str(len(body)))
  211. self.end_headers()
  212. self.wfile.write(body)
  213. def _err(self, msg, code=400):
  214. self._json({"error": msg}, code)
  215. def do_GET(self):
  216. u = urlparse(self.path)
  217. qs = {k: v[0] for k, v in parse_qs(u.query).items()}
  218. try:
  219. if u.path == "/" or u.path == "/index.html":
  220. body = (HERE / "index.html").read_bytes()
  221. self.send_response(200)
  222. self.send_header("Content-Type", "text/html; charset=utf-8")
  223. self.send_header("Content-Length", str(len(body)))
  224. # 单文件前端,改版频繁:禁缓存,避免浏览器拿到旧 index.html
  225. self.send_header("Cache-Control", "no-cache, no-store, must-revalidate")
  226. self.end_headers()
  227. self.wfile.write(body)
  228. elif u.path == "/search.html":
  229. # 聚类库 tab 内嵌的知识检索页;API 域名由 .env 注入
  230. body = _render_search_html()
  231. self.send_response(200)
  232. self.send_header("Content-Type", "text/html; charset=utf-8")
  233. self.send_header("Content-Length", str(len(body)))
  234. self.send_header("Cache-Control", "no-cache, no-store, must-revalidate")
  235. self.end_headers()
  236. self.wfile.write(body)
  237. elif u.path == "/api/dashboard":
  238. self._json(_dashboard())
  239. elif u.path == "/api/queries":
  240. self._json(db.fetch_queries(qs.get("mode", "process")))
  241. elif u.path == "/api/posts":
  242. self._json(db.fetch_posts(qs.get("query_id", ""), qs.get("mode", "process")))
  243. elif u.path == "/api/process_versions":
  244. self._json(db.fetch_process_versions(qs.get("case_id", "")))
  245. elif u.path == "/api/process":
  246. r = db.fetch_process(qs.get("case_id", ""), qs.get("version"))
  247. self._json(r) if r else self._err("无解构记录", 404)
  248. elif u.path == "/api/tools_versions":
  249. self._json(db.fetch_tools_versions(qs.get("case_id", "")))
  250. elif u.path == "/api/tools":
  251. r = db.fetch_tools(qs.get("case_id", ""), qs.get("version"))
  252. self._json(r) if r else self._err("无解构记录", 404)
  253. elif u.path == "/api/task_status":
  254. r = _task_status(qs.get("task_id", ""))
  255. self._json(r) if r else self._err("未知 task_id", 404)
  256. else:
  257. self._err("not found", 404)
  258. except Exception as e:
  259. self._err(f"{type(e).__name__}: {e}", 500)
  260. def do_POST(self):
  261. u = urlparse(self.path)
  262. try:
  263. n = int(self.headers.get("Content-Length") or 0)
  264. payload = json.loads(self.rfile.read(n) or b"{}")
  265. except Exception:
  266. return self._err("body 必须是 JSON")
  267. try:
  268. if u.path in ("/api/extract_process", "/api/extract_tools"):
  269. qid = payload.get("query_id")
  270. cids = payload.get("case_ids") or []
  271. if not qid or not cids:
  272. return self._err("缺 query_id / case_ids")
  273. script = ("pipeline/procedure_extract.py" if u.path.endswith("process")
  274. else "pipeline/tool_extract.py")
  275. cmd = [sys.executable, script, "--query-id", qid,
  276. "--case-ids", ",".join(cids)]
  277. if payload.get("model"):
  278. cmd += ["--model", payload["model"]]
  279. kind = "proc" if u.path.endswith("process") else "tool"
  280. self._json({"task_id": _spawn_task(kind, cmd)})
  281. elif u.path == "/api/run_search":
  282. query = (payload.get("query") or "").strip()
  283. if not query:
  284. return self._err("缺 query")
  285. qid = payload.get("query_id") or _next_query_id()
  286. cmd = [sys.executable, "pipeline/search_eval.py",
  287. "--query-id", qid, "--query", query]
  288. if payload.get("synonyms"):
  289. cmd += ["--synonyms", payload["synonyms"]]
  290. if payload.get("mode_type") in ("工序", "工具"):
  291. cmd += ["--mode-type", payload["mode_type"]]
  292. if payload.get("platforms"):
  293. cmd += ["--platforms", payload["platforms"]]
  294. if payload.get("max_count"):
  295. cmd += ["--max-count", str(payload["max_count"])]
  296. self._json({"task_id": _spawn_task("search", cmd), "query_id": qid})
  297. else:
  298. self._err("not found", 404)
  299. except Exception as e:
  300. self._err(f"{type(e).__name__}: {e}", 500)
  301. def log_message(self, fmt, *a):
  302. pass # 静默访问日志
  303. if __name__ == "__main__":
  304. print(f"🚀 mode_workflow server → http://0.0.0.0:{PORT}")
  305. ThreadingHTTPServer(("0.0.0.0", PORT), Handler).serve_forever()