Kaynağa Gözat

feat(mode_workflow): 添加连接池、合并接口与缓存优化

为数据库连接添加连接池,减少远程RDS的握手开销
新增`/api/extract`接口,合并版本列表和解构详情请求,减少前端往返次数
为`/api/dashboard`添加缓存,任务完成时主动失效并设置60s兜底TTL
为相关接口添加ETag支持,启用浏览器304缓存减少重复传输
优化SQL查询,仅提取所需字段降低数据传输量
更新README文档,补充模块职责与新特性说明
将服务器输出日志加入`.gitignore`并清理旧日志文件
刘文武 1 gün önce
ebeveyn
işleme
0b23364816

+ 1 - 0
examples/mode_workflow/.gitignore

@@ -7,3 +7,4 @@ __pycache__/
 .server.pid
 .cloudflared.log
 .cloudflared.pid
+.server_8772.out

+ 0 - 52
examples/mode_workflow/.server_8772.out

@@ -1,52 +0,0 @@
-----------------------------------------
-Exception occurred during processing of request from ('127.0.0.1', 56682)
-Traceback (most recent call last):
-  File "/Users/max_liu/max_liu/company/Agent/examples/mode_workflow/server.py", line 301, in do_GET
-    self._json(_dashboard())
-    ~~~~~~~~~~^^^^^^^^^^^^^^
-  File "/Users/max_liu/max_liu/company/Agent/examples/mode_workflow/server.py", line 244, in _json
-    self.wfile.write(body)
-    ~~~~~~~~~~~~~~~~^^^^^^
-  File "/usr/local/anaconda3/lib/python3.13/socketserver.py", line 845, in write
-    self._sock.sendall(b)
-    ~~~~~~~~~~~~~~~~~~^^^
-BrokenPipeError: [Errno 32] Broken pipe
-
-During handling of the above exception, another exception occurred:
-
-Traceback (most recent call last):
-  File "/usr/local/anaconda3/lib/python3.13/socketserver.py", line 697, in process_request_thread
-    self.finish_request(request, client_address)
-    ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
-  File "/usr/local/anaconda3/lib/python3.13/socketserver.py", line 362, in finish_request
-    self.RequestHandlerClass(request, client_address, self)
-    ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-  File "/usr/local/anaconda3/lib/python3.13/socketserver.py", line 766, in __init__
-    self.handle()
-    ~~~~~~~~~~~^^
-  File "/usr/local/anaconda3/lib/python3.13/http/server.py", line 436, in handle
-    self.handle_one_request()
-    ~~~~~~~~~~~~~~~~~~~~~~~^^
-  File "/usr/local/anaconda3/lib/python3.13/http/server.py", line 424, in handle_one_request
-    method()
-    ~~~~~~^^
-  File "/Users/max_liu/max_liu/company/Agent/examples/mode_workflow/server.py", line 324, in do_GET
-    self._err(f"{type(e).__name__}: {e}", 500)
-    ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-  File "/Users/max_liu/max_liu/company/Agent/examples/mode_workflow/server.py", line 247, in _err
-    self._json({"error": msg}, code)
-    ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^
-  File "/Users/max_liu/max_liu/company/Agent/examples/mode_workflow/server.py", line 243, in _json
-    self.end_headers()
-    ~~~~~~~~~~~~~~~~^^
-  File "/usr/local/anaconda3/lib/python3.13/http/server.py", line 538, in end_headers
-    self.flush_headers()
-    ~~~~~~~~~~~~~~~~~~^^
-  File "/usr/local/anaconda3/lib/python3.13/http/server.py", line 542, in flush_headers
-    self.wfile.write(b"".join(self._headers_buffer))
-    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-  File "/usr/local/anaconda3/lib/python3.13/socketserver.py", line 845, in write
-    self._sock.sendall(b)
-    ~~~~~~~~~~~~~~~~~~^^^
-BrokenPipeError: [Errno 32] Broken pipe
-----------------------------------------

+ 2 - 2
examples/mode_workflow/README.md

@@ -18,8 +18,8 @@ python server.py              # http://localhost:8772
 
 | 文件 | 职责 |
 |---|---|
