Procházet zdrojové kódy

feat(mode_workflow): 三表 DDL 与 MySQL 读写层

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
刘文武 před 5 dny
rodič
revize
891358dd7c
1 změnil soubory, kde provedl 513 přidání a 0 odebrání
  1. 513 0
      examples/mode_workflow/db.py

+ 513 - 0
examples/mode_workflow/db.py

@@ -0,0 +1,513 @@
+# -*- coding: utf-8 -*-
+"""mode_workflow · MySQL 持久化(DB 为唯一事实源)
+================================================================================
+读 .env 的 MYSQL_* 连接 MySQL。三张表:
+  search_data  —— 每行一个 (query, 帖子):搜索 + llm 评估结果
+  mode_process —— 每行一个解构出的工序(steps 等嵌套结构存 JSON 列)
+  mode_tools   —— 每行一个解构出的工具
+
+与旧 fixed_query_eval/db.py 的关键差异:本系统 DB 是主存储,写入失败直接 raise,
+不做"失败不阻断"。读侧保留防御(返回空/None)。
+
+用法:
+  python db.py init    # 建表(幂等)
+  python db.py check   # 打印三表行数
+"""
+import json
+import os
+import sys
+from pathlib import Path
+
+PROJECT_ROOT = Path(__file__).resolve().parents[2]
+sys.path.insert(0, str(PROJECT_ROOT))
+
+from dotenv import load_dotenv
+load_dotenv()
+
+import pymysql
+from pymysql.cursors import DictCursor
+
+
+def _conn():
+    if not os.getenv("MYSQL_HOST"):
+        raise RuntimeError("缺 MYSQL_HOST:检查 .env 的 MYSQL_* 配置")
+    return pymysql.connect(
+        host=os.getenv("MYSQL_HOST"),
+        port=int(os.getenv("MYSQL_PORT", 3306)),
+        user=os.getenv("MYSQL_USER"),
+        password=os.getenv("MYSQL_PASSWORD"),
+        database=os.getenv("MYSQL_DATABASE"),
+        charset="utf8mb4", cursorclass=DictCursor,
+        autocommit=True, connect_timeout=10,
+    )
+
+
+# ── DDL ──────────────────────────────────────────────────────────────────────
+
+DDL_SEARCH = """
+CREATE TABLE IF NOT EXISTS search_data (
+  id            BIGINT AUTO_INCREMENT PRIMARY KEY,
+  query_id      VARCHAR(32)   NOT NULL COMMENT 'q0000',
+  query_text    VARCHAR(512)  NULL,
+  case_id       VARCHAR(128)  NOT NULL COMMENT 'platform_channelContentId',
+  platform      VARCHAR(32)   NULL,
+  channel_content_id VARCHAR(128) NULL,
+  title         VARCHAR(512)  NULL,
+  url           VARCHAR(1024) NULL,
+  content_type  VARCHAR(32)   NULL,
+  body          LONGTEXT      NULL,
+  images        JSON          NULL,
+  videos        JSON          NULL,
+  like_count    INT           NULL,
+  publish_time  VARCHAR(64)   NULL,
+  quality_score FLOAT         NULL COMMENT 'post._quality_score',
+  quality_grade VARCHAR(8)    NULL,
+  found_by      JSON          NULL COMMENT '命中的措辞数组',
+  knowledge_type JSON         NULL COMMENT '["能力","工序","工具"] 子集',
+  overall_score FLOAT         NULL COMMENT '(相关均值+质量均值)/2',
+  llm_evaluation JSON         NULL COMMENT '评估全量 blob',
+  created_at    TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+  updated_at    TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+  UNIQUE KEY uk_qid_case (query_id, case_id),
+  KEY idx_platform (platform)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='搜索+评估结果';
+"""
+
+DDL_PROCESS = """
+CREATE TABLE IF NOT EXISTS mode_process (
+  id            BIGINT AUTO_INCREMENT PRIMARY KEY,
+  query_id      VARCHAR(32)   NOT NULL,
+  case_id       VARCHAR(128)  NOT NULL,
+  platform      VARCHAR(32)   NULL,
+  post_title    VARCHAR(512)  NULL,
+  source        JSON          NULL COMMENT '解构返回的 source 块',
+  procedure_id  VARCHAR(16)   NULL COMMENT 'p1,p2…',
+  name          VARCHAR(255)  NULL,
+  purpose       TEXT          NULL,
+  category      VARCHAR(32)   NULL COMMENT '产物创造/资产建设/自动化/分析/学习',
+  declarations  JSON          NULL,
+  type_registry JSON          NULL,
+  steps         JSON          NULL COMMENT '步骤数组全量',
+  step_count    INT           NULL,
+  tools_used    JSON          NULL COMMENT '从 steps[].via 去重提取',
+  model         VARCHAR(64)   NULL,
+  version       VARCHAR(16)   NULL COMMENT 'v_MMDDHHMM,保留历史',
+  cost_usd      DECIMAL(10,6) NULL COMMENT '本次解构调用成本(同版本各行相同,聚合需按 case+version 去重)',
+  duration_s    FLOAT         NULL,
+  created_at    TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+  KEY idx_case_ver (case_id, version),
+  KEY idx_qid (query_id)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='工序解构结果(每行一个工序)';
+"""
+
+DDL_TOOLS = """
+CREATE TABLE IF NOT EXISTS mode_tools (
+  id            BIGINT AUTO_INCREMENT PRIMARY KEY,
+  query_id      VARCHAR(32)   NOT NULL,
+  case_id       VARCHAR(128)  NOT NULL,
+  platform      VARCHAR(32)   NULL,
+  post_title    VARCHAR(512)  NULL,
+  tool_name     VARCHAR(255)  NULL,
+  substance_scope JSON        NULL COMMENT '实质作用域(数组)',
+  form_scope    JSON          NULL COMMENT '形式作用域(数组或null)',
+  creation_layer VARCHAR(32)  NULL COMMENT '制作层/创作层',
+  source_link   VARCHAR(1024) NULL,
+  input_desc    TEXT          NULL,
+  output_desc   TEXT          NULL,
+  usage_json    JSON          NULL,
+  cases_json    JSON          NULL,
+  defects_json  JSON          NULL,
+  updated_time  VARCHAR(64)   NULL COMMENT '工具最新更新时间',
+  model         VARCHAR(64)   NULL,
+  version       VARCHAR(16)   NULL,
+  cost_usd      DECIMAL(10,6) NULL COMMENT '同 mode_process,聚合按 case+version 去重',
+  duration_s    FLOAT         NULL,
+  created_at    TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+  KEY idx_case_ver (case_id, version),
+  KEY idx_qid (query_id),
+  KEY idx_tool_name (tool_name)
+) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='工具解构结果(每行一个工具)';
+"""
+
+
+def init_tables():
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            cur.execute(DDL_SEARCH)
+            cur.execute(DDL_PROCESS)
+            cur.execute(DDL_TOOLS)
+        print("✅ 建表完成:search_data, mode_process, mode_tools")
+    finally:
+        conn.close()
+
+
+# ── 工具函数 ──────────────────────────────────────────────────────────────────
+
+def _loads(v, default=None):
+    """pymysql 的 JSON 列可能返回字符串,统一解析。"""
+    if v is None:
+        return default
+    if isinstance(v, (list, dict)):
+        return v
+    try:
+        return json.loads(v)
+    except Exception:
+        return default
+
+
+def _j(v):
+    """写入 JSON 列:None 保持 NULL,其余 dumps。"""
+    return None if v is None else json.dumps(v, ensure_ascii=False)
+
+
+def _collect_scores(node):
+    """递归收集嵌套评估里所有数值「得分」。"""
+    out = []
+    if isinstance(node, dict):
+        for k, v in node.items():
+            if k == "得分" and isinstance(v, (int, float)):
+                out.append(float(v))
+            else:
+                out.extend(_collect_scores(v))
+    elif isinstance(node, list):
+        for v in node:
+            out.extend(_collect_scores(v))
+    return out
+
+
+def overall_score(e):
+    """综合分 = (相关性各项均值 + 质量各项均值) / 可得部分数。算不出返回 None。"""
+    parts = []
+    for key in ("相关性", "质量"):
+        scores = _collect_scores((e or {}).get(key))
+        if scores:
+            parts.append(sum(scores) / len(scores))
+    return round(sum(parts) / len(parts), 2) if parts else None
+
+
+# ── search_data ──────────────────────────────────────────────────────────────
+
+def upsert_search_posts(query_id, query_text, results):
+    """一组搜索结果写入 search_data(按 (query_id, case_id) upsert)。返回写入条数。"""
+    if not results:
+        return 0
+    rows = []
+    for r in results:
+        post = r.get("post") or {}
+        e = r.get("llm_evaluation") or {}
+        rows.append((
+            query_id, query_text, r.get("case_id"), r.get("platform"),
+            r.get("channel_content_id"),
+            (post.get("title") or post.get("desc") or "")[:500],
+            r.get("source_url"), post.get("content_type"),
+            post.get("body_text") or post.get("desc") or "",
+            _j(post.get("images") or []), _j(post.get("videos") or []),
+            post.get("like_count"),
+            str(post.get("publish_time") or post.get("publish_timestamp") or "")[:64],
+            post.get("_quality_score"), post.get("_quality_grade"),
+            _j(r.get("found_by_queries") or []),
+            _j(e.get("知识类型") or []),
+            overall_score(e),
+            _j(e),
+        ))
+    sql = """
+    INSERT INTO search_data
+      (query_id, query_text, case_id, platform, channel_content_id, title, url,
+       content_type, body, images, videos, like_count, publish_time,
+       quality_score, quality_grade, found_by, knowledge_type, overall_score, llm_evaluation)
+    VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
+    ON DUPLICATE KEY UPDATE
+      query_text=VALUES(query_text), platform=VALUES(platform),
+      channel_content_id=VALUES(channel_content_id), title=VALUES(title), url=VALUES(url),
+      content_type=VALUES(content_type), body=VALUES(body), images=VALUES(images),
+      videos=VALUES(videos), like_count=VALUES(like_count), publish_time=VALUES(publish_time),
+      quality_score=VALUES(quality_score), quality_grade=VALUES(quality_grade),
+      found_by=VALUES(found_by), knowledge_type=VALUES(knowledge_type),
+      overall_score=VALUES(overall_score), llm_evaluation=VALUES(llm_evaluation);
+    """
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            cur.executemany(sql, rows)
+        return len(rows)
+    finally:
+        conn.close()
+
+
+def fetch_queries():
+    """query 列表 + 帖子数 + 解构进度。"""
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            cur.execute("""SELECT query_id, MAX(query_text) AS query_text,
+                                  COUNT(*) AS post_count
+                           FROM search_data GROUP BY query_id ORDER BY query_id""")
+            queries = cur.fetchall()
+            cur.execute("SELECT query_id, COUNT(DISTINCT case_id) AS n FROM mode_process GROUP BY query_id")
+            np = {r["query_id"]: r["n"] for r in cur.fetchall()}
+            cur.execute("SELECT query_id, COUNT(DISTINCT case_id) AS n FROM mode_tools GROUP BY query_id")
+            nt = {r["query_id"]: r["n"] for r in cur.fetchall()}
+    finally:
+        conn.close()
+    for q in queries:
+        q["process_done"] = np.get(q["query_id"], 0)
+        q["tools_done"] = nt.get(q["query_id"], 0)
+    return queries
+
+
+def fetch_posts(query_id):
+    """某 query 下全部帖子(JSON 列已解析),带 has_process/has_tools 标记。"""
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            cur.execute("""SELECT * FROM search_data WHERE query_id=%s
+                           ORDER BY overall_score DESC, id""", (query_id,))
+            rows = cur.fetchall()
+            cur.execute("SELECT DISTINCT case_id FROM mode_process WHERE query_id=%s", (query_id,))
+            hp = {r["case_id"] for r in cur.fetchall()}
+            cur.execute("SELECT DISTINCT case_id FROM mode_tools WHERE query_id=%s", (query_id,))
+            ht = {r["case_id"] for r in cur.fetchall()}
+    finally:
+        conn.close()
+    for r in rows:
+        for col in ("images", "videos", "found_by", "knowledge_type", "llm_evaluation"):
+            r[col] = _loads(r[col])
+        r["has_process"] = r["case_id"] in hp
+        r["has_tools"] = r["case_id"] in ht
+        r.pop("created_at", None); r.pop("updated_at", None)
+    return rows
+
+
+def fetch_post(query_id, case_id):
+    """单帖完整行(给 pipeline 脚本重建 source 用)。无则 None。"""
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            cur.execute("SELECT * FROM search_data WHERE query_id=%s AND case_id=%s",
+                        (query_id, case_id))
+            row = cur.fetchone()
+    finally:
+        conn.close()
+    if not row:
+        return None
+    for col in ("images", "videos", "found_by", "knowledge_type", "llm_evaluation"):
+        row[col] = _loads(row[col])
+    return row
+
+
+# ── mode_process ─────────────────────────────────────────────────────────────
+
+def replace_process(query_id, case_id, platform, post_title, payload,
+                    model, version, cost_usd, duration_s):
+    """写入一帖某版本的工序解构结果(payload = {source, procedures})。
+    删 (case_id, version) 旧行再插,同版本重跑幂等、跨版本保留历史。返回工序条数。"""
+    source = payload.get("source")
+    procedures = payload.get("procedures") or []
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            cur.execute("DELETE FROM mode_process WHERE case_id=%s AND version=%s",
+                        (case_id, version))
+            if procedures:
+                rows = []
+                for p in procedures:
+                    steps = p.get("steps") or []
+                    vias = []
+                    for s in steps:
+                        v = s.get("via")
+                        if v and v not in vias:
+                            vias.append(v)
+                    rows.append((
+                        query_id, case_id, platform, (post_title or "")[:500],
+                        _j(source), p.get("id"), (p.get("name") or "")[:250],
+                        p.get("purpose"), p.get("category"),
+                        _j(p.get("declarations")), _j(p.get("type_registry")),
+                        _j(steps), len(steps), _j(vias),
+                        model, version, cost_usd, duration_s,
+                    ))
+                cur.executemany("""
+                INSERT INTO mode_process
+                  (query_id, case_id, platform, post_title, source, procedure_id, name,
+                   purpose, category, declarations, type_registry, steps, step_count,
+                   tools_used, model, version, cost_usd, duration_s)
+                VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
+                """, rows)
+        return len(procedures)
+    finally:
+        conn.close()
+
+
+def fetch_process_versions(case_id):
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            cur.execute("""SELECT version, COUNT(*) AS n, MAX(model) AS model
+                           FROM mode_process WHERE case_id=%s
+                           GROUP BY version ORDER BY version DESC""", (case_id,))
+            return cur.fetchall()
+    finally:
+        conn.close()
+
+
+def fetch_process(case_id, version=None):
+    """重建 {case_id, version, model, source, procedures:[...]}。version=None 取最新。"""
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            if version is None:
+                cur.execute("""SELECT version FROM mode_process WHERE case_id=%s
+                               ORDER BY version DESC, id DESC LIMIT 1""", (case_id,))
+                row = cur.fetchone()
+                if not row:
+                    return None
+                version = row["version"]
+            cur.execute("""SELECT * FROM mode_process WHERE case_id=%s AND version=%s
+                           ORDER BY id""", (case_id, version))
+            rows = cur.fetchall()
+    finally:
+        conn.close()
+    if not rows:
+        return None
+    procedures = [{
+        "id": r["procedure_id"], "name": r["name"], "purpose": r["purpose"],
+        "category": r["category"], "declarations": _loads(r["declarations"]),
+        "type_registry": _loads(r["type_registry"]), "steps": _loads(r["steps"], []),
+        "tools_used": _loads(r["tools_used"], []),
+    } for r in rows]
+    return {"case_id": case_id, "version": version, "platform": rows[0]["platform"],
+            "title": rows[0]["post_title"], "model": rows[0]["model"],
+            "cost_usd": float(rows[0]["cost_usd"]) if rows[0]["cost_usd"] is not None else None,
+            "duration_s": rows[0]["duration_s"],
+            "source": _loads(rows[0]["source"]), "procedures": procedures}
+
+
+# ── mode_tools ───────────────────────────────────────────────────────────────
+
+def replace_tools(query_id, case_id, platform, post_title, tools,
+                  model, version, cost_usd, duration_s):
+    """写入一帖某版本的工具解构结果。语义同 replace_process。返回工具条数。"""
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            cur.execute("DELETE FROM mode_tools WHERE case_id=%s AND version=%s",
+                        (case_id, version))
+            if tools:
+                rows = [(
+                    query_id, case_id, platform, (post_title or "")[:500],
+                    (t.get("工具名称") or "")[:250],
+                    _j(t.get("实质作用域")), _j(t.get("形式作用域")),
+                    t.get("创作层级"), t.get("来源链接"), t.get("输入"), t.get("输出"),
+                    _j(t.get("用法")), _j(t.get("案例")), _j(t.get("缺点")),
+                    t.get("最新更新时间"), model, version, cost_usd, duration_s,
+                ) for t in tools]
+                cur.executemany("""
+                INSERT INTO mode_tools
+                  (query_id, case_id, platform, post_title, tool_name, substance_scope,
+                   form_scope, creation_layer, source_link, input_desc, output_desc,
+                   usage_json, cases_json, defects_json, updated_time, model, version,
+                   cost_usd, duration_s)
+                VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
+                """, rows)
+        return len(tools)
+    finally:
+        conn.close()
+
+
+def fetch_tools_versions(case_id):
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            cur.execute("""SELECT version, COUNT(*) AS n, MAX(model) AS model
+                           FROM mode_tools WHERE case_id=%s
+                           GROUP BY version ORDER BY version DESC""", (case_id,))
+            return cur.fetchall()
+    finally:
+        conn.close()
+
+
+def fetch_tools(case_id, version=None):
+    """重建 {case_id, version, model, tool_count, tools:[...]}。version=None 取最新。"""
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            if version is None:
+                cur.execute("""SELECT version FROM mode_tools WHERE case_id=%s
+                               ORDER BY version DESC, id DESC LIMIT 1""", (case_id,))
+                row = cur.fetchone()
+                if not row:
+                    return None
+                version = row["version"]
+            cur.execute("""SELECT * FROM mode_tools WHERE case_id=%s AND version=%s
+                           ORDER BY id""", (case_id, version))
+            rows = cur.fetchall()
+    finally:
+        conn.close()
+    if not rows:
+        return None
+    tools = [{
+        "工具名称": r["tool_name"], "实质作用域": _loads(r["substance_scope"]),
+        "形式作用域": _loads(r["form_scope"]), "创作层级": r["creation_layer"],
+        "来源链接": r["source_link"], "输入": r["input_desc"], "输出": r["output_desc"],
+        "用法": _loads(r["usage_json"]), "案例": _loads(r["cases_json"]),
+        "缺点": _loads(r["defects_json"]), "最新更新时间": r["updated_time"],
+    } for r in rows]
+    return {"case_id": case_id, "version": version, "platform": rows[0]["platform"],
+            "title": rows[0]["post_title"], "model": rows[0]["model"],
+            "cost_usd": float(rows[0]["cost_usd"]) if rows[0]["cost_usd"] is not None else None,
+            "duration_s": rows[0]["duration_s"],
+            "tool_count": len(tools), "tools": tools}
+
+
+# ── Dashboard 原始行(指标计算在 server.py)─────────────────────────────────────
+
+def fetch_dashboard_rows():
+    """拉 Dashboard 计算所需的轻量行。数据量级:百~千行,Python 聚合足够。"""
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            cur.execute("SELECT query_id, case_id, knowledge_type FROM search_data")
+            posts = cur.fetchall()
+            cur.execute("""SELECT case_id, version, steps, tools_used, cost_usd,
+                                  duration_s, created_at FROM mode_process""")
+            procs = cur.fetchall()
+            cur.execute("""SELECT case_id, version, tool_name, substance_scope,
+                                  form_scope, cost_usd, duration_s, created_at
+                           FROM mode_tools""")
+            tools = cur.fetchall()
+    finally:
+        conn.close()
+    for p in posts:
+        p["knowledge_type"] = _loads(p["knowledge_type"], [])
+    for r in procs:
+        r["steps"] = _loads(r["steps"], [])
+        r["tools_used"] = _loads(r["tools_used"], [])
+        r["cost_usd"] = float(r["cost_usd"]) if r["cost_usd"] is not None else None
+        r["created_at"] = str(r["created_at"]) if r["created_at"] else None
+    for r in tools:
+        r["substance_scope"] = _loads(r["substance_scope"], [])
+        r["form_scope"] = _loads(r["form_scope"], [])
+        r["cost_usd"] = float(r["cost_usd"]) if r["cost_usd"] is not None else None
+        r["created_at"] = str(r["created_at"]) if r["created_at"] else None
+    return posts, procs, tools
+
+
+def check():
+    conn = _conn()
+    try:
+        with conn.cursor() as cur:
+            for t in ("search_data", "mode_process", "mode_tools"):
+                cur.execute(f"SELECT COUNT(*) AS n FROM {t}")
+                print(f"{t}: {cur.fetchone()['n']} 行")
+    finally:
+        conn.close()
+
+
+if __name__ == "__main__":
+    cmd = sys.argv[1] if len(sys.argv) > 1 else ""
+    if cmd == "init":
+        init_tables()
+    elif cmd == "check":
+        check()
+    else:
+        print("用法:\n  python db.py init    # 建表\n  python db.py check   # 三表行数")