from __future__ import annotations import math from contextlib import contextmanager from contextvars import ContextVar from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple import pymysql from pymysql.cursors import DictCursor from .errors import ( MySQLConnectionError, MySQLConfigError, MySQLQueryError, MySQLTransactionError, ) from .mysql_manager import MySQLClientManager, get_global_manager from .mysql_client import MySQLClient _tx_connection_var: ContextVar[Optional[pymysql.connections.Connection]] = ContextVar( "examples_how_db_utils_tx_connection", default=None ) def _normalize_where_params(where_params: Any) -> Optional[Tuple[Any, ...]]: if where_params is None: return None if isinstance(where_params, (list, tuple)): return tuple(where_params) if isinstance(where_params, dict): return tuple(where_params.values()) # Fallback: keep as-is (pymysql supports sequences/tuples or mapping) return (where_params,) class MySQLDB: """ High-level MySQL API (CRUD + advanced queries + transaction). Interface is aligned with `how_decode/utils/mysql/mysql_db` style. """ def __init__(self, *, manager: MySQLClientManager, source: str = "default"): self._manager = manager self._source = source @property def source(self) -> str: return self._source def _client(self) -> MySQLClient: return self._manager.get_client(self._source) def _is_in_transaction(self) -> bool: return _tx_connection_var.get() is not None @contextmanager def _get_connection_and_cursor( self, connection: Optional[pymysql.connections.Connection] = None ) -> Iterator[Tuple[pymysql.connections.Connection, DictCursor, bool]]: """ Returns (connection, cursor, should_close_connection). - If `connection` is provided: uses it and should_close_connection=False - Else if transaction connection exists: uses it and should_close_connection=False - Else opens a new connection: should_close_connection=True """ client = self._client() tx_conn = _tx_connection_var.get() should_close = False actual_conn: Optional[pymysql.connections.Connection] = None if connection is not None: actual_conn = connection elif tx_conn is not None: actual_conn = tx_conn else: actual_conn = client.open_connection() should_close = True cursor = actual_conn.cursor(DictCursor) try: yield actual_conn, cursor, should_close finally: try: cursor.close() except Exception: pass if should_close: try: actual_conn.close() except Exception: pass @contextmanager def transaction(self, isolation_level: Optional[str] = None): """ Transaction context manager. Important: when you call CRUD methods inside this context without passing `connection`, they will automatically reuse the same transaction connection (via ContextVar). """ client = self._client() conn = None token = None try: conn = client.open_connection() # Ensure explicit transaction. conn.autocommit(False) if isolation_level: conn.execute( f"SET SESSION TRANSACTION ISOLATION LEVEL {isolation_level}" ) conn.begin() token = _tx_connection_var.set(conn) yield conn conn.commit() except Exception as e: if conn is not None: try: conn.rollback() except Exception: pass raise MySQLTransactionError( message=f"transaction failed (source={self._source}): {e}", original_error=e, ) from e finally: if token is not None: _tx_connection_var.reset(token) if conn is not None: try: conn.close() except Exception: pass # ----------------------- # Basic CRUD # ----------------------- def select( self, table: str, columns: str = "*", where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, order_by: str = "", limit: Optional[int] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> List[Dict[str, Any]]: sql = f"SELECT {columns} FROM {table}" if where: sql += f" WHERE {where}" if order_by: sql += f" ORDER BY {order_by}" if limit is not None: sql += f" LIMIT {limit}" params = _normalize_where_params(where_params) try: with self._get_connection_and_cursor(connection) as (conn, cursor, should_close): cursor.execute(sql, params) return list(cursor.fetchall()) except Exception as e: raise MySQLQueryError( message=f"select failed (source={self._source}): {e}", original_error=e, ) from e def select_one( self, table: str, columns: str = "*", where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> Optional[Dict[str, Any]]: sql = f"SELECT {columns} FROM {table}" if where: sql += f" WHERE {where}" sql += " LIMIT 1" params = _normalize_where_params(where_params) try: with self._get_connection_and_cursor(connection) as (_conn, cursor, _should_close): cursor.execute(sql, params) return cursor.fetchone() except Exception as e: raise MySQLQueryError( message=f"select_one failed (source={self._source}): {e}", original_error=e, ) from e def insert( self, table: str, data: Dict[str, Any], connection: Optional[pymysql.connections.Connection] = None, ) -> int: if not data: raise ValueError("insert data must not be empty") columns = list(data.keys()) placeholders = ", ".join(["%s"] * len(columns)) sql = f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders})" params = tuple(data.values()) conn: Optional[pymysql.connections.Connection] = None try: with self._get_connection_and_cursor(connection) as (conn, cursor, should_close): cursor.execute(sql, params) # If not inside transaction, auto commit. if not self._is_in_transaction() and should_close: conn.commit() return int(getattr(cursor, "lastrowid", 0) or 0) except Exception as e: # If we opened a connection ourselves, rollback to be safe. if conn is not None and connection is None and not self._is_in_transaction(): try: conn.rollback() except Exception: pass raise MySQLQueryError( message=f"insert failed (source={self._source}): {e}", original_error=e, ) from e def insert_many( self, table: str, data_list: List[Dict[str, Any]], connection: Optional[pymysql.connections.Connection] = None, ) -> int: if not data_list: raise ValueError("insert_many data_list must not be empty") columns = list(data_list[0].keys()) placeholders = ", ".join(["%s"] * len(columns)) sql = f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders})" params_list = [tuple(d[col] for col in columns) for d in data_list] conn: Optional[pymysql.connections.Connection] = None try: with self._get_connection_and_cursor(connection) as (conn, cursor, should_close): cursor.executemany(sql, params_list) if not self._is_in_transaction() and should_close: conn.commit() return int(cursor.rowcount or 0) except Exception as e: if conn is not None and connection is None and not self._is_in_transaction(): try: conn.rollback() except Exception: pass raise MySQLQueryError( message=f"insert_many failed (source={self._source}): {e}", original_error=e, ) from e def update( self, table: str, data: Dict[str, Any], where: str, where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> int: if not data: raise ValueError("update data must not be empty") set_clause = ", ".join([f"{col}=%s" for col in data.keys()]) sql = f"UPDATE {table} SET {set_clause} WHERE {where}" params: List[Any] = list(data.values()) wp = _normalize_where_params(where_params) if wp is not None: params.extend(list(wp)) conn: Optional[pymysql.connections.Connection] = None try: with self._get_connection_and_cursor(connection) as (conn, cursor, should_close): cursor.execute(sql, tuple(params)) if not self._is_in_transaction() and should_close: conn.commit() return int(cursor.rowcount or 0) except Exception as e: if conn is not None and connection is None and not self._is_in_transaction(): try: conn.rollback() except Exception: pass raise MySQLQueryError( message=f"update failed (source={self._source}): {e}", original_error=e, ) from e def delete( self, table: str, where: str, where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> int: sql = f"DELETE FROM {table} WHERE {where}" params = _normalize_where_params(where_params) conn: Optional[pymysql.connections.Connection] = None try: with self._get_connection_and_cursor(connection) as (conn, cursor, should_close): cursor.execute(sql, params) if not self._is_in_transaction() and should_close: conn.commit() return int(cursor.rowcount or 0) except Exception as e: if conn is not None and connection is None and not self._is_in_transaction(): try: conn.rollback() except Exception: pass raise MySQLQueryError( message=f"delete failed (source={self._source}): {e}", original_error=e, ) from e def execute_many( self, sql: str, params_list: List[Sequence[Any] | Mapping[str, Any]], connection: Optional[pymysql.connections.Connection] = None, ) -> int: params_seq = [_normalize_where_params(p) for p in params_list] conn: Optional[pymysql.connections.Connection] = None try: with self._get_connection_and_cursor(connection) as (conn, cursor, should_close): cursor.executemany(sql, params_seq) if not self._is_in_transaction() and should_close: conn.commit() return int(cursor.rowcount or 0) except Exception as e: if conn is not None and connection is None and not self._is_in_transaction(): try: conn.rollback() except Exception: pass raise MySQLQueryError( message=f"execute_many failed (source={self._source}): {e}", original_error=e, ) from e # ----------------------- # Query helpers # ----------------------- def count( self, table: str, where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> int: sql = f"SELECT COUNT(*) as count FROM {table}" if where: sql += f" WHERE {where}" try: params = _normalize_where_params(where_params) with self._get_connection_and_cursor(connection) as (_conn, cursor, _should_close): cursor.execute(sql, params) r = cursor.fetchone() if not r: return 0 return int(r.get("count") or 0) except Exception as e: raise MySQLQueryError( message=f"count failed (source={self._source}): {e}", original_error=e, ) from e def exists( self, table: str, where: str, where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> bool: return self.count( table, where=where, where_params=where_params, connection=connection ) > 0 def paginate( self, table: str, page: int = 1, page_size: int = 20, columns: str = "*", where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, order_by: str = "", connection: Optional[pymysql.connections.Connection] = None, ) -> Dict[str, Any]: if page < 1: page = 1 if page_size < 1: page_size = 20 total_count = self.count( table, where=where, where_params=where_params, connection=connection ) total_pages = math.ceil(total_count / page_size) if total_count > 0 else 1 offset = (page - 1) * page_size sql = f"SELECT {columns} FROM {table}" if where: sql += f" WHERE {where}" if order_by: sql += f" ORDER BY {order_by}" sql += f" LIMIT {page_size} OFFSET {offset}" params = _normalize_where_params(where_params) try: with self._get_connection_and_cursor(connection) as (_conn, cursor, _should_close): cursor.execute(sql, params) data = list(cursor.fetchall()) except Exception as e: raise MySQLQueryError( message=f"paginate failed (source={self._source}): {e}", original_error=e, ) from e return { "data": data, "pagination": { "current_page": page, "page_size": page_size, "total_count": total_count, "total_pages": total_pages, "has_prev": page > 1, "has_next": page < total_pages, "prev_page": page - 1 if page > 1 else None, "next_page": page + 1 if page < total_pages else None, }, } def select_with_sort( self, table: str, columns: str = "*", where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, sort_field: str = "id", sort_order: str = "ASC", limit: Optional[int] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> List[Dict[str, Any]]: sort_order = (sort_order or "").upper() if sort_order not in ["ASC", "DESC"]: sort_order = "ASC" order_by = f"{sort_field} {sort_order}" return self.select( table, columns=columns, where=where, where_params=where_params, order_by=order_by, limit=limit, connection=connection, ) def select_with_multiple_sort( self, table: str, columns: str = "*", where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, sort_fields: Optional[List[Tuple[str, str]]] = None, limit: Optional[int] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> List[Dict[str, Any]]: order_by = "" if sort_fields: parts: List[str] = [] for field, order in sort_fields: order_u = (order or "").upper() if order_u not in ["ASC", "DESC"]: order_u = "ASC" parts.append(f"{field} {order_u}") order_by = ", ".join(parts) return self.select( table, columns=columns, where=where, where_params=where_params, order_by=order_by, limit=limit, connection=connection, ) def aggregate( self, table: str, agg_functions: Dict[str, str], where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, group_by: str = "", having: str = "", having_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> List[Dict[str, Any]]: if not agg_functions: raise ValueError("agg_functions must not be empty") select_parts: List[str] = [] if group_by: select_parts.append(group_by) for alias, func in agg_functions.items(): select_parts.append(f"{func} AS {alias}") sql = f"SELECT {', '.join(select_parts)} FROM {table}" if where: sql += f" WHERE {where}" if group_by: sql += f" GROUP BY {group_by}" if having: sql += f" HAVING {having}" params: List[Any] = [] wp = _normalize_where_params(where_params) if wp is not None: params.extend(list(wp)) hp = _normalize_where_params(having_params) if hp is not None: params.extend(list(hp)) try: with self._get_connection_and_cursor(connection) as (_conn, cursor, _should_close): cursor.execute(sql, tuple(params) if params else None) return list(cursor.fetchall()) except Exception as e: raise MySQLQueryError( message=f"aggregate failed (source={self._source}): {e}", original_error=e, ) from e def sum( self, table: str, column: str, where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> float: rows = self.aggregate( table=table, agg_functions={"sum_result": f"SUM({column})"}, where=where, where_params=where_params, connection=connection, ) v = rows[0].get("sum_result") if rows else None return float(v) if v is not None else 0.0 def avg( self, table: str, column: str, where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> float: rows = self.aggregate( table=table, agg_functions={"avg_result": f"AVG({column})"}, where=where, where_params=where_params, connection=connection, ) v = rows[0].get("avg_result") if rows else None return float(v) if v is not None else 0.0 def max( self, table: str, column: str, where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> Any: rows = self.aggregate( table=table, agg_functions={"max_result": f"MAX({column})"}, where=where, where_params=where_params, connection=connection, ) return rows[0].get("max_result") if rows and rows[0].get("max_result") is not None else None def min( self, table: str, column: str, where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> Any: rows = self.aggregate( table=table, agg_functions={"min_result": f"MIN({column})"}, where=where, where_params=where_params, connection=connection, ) return rows[0].get("min_result") if rows and rows[0].get("min_result") is not None else None def group_count( self, table: str, group_column: str, where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, order_by: str = "", limit: Optional[int] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> List[Dict[str, Any]]: sql = f"SELECT {group_column}, COUNT(*) as count FROM {table}" if where: sql += f" WHERE {where}" sql += f" GROUP BY {group_column}" if order_by: sql += f" ORDER BY {order_by}" else: sql += " ORDER BY count DESC" if limit is not None: sql += f" LIMIT {limit}" params = _normalize_where_params(where_params) try: with self._get_connection_and_cursor(connection) as (_conn, cursor, _should_close): cursor.execute(sql, params) return list(cursor.fetchall()) except Exception as e: raise MySQLQueryError( message=f"group_count failed (source={self._source}): {e}", original_error=e, ) from e def search( self, table: str, search_columns: List[str], keyword: str, columns: str = "*", where: str = "", where_params: Optional[Sequence[Any] | Mapping[str, Any]] = None, order_by: str = "", limit: Optional[int] = None, connection: Optional[pymysql.connections.Connection] = None, ) -> List[Dict[str, Any]]: if not search_columns or not keyword: return [] search_conditions: List[str] = [] search_params: List[Any] = [] for col in search_columns: search_conditions.append(f"{col} LIKE %s") search_params.append(f"%{keyword}%") search_where = f"({' OR '.join(search_conditions)})" final_where = search_where final_params: List[Any] = list(search_params) if where: final_where = f"{search_where} AND ({where})" wp = _normalize_where_params(where_params) if wp is not None: final_params.extend(list(wp)) return self.select( table, columns=columns, where=final_where, where_params=tuple(final_params), order_by=order_by, limit=limit, connection=connection, ) # ----------------------- # Functional transaction helpers (optional) # ----------------------- def execute_in_transaction( self, func, *args, isolation_level: Optional[str] = None, **kwargs, ) -> Any: with self.transaction(isolation_level=isolation_level) as conn: return func(conn, *args, **kwargs) def batch_operations( self, operations: list, isolation_level: Optional[str] = None, ) -> list: results: list = [] with self.transaction(isolation_level=isolation_level) as conn: for op in operations: method_name, args, op_kwargs = op op_kwargs = op_kwargs or {} op_kwargs["connection"] = conn method = getattr(self, method_name) results.append(method(*args, **op_kwargs)) return results _GLOBAL_DB: Dict[str, MySQLDB] = {} def get_mysql_db(source: str = "default") -> MySQLDB: if source not in _GLOBAL_DB: mgr = get_global_manager() _GLOBAL_DB[source] = MySQLDB(manager=mgr, source=source) return _GLOBAL_DB[source] # For compatibility with how_decode/utils/mysql (global mysql_db) mysql_db = get_mysql_db("default")