-| `db.py` | 四表 DDL + 全部读写(读 .env MYSQL_*) |
-| `server.py` | 页面 + API + 解构任务子进程管理(端口 8772) |
+| `db.py` | 四表 DDL + 全部读写(读 .env MYSQL_*);连接走 `PooledDB` 池(远程 RDS 每次握手 ~0.5s,池复用避免每请求重连) |
+| `server.py` | 页面 + API + 解构任务子进程管理(端口 8772);`/api/dashboard` 结果带缓存(任务完成时作废 + 60s 兜底 TTL),`/api/extract` 等带 ETag/304 |
 | `index.html` | 单文件前端:Dashboard / Dataset / 聚类库 |
 | `pipeline/search_eval.py` | 任意 query 搜索+评估 → search_process / search_tools(按解构方向分表) |
 | `pipeline/procedure_extract.py` | 工序解构(LLM 直出)→ mode_process |

+ 110 - 22
examples/mode_workflow/db.py

@@ -29,20 +29,43 @@ load_dotenv()
 
 import pymysql
 from pymysql.cursors import DictCursor
+from dbutils.pooled_db import PooledDB
+
+# ── 连接池 ──────────────────────────────────────────────────────────────────
+# MySQL 是远程 RDS,每次 pymysql.connect() 的 TCP+鉴权握手 ~0.5s。旧实现每个
+# 请求新建一条连接,一次"点开帖子"要 2~3 个请求 = 2~3 次握手 ≈ 1s。改用连接池
+# 复用长连接后,握手只在池初始化时各发生一次,后续取连接近乎零开销。
+# server.py 是 ThreadingHTTPServer(每请求一线程),PooledDB 线程安全,正好匹配。
+# 注意:fetch_* 里的 conn.close() 在池连接上语义是"归还池中"而非真正断开。
+_POOL = None
+
+
+def _pool():
+    global _POOL
+    if _POOL is None:
+        if not os.getenv("MYSQL_HOST"):
+            raise RuntimeError("缺 MYSQL_HOST:检查 .env 的 MYSQL_* 配置")
+        _POOL = PooledDB(
+            creator=pymysql,
+            mincached=2,          # 启动即预热 2 条,首点不再吃冷握手
+            maxcached=5,          # 空闲保留上限
+            maxconnections=20,    # 并发上限(ThreadingHTTPServer 线程数)
+            blocking=True,        # 连接耗尽时等待而非报错
+            ping=1,               # 取用前 ping,自动剔除被 RDS 掐断的死连接
+            host=os.getenv("MYSQL_HOST"),
+            port=int(os.getenv("MYSQL_PORT", 3306)),
+            user=os.getenv("MYSQL_USER"),
+            password=os.getenv("MYSQL_PASSWORD"),
+            database=os.getenv("MYSQL_DATABASE"),
+            charset="utf8mb4", cursorclass=DictCursor,
+            autocommit=True, connect_timeout=10,
+        )
+    return _POOL
 
 
 def _conn():
-    if not os.getenv("MYSQL_HOST"):
-        raise RuntimeError("缺 MYSQL_HOST:检查 .env 的 MYSQL_* 配置")
-    return pymysql.connect(
-        host=os.getenv("MYSQL_HOST"),
-        port=int(os.getenv("MYSQL_PORT", 3306)),
-        user=os.getenv("MYSQL_USER"),
-        password=os.getenv("MYSQL_PASSWORD"),
-        database=os.getenv("MYSQL_DATABASE"),
-        charset="utf8mb4", cursorclass=DictCursor,
-        autocommit=True, connect_timeout=10,
-    )
+    """从池取一条连接;用法不变(with cursor / conn.close() 归还池)。"""
+    return _pool().connection()
 
 
 # ── DDL ──────────────────────────────────────────────────────────────────────
@@ -267,6 +290,23 @@ def is_adopted(overall, evaluation, publish_time):
     return True
 
 
+def is_adopted_rel(overall, rel, publish_time):
+    """is_adopted 的轻量版:相关性得分(rel)已由 SQL JSON_EXTRACT 直接取出,
+    无需传输/解析整块 llm_evaluation。判定口径与 is_adopted 完全一致。"""
+    try:
+        rel = float(rel) if rel is not None else None
+    except (TypeError, ValueError):
+        rel = None
+    if rel is not None and rel < 4:
+        return False
+    rh = _recency_hard(publish_time)
+    if rh is not None and rh < 2:
+        return False
+    if overall is not None and float(overall) < 6:
+        return False
+    return True
+
+
 # ── search_process / search_tools ────────────────────────────────────────────
 
 def upsert_search_posts(query_id, query_text, results, table="search_process"):
@@ -462,6 +502,11 @@ def fetch_process(case_id, version=None):
             rows = cur.fetchall()
     finally:
         conn.close()
