| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- # -*- coding: utf-8 -*-
- """mode_procedure · MySQL 持久化(工序解构结果双写:本地文件 + 数据库)
- ================================================================================
- 读 .env 的 MYSQL_* 连接 MySQL(pymysql)。两张表,结构相同:
- agent_dsl —— Agent 工序解构结果(run_cyber,每版本一行)
- mode_dsl —— 大模型工序解构结果(单次大模型直出,每版本一行)
- 设计原则(沿用 fixed_query_eval/db.py):
- - **失败不阻断**:写库 try/except 包,DB 挂了不影响本地文件(文件是主存储)。
- - **保留历史**:(q, case_id, version) 每版本一行;同版本重跑先删本版本再插(幂等)。
- - workflow.json 嵌套深(procedures→steps→io),整份存 LONGTEXT,不拆关系行。
- 建表:跑 `python db.py init`。
- """
- import os
- import json
- import sys
- from pathlib import Path
- from datetime import datetime
- PROJECT_ROOT = Path(__file__).resolve().parents[5]
- sys.path.insert(0, str(PROJECT_ROOT))
- from dotenv import load_dotenv
- load_dotenv()
- try:
- import pymysql
- from pymysql.cursors import DictCursor
- except ImportError:
- pymysql = None
- TABLES = ("agent_dsl", "mode_dsl") # 白名单:表名不能用占位符,必须校验
- def _enabled() -> bool:
- return pymysql is not None and bool(os.getenv("MYSQL_HOST"))
- def _conn():
- return pymysql.connect(
- host=os.getenv("MYSQL_HOST"),
- port=int(os.getenv("MYSQL_PORT", 3306)),
- user=os.getenv("MYSQL_USER"),
- password=os.getenv("MYSQL_PASSWORD"),
- database=os.getenv("MYSQL_DATABASE"),
- charset="utf8mb4",
- cursorclass=DictCursor,
- autocommit=True,
- connect_timeout=10,
- )
- # ── DDL ──────────────────────────────────────────────────────────────────────
- def _ddl(table, comment):
- return f"""
- CREATE TABLE IF NOT EXISTS {table} (
- id BIGINT AUTO_INCREMENT PRIMARY KEY,
- q VARCHAR(255) NOT NULL COMMENT 'query 目录名(自由中文)',
- case_id VARCHAR(128) NOT NULL COMMENT 'platform_channelContentId',
- platform VARCHAR(32) NULL,
- post_title VARCHAR(512) NULL,
- source_link VARCHAR(1024) NULL,
- model VARCHAR(64) NULL COMMENT '提取模型',
- version VARCHAR(16) NULL COMMENT 'v_MMDDHHMM(每次生成;保留历史多版本共存)',
- procedure_count INT NULL COMMENT '工序数',
- workflow_json LONGTEXT NULL COMMENT '整份 workflow.json',
- source_json LONGTEXT NULL COMMENT '_source.json(喂给模型的帖子源)',
- created_time VARCHAR(64) NULL,
- KEY idx_q_case_ver (q(128), case_id, version)
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='{comment}';
- """
- def init_tables():
- """建表(幂等)。"""
- if not _enabled():
- print("⚠️ MySQL 未启用(缺 pymysql 或 MYSQL_HOST),跳过建表")
- return False
- conn = _conn()
- try:
- with conn.cursor() as cur:
- cur.execute(_ddl("agent_dsl", "Agent 工序解构结果(每版本一行)"))
- cur.execute(_ddl("mode_dsl", "大模型工序解构结果(每版本一行)"))
- print("✅ 建表完成:agent_dsl, mode_dsl")
- return True
- finally:
- conn.close()
- # ── 写入 ─────────────────────────────────────────────────────────────────────
- def upsert_dsl(table, q, case_id, version, model, workflow, source=None,
- platform=None, post_title=None, source_link=None):
- """写入一帖某版本的工序解构结果。**保留历史**:只删本 (q,case_id,version) 旧行再插。
- table 必须是 'agent_dsl' / 'mode_dsl'。失败返回 0(不阻断本地文件)。"""
- if not _enabled():
- return 0
- if table not in TABLES:
- raise ValueError(f"非法表名: {table}")
- procs = (workflow or {}).get("procedures") or []
- try:
- conn = _conn()
- try:
- with conn.cursor() as cur:
- cur.execute(f"DELETE FROM {table} WHERE q=%s AND case_id=%s AND version=%s",
- (q, case_id, version))
- cur.execute(f"""
- INSERT INTO {table}
- (q, case_id, platform, post_title, source_link, model, version,
- procedure_count, workflow_json, source_json, created_time)
- VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
- """, (
- q, case_id, platform, (post_title or "")[:500], source_link, model, version,
- len(procs),
- json.dumps(workflow, ensure_ascii=False) if workflow is not None else None,
- json.dumps(source, ensure_ascii=False) if source is not None else None,
- datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- ))
- return 1
- finally:
- conn.close()
- except Exception as ex:
- print(f"⚠️ {table} 写库失败(不影响本地文件):{ex}")
- return 0
- # ── 读取 ─────────────────────────────────────────────────────────────────────
- def fetch_versions(table, q, case_id):
- """列出某帖所有版本(DESC,最新在前)。返回 [{version, model, procedure_count, created_time}]。"""
- if not _enabled() or table not in TABLES:
- return []
- try:
- conn = _conn()
- try:
- with conn.cursor() as cur:
- cur.execute(f"""SELECT version, model, procedure_count, created_time
- FROM {table} WHERE q=%s AND case_id=%s
- ORDER BY version DESC""", (q, case_id))
- return list(cur.fetchall())
- finally:
- conn.close()
- except Exception as ex:
- print(f"⚠️ {table} 读版本失败:{ex}")
- return []
- def fetch_dsl(table, q, case_id, version=None):
- """取某帖某版本的 workflow(version=None 取最新)。返回 dict 或 None。"""
- if not _enabled() or table not in TABLES:
- return None
- try:
- conn = _conn()
- try:
- with conn.cursor() as cur:
- if version:
- cur.execute(f"SELECT * FROM {table} WHERE q=%s AND case_id=%s AND version=%s",
- (q, case_id, version))
- else:
- cur.execute(f"""SELECT * FROM {table} WHERE q=%s AND case_id=%s
- ORDER BY version DESC LIMIT 1""", (q, case_id))
- row = cur.fetchone()
- if row and row.get("workflow_json"):
- row["workflow"] = _loads(row["workflow_json"])
- return row
- finally:
- conn.close()
- except Exception as ex:
- print(f"⚠️ {table} 读取失败:{ex}")
- return None
- def _loads(v, default=None):
- if v is None:
- return default
- if isinstance(v, (list, dict)):
- return v
- try:
- return json.loads(v)
- except Exception:
- return default
- if __name__ == "__main__":
- cmd = sys.argv[1] if len(sys.argv) > 1 else "init"
- if cmd == "init":
- init_tables()
- else:
- print(f"用法: python db.py init")
|