server.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. # -*- coding: utf-8 -*-
  2. """mode_workflow server · 页面 + API + 解构任务管理
  3. ================================================================================
  4. 单服务(默认 8772):
  5. - GET / index.html
  6. - GET /search.html 知识检索页(聚类库 tab 内嵌;API 域名由 .env 注入)
  7. - GET /api/dashboard Dashboard 全部聚合指标(含内容树覆盖)
  8. - GET /api/queries|posts|process|tools(+_versions) Dataset 数据
  9. - POST /api/run_search|extract_process|extract_tools 起子进程跑 pipeline
  10. - GET /api/task_status 轮询任务状态(读日志尾部)
  11. 用法:python server.py [port]
  12. """
  13. import json
  14. import os
  15. import subprocess
  16. import sys
  17. import threading
  18. from collections import Counter
  19. from datetime import datetime
  20. from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
  21. from pathlib import Path
  22. from urllib.parse import urlparse, parse_qs
  23. import urllib.request
  24. import urllib.error
  25. try:
  26. sys.stdout.reconfigure(encoding="utf-8")
  27. except Exception:
  28. pass
  29. HERE = Path(__file__).resolve().parent
  30. sys.path.insert(0, str(HERE))
  31. import db
  32. PORT = int(sys.argv[1]) if len(sys.argv) > 1 else 8772
  33. MATRIX_FILE = HERE / "reference" / "judged_matrix.json"
  34. LOG_DIR = HERE / "runs" / "logs"
  35. # 知识检索后端地址:从 .env 的 KNOWLEDGE_API_BASE 读取(db.py 已 load_dotenv)。
  36. # 注意:不能把它注入到 search.html 让浏览器直连——后端是明文 http://,而页面
  37. # 经 Cloudflare 隧道是 https://,浏览器会以「混合内容(Mixed Content)」拦截请求。
  38. # 因此 search.html 保持相对路径 '/api/v1/knowledge',由本服务同源反代到后端,
  39. # 这样既无混合内容、也无跨域(CORS)问题。
  40. KNOWLEDGE_API_BASE = os.getenv("KNOWLEDGE_API_BASE", "").rstrip("/")
  41. def _render_search_html():
  42. # 保持相对路径,接口走本服务的 /api/v1/knowledge 反向代理。
  43. return (HERE / "search.html").read_bytes()
  44. # ── 任务管理:task_id → {proc, log, status} ──────────────────────────────────
  45. TASKS = {}
  46. _TASK_LOCK = threading.Lock()
  47. def _spawn_task(kind, cmd):
  48. LOG_DIR.mkdir(parents=True, exist_ok=True)
  49. task_id = f"{kind}_{datetime.now().strftime('%m%d%H%M%S%f')}"
  50. log_path = LOG_DIR / f"{task_id}.log"
  51. f = open(log_path, "w", encoding="utf-8")
  52. proc = subprocess.Popen(cmd, stdout=f, stderr=subprocess.STDOUT,
  53. cwd=str(HERE), text=True)
  54. with _TASK_LOCK:
  55. TASKS[task_id] = {"proc": proc, "log": log_path, "status": "running"}
  56. def _wait():
  57. rc = proc.wait()
  58. f.close()
  59. with _TASK_LOCK:
  60. TASKS[task_id]["status"] = "done" if rc == 0 else "failed"
  61. threading.Thread(target=_wait, daemon=True).start()
  62. return task_id
  63. def _task_status(task_id):
  64. with _TASK_LOCK:
  65. t = TASKS.get(task_id)
  66. if not t:
  67. return None
  68. tail = ""
  69. try:
  70. text = t["log"].read_text(encoding="utf-8", errors="replace")
  71. tail = text[-3000:]
  72. except Exception:
  73. pass
  74. return {"status": t["status"], "log_tail": tail}
  75. def _next_query_id():
  76. """两张搜索表统一编号,避免跨方向撞 ID。"""
  77. qs = [q["query_id"] for m in ("process", "tools") for q in db.fetch_queries(m)]
  78. nums = [int(q[1:]) for q in qs if q.startswith("q") and q[1:].isdigit()]
  79. return f"q{(max(nums) + 1 if nums else 0):04d}"
  80. # ── Dashboard 聚合 ────────────────────────────────────────────────────────────
  81. def _split_values(v):
  82. """substance/form 字段:数组直接用;字符串按 、,/ 分割;None 丢弃。"""
  83. out = []
  84. items = v if isinstance(v, list) else [v]
  85. for it in items:
  86. if not it or not isinstance(it, str):
  87. continue
  88. for piece in it.replace(",", "、").replace("/", "、").split("、"):
  89. piece = piece.strip()
  90. if piece:
  91. out.append(piece)
  92. return out
  93. def _dashboard():
  94. posts, procs, tools = db.fetch_dashboard_rows()
  95. # 最新版本行集(覆盖度/Top10 用最新版,成本/耗时按全部版本累计)
  96. def latest(rows):
  97. best = {}
  98. for r in rows:
  99. cid = r["case_id"]
  100. if cid not in best or (r["version"] or "") > (best[cid] or ""):
  101. best[cid] = r["version"]
  102. return [r for r in rows if r["version"] == best[r["case_id"]]]
  103. latest_procs = latest(procs)
  104. latest_tools = latest(tools)
  105. # 内容树覆盖:steps 的 (action 叶子 × 输入/输出 type) ∩ 有效节点(tier≥1)
  106. jm = json.loads(MATRIX_FILE.read_text(encoding="utf-8"))
  107. a_idx = {a["name"]: i for i, a in enumerate(jm["actions"])}
  108. t_idx = {t["name"]: i for i, t in enumerate(jm["types"])}
  109. valid = set()
  110. for ai, row in enumerate(jm["matrix"]):
  111. for ti, cell in enumerate(row):
  112. if isinstance(cell, dict) and cell.get("tier", 0) >= 1:
  113. valid.add((ai, ti))
  114. covered = set()
  115. via_counter = Counter()
  116. substance_counter = Counter()
  117. form_counter = Counter()
  118. for r in latest_procs:
  119. for s in r["steps"]:
  120. if not isinstance(s, dict):
  121. continue
  122. leaf = (s.get("action") or "").split("/")[-1].strip()
  123. types = []
  124. for io in ("inputs", "outputs"):
  125. for x in s.get(io) or []:
  126. if isinstance(x, dict) and x.get("type"):
  127. types.append(str(x["type"]).strip())
  128. if leaf in a_idx:
  129. for tp in types:
  130. if tp in t_idx and (a_idx[leaf], t_idx[tp]) in valid:
  131. covered.add((a_idx[leaf], t_idx[tp]))
  132. via = (s.get("via") or "").strip()
  133. if via:
  134. via_counter[via] += 1
  135. for v in _split_values(s.get("substance")):
  136. substance_counter[v] += 1
  137. for v in _split_values(s.get("form")):
  138. form_counter[v] += 1
  139. for r in latest_tools:
  140. for v in _split_values(r["substance_scope"]):
  141. substance_counter[v] += 1
  142. for v in _split_values(r["form_scope"]):
  143. form_counter[v] += 1
  144. # 成本/耗时:同一 (case_id, version) 只计一次(各行重复存同一次调用的值)
  145. def cost_groups(rows):
  146. g = {}
  147. for r in rows:
  148. key = (r["case_id"], r["version"])
  149. if key not in g and r["cost_usd"] is not None:
  150. g[key] = (r["cost_usd"], r["duration_s"] or 0.0, r["created_at"])
  151. return list(g.values())
  152. runs = cost_groups(procs) + cost_groups(tools)
  153. total_cost = round(sum(c for c, _, _ in runs), 4)
  154. total_dur = round(sum(d for _, d, _ in runs), 1)
  155. # 按日成本趋势
  156. daily = Counter()
  157. for c, _, ts in runs:
  158. if ts:
  159. daily[ts[:10]] += c
  160. cost_trend = [{"date": d, "cost": round(v, 4)} for d, v in sorted(daily.items())]
  161. # 进度:分子分母同口径,都走「采纳」。分母 = 该方向 search 表里采纳的帖(distinct case),
  162. # 即「需解构」;分子 = 采纳帖里已解构的(∩ 保证 ≤ 分母,杜绝越界/虚高)。
  163. # 方向由 p["mode"] 区分(process=search_process,tools=search_tools),不再看 knowledge_type。
  164. proc_targets = {p["case_id"] for p in posts if p["mode"] == "process" and p["adopted"]}
  165. tool_targets = {p["case_id"] for p in posts if p["mode"] == "tools" and p["adopted"]}
  166. proc_extracted = {r["case_id"] for r in procs}
  167. tool_extracted = {r["case_id"] for r in tools}
  168. proc_done = proc_extracted & proc_targets
  169. tool_done = tool_extracted & tool_targets
  170. # 渠道分项/解构总数:按实际解构过的 distinct case(不限采纳),平台由 case 内禀。
  171. extracted_all = proc_extracted | tool_extracted
  172. case_plat = {p["case_id"]: (p["platform"] or "other") for p in posts}
  173. collected_by_plat = Counter((p["platform"] or "other") for p in posts)
  174. extracted_by_plat = Counter(case_plat.get(c, "other") for c in extracted_all)
  175. return {
  176. "result": {
  177. "collected_by_platform": collected_by_plat.most_common(),
  178. "extracted_by_platform": extracted_by_plat.most_common(),
  179. "matrix_covered": len(covered), "matrix_valid": len(valid),
  180. "matrix_cells": sorted([ai, ti] for ai, ti in covered),
  181. "matrix_actions": [a["name"] for a in jm["actions"]],
  182. "matrix_types": [t["name"] for t in jm["types"]],
  183. "substance_count": len(substance_counter),
  184. "substance_top": substance_counter.most_common(15),
  185. "form_count": len(form_counter),
  186. "form_top": form_counter.most_common(15),
  187. "post_count": len(posts),
  188. "extracted_post_count": len(extracted_all),
  189. "tool_count": len({r["tool_name"] for r in latest_tools if r["tool_name"]}),
  190. "via_top10": via_counter.most_common(10),
  191. },
  192. "process_data": {
  193. "run_count": len(runs),
  194. "avg_cost": round(total_cost / len(runs), 4) if runs else 0,
  195. "total_cost": total_cost,
  196. "avg_duration": round(total_dur / len(runs), 1) if runs else 0,
  197. "total_duration": total_dur,
  198. "cost_trend": cost_trend,
  199. "process_progress": {"done": len(proc_done), "total": len(proc_targets)},
  200. "tools_progress": {"done": len(tool_done), "total": len(tool_targets)},
  201. },
  202. }
  203. # ── HTTP handler ─────────────────────────────────────────────────────────────
  204. class Handler(BaseHTTPRequestHandler):
  205. def _json(self, data, code=200):
  206. body = json.dumps(data, ensure_ascii=False, default=str).encode("utf-8")
  207. self.send_response(code)
  208. self.send_header("Content-Type", "application/json; charset=utf-8")
  209. self.send_header("Content-Length", str(len(body)))
  210. self.end_headers()
  211. self.wfile.write(body)
  212. def _err(self, msg, code=400):
  213. self._json({"error": msg}, code)
  214. def _proxy_image(self, url):
  215. """同源图片反代:绕过公众号(mmbiz.qpic.cn)等站点的防盗链。
  216. 浏览器侧 referrerpolicy=no-referrer 偶尔仍被拦,服务端直取最稳:
  217. 不带 Referer、给个常规 UA,把图片字节原样转回,并加长缓存。"""
  218. if not url or not (url.startswith("http://") or url.startswith("https://")):
  219. return self._err("非法图片地址", 400)
  220. host = (urlparse(url).hostname or "").lower()
  221. # 防 SSRF:挡掉内网/本机地址
  222. if host in ("localhost", "127.0.0.1", "0.0.0.0", "::1") or \
  223. host.startswith("10.") or host.startswith("192.168.") or \
  224. host.startswith("169.254.") or host.endswith(".internal"):
  225. return self._err("禁止的图片地址", 403)
  226. req = urllib.request.Request(url, headers={
  227. "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
  228. "AppleWebKit/537.36 (KHTML, like Gecko) "
  229. "Chrome/120.0 Safari/537.36",
  230. "Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8",
  231. })
  232. try:
  233. with urllib.request.urlopen(req, timeout=30) as resp:
  234. payload = resp.read()
  235. ct = resp.headers.get("Content-Type", "image/jpeg")
  236. except urllib.error.HTTPError as e:
  237. return self._err(f"上游图片返回 {e.code}", e.code if 400 <= e.code < 600 else 502)
  238. except Exception as e:
  239. return self._err(f"图片不可达:{type(e).__name__}: {e}", 502)
  240. self.send_response(200)
  241. self.send_header("Content-Type", ct)
  242. self.send_header("Content-Length", str(len(payload)))
  243. self.send_header("Cache-Control", "public, max-age=86400")
  244. self.end_headers()
  245. self.wfile.write(payload)
  246. def _proxy_knowledge(self, body=None):
  247. """把 /api/v1/knowledge* 同源反代到 KNOWLEDGE_API_BASE(明文后端)。
  248. 浏览器只跟本服务(经隧道走 https)通信,规避混合内容 + 跨域。"""
  249. if not KNOWLEDGE_API_BASE:
  250. return self._err("KNOWLEDGE_API_BASE 未配置", 502)
  251. target = KNOWLEDGE_API_BASE + self.path # self.path 含 query string
  252. headers = {}
  253. ct = self.headers.get("Content-Type")
  254. if ct:
  255. headers["Content-Type"] = ct
  256. req = urllib.request.Request(target, data=body, headers=headers,
  257. method=self.command)
  258. try:
  259. with urllib.request.urlopen(req, timeout=120) as resp:
  260. payload = resp.read()
  261. code = resp.status
  262. rct = resp.headers.get("Content-Type", "application/json; charset=utf-8")
  263. except urllib.error.HTTPError as e:
  264. payload = e.read()
  265. code = e.code
  266. rct = e.headers.get("Content-Type", "application/json; charset=utf-8")
  267. except Exception as e:
  268. return self._err(f"知识检索后端不可达:{type(e).__name__}: {e}", 502)
  269. self.send_response(code)
  270. self.send_header("Content-Type", rct)
  271. self.send_header("Content-Length", str(len(payload)))
  272. self.end_headers()
  273. self.wfile.write(payload)
  274. def do_GET(self):
  275. u = urlparse(self.path)
  276. qs = {k: v[0] for k, v in parse_qs(u.query).items()}
  277. try:
  278. if u.path == "/" or u.path == "/index.html":
  279. body = (HERE / "index.html").read_bytes()
  280. self.send_response(200)
  281. self.send_header("Content-Type", "text/html; charset=utf-8")
  282. self.send_header("Content-Length", str(len(body)))
  283. # 单文件前端,改版频繁:禁缓存,避免浏览器拿到旧 index.html
  284. self.send_header("Cache-Control", "no-cache, no-store, must-revalidate")
  285. self.end_headers()
  286. self.wfile.write(body)
  287. elif u.path == "/search.html":
  288. # 聚类库 tab 内嵌的知识检索页;API 域名由 .env 注入
  289. body = _render_search_html()
  290. self.send_response(200)
  291. self.send_header("Content-Type", "text/html; charset=utf-8")
  292. self.send_header("Content-Length", str(len(body)))
  293. self.send_header("Cache-Control", "no-cache, no-store, must-revalidate")
  294. self.end_headers()
  295. self.wfile.write(body)
  296. elif u.path == "/api/dashboard":
  297. self._json(_dashboard())
  298. elif u.path == "/api/queries":
  299. self._json(db.fetch_queries(qs.get("mode", "process")))
  300. elif u.path == "/api/posts":
  301. self._json(db.fetch_posts(qs.get("query_id", ""), qs.get("mode", "process")))
  302. elif u.path == "/api/process_versions":
  303. self._json(db.fetch_process_versions(qs.get("case_id", "")))
  304. elif u.path == "/api/process":
  305. r = db.fetch_process(qs.get("case_id", ""), qs.get("version"))
  306. self._json(r) if r else self._err("无解构记录", 404)
  307. elif u.path == "/api/tools_versions":
  308. self._json(db.fetch_tools_versions(qs.get("case_id", "")))
  309. elif u.path == "/api/tools":
  310. r = db.fetch_tools(qs.get("case_id", ""), qs.get("version"))
  311. self._json(r) if r else self._err("无解构记录", 404)
  312. elif u.path == "/api/task_status":
  313. r = _task_status(qs.get("task_id", ""))
  314. self._json(r) if r else self._err("未知 task_id", 404)
  315. elif u.path == "/api/img":
  316. self._proxy_image(qs.get("u", ""))
  317. elif u.path.startswith("/api/v1/knowledge"):
  318. self._proxy_knowledge()
  319. else:
  320. self._err("not found", 404)
  321. except Exception as e:
  322. self._err(f"{type(e).__name__}: {e}", 500)
  323. def do_POST(self):
  324. u = urlparse(self.path)
  325. try:
  326. n = int(self.headers.get("Content-Length") or 0)
  327. raw = self.rfile.read(n)
  328. except Exception:
  329. return self._err("读取请求体失败")
  330. # 知识检索接口:原样反代到后端,不在本服务做 JSON 解析
  331. if u.path.startswith("/api/v1/knowledge"):
  332. return self._proxy_knowledge(body=raw)
  333. try:
  334. payload = json.loads(raw or b"{}")
  335. except Exception:
  336. return self._err("body 必须是 JSON")
  337. try:
  338. if u.path in ("/api/extract_process", "/api/extract_tools"):
  339. qid = payload.get("query_id")
  340. cids = payload.get("case_ids") or []
  341. if not qid or not cids:
  342. return self._err("缺 query_id / case_ids")
  343. script = ("pipeline/procedure_extract.py" if u.path.endswith("process")
  344. else "pipeline/tool_extract.py")
  345. cmd = [sys.executable, script, "--query-id", qid,
  346. "--case-ids", ",".join(cids)]
  347. if payload.get("model"):
  348. cmd += ["--model", payload["model"]]
  349. if payload.get("force"): # 默认按 case 全局去重;force 才强制重解构
  350. cmd += ["--force"]
  351. kind = "proc" if u.path.endswith("process") else "tool"
  352. self._json({"task_id": _spawn_task(kind, cmd)})
  353. elif u.path == "/api/run_search":
  354. query = (payload.get("query") or "").strip()
  355. if not query:
  356. return self._err("缺 query")
  357. qid = payload.get("query_id") or _next_query_id()
  358. cmd = [sys.executable, "pipeline/search_eval.py",
  359. "--query-id", qid, "--query", query]
  360. if payload.get("synonyms"):
  361. cmd += ["--synonyms", payload["synonyms"]]
  362. if payload.get("mode_type") in ("工序", "工具"):
  363. cmd += ["--mode-type", payload["mode_type"]]
  364. if payload.get("platforms"):
  365. cmd += ["--platforms", payload["platforms"]]
  366. if payload.get("max_count"):
  367. cmd += ["--max-count", str(payload["max_count"])]
  368. self._json({"task_id": _spawn_task("search", cmd), "query_id": qid})
  369. else:
  370. self._err("not found", 404)
  371. except Exception as e:
  372. self._err(f"{type(e).__name__}: {e}", 500)
  373. def log_message(self, fmt, *a):
  374. pass # 静默访问日志
  375. if __name__ == "__main__":
  376. print(f"🚀 mode_workflow server → http://0.0.0.0:{PORT}")
  377. ThreadingHTTPServer(("0.0.0.0", PORT), Handler).serve_forever()