liuzhiheng 21 часов назад
Родитель
Сommit
b57e065037

+ 146 - 0
examples_how/db_utils/library_db_manager.py

@@ -0,0 +1,146 @@
+from __future__ import annotations
+
+import os
+from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
+
+from dotenv import load_dotenv
+from sqlalchemy import create_engine, text
+from sqlalchemy.orm import Session, sessionmaker
+
+load_dotenv()
+
+
+def _normalize_where_params(where_params: Any) -> Optional[Tuple[Any, ...]]:
+    if where_params is None:
+        return None
+    if isinstance(where_params, (list, tuple)):
+        return tuple(where_params)
+    if isinstance(where_params, dict):
+        return tuple(where_params.values())
+    return (where_params,)
+
+
+def _percent_s_sql_to_text(
+    sql: str, params: Optional[Tuple[Any, ...]]
+) -> tuple[Any, dict[str, Any]]:
+    """把 PyMySQL 风格的 %s 转为 SQLAlchemy `text()` 的 :p0 绑定参数。"""
+    if params is None or len(params) == 0:
+        if "%s" in sql:
+            raise ValueError("SQL 含 %s 但未提供绑定参数")
+        return text(sql), {}
+
+    n = sql.count("%s")
+    if n != len(params):
+        raise ValueError(
+            f"%s 个数与参数不一致:SQL 中 {n} 个 %s,参数 {len(params)} 个"
+        )
+    parts = sql.split("%s")
+    chunks: list[str] = []
+    bind: dict[str, Any] = {}
+    for i in range(n):
+        name = f"p{i}"
+        chunks.append(parts[i])
+        chunks.append(f":{name}")
+        bind[name] = params[i]
+    chunks.append(parts[-1])
+    return text("".join(chunks)), bind
+
+
+# PostgreSQL open_aigc:优先环境变量 PGVECTOR_DSN,缺省与业务侧常用默认一致(与 pattern_global_v2 无 import 依赖)
+_OPEN_AIGC_PG_DSN = os.getenv(
+    "PGVECTOR_DSN",
+    "postgresql://aiddit_aigc:%25a%26%26yqNxg%5EV1%24toJ%2AWOa%5E-b%5EX%3DQJ@gp-t4n72471pkmt4b9q7o-master.gpdbmaster.singapore.rds.aliyuncs.com:5432/open_aigc",
+)
+
+
+class DatabaseManager:
+    """library_data 专用:open_aigc(PostgreSQL)连接池。"""
+
+    def __init__(self):
+        dsn = (_OPEN_AIGC_PG_DSN or "").strip()
+        if not dsn:
+            raise ValueError(
+                "未配置 open_aigc PostgreSQL,请设置环境变量 PGVECTOR_DSN,"
+                "格式: postgresql://user:pass@host:5432/dbname"
+            )
+        # 默认使用 psycopg3(postgresql+psycopg),`pip install "psycopg[binary]"` 即有轮子,无需本机 pg_config。
+        # 若已在 DSN 中指定 postgresql+psycopg2:// 或 postgresql+psycopg://,则不再改写。
+        if dsn.startswith("postgresql://"):
+            dsn = dsn.replace("postgresql://", "postgresql+psycopg://", 1)
+        self.engine = create_engine(
+            dsn,
+            pool_pre_ping=True,
+            pool_recycle=1800,
+            pool_size=5,
+            max_overflow=10,
+            connect_args={
+                "connect_timeout": 30,
+                "options": "-c statement_timeout=300000",
+            },
+        )
+        self.SessionLocal = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
+
+    def get_session(self) -> Session:
+        """获取数据库会话"""
+        return self.SessionLocal()
+
+    def _execute_dict_query(
+        self, sql: str, params: Optional[Tuple[Any, ...]]
+    ) -> List[Dict[str, Any]]:
+        stmt, bind = _percent_s_sql_to_text(sql, params)
+        with self.engine.connect() as conn:
+            result = conn.execute(stmt, bind)
+            return [dict(row._mapping) for row in result]
+
+    def select(
+        self,
+        table: str,
+        columns: str = "*",
+        where: str = "",
+        where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
+        order_by: str = "",
+        limit: Optional[int] = None,
+    ) -> List[Dict[str, Any]]:
+        """与 `mysql_db.select` 相同占位符风格(%s),用于只读导出。"""
+        sql = f"SELECT {columns} FROM {table}"
+        if where:
+            sql += f" WHERE {where}"
+        if order_by:
+            sql += f" ORDER BY {order_by}"
+        if limit is not None:
+            sql += f" LIMIT {int(limit)}"
+        return self._execute_dict_query(sql, _normalize_where_params(where_params))
+
+    def select_one(
+        self,
+        table: str,
+        columns: str = "*",
+        where: str = "",
+        where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
+    ) -> Optional[Dict[str, Any]]:
+        """与 `mysql_db.select_one` 相同占位符风格(%s)。"""
+        sql = f"SELECT {columns} FROM {table}"
+        if where:
+            sql += f" WHERE {where}"
+        sql += " LIMIT 1"
+        rows = self._execute_dict_query(sql, _normalize_where_params(where_params))
+        return rows[0] if rows else None
+
+    def fetchall(
+        self,
+        sql: str,
+        params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
+    ) -> List[Dict[str, Any]]:
+        """执行原生 SQL,占位符为 %s(与 pymysql 习惯一致)。"""
+        return self._execute_dict_query(sql, _normalize_where_params(params))
+
+
+_open_aigc_db_singleton: Optional[DatabaseManager] = None
+
+
+def get_open_aigc_db() -> DatabaseManager:
+    """open_aigc(PostgreSQL)单例,供导出脚本等只读查询。"""
+    global _open_aigc_db_singleton
+    if _open_aigc_db_singleton is None:
+        _open_aigc_db_singleton = DatabaseManager()
+    return _open_aigc_db_singleton

