library_db_manager.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from __future__ import annotations
  2. import os
  3. from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
  4. from dotenv import load_dotenv
  5. from sqlalchemy import create_engine, text
  6. from sqlalchemy.orm import Session, sessionmaker
  7. load_dotenv()
  8. def _normalize_where_params(where_params: Any) -> Optional[Tuple[Any, ...]]:
  9. if where_params is None:
  10. return None
  11. if isinstance(where_params, (list, tuple)):
  12. return tuple(where_params)
  13. if isinstance(where_params, dict):
  14. return tuple(where_params.values())
  15. return (where_params,)
  16. def _percent_s_sql_to_text(
  17. sql: str, params: Optional[Tuple[Any, ...]]
  18. ) -> tuple[Any, dict[str, Any]]:
  19. """把 PyMySQL 风格的 %s 转为 SQLAlchemy `text()` 的 :p0 绑定参数。"""
  20. if params is None or len(params) == 0:
  21. if "%s" in sql:
  22. raise ValueError("SQL 含 %s 但未提供绑定参数")
  23. return text(sql), {}
  24. n = sql.count("%s")
  25. if n != len(params):
  26. raise ValueError(
  27. f"%s 个数与参数不一致:SQL 中 {n} 个 %s,参数 {len(params)} 个"
  28. )
  29. parts = sql.split("%s")
  30. chunks: list[str] = []
  31. bind: dict[str, Any] = {}
  32. for i in range(n):
  33. name = f"p{i}"
  34. chunks.append(parts[i])
  35. chunks.append(f":{name}")
  36. bind[name] = params[i]
  37. chunks.append(parts[-1])
  38. return text("".join(chunks)), bind
  39. # PostgreSQL open_aigc:优先环境变量 PGVECTOR_DSN,缺省与业务侧常用默认一致(与 pattern_global_v2 无 import 依赖)
  40. _OPEN_AIGC_PG_DSN = os.getenv(
  41. "PGVECTOR_DSN",
  42. "postgresql://aiddit_aigc:%25a%26%26yqNxg%5EV1%24toJ%2AWOa%5E-b%5EX%3DQJ@gp-t4n72471pkmt4b9q7o-master.gpdbmaster.singapore.rds.aliyuncs.com:5432/open_aigc",
  43. )
  44. class DatabaseManager:
  45. """library_data 专用:open_aigc(PostgreSQL)连接池。"""
  46. def __init__(self):
  47. dsn = (_OPEN_AIGC_PG_DSN or "").strip()
  48. if not dsn:
  49. raise ValueError(
  50. "未配置 open_aigc PostgreSQL,请设置环境变量 PGVECTOR_DSN,"
  51. "格式: postgresql://user:pass@host:5432/dbname"
  52. )
  53. # 默认使用 psycopg3(postgresql+psycopg),`pip install "psycopg[binary]"` 即有轮子,无需本机 pg_config。
  54. # 若已在 DSN 中指定 postgresql+psycopg2:// 或 postgresql+psycopg://,则不再改写。
  55. if dsn.startswith("postgresql://"):
  56. dsn = dsn.replace("postgresql://", "postgresql+psycopg://", 1)
  57. self.engine = create_engine(
  58. dsn,
  59. pool_pre_ping=True,
  60. pool_recycle=1800,
  61. pool_size=5,
  62. max_overflow=10,
  63. connect_args={
  64. "connect_timeout": 30,
  65. "options": "-c statement_timeout=300000",
  66. },
  67. )
  68. self.SessionLocal = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
  69. def get_session(self) -> Session:
  70. """获取数据库会话"""
  71. return self.SessionLocal()
  72. def _execute_dict_query(
  73. self, sql: str, params: Optional[Tuple[Any, ...]]
  74. ) -> List[Dict[str, Any]]:
  75. stmt, bind = _percent_s_sql_to_text(sql, params)
  76. with self.engine.connect() as conn:
  77. result = conn.execute(stmt, bind)
  78. return [dict(row._mapping) for row in result]
  79. def select(
  80. self,
  81. table: str,
  82. columns: str = "*",
  83. where: str = "",
  84. where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
  85. order_by: str = "",
  86. limit: Optional[int] = None,
  87. ) -> List[Dict[str, Any]]:
  88. """与 `mysql_db.select` 相同占位符风格(%s),用于只读导出。"""
  89. sql = f"SELECT {columns} FROM {table}"
  90. if where:
  91. sql += f" WHERE {where}"
  92. if order_by:
  93. sql += f" ORDER BY {order_by}"
  94. if limit is not None:
  95. sql += f" LIMIT {int(limit)}"
  96. return self._execute_dict_query(sql, _normalize_where_params(where_params))
  97. def select_one(
  98. self,
  99. table: str,
  100. columns: str = "*",
  101. where: str = "",
  102. where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
  103. ) -> Optional[Dict[str, Any]]:
  104. """与 `mysql_db.select_one` 相同占位符风格(%s)。"""
  105. sql = f"SELECT {columns} FROM {table}"
  106. if where:
  107. sql += f" WHERE {where}"
  108. sql += " LIMIT 1"
  109. rows = self._execute_dict_query(sql, _normalize_where_params(where_params))
  110. return rows[0] if rows else None
  111. def fetchall(
  112. self,
  113. sql: str,
  114. params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
  115. ) -> List[Dict[str, Any]]:
  116. """执行原生 SQL,占位符为 %s(与 pymysql 习惯一致)。"""
  117. return self._execute_dict_query(sql, _normalize_where_params(params))
  118. _open_aigc_db_singleton: Optional[DatabaseManager] = None
  119. def get_open_aigc_db() -> DatabaseManager:
  120. """open_aigc(PostgreSQL)单例,供导出脚本等只读查询。"""
  121. global _open_aigc_db_singleton
  122. if _open_aigc_db_singleton is None:
  123. _open_aigc_db_singleton = DatabaseManager()
  124. return _open_aigc_db_singleton