db.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # -*- coding: utf-8 -*-
  2. """mode_procedure · MySQL 持久化(工序解构结果双写:本地文件 + 数据库)
  3. ================================================================================
  4. 读 .env 的 MYSQL_* 连接 MySQL(pymysql)。两张表,结构相同:
  5. agent_dsl —— Agent 工序解构结果(run_cyber,每版本一行)
  6. mode_dsl —— 大模型工序解构结果(单次大模型直出,每版本一行)
  7. 设计原则(沿用 fixed_query_eval/db.py):
  8. - **失败不阻断**:写库 try/except 包,DB 挂了不影响本地文件(文件是主存储)。
  9. - **保留历史**:(q, case_id, version) 每版本一行;同版本重跑先删本版本再插(幂等)。
  10. - workflow.json 嵌套深(procedures→steps→io),整份存 LONGTEXT,不拆关系行。
  11. 建表:跑 `python db.py init`。
  12. """
  13. import os
  14. import json
  15. import sys
  16. from pathlib import Path
  17. from datetime import datetime
  18. PROJECT_ROOT = Path(__file__).resolve().parents[5]
  19. sys.path.insert(0, str(PROJECT_ROOT))
  20. from dotenv import load_dotenv
  21. load_dotenv()
  22. try:
  23. import pymysql
  24. from pymysql.cursors import DictCursor
  25. except ImportError:
  26. pymysql = None
  27. TABLES = ("agent_dsl", "mode_dsl") # 白名单:表名不能用占位符,必须校验
  28. def _enabled() -> bool:
  29. return pymysql is not None and bool(os.getenv("MYSQL_HOST"))
  30. def _conn():
  31. return pymysql.connect(
  32. host=os.getenv("MYSQL_HOST"),
  33. port=int(os.getenv("MYSQL_PORT", 3306)),
  34. user=os.getenv("MYSQL_USER"),
  35. password=os.getenv("MYSQL_PASSWORD"),
  36. database=os.getenv("MYSQL_DATABASE"),
  37. charset="utf8mb4",
  38. cursorclass=DictCursor,
  39. autocommit=True,
  40. connect_timeout=10,
  41. )
  42. # ── DDL ──────────────────────────────────────────────────────────────────────
  43. def _ddl(table, comment):
  44. return f"""
  45. CREATE TABLE IF NOT EXISTS {table} (
  46. id BIGINT AUTO_INCREMENT PRIMARY KEY,
  47. q VARCHAR(255) NOT NULL COMMENT 'query 目录名(自由中文)',
  48. case_id VARCHAR(128) NOT NULL COMMENT 'platform_channelContentId',
  49. platform VARCHAR(32) NULL,
  50. post_title VARCHAR(512) NULL,
  51. source_link VARCHAR(1024) NULL,
  52. model VARCHAR(64) NULL COMMENT '提取模型',
  53. version VARCHAR(16) NULL COMMENT 'v_MMDDHHMM(每次生成;保留历史多版本共存)',
  54. procedure_count INT NULL COMMENT '工序数',
  55. workflow_json LONGTEXT NULL COMMENT '整份 workflow.json',
  56. source_json LONGTEXT NULL COMMENT '_source.json(喂给模型的帖子源)',
  57. created_time VARCHAR(64) NULL,
  58. KEY idx_q_case_ver (q(128), case_id, version)
  59. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='{comment}';
  60. """
  61. def init_tables():
  62. """建表(幂等)。"""
  63. if not _enabled():
  64. print("⚠️ MySQL 未启用(缺 pymysql 或 MYSQL_HOST),跳过建表")
  65. return False
  66. conn = _conn()
  67. try:
  68. with conn.cursor() as cur:
  69. cur.execute(_ddl("agent_dsl", "Agent 工序解构结果(每版本一行)"))
  70. cur.execute(_ddl("mode_dsl", "大模型工序解构结果(每版本一行)"))
  71. print("✅ 建表完成:agent_dsl, mode_dsl")
  72. return True
  73. finally:
  74. conn.close()
  75. # ── 写入 ─────────────────────────────────────────────────────────────────────
  76. def upsert_dsl(table, q, case_id, version, model, workflow, source=None,
  77. platform=None, post_title=None, source_link=None):
  78. """写入一帖某版本的工序解构结果。**保留历史**:只删本 (q,case_id,version) 旧行再插。
  79. table 必须是 'agent_dsl' / 'mode_dsl'。失败返回 0(不阻断本地文件)。"""
  80. if not _enabled():
  81. return 0
  82. if table not in TABLES:
  83. raise ValueError(f"非法表名: {table}")
  84. procs = (workflow or {}).get("procedures") or []
  85. try:
  86. conn = _conn()
  87. try:
  88. with conn.cursor() as cur:
  89. cur.execute(f"DELETE FROM {table} WHERE q=%s AND case_id=%s AND version=%s",
  90. (q, case_id, version))
  91. cur.execute(f"""
  92. INSERT INTO {table}
  93. (q, case_id, platform, post_title, source_link, model, version,
  94. procedure_count, workflow_json, source_json, created_time)
  95. VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
  96. """, (
  97. q, case_id, platform, (post_title or "")[:500], source_link, model, version,
  98. len(procs),
  99. json.dumps(workflow, ensure_ascii=False) if workflow is not None else None,
  100. json.dumps(source, ensure_ascii=False) if source is not None else None,
  101. datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
  102. ))
  103. return 1
  104. finally:
  105. conn.close()
  106. except Exception as ex:
  107. print(f"⚠️ {table} 写库失败(不影响本地文件):{ex}")
  108. return 0
  109. # ── 读取 ─────────────────────────────────────────────────────────────────────
  110. def fetch_versions(table, q, case_id):
  111. """列出某帖所有版本(DESC,最新在前)。返回 [{version, model, procedure_count, created_time}]。"""
  112. if not _enabled() or table not in TABLES:
  113. return []
  114. try:
  115. conn = _conn()
  116. try:
  117. with conn.cursor() as cur:
  118. cur.execute(f"""SELECT version, model, procedure_count, created_time
  119. FROM {table} WHERE q=%s AND case_id=%s
  120. ORDER BY version DESC""", (q, case_id))
  121. return list(cur.fetchall())
  122. finally:
  123. conn.close()
  124. except Exception as ex:
  125. print(f"⚠️ {table} 读版本失败:{ex}")
  126. return []
  127. def fetch_dsl(table, q, case_id, version=None):
  128. """取某帖某版本的 workflow(version=None 取最新)。返回 dict 或 None。"""
  129. if not _enabled() or table not in TABLES:
  130. return None
  131. try:
  132. conn = _conn()
  133. try:
  134. with conn.cursor() as cur:
  135. if version:
  136. cur.execute(f"SELECT * FROM {table} WHERE q=%s AND case_id=%s AND version=%s",
  137. (q, case_id, version))
  138. else:
  139. cur.execute(f"""SELECT * FROM {table} WHERE q=%s AND case_id=%s
  140. ORDER BY version DESC LIMIT 1""", (q, case_id))
  141. row = cur.fetchone()
  142. if row and row.get("workflow_json"):
  143. row["workflow"] = _loads(row["workflow_json"])
  144. return row
  145. finally:
  146. conn.close()
  147. except Exception as ex:
  148. print(f"⚠️ {table} 读取失败:{ex}")
  149. return None
  150. def _loads(v, default=None):
  151. if v is None:
  152. return default
  153. if isinstance(v, (list, dict)):
  154. return v
  155. try:
  156. return json.loads(v)
  157. except Exception:
  158. return default
  159. if __name__ == "__main__":
  160. cmd = sys.argv[1] if len(sys.argv) > 1 else "init"
  161. if cmd == "init":
  162. init_tables()
  163. else:
  164. print(f"用法: python db.py init")