|
@@ -0,0 +1,146 @@
|
|
|
|
|
+from __future__ import annotations
|
|
|
|
|
+
|
|
|
|
|
+import os
|
|
|
|
|
+from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
|
|
|
|
|
+
|
|
|
|
|
+from dotenv import load_dotenv
|
|
|
|
|
+from sqlalchemy import create_engine, text
|
|
|
|
|
+from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
|
+
|
|
|
|
|
+load_dotenv()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _normalize_where_params(where_params: Any) -> Optional[Tuple[Any, ...]]:
|
|
|
|
|
+ if where_params is None:
|
|
|
|
|
+ return None
|
|
|
|
|
+ if isinstance(where_params, (list, tuple)):
|
|
|
|
|
+ return tuple(where_params)
|
|
|
|
|
+ if isinstance(where_params, dict):
|
|
|
|
|
+ return tuple(where_params.values())
|
|
|
|
|
+ return (where_params,)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _percent_s_sql_to_text(
|
|
|
|
|
+ sql: str, params: Optional[Tuple[Any, ...]]
|
|
|
|
|
+) -> tuple[Any, dict[str, Any]]:
|
|
|
|
|
+ """把 PyMySQL 风格的 %s 转为 SQLAlchemy `text()` 的 :p0 绑定参数。"""
|
|
|
|
|
+ if params is None or len(params) == 0:
|
|
|
|
|
+ if "%s" in sql:
|
|
|
|
|
+ raise ValueError("SQL 含 %s 但未提供绑定参数")
|
|
|
|
|
+ return text(sql), {}
|
|
|
|
|
+
|
|
|
|
|
+ n = sql.count("%s")
|
|
|
|
|
+ if n != len(params):
|
|
|
|
|
+ raise ValueError(
|
|
|
|
|
+ f"%s 个数与参数不一致:SQL 中 {n} 个 %s,参数 {len(params)} 个"
|
|
|
|
|
+ )
|
|
|
|
|
+ parts = sql.split("%s")
|
|
|
|
|
+ chunks: list[str] = []
|
|
|
|
|
+ bind: dict[str, Any] = {}
|
|
|
|
|
+ for i in range(n):
|
|
|
|
|
+ name = f"p{i}"
|
|
|
|
|
+ chunks.append(parts[i])
|
|
|
|
|
+ chunks.append(f":{name}")
|
|
|
|
|
+ bind[name] = params[i]
|
|
|
|
|
+ chunks.append(parts[-1])
|
|
|
|
|
+ return text("".join(chunks)), bind
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# PostgreSQL open_aigc:优先环境变量 PGVECTOR_DSN,缺省与业务侧常用默认一致(与 pattern_global_v2 无 import 依赖)
|
|
|
|
|
+_OPEN_AIGC_PG_DSN = os.getenv(
|
|
|
|
|
+ "PGVECTOR_DSN",
|
|
|
|
|
+ "postgresql://aiddit_aigc:%25a%26%26yqNxg%5EV1%24toJ%2AWOa%5E-b%5EX%3DQJ@gp-t4n72471pkmt4b9q7o-master.gpdbmaster.singapore.rds.aliyuncs.com:5432/open_aigc",
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class DatabaseManager:
|
|
|
|
|
+ """library_data 专用:open_aigc(PostgreSQL)连接池。"""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ dsn = (_OPEN_AIGC_PG_DSN or "").strip()
|
|
|
|
|
+ if not dsn:
|
|
|
|
|
+ raise ValueError(
|
|
|
|
|
+ "未配置 open_aigc PostgreSQL,请设置环境变量 PGVECTOR_DSN,"
|
|
|
|
|
+ "格式: postgresql://user:pass@host:5432/dbname"
|
|
|
|
|
+ )
|
|
|
|
|
+ # 默认使用 psycopg3(postgresql+psycopg),`pip install "psycopg[binary]"` 即有轮子,无需本机 pg_config。
|
|
|
|
|
+ # 若已在 DSN 中指定 postgresql+psycopg2:// 或 postgresql+psycopg://,则不再改写。
|
|
|
|
|
+ if dsn.startswith("postgresql://"):
|
|
|
|
|
+ dsn = dsn.replace("postgresql://", "postgresql+psycopg://", 1)
|
|
|
|
|
+ self.engine = create_engine(
|
|
|
|
|
+ dsn,
|
|
|
|
|
+ pool_pre_ping=True,
|
|
|
|
|
+ pool_recycle=1800,
|
|
|
|
|
+ pool_size=5,
|
|
|
|
|
+ max_overflow=10,
|
|
|
|
|
+ connect_args={
|
|
|
|
|
+ "connect_timeout": 30,
|
|
|
|
|
+ "options": "-c statement_timeout=300000",
|
|
|
|
|
+ },
|
|
|
|
|
+ )
|
|
|
|
|
+ self.SessionLocal = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
|
|
|
|
|
+
|
|
|
|
|
+ def get_session(self) -> Session:
|
|
|
|
|
+ """获取数据库会话"""
|
|
|
|
|
+ return self.SessionLocal()
|
|
|
|
|
+
|
|
|
|
|
+ def _execute_dict_query(
|
|
|
|
|
+ self, sql: str, params: Optional[Tuple[Any, ...]]
|
|
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
|
|
+ stmt, bind = _percent_s_sql_to_text(sql, params)
|
|
|
|
|
+ with self.engine.connect() as conn:
|
|
|
|
|
+ result = conn.execute(stmt, bind)
|
|
|
|
|
+ return [dict(row._mapping) for row in result]
|
|
|
|
|
+
|
|
|
|
|
+ def select(
|
|
|
|
|
+ self,
|
|
|
|
|
+ table: str,
|
|
|
|
|
+ columns: str = "*",
|
|
|
|
|
+ where: str = "",
|
|
|
|
|
+ where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
|
|
|
|
|
+ order_by: str = "",
|
|
|
|
|
+ limit: Optional[int] = None,
|
|
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
|
|
+ """与 `mysql_db.select` 相同占位符风格(%s),用于只读导出。"""
|
|
|
|
|
+ sql = f"SELECT {columns} FROM {table}"
|
|
|
|
|
+ if where:
|
|
|
|
|
+ sql += f" WHERE {where}"
|
|
|
|
|
+ if order_by:
|
|
|
|
|
+ sql += f" ORDER BY {order_by}"
|
|
|
|
|
+ if limit is not None:
|
|
|
|
|
+ sql += f" LIMIT {int(limit)}"
|
|
|
|
|
+ return self._execute_dict_query(sql, _normalize_where_params(where_params))
|
|
|
|
|
+
|
|
|
|
|
+ def select_one(
|
|
|
|
|
+ self,
|
|
|
|
|
+ table: str,
|
|
|
|
|
+ columns: str = "*",
|
|
|
|
|
+ where: str = "",
|
|
|
|
|
+ where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
|
|
|
|
|
+ ) -> Optional[Dict[str, Any]]:
|
|
|
|
|
+ """与 `mysql_db.select_one` 相同占位符风格(%s)。"""
|
|
|
|
|
+ sql = f"SELECT {columns} FROM {table}"
|
|
|
|
|
+ if where:
|
|
|
|
|
+ sql += f" WHERE {where}"
|
|
|
|
|
+ sql += " LIMIT 1"
|
|
|
|
|
+ rows = self._execute_dict_query(sql, _normalize_where_params(where_params))
|
|
|
|
|
+ return rows[0] if rows else None
|
|
|
|
|
+
|
|
|
|
|
+ def fetchall(
|
|
|
|
|
+ self,
|
|
|
|
|
+ sql: str,
|
|
|
|
|
+ params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
|
|
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
|
|
+ """执行原生 SQL,占位符为 %s(与 pymysql 习惯一致)。"""
|
|
|
|
|
+ return self._execute_dict_query(sql, _normalize_where_params(params))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+_open_aigc_db_singleton: Optional[DatabaseManager] = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def get_open_aigc_db() -> DatabaseManager:
|
|
|
|
|
+ """open_aigc(PostgreSQL)单例,供导出脚本等只读查询。"""
|
|
|
|
|
+ global _open_aigc_db_singleton
|
|
|
|
|
+ if _open_aigc_db_singleton is None:
|
|
|
|
|
+ _open_aigc_db_singleton = DatabaseManager()
|
|
|
|
|
+ return _open_aigc_db_singleton
|