+ 5 - 7
examples_how/overall_derivation/data_export_from_db/export_account_element_classification.py

@@ -14,7 +14,7 @@ _EXAMPLES_HOW_DIR = Path(__file__).resolve().parents[2]
 if str(_EXAMPLES_HOW_DIR) not in sys.path:
     sys.path.insert(0, str(_EXAMPLES_HOW_DIR))
 
-from db_utils.mysql_db import get_mysql_db  # noqa: E402
+from db_utils.library_db_manager import get_open_aigc_db  # noqa: E402
 
 _SOURCE_TYPES = ("实质", "形式", "意图")
 _FILE_NAMES = {"实质": "实质_tree.json", "形式": "形式_tree.json", "意图": "意图_tree.json"}
@@ -64,8 +64,7 @@ def _fetch_classification_rows_from_ecm(account_name: str) -> list[dict[str, Any
       AND TRIM(ecm.post_id) <> ''
       AND ecm.element_type IN ('实质', '形式')
     """
-    client = get_mysql_db("default")._client()
-    return client.fetchall(sql, (account_name,))
+    return get_open_aigc_db().fetchall(sql, (account_name,))
 
 
 def _fetch_intent_rows_from_decode(account_name: str) -> list[dict[str, Any]]:
@@ -92,8 +91,7 @@ def _fetch_intent_rows_from_decode(account_name: str) -> list[dict[str, Any]]:
       AND pdte.post_id IS NOT NULL
       AND TRIM(pdte.post_id) <> ''
     """
-    client = get_mysql_db("default")._client()
-    return client.fetchall(sql, (account_name,))
+    return get_open_aigc_db().fetchall(sql, (account_name,))
 
 
 def _fetch_classification_rows(account_name: str) -> list[dict[str, Any]]:
@@ -103,7 +101,7 @@ def _fetch_classification_rows(account_name: str) -> list[dict[str, Any]]:
 
 
 def _fetch_account_platform(account_name: str) -> str:
-    row = get_mysql_db("default").select_one(
+    row = get_open_aigc_db().select_one(
         table="post",
         columns="platform",
         where="platform_account_name=%s",
@@ -269,4 +267,4 @@ def main(account_name) -> None:
 
 
 if __name__ == "__main__":
-    main(account_name="家有大志")
+    main(account_name="秒懂金融")

+ 5 - 4
examples_how/overall_derivation/data_export_from_db/export_post.py

@@ -14,7 +14,7 @@ _EXAMPLES_HOW_DIR = Path(__file__).resolve().parents[2]
 if str(_EXAMPLES_HOW_DIR) not in sys.path:
     sys.path.insert(0, str(_EXAMPLES_HOW_DIR))
 
-from db_utils.mysql_db import mysql_db  # noqa: E402
+from db_utils.library_db_manager import get_open_aigc_db  # noqa: E402
 
 
 def _to_ms(v: Any) -> Optional[int]:
@@ -25,7 +25,7 @@ def _to_ms(v: Any) -> Optional[int]:
     if isinstance(v, datetime):
         return int(v.timestamp() * 1000)
 
-    # pymysql 可能把 BIGINT/JSON 数字返回为 int/float,或把 datetime 返回为字符串
+    # 驱动可能把 BIGINT/JSON 数字返回为 int/float,或把 datetime 返回为字符串
     if isinstance(v, (int, float)):
         n = float(v)
         # 常见情况:ms (>= 1e12),s (>= 1e9)
@@ -160,10 +160,11 @@ def export_posts_by_account(
     last_id = 0
     exported = 0
 
+    db = get_open_aigc_db()
     while True:
         where = "platform_account_name=%s AND id>%s"
         where_params = (account_name, last_id)
-        rows = mysql_db.select(
+        rows = db.select(
             table="post",
             columns=columns,
             where=where,
@@ -207,5 +208,5 @@ def main(account_name) -> None:
 
 
 if __name__ == "__main__":
-    main("家有大志")
+    main("秒懂金融")
 

+ 8 - 6
examples_how/overall_derivation/data_export_from_db/export_post_decode.py

@@ -13,7 +13,7 @@ _EXAMPLES_HOW_DIR = Path(__file__).resolve().parents[2]
 if str(_EXAMPLES_HOW_DIR) not in sys.path:
     sys.path.insert(0, str(_EXAMPLES_HOW_DIR))
 
-from db_utils.mysql_db import mysql_db  # noqa: E402
+from db_utils.library_db_manager import get_open_aigc_db  # noqa: E402
 
 
 def _empty_decode_json(post_id: str) -> dict[str, Any]:
@@ -57,9 +57,10 @@ def _build_topic_point_item(
 
 def _iter_post_ids_for_account(account_name: str, *, page_size: int = 500) -> Any:
     """按 keyset 分页产出 post_id。"""
+    db = get_open_aigc_db()
     last_id = 0
     while True:
-        rows = mysql_db.select(
+        rows = db.select(
             table="post",
             columns="id,post_id",
             where="platform_account_name=%s AND id>%s",
@@ -102,9 +103,10 @@ def export_post_decode_for_account(
         if not post_ids:
             return False
 
+        db = get_open_aigc_db()
         # 同一帖子多条解构结果时,取 id 最大的一条
         placeholders = ",".join(["%s"] * len(post_ids))
-        decode_rows = mysql_db.select(
+        decode_rows = db.select(
             table="post_decode_result",
             columns="*",
             where=f"post_id IN ({placeholders})",
@@ -125,7 +127,7 @@ def export_post_decode_for_account(
         topic_by_result: DefaultDict[int, list[dict[str, Any]]] = defaultdict(list)
         if result_ids:
             ph2 = ",".join(["%s"] * len(result_ids))
-            tps = mysql_db.select(
+            tps = db.select(
                 table="post_decode_topic_point",
                 columns="*",
                 where=f"post_decode_result_id IN ({ph2})",
@@ -145,7 +147,7 @@ def export_post_decode_for_account(
         elements_by_tp: dict[int, list[dict[str, Any]]] = defaultdict(list)
         if tp_ids:
             ph3 = ",".join(["%s"] * len(tp_ids))
-            elems = mysql_db.select(
+            elems = db.select(
                 table="post_decode_topic_point_element",
                 columns="*",
                 where=f"topic_point_id IN ({ph3})",
@@ -223,4 +225,4 @@ def main(account_name) -> None:
 
 
 if __name__ == "__main__":
-    main("家有大志")
+    main("秒懂金融")