| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732 |
- from __future__ import annotations
- import math
- from contextlib import contextmanager
- from contextvars import ContextVar
- from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple
- import pymysql
- from pymysql.cursors import DictCursor
- from .errors import (
- MySQLConnectionError,
- MySQLConfigError,
- MySQLQueryError,
- MySQLTransactionError,
- )
- from .mysql_manager import MySQLClientManager, get_global_manager
- from .mysql_client import MySQLClient
- _tx_connection_var: ContextVar[Optional[pymysql.connections.Connection]] = ContextVar(
- "examples_how_db_utils_tx_connection", default=None
- )
- def _normalize_where_params(where_params: Any) -> Optional[Tuple[Any, ...]]:
- if where_params is None:
- return None
- if isinstance(where_params, (list, tuple)):
- return tuple(where_params)
- if isinstance(where_params, dict):
- return tuple(where_params.values())
- # Fallback: keep as-is (pymysql supports sequences/tuples or mapping)
- return (where_params,)
- class MySQLDB:
- """
- High-level MySQL API (CRUD + advanced queries + transaction).
- Interface is aligned with `how_decode/utils/mysql/mysql_db` style.
- """
- def __init__(self, *, manager: MySQLClientManager, source: str = "default"):
- self._manager = manager
- self._source = source
- @property
- def source(self) -> str:
- return self._source
- def _client(self) -> MySQLClient:
- return self._manager.get_client(self._source)
- def _is_in_transaction(self) -> bool:
- return _tx_connection_var.get() is not None
- @contextmanager
- def _get_connection_and_cursor(
- self, connection: Optional[pymysql.connections.Connection] = None
- ) -> Iterator[Tuple[pymysql.connections.Connection, DictCursor, bool]]:
- """
- Returns (connection, cursor, should_close_connection).
- - If `connection` is provided: uses it and should_close_connection=False
- - Else if transaction connection exists: uses it and should_close_connection=False
- - Else opens a new connection: should_close_connection=True
- """
- client = self._client()
- tx_conn = _tx_connection_var.get()
- should_close = False
- actual_conn: Optional[pymysql.connections.Connection] = None
- if connection is not None:
- actual_conn = connection
- elif tx_conn is not None:
- actual_conn = tx_conn
- else:
- actual_conn = client.open_connection()
- should_close = True
- cursor = actual_conn.cursor(DictCursor)
- try:
- yield actual_conn, cursor, should_close
- finally:
- try:
- cursor.close()
- except Exception:
- pass
- if should_close:
- try:
- actual_conn.close()
- except Exception:
- pass
- @contextmanager
- def transaction(self, isolation_level: Optional[str] = None):
- """
- Transaction context manager.
- Important: when you call CRUD methods inside this context without passing `connection`,
- they will automatically reuse the same transaction connection (via ContextVar).
- """
- client = self._client()
- conn = None
- token = None
- try:
- conn = client.open_connection()
- # Ensure explicit transaction.
- conn.autocommit(False)
- if isolation_level:
- conn.execute(
- f"SET SESSION TRANSACTION ISOLATION LEVEL {isolation_level}"
- )
- conn.begin()
- token = _tx_connection_var.set(conn)
- yield conn
- conn.commit()
- except Exception as e:
- if conn is not None:
- try:
- conn.rollback()
- except Exception:
- pass
- raise MySQLTransactionError(
- message=f"transaction failed (source={self._source}): {e}",
- original_error=e,
- ) from e
- finally:
- if token is not None:
- _tx_connection_var.reset(token)
- if conn is not None:
- try:
- conn.close()
- except Exception:
- pass
- # -----------------------
- # Basic CRUD
- # -----------------------
- def select(
- self,
- table: str,
- columns: str = "*",
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- order_by: str = "",
- limit: Optional[int] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> List[Dict[str, Any]]:
- sql = f"SELECT {columns} FROM {table}"
- if where:
- sql += f" WHERE {where}"
- if order_by:
- sql += f" ORDER BY {order_by}"
- if limit is not None:
- sql += f" LIMIT {limit}"
- params = _normalize_where_params(where_params)
- try:
- with self._get_connection_and_cursor(connection) as (conn, cursor, should_close):
- cursor.execute(sql, params)
- return list(cursor.fetchall())
- except Exception as e:
- raise MySQLQueryError(
- message=f"select failed (source={self._source}): {e}",
- original_error=e,
- ) from e
- def select_one(
- self,
- table: str,
- columns: str = "*",
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> Optional[Dict[str, Any]]:
- sql = f"SELECT {columns} FROM {table}"
- if where:
- sql += f" WHERE {where}"
- sql += " LIMIT 1"
- params = _normalize_where_params(where_params)
- try:
- with self._get_connection_and_cursor(connection) as (_conn, cursor, _should_close):
- cursor.execute(sql, params)
- return cursor.fetchone()
- except Exception as e:
- raise MySQLQueryError(
- message=f"select_one failed (source={self._source}): {e}",
- original_error=e,
- ) from e
- def insert(
- self,
- table: str,
- data: Dict[str, Any],
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> int:
- if not data:
- raise ValueError("insert data must not be empty")
- columns = list(data.keys())
- placeholders = ", ".join(["%s"] * len(columns))
- sql = f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders})"
- params = tuple(data.values())
- conn: Optional[pymysql.connections.Connection] = None
- try:
- with self._get_connection_and_cursor(connection) as (conn, cursor, should_close):
- cursor.execute(sql, params)
- # If not inside transaction, auto commit.
- if not self._is_in_transaction() and should_close:
- conn.commit()
- return int(getattr(cursor, "lastrowid", 0) or 0)
- except Exception as e:
- # If we opened a connection ourselves, rollback to be safe.
- if conn is not None and connection is None and not self._is_in_transaction():
- try:
- conn.rollback()
- except Exception:
- pass
- raise MySQLQueryError(
- message=f"insert failed (source={self._source}): {e}",
- original_error=e,
- ) from e
- def insert_many(
- self,
- table: str,
- data_list: List[Dict[str, Any]],
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> int:
- if not data_list:
- raise ValueError("insert_many data_list must not be empty")
- columns = list(data_list[0].keys())
- placeholders = ", ".join(["%s"] * len(columns))
- sql = f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders})"
- params_list = [tuple(d[col] for col in columns) for d in data_list]
- conn: Optional[pymysql.connections.Connection] = None
- try:
- with self._get_connection_and_cursor(connection) as (conn, cursor, should_close):
- cursor.executemany(sql, params_list)
- if not self._is_in_transaction() and should_close:
- conn.commit()
- return int(cursor.rowcount or 0)
- except Exception as e:
- if conn is not None and connection is None and not self._is_in_transaction():
- try:
- conn.rollback()
- except Exception:
- pass
- raise MySQLQueryError(
- message=f"insert_many failed (source={self._source}): {e}",
- original_error=e,
- ) from e
- def update(
- self,
- table: str,
- data: Dict[str, Any],
- where: str,
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> int:
- if not data:
- raise ValueError("update data must not be empty")
- set_clause = ", ".join([f"{col}=%s" for col in data.keys()])
- sql = f"UPDATE {table} SET {set_clause} WHERE {where}"
- params: List[Any] = list(data.values())
- wp = _normalize_where_params(where_params)
- if wp is not None:
- params.extend(list(wp))
- conn: Optional[pymysql.connections.Connection] = None
- try:
- with self._get_connection_and_cursor(connection) as (conn, cursor, should_close):
- cursor.execute(sql, tuple(params))
- if not self._is_in_transaction() and should_close:
- conn.commit()
- return int(cursor.rowcount or 0)
- except Exception as e:
- if conn is not None and connection is None and not self._is_in_transaction():
- try:
- conn.rollback()
- except Exception:
- pass
- raise MySQLQueryError(
- message=f"update failed (source={self._source}): {e}",
- original_error=e,
- ) from e
- def delete(
- self,
- table: str,
- where: str,
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> int:
- sql = f"DELETE FROM {table} WHERE {where}"
- params = _normalize_where_params(where_params)
- conn: Optional[pymysql.connections.Connection] = None
- try:
- with self._get_connection_and_cursor(connection) as (conn, cursor, should_close):
- cursor.execute(sql, params)
- if not self._is_in_transaction() and should_close:
- conn.commit()
- return int(cursor.rowcount or 0)
- except Exception as e:
- if conn is not None and connection is None and not self._is_in_transaction():
- try:
- conn.rollback()
- except Exception:
- pass
- raise MySQLQueryError(
- message=f"delete failed (source={self._source}): {e}",
- original_error=e,
- ) from e
- def execute_many(
- self,
- sql: str,
- params_list: List[Sequence[Any] | Mapping[str, Any]],
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> int:
- params_seq = [_normalize_where_params(p) for p in params_list]
- conn: Optional[pymysql.connections.Connection] = None
- try:
- with self._get_connection_and_cursor(connection) as (conn, cursor, should_close):
- cursor.executemany(sql, params_seq)
- if not self._is_in_transaction() and should_close:
- conn.commit()
- return int(cursor.rowcount or 0)
- except Exception as e:
- if conn is not None and connection is None and not self._is_in_transaction():
- try:
- conn.rollback()
- except Exception:
- pass
- raise MySQLQueryError(
- message=f"execute_many failed (source={self._source}): {e}",
- original_error=e,
- ) from e
- # -----------------------
- # Query helpers
- # -----------------------
- def count(
- self,
- table: str,
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> int:
- sql = f"SELECT COUNT(*) as count FROM {table}"
- if where:
- sql += f" WHERE {where}"
- try:
- params = _normalize_where_params(where_params)
- with self._get_connection_and_cursor(connection) as (_conn, cursor, _should_close):
- cursor.execute(sql, params)
- r = cursor.fetchone()
- if not r:
- return 0
- return int(r.get("count") or 0)
- except Exception as e:
- raise MySQLQueryError(
- message=f"count failed (source={self._source}): {e}",
- original_error=e,
- ) from e
- def exists(
- self,
- table: str,
- where: str,
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> bool:
- return self.count(
- table, where=where, where_params=where_params, connection=connection
- ) > 0
- def paginate(
- self,
- table: str,
- page: int = 1,
- page_size: int = 20,
- columns: str = "*",
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- order_by: str = "",
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> Dict[str, Any]:
- if page < 1:
- page = 1
- if page_size < 1:
- page_size = 20
- total_count = self.count(
- table, where=where, where_params=where_params, connection=connection
- )
- total_pages = math.ceil(total_count / page_size) if total_count > 0 else 1
- offset = (page - 1) * page_size
- sql = f"SELECT {columns} FROM {table}"
- if where:
- sql += f" WHERE {where}"
- if order_by:
- sql += f" ORDER BY {order_by}"
- sql += f" LIMIT {page_size} OFFSET {offset}"
- params = _normalize_where_params(where_params)
- try:
- with self._get_connection_and_cursor(connection) as (_conn, cursor, _should_close):
- cursor.execute(sql, params)
- data = list(cursor.fetchall())
- except Exception as e:
- raise MySQLQueryError(
- message=f"paginate failed (source={self._source}): {e}",
- original_error=e,
- ) from e
- return {
- "data": data,
- "pagination": {
- "current_page": page,
- "page_size": page_size,
- "total_count": total_count,
- "total_pages": total_pages,
- "has_prev": page > 1,
- "has_next": page < total_pages,
- "prev_page": page - 1 if page > 1 else None,
- "next_page": page + 1 if page < total_pages else None,
- },
- }
- def select_with_sort(
- self,
- table: str,
- columns: str = "*",
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- sort_field: str = "id",
- sort_order: str = "ASC",
- limit: Optional[int] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> List[Dict[str, Any]]:
- sort_order = (sort_order or "").upper()
- if sort_order not in ["ASC", "DESC"]:
- sort_order = "ASC"
- order_by = f"{sort_field} {sort_order}"
- return self.select(
- table,
- columns=columns,
- where=where,
- where_params=where_params,
- order_by=order_by,
- limit=limit,
- connection=connection,
- )
- def select_with_multiple_sort(
- self,
- table: str,
- columns: str = "*",
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- sort_fields: Optional[List[Tuple[str, str]]] = None,
- limit: Optional[int] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> List[Dict[str, Any]]:
- order_by = ""
- if sort_fields:
- parts: List[str] = []
- for field, order in sort_fields:
- order_u = (order or "").upper()
- if order_u not in ["ASC", "DESC"]:
- order_u = "ASC"
- parts.append(f"{field} {order_u}")
- order_by = ", ".join(parts)
- return self.select(
- table,
- columns=columns,
- where=where,
- where_params=where_params,
- order_by=order_by,
- limit=limit,
- connection=connection,
- )
- def aggregate(
- self,
- table: str,
- agg_functions: Dict[str, str],
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- group_by: str = "",
- having: str = "",
- having_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> List[Dict[str, Any]]:
- if not agg_functions:
- raise ValueError("agg_functions must not be empty")
- select_parts: List[str] = []
- if group_by:
- select_parts.append(group_by)
- for alias, func in agg_functions.items():
- select_parts.append(f"{func} AS {alias}")
- sql = f"SELECT {', '.join(select_parts)} FROM {table}"
- if where:
- sql += f" WHERE {where}"
- if group_by:
- sql += f" GROUP BY {group_by}"
- if having:
- sql += f" HAVING {having}"
- params: List[Any] = []
- wp = _normalize_where_params(where_params)
- if wp is not None:
- params.extend(list(wp))
- hp = _normalize_where_params(having_params)
- if hp is not None:
- params.extend(list(hp))
- try:
- with self._get_connection_and_cursor(connection) as (_conn, cursor, _should_close):
- cursor.execute(sql, tuple(params) if params else None)
- return list(cursor.fetchall())
- except Exception as e:
- raise MySQLQueryError(
- message=f"aggregate failed (source={self._source}): {e}",
- original_error=e,
- ) from e
- def sum(
- self,
- table: str,
- column: str,
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> float:
- rows = self.aggregate(
- table=table,
- agg_functions={"sum_result": f"SUM({column})"},
- where=where,
- where_params=where_params,
- connection=connection,
- )
- v = rows[0].get("sum_result") if rows else None
- return float(v) if v is not None else 0.0
- def avg(
- self,
- table: str,
- column: str,
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> float:
- rows = self.aggregate(
- table=table,
- agg_functions={"avg_result": f"AVG({column})"},
- where=where,
- where_params=where_params,
- connection=connection,
- )
- v = rows[0].get("avg_result") if rows else None
- return float(v) if v is not None else 0.0
- def max(
- self,
- table: str,
- column: str,
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> Any:
- rows = self.aggregate(
- table=table,
- agg_functions={"max_result": f"MAX({column})"},
- where=where,
- where_params=where_params,
- connection=connection,
- )
- return rows[0].get("max_result") if rows and rows[0].get("max_result") is not None else None
- def min(
- self,
- table: str,
- column: str,
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> Any:
- rows = self.aggregate(
- table=table,
- agg_functions={"min_result": f"MIN({column})"},
- where=where,
- where_params=where_params,
- connection=connection,
- )
- return rows[0].get("min_result") if rows and rows[0].get("min_result") is not None else None
- def group_count(
- self,
- table: str,
- group_column: str,
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- order_by: str = "",
- limit: Optional[int] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> List[Dict[str, Any]]:
- sql = f"SELECT {group_column}, COUNT(*) as count FROM {table}"
- if where:
- sql += f" WHERE {where}"
- sql += f" GROUP BY {group_column}"
- if order_by:
- sql += f" ORDER BY {order_by}"
- else:
- sql += " ORDER BY count DESC"
- if limit is not None:
- sql += f" LIMIT {limit}"
- params = _normalize_where_params(where_params)
- try:
- with self._get_connection_and_cursor(connection) as (_conn, cursor, _should_close):
- cursor.execute(sql, params)
- return list(cursor.fetchall())
- except Exception as e:
- raise MySQLQueryError(
- message=f"group_count failed (source={self._source}): {e}",
- original_error=e,
- ) from e
- def search(
- self,
- table: str,
- search_columns: List[str],
- keyword: str,
- columns: str = "*",
- where: str = "",
- where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None,
- order_by: str = "",
- limit: Optional[int] = None,
- connection: Optional[pymysql.connections.Connection] = None,
- ) -> List[Dict[str, Any]]:
- if not search_columns or not keyword:
- return []
- search_conditions: List[str] = []
- search_params: List[Any] = []
- for col in search_columns:
- search_conditions.append(f"{col} LIKE %s")
- search_params.append(f"%{keyword}%")
- search_where = f"({' OR '.join(search_conditions)})"
- final_where = search_where
- final_params: List[Any] = list(search_params)
- if where:
- final_where = f"{search_where} AND ({where})"
- wp = _normalize_where_params(where_params)
- if wp is not None:
- final_params.extend(list(wp))
- return self.select(
- table,
- columns=columns,
- where=final_where,
- where_params=tuple(final_params),
- order_by=order_by,
- limit=limit,
- connection=connection,
- )
- # -----------------------
- # Functional transaction helpers (optional)
- # -----------------------
- def execute_in_transaction(
- self,
- func,
- *args,
- isolation_level: Optional[str] = None,
- **kwargs,
- ) -> Any:
- with self.transaction(isolation_level=isolation_level) as conn:
- return func(conn, *args, **kwargs)
- def batch_operations(
- self,
- operations: list,
- isolation_level: Optional[str] = None,
- ) -> list:
- results: list = []
- with self.transaction(isolation_level=isolation_level) as conn:
- for op in operations:
- method_name, args, op_kwargs = op
- op_kwargs = op_kwargs or {}
- op_kwargs["connection"] = conn
- method = getattr(self, method_name)
- results.append(method(*args, **op_kwargs))
- return results
- _GLOBAL_DB: Dict[str, MySQLDB] = {}
- def get_mysql_db(source: str = "default") -> MySQLDB:
- if source not in _GLOBAL_DB:
- mgr = get_global_manager()
- _GLOBAL_DB[source] = MySQLDB(manager=mgr, source=source)
- return _GLOBAL_DB[source]
- # For compatibility with how_decode/utils/mysql (global mysql_db)
- mysql_db = get_mysql_db("default")
|