# -*- coding: utf-8 -*- """mode_procedure · MySQL 持久化(工序解构结果双写:本地文件 + 数据库) ================================================================================ 读 .env 的 MYSQL_* 连接 MySQL(pymysql)。两张表,结构相同: agent_dsl —— Agent 工序解构结果(run_cyber,每版本一行) mode_dsl —— 大模型工序解构结果(单次大模型直出,每版本一行) 设计原则(沿用 fixed_query_eval/db.py): - **失败不阻断**:写库 try/except 包,DB 挂了不影响本地文件(文件是主存储)。 - **保留历史**:(q, case_id, version) 每版本一行;同版本重跑先删本版本再插(幂等)。 - workflow.json 嵌套深(procedures→steps→io),整份存 LONGTEXT,不拆关系行。 建表:跑 `python db.py init`。 """ import os import json import sys from pathlib import Path from datetime import datetime PROJECT_ROOT = Path(__file__).resolve().parents[5] sys.path.insert(0, str(PROJECT_ROOT)) from dotenv import load_dotenv load_dotenv() try: import pymysql from pymysql.cursors import DictCursor except ImportError: pymysql = None TABLES = ("agent_dsl", "mode_dsl") # 白名单:表名不能用占位符,必须校验 def _enabled() -> bool: return pymysql is not None and bool(os.getenv("MYSQL_HOST")) def _conn(): return pymysql.connect( host=os.getenv("MYSQL_HOST"), port=int(os.getenv("MYSQL_PORT", 3306)), user=os.getenv("MYSQL_USER"), password=os.getenv("MYSQL_PASSWORD"), database=os.getenv("MYSQL_DATABASE"), charset="utf8mb4", cursorclass=DictCursor, autocommit=True, connect_timeout=10, ) # ── DDL ────────────────────────────────────────────────────────────────────── def _ddl(table, comment): return f""" CREATE TABLE IF NOT EXISTS {table} ( id BIGINT AUTO_INCREMENT PRIMARY KEY, q VARCHAR(255) NOT NULL COMMENT 'query 目录名(自由中文)', case_id VARCHAR(128) NOT NULL COMMENT 'platform_channelContentId', platform VARCHAR(32) NULL, post_title VARCHAR(512) NULL, source_link VARCHAR(1024) NULL, model VARCHAR(64) NULL COMMENT '提取模型', version VARCHAR(16) NULL COMMENT 'v_MMDDHHMM(每次生成;保留历史多版本共存)', procedure_count INT NULL COMMENT '工序数', workflow_json LONGTEXT NULL COMMENT '整份 workflow.json', source_json LONGTEXT NULL COMMENT '_source.json(喂给模型的帖子源)', created_time VARCHAR(64) NULL, KEY idx_q_case_ver (q(128), case_id, version) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='{comment}'; """ def init_tables(): """建表(幂等)。""" if not _enabled(): print("⚠️ MySQL 未启用(缺 pymysql 或 MYSQL_HOST),跳过建表") return False conn = _conn() try: with conn.cursor() as cur: cur.execute(_ddl("agent_dsl", "Agent 工序解构结果(每版本一行)")) cur.execute(_ddl("mode_dsl", "大模型工序解构结果(每版本一行)")) print("✅ 建表完成:agent_dsl, mode_dsl") return True finally: conn.close() # ── 写入 ───────────────────────────────────────────────────────────────────── def upsert_dsl(table, q, case_id, version, model, workflow, source=None, platform=None, post_title=None, source_link=None): """写入一帖某版本的工序解构结果。**保留历史**:只删本 (q,case_id,version) 旧行再插。 table 必须是 'agent_dsl' / 'mode_dsl'。失败返回 0(不阻断本地文件)。""" if not _enabled(): return 0 if table not in TABLES: raise ValueError(f"非法表名: {table}") procs = (workflow or {}).get("procedures") or [] try: conn = _conn() try: with conn.cursor() as cur: cur.execute(f"DELETE FROM {table} WHERE q=%s AND case_id=%s AND version=%s", (q, case_id, version)) cur.execute(f""" INSERT INTO {table} (q, case_id, platform, post_title, source_link, model, version, procedure_count, workflow_json, source_json, created_time) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s) """, ( q, case_id, platform, (post_title or "")[:500], source_link, model, version, len(procs), json.dumps(workflow, ensure_ascii=False) if workflow is not None else None, json.dumps(source, ensure_ascii=False) if source is not None else None, datetime.now().strftime("%Y-%m-%d %H:%M:%S"), )) return 1 finally: conn.close() except Exception as ex: print(f"⚠️ {table} 写库失败(不影响本地文件):{ex}") return 0 # ── 读取 ───────────────────────────────────────────────────────────────────── def fetch_versions(table, q, case_id): """列出某帖所有版本(DESC,最新在前)。返回 [{version, model, procedure_count, created_time}]。""" if not _enabled() or table not in TABLES: return [] try: conn = _conn() try: with conn.cursor() as cur: cur.execute(f"""SELECT version, model, procedure_count, created_time FROM {table} WHERE q=%s AND case_id=%s ORDER BY version DESC""", (q, case_id)) return list(cur.fetchall()) finally: conn.close() except Exception as ex: print(f"⚠️ {table} 读版本失败:{ex}") return [] def fetch_dsl(table, q, case_id, version=None): """取某帖某版本的 workflow(version=None 取最新)。返回 dict 或 None。""" if not _enabled() or table not in TABLES: return None try: conn = _conn() try: with conn.cursor() as cur: if version: cur.execute(f"SELECT * FROM {table} WHERE q=%s AND case_id=%s AND version=%s", (q, case_id, version)) else: cur.execute(f"""SELECT * FROM {table} WHERE q=%s AND case_id=%s ORDER BY version DESC LIMIT 1""", (q, case_id)) row = cur.fetchone() if row and row.get("workflow_json"): row["workflow"] = _loads(row["workflow_json"]) return row finally: conn.close() except Exception as ex: print(f"⚠️ {table} 读取失败:{ex}") return None def _loads(v, default=None): if v is None: return default if isinstance(v, (list, dict)): return v try: return json.loads(v) except Exception: return default if __name__ == "__main__": cmd = sys.argv[1] if len(sys.argv) > 1 else "init" if cmd == "init": init_tables() else: print(f"用法: python db.py init")