+    return _proc_payload(case_id, version, rows)
+
+
+def _proc_payload(case_id, version, rows):
+    """mode_process 行集 → {case_id, version, …, procedures:[...]}。无行返回 None。"""
     if not rows:
         return None
     procedures = [{
@@ -538,6 +583,11 @@ def fetch_tools(case_id, version=None):
             rows = cur.fetchall()
     finally:
         conn.close()
+    return _tools_payload(case_id, version, rows)
+
+
+def _tools_payload(case_id, version, rows):
+    """mode_tools 行集 → {case_id, version, …, tools:[...]}。无行返回 None。"""
     if not rows:
         return None
     tools = [{
@@ -554,6 +604,33 @@ def fetch_tools(case_id, version=None):
             "tool_count": len(tools), "tools": tools}
 
 
+# ── 点击帖子合一查询(单连接,最少往返;远程 RDS 每次往返 ~80ms,故按次数优化)──
+
+def fetch_extract(mode, case_id, version=None):
+    """一次取版本列表 + 解构详情,复用同一条池连接、最少往返。
+    返回 {versions, data, missing}。mode: process / tools。"""
+    is_proc = mode != "tools"
+    mtable = _mode_table("process" if is_proc else "tools")
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            cur.execute(f"""SELECT version, COUNT(*) AS n, MAX(model) AS model
+                            FROM {mtable} WHERE case_id=%s
+                            GROUP BY version ORDER BY version DESC""", (case_id,))
+            versions = cur.fetchall()
+            # 详情:把"取最新版本"折进同一条 SQL,版本指定时直接用;省一次往返。
+            target = version or (versions[0]["version"] if versions else None)
+            rows = []
+            if target is not None:
+                cur.execute(f"SELECT * FROM {mtable} WHERE case_id=%s AND version=%s ORDER BY id",
+                            (case_id, target))
+                rows = cur.fetchall()
+    finally:
+        conn.close()
+    payload = (_proc_payload if is_proc else _tools_payload)(case_id, target, rows)
+    return {"versions": versions, "data": payload, "missing": payload is None}
+
+
 # ── 跨 query 去重 / link 复制(方案A:解构前先去重,避免重复花钱)──────────────
 # case_id 是帖子物理身份(platform_channelContentId),与 query 无关。同一帖被多个
 # query 搜到时只需真实解构一次;其余 query 用 link_* 复制行补齐关联(cost=0)。
@@ -613,15 +690,23 @@ def link_process(query_id, case_id, mode="process"):
 
 # ── Dashboard 原始行(指标计算在 server.py)─────────────────────────────────────
 
+# 采纳判定只需「和内容制作知识相关」的得分,用 SQL JSON_EXTRACT 直取这一个标量,
+# 避免把整块 llm_evaluation(本库 ~1.5MB)拉到 Python 再解析。得分可能直接是数字,
+# 也可能裹在 {"得分": x} 里,COALESCE 两条路径覆盖两种存法,口径同 is_adopted。
+_REL_SQL = ("JSON_UNQUOTE(COALESCE("
+            "JSON_EXTRACT(llm_evaluation,'$.\"相关性\".\"和内容制作知识相关\".\"得分\"'),"
+            "JSON_EXTRACT(llm_evaluation,'$.\"相关性\".\"和内容制作知识相关\"')))")
+
+
 def fetch_dashboard_rows():
-    """拉 Dashboard 计算所需的轻量行。数据量级:百~千行,Python 聚合足够。"""
+    """拉 Dashboard 计算所需的轻量行。数据量级:百~千行,Python 聚合足够。
+    优化:① 不传 llm_evaluation 整块,SQL 只取采纳判定要的相关性得分;
+    ② steps 只取每个 case 的最新版本(覆盖度只看最新版),历史/link_ 版本不传 steps。"""
     conn = _conn()
     try:
         with conn.cursor() as cur:
-            # 进度分母走「采纳」口径,需带上 is_adopted 判定所需字段;
-            # mode 标方向(工序帖来自 search_process,工具帖来自 search_tools)。
-            cols = ("query_id, case_id, platform, knowledge_type, "
-                    "overall_score, publish_time, llm_evaluation")
+            # 进度分母走「采纳」口径;mode 标方向(工序帖来自 search_process)。
+            cols = f"query_id, case_id, platform, overall_score, publish_time, {_REL_SQL} AS rel"
             cur.execute(f"SELECT {cols} FROM search_process")
             posts = cur.fetchall()
             for p in posts:
@@ -631,8 +716,14 @@ def fetch_dashboard_rows():
             for p in st:
                 p["mode"] = "tools"
             posts += st
-            cur.execute("""SELECT case_id, version, steps, tools_used, cost_usd,
-                                  duration_s, created_at FROM mode_process""")
+            # 成本/耗时按全部版本计;steps 仅最新版需要 → 非最新版只回 NULL,省传输。
+            cur.execute("""SELECT p.case_id, p.version, p.cost_usd, p.duration_s, p.created_at,
+                                  CASE WHEN p.version = m.maxv THEN p.steps END AS steps
+                           FROM mode_process p
+                           JOIN (SELECT case_id, MAX(version) AS maxv
+                                 FROM mode_process GROUP BY case_id) m
+                             ON p.case_id = m.case_id
+                           ORDER BY p.id""")
             procs = cur.fetchall()
             cur.execute("""SELECT case_id, version, tool_name, substance_scope,
                                   form_scope, cost_usd, duration_s, created_at
@@ -641,13 +732,10 @@ def fetch_dashboard_rows():
     finally:
         conn.close()
     for p in posts:
-        p["knowledge_type"] = _loads(p["knowledge_type"], [])
         # 采纳判定:口径同帖子列表(is_adopted),作为「需解构」分母依据
-        p["adopted"] = is_adopted(
-            p["overall_score"], _loads(p["llm_evaluation"]), p["publish_time"])
+        p["adopted"] = is_adopted_rel(p["overall_score"], p["rel"], p["publish_time"])
     for r in procs:
         r["steps"] = _loads(r["steps"], [])
-        r["tools_used"] = _loads(r["tools_used"], [])
         r["cost_usd"] = float(r["cost_usd"]) if r["cost_usd"] is not None else None
         r["created_at"] = str(r["created_at"]) if r["created_at"] else None
     for r in tools:

+ 7 - 10
examples/mode_workflow/index.html

@@ -2890,23 +2890,20 @@
       async function loadExtract() {
         if (!state.caseId) return renderExtractEmpty();
         const isProc = state.mode === "process";
-        const vURL = `/api/${isProc ? "process" : "tools"}_versions?case_id=` + encodeURIComponent(state.caseId);
-        const dURL =
-          `/api/${isProc ? "process" : "tools"}?case_id=` +
+        // 版本列表 + 解构详情合一,一个请求拿全(服务端同连接两查,ETag 命中可走 304)
+        const url =
+          `/api/extract?mode=${state.mode}&case_id=` +
           encodeURIComponent(state.caseId) +
           (state.version ? "&version=" + encodeURIComponent(state.version) : "");
         let versions = [],
           data = null,
           missing = false;
         try {
-          versions = await api(vURL);
+          const res = await api(url);
+          versions = res.versions || [];
+          data = res.data;
+          missing = res.missing || !data;
         } catch (e) {}
-        try {
-          data = await api(dURL);
-        } catch (e) {
-          if (e.status === 404) missing = true;
-          else throw e;
-        }
         renderExtractHead(versions, data, missing);
         const body = $("#xp-body");
         if (missing || !data) {

+ 62 - 1
examples/mode_workflow/server.py

@@ -6,16 +6,19 @@
   - GET  /search.html         知识检索页(聚类库 tab 内嵌;API 域名由 .env 注入)
   - GET  /api/dashboard       Dashboard 全部聚合指标(含内容树覆盖)
   - GET  /api/queries|posts|process|tools(+_versions)   Dataset 数据
+  - GET  /api/extract         点帖子合一:版本列表+解构详情(单连接,带 ETag/304)
   - POST /api/run_search|extract_process|extract_tools  起子进程跑 pipeline
   - GET  /api/task_status     轮询任务状态(读日志尾部)
 
 用法:python server.py [port]
 """
+import hashlib
 import json
 import os
 import subprocess
 import sys
 import threading
+import time
 from collections import Counter
 from datetime import datetime
 from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
@@ -71,6 +74,8 @@ def _spawn_task(kind, cmd):
         f.close()
         with _TASK_LOCK:
             TASKS[task_id]["status"] = "done" if rc == 0 else "failed"
+        # 搜索/解构任务一结束,四表数据可能变,作废 Dashboard 缓存,下次重算
+        _invalidate_dashboard()
 
     threading.Thread(target=_wait, daemon=True).start()
     return task_id
@@ -113,6 +118,31 @@ def _split_values(v):
     return out
 
 
+# ── Dashboard 结果缓存 ────────────────────────────────────────────────────────
+# Dashboard 要拉/聚合四表(本库远程 RDS 下整体 ~2s),但数据只在搜索/解构任务
+# 完成时才变。故缓存计算结果:命中即 <1ms 返回;任务结束时主动作废(见 _spawn_task),
+# 另设兜底 TTL 兜住外部直接改库的情况。
+_DASH_CACHE = {"data": None, "ts": 0.0}
+_DASH_LOCK = threading.Lock()
+_DASH_TTL = 60.0   # 秒
+
+
+def _invalidate_dashboard():
+    with _DASH_LOCK:
+        _DASH_CACHE["ts"] = 0.0
+
+
+def _dashboard_cached():
+    with _DASH_LOCK:
+        if _DASH_CACHE["data"] is not None and time.monotonic() - _DASH_CACHE["ts"] < _DASH_TTL:
+            return _DASH_CACHE["data"]
+    data = _dashboard()   # 计算放锁外,不阻塞其它请求(偶发并发重算可接受)
+    with _DASH_LOCK:
+        _DASH_CACHE["data"] = data
+        _DASH_CACHE["ts"] = time.monotonic()
+    return data
+
+
 def _dashboard():
     posts, procs, tools = db.fetch_dashboard_rows()
 
@@ -249,6 +279,27 @@ class Handler(BaseHTTPRequestHandler):
     def _err(self, msg, code=400):
         self._json({"error": msg}, code)
 
+    def _json_etag(self, data):
+        """带 ETag 的 JSON 响应:解构结果按 (case_id,version) 内容不变,
+        浏览器再次点开同一帖时带 If-None-Match 命中 304,免传 body。
+        Cache-Control: no-cache = 每次都回服务端校验(连接池下校验仅 ~50ms),
+        永不返回过期数据(重新解构后 ETag 变化即自动失效)。"""
+        body = json.dumps(data, ensure_ascii=False, default=str).encode("utf-8")
+        etag = '"' + hashlib.md5(body).hexdigest() + '"'
+        if self.headers.get("If-None-Match") == etag:
+            self.send_response(304)
+            self.send_header("ETag", etag)
+            self.send_header("Cache-Control", "no-cache")
+            self.end_headers()
+            return
+        self.send_response(200)
+        self.send_header("Content-Type", "application/json; charset=utf-8")
+        self.send_header("Content-Length", str(len(body)))
+        self.send_header("ETag", etag)
+        self.send_header("Cache-Control", "no-cache")
+        self.end_headers()
+        self.wfile.write(body)
+
     def _proxy_image(self, url):
         """同源图片反代:绕过公众号(mmbiz.qpic.cn)等站点的防盗链。
         浏览器侧 referrerpolicy=no-referrer 偶尔仍被拦,服务端直取最稳:
@@ -334,11 +385,15 @@ class Handler(BaseHTTPRequestHandler):
                 self.end_headers()
                 self.wfile.write(body)
             elif u.path == "/api/dashboard":
-                self._json(_dashboard())
+                self._json_etag(_dashboard_cached())
             elif u.path == "/api/queries":
                 self._json(db.fetch_queries(qs.get("mode", "process")))
             elif u.path == "/api/posts":
                 self._json(db.fetch_posts(qs.get("query_id", ""), qs.get("mode", "process")))
+            elif u.path == "/api/extract":
+                # 一次点击合一:单连接同时取版本列表 + 解构详情,前端少一次往返。
+                self._json_etag(db.fetch_extract(
+                    qs.get("mode", "process"), qs.get("case_id", ""), qs.get("version")))
             elif u.path == "/api/process_versions":
                 self._json(db.fetch_process_versions(qs.get("case_id", "")))
             elif u.path == "/api/process":
@@ -431,4 +486,10 @@ class Handler(BaseHTTPRequestHandler):
 
 if __name__ == "__main__":
     print(f"🚀 mode_workflow server → http://0.0.0.0:{PORT}")
+    # 预热连接池:把首请求要付的 RDS 握手提前到启动阶段(失败不阻断启动)
+    try:
+        db._conn().close()
+        print("✅ DB 连接池已预热")
+    except Exception as e:
+        print(f"⚠ 连接池预热失败(忽略,首请求会重试):{type(e).__name__}: {e}")
     ThreadingHTTPServer(("0.0.0.0", PORT), Handler).serve_forever()