from __future__ import annotations import os from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple from dotenv import load_dotenv from sqlalchemy import create_engine, text from sqlalchemy.orm import Session, sessionmaker load_dotenv() def _normalize_where_params(where_params: Any) -> Optional[Tuple[Any, ...]]: if where_params is None: return None if isinstance(where_params, (list, tuple)): return tuple(where_params) if isinstance(where_params, dict): return tuple(where_params.values()) return (where_params,) def _percent_s_sql_to_text( sql: str, params: Optional[Tuple[Any, ...]] ) -> tuple[Any, dict[str, Any]]: """把 PyMySQL 风格的 %s 转为 SQLAlchemy `text()` 的 :p0 绑定参数。""" if params is None or len(params) == 0: if "%s" in sql: raise ValueError("SQL 含 %s 但未提供绑定参数") return text(sql), {} n = sql.count("%s") if n != len(params): raise ValueError( f"%s 个数与参数不一致:SQL 中 {n} 个 %s,参数 {len(params)} 个" ) parts = sql.split("%s") chunks: list[str] = [] bind: dict[str, Any] = {} for i in range(n): name = f"p{i}" chunks.append(parts[i]) chunks.append(f":{name}") bind[name] = params[i] chunks.append(parts[-1]) return text("".join(chunks)), bind # PostgreSQL open_aigc:优先环境变量 PGVECTOR_DSN,缺省与业务侧常用默认一致(与 pattern_global_v2 无 import 依赖) _OPEN_AIGC_PG_DSN = os.getenv( "PGVECTOR_DSN", "postgresql://aiddit_aigc:%25a%26%26yqNxg%5EV1%24toJ%2AWOa%5E-b%5EX%3DQJ@gp-t4n72471pkmt4b9q7o-master.gpdbmaster.singapore.rds.aliyuncs.com:5432/open_aigc", ) class DatabaseManager: """library_data 专用:open_aigc(PostgreSQL)连接池。""" def __init__(self): dsn = (_OPEN_AIGC_PG_DSN or "").strip() if not dsn: raise ValueError( "未配置 open_aigc PostgreSQL,请设置环境变量 PGVECTOR_DSN," "格式: postgresql://user:pass@host:5432/dbname" ) # 默认使用 psycopg3(postgresql+psycopg),`pip install "psycopg[binary]"` 即有轮子,无需本机 pg_config。 # 若已在 DSN 中指定 postgresql+psycopg2:// 或 postgresql+psycopg://,则不再改写。 if dsn.startswith("postgresql://"): dsn = dsn.replace("postgresql://", "postgresql+psycopg://", 1) self.engine = create_engine( dsn, pool_pre_ping=True, pool_recycle=1800, pool_size=5, max_overflow=10, connect_args={ "connect_timeout": 30, "options": "-c statement_timeout=300000", }, ) self.SessionLocal = sessionmaker(bind=self.engine, autoflush=False, autocommit=False) def get_session(self) -> Session: """获取数据库会话""" return self.SessionLocal() def _execute_dict_query( self, sql: str, params: Optional[Tuple[Any, ...]] ) -> List[Dict[str, Any]]: stmt, bind = _percent_s_sql_to_text(sql, params) with self.engine.connect() as conn: result = conn.execute(stmt, bind) return [dict(row._mapping) for row in result] def select( self, table: str, columns: str = "*", where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, order_by: str = "", limit: Optional[int] = None, ) -> List[Dict[str, Any]]: """与 `mysql_db.select` 相同占位符风格(%s),用于只读导出。""" sql = f"SELECT {columns} FROM {table}" if where: sql += f" WHERE {where}" if order_by: sql += f" ORDER BY {order_by}" if limit is not None: sql += f" LIMIT {int(limit)}" return self._execute_dict_query(sql, _normalize_where_params(where_params)) def select_one( self, table: str, columns: str = "*", where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, ) -> Optional[Dict[str, Any]]: """与 `mysql_db.select_one` 相同占位符风格(%s)。""" sql = f"SELECT {columns} FROM {table}" if where: sql += f" WHERE {where}" sql += " LIMIT 1" rows = self._execute_dict_query(sql, _normalize_where_params(where_params)) return rows[0] if rows else None def fetchall( self, sql: str, params: Optional[Sequence[Any] | Mapping[str, Any]] = None, ) -> List[Dict[str, Any]]: """执行原生 SQL,占位符为 %s(与 pymysql 习惯一致)。""" return self._execute_dict_query(sql, _normalize_where_params(params)) _open_aigc_db_singleton: Optional[DatabaseManager] = None def get_open_aigc_db() -> DatabaseManager: """open_aigc(PostgreSQL)单例,供导出脚本等只读查询。""" global _open_aigc_db_singleton if _open_aigc_db_singleton is None: _open_aigc_db_singleton = DatabaseManager() return _open_aigc_db_singleton