| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- from __future__ import annotations
- import os
- from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
- from dotenv import load_dotenv
- from sqlalchemy import create_engine, text
- from sqlalchemy.orm import Session, sessionmaker
- load_dotenv()
- def _normalize_where_params(where_params: Any) -> Optional[Tuple[Any, ...]]:
- if where_params is None:
- return None
- if isinstance(where_params, (list, tuple)):
- return tuple(where_params)
- if isinstance(where_params, dict):
- return tuple(where_params.values())
- return (where_params,)
- def _percent_s_sql_to_text(
- sql: str, params: Optional[Tuple[Any, ...]]
- ) -> tuple[Any, dict[str, Any]]:
- """把 PyMySQL 风格的 %s 转为 SQLAlchemy `text()` 的 :p0 绑定参数。"""
- if params is None or len(params) == 0:
- if "%s" in sql:
- raise ValueError("SQL 含 %s 但未提供绑定参数")
- return text(sql), {}
- n = sql.count("%s")
- if n != len(params):
- raise ValueError(
- f"%s 个数与参数不一致:SQL 中 {n} 个 %s,参数 {len(params)} 个"
- )
- parts = sql.split("%s")
- chunks: list[str] = []
- bind: dict[str, Any] = {}
- for i in range(n):
- name = f"p{i}"
- chunks.append(parts[i])
- chunks.append(f":{name}")
- bind[name] = params[i]
- chunks.append(parts[-1])
- return text("".join(chunks)), bind
- # PostgreSQL open_aigc:优先环境变量 PGVECTOR_DSN,缺省与业务侧常用默认一致(与 pattern_global_v2 无 import 依赖)
- _OPEN_AIGC_PG_DSN = os.getenv(
- "PGVECTOR_DSN",
- "postgresql://aiddit_aigc:%25a%26%26yqNxg%5EV1%24toJ%2AWOa%5E-b%5EX%3DQJ@gp-t4n72471pkmt4b9q7o-master.gpdbmaster.singapore.rds.aliyuncs.com:5432/open_aigc",
- )
- class DatabaseManager:
- """library_data 专用:open_aigc(PostgreSQL)连接池。"""
- def __init__(self):
- dsn = (_OPEN_AIGC_PG_DSN or "").strip()
- if not dsn:
- raise ValueError(
- "未配置 open_aigc PostgreSQL,请设置环境变量 PGVECTOR_DSN,"
- "格式: postgresql://user:pass@host:5432/dbname"
- )
- # 默认使用 psycopg3(postgresql+psycopg),`pip install "psycopg[binary]"` 即有轮子,无需本机 pg_config。
- # 若已在 DSN 中指定 postgresql+psycopg2:// 或 postgresql+psycopg://,则不再改写。
- if dsn.startswith("postgresql://"):
- dsn = dsn.replace("postgresql://", "postgresql+psycopg://", 1)
- self.engine = create_engine(
- dsn,
- pool_pre_ping=True,
- pool_recycle=1800,
- pool_size=5,
- max_overflow=10,
- connect_args={
- "connect_timeout": 30,
- "options": "-c statement_timeout=300000",
- },
- )
- self.SessionLocal = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
- def get_session(self) -> Session:
- """获取数据库会话"""
- return self.SessionLocal()
- def _execute_dict_query(
- self, sql: str, params: Optional[Tuple[Any, ...]]
- ) -> List[Dict[str, Any]]:
- stmt, bind = _percent_s_sql_to_text(sql, params)
- with self.engine.connect() as conn:
- result = conn.execute(stmt, bind)
- return [dict(row._mapping) for row in result]
- def select(
- self,
- table: str,
- columns: str = "*",
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- order_by: str = "",
- limit: Optional[int] = None,
- ) -> List[Dict[str, Any]]:
- """与 `mysql_db.select` 相同占位符风格(%s),用于只读导出。"""
- sql = f"SELECT {columns} FROM {table}"
- if where:
- sql += f" WHERE {where}"
- if order_by:
- sql += f" ORDER BY {order_by}"
- if limit is not None:
- sql += f" LIMIT {int(limit)}"
- return self._execute_dict_query(sql, _normalize_where_params(where_params))
- def select_one(
- self,
- table: str,
- columns: str = "*",
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- ) -> Optional[Dict[str, Any]]:
- """与 `mysql_db.select_one` 相同占位符风格(%s)。"""
- sql = f"SELECT {columns} FROM {table}"
- if where:
- sql += f" WHERE {where}"
- sql += " LIMIT 1"
- rows = self._execute_dict_query(sql, _normalize_where_params(where_params))
- return rows[0] if rows else None
- def fetchall(
- self,
- sql: str,
- params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- ) -> List[Dict[str, Any]]:
- """执行原生 SQL,占位符为 %s(与 pymysql 习惯一致)。"""
- return self._execute_dict_query(sql, _normalize_where_params(params))
- _open_aigc_db_singleton: Optional[DatabaseManager] = None
- def get_open_aigc_db() -> DatabaseManager:
- """open_aigc(PostgreSQL)单例,供导出脚本等只读查询。"""
- global _open_aigc_db_singleton
- if _open_aigc_db_singleton is None:
- _open_aigc_db_singleton = DatabaseManager()
- return _open_aigc_db_singleton
|