from __future__ import annotations from contextlib import contextmanager from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple import pymysql from pymysql.cursors import DictCursor from .errors import MySQLConnectionError, MySQLQueryError from .types import ExecResult, MySQLConfig, Params try: # Optional dependency for connection pooling. from dbutils.pooled_db import PooledDB # type: ignore _HAS_DBUTILS_POOL = True except Exception: # pragma: no cover PooledDB = None # type: ignore _HAS_DBUTILS_POOL = False def _normalize_params(params: Params) -> Any: if params is None: return None # pymysql supports Mapping (dict) and Sequence (tuple/list). if isinstance(params, (list, tuple)): return tuple(params) return params class MySQLClient: """ Synchronous MySQL client based on `pymysql`. Main APIs: - fetchone / fetchall / fetchmany for SELECT queries - execute / executemany for INSERT/UPDATE/DELETE/DDL - transaction() context manager for multi-statement transactions """ def __init__(self, config: MySQLConfig): self._config = config self._pool = None if self._config.use_pool and _HAS_DBUTILS_POOL: conn_kwargs: Dict[str, Any] = dict( host=self._config.host, port=self._config.port, user=self._config.user, password=self._config.password, database=self._config.database, charset=self._config.charset, connect_timeout=self._config.connect_timeout, read_timeout=self._config.read_timeout, write_timeout=self._config.write_timeout, autocommit=self._config.autocommit, ) self._pool = PooledDB( creator=pymysql, mincached=self._config.pool_mincached, maxconnections=self._config.pool_maxconnections, blocking=True, **conn_kwargs, ) @property def config(self) -> MySQLConfig: return self._config def open_connection(self) -> pymysql.connections.Connection: """ Open a raw pymysql connection. Intended usage: - transaction() keep the same connection for multiple operations - other cases where you want manual connection lifecycle Note: - Caller must close the connection via `close_connection()`. """ if self._pool is not None: return self._pool.connection() return self._connect() def close_connection(self, connection: pymysql.connections.Connection) -> None: """Close a previously opened connection (returns to pool when applicable).""" try: connection.close() except Exception: pass def _connect(self) -> pymysql.connections.Connection: try: return pymysql.connect( host=self._config.host, port=self._config.port, user=self._config.user, password=self._config.password, database=self._config.database, charset=self._config.charset, cursorclass=DictCursor, connect_timeout=self._config.connect_timeout, read_timeout=self._config.read_timeout, write_timeout=self._config.write_timeout, autocommit=self._config.autocommit, ) except Exception as e: # pragma: no cover raise MySQLConnectionError( f"MySQL connection failed (source={self._config.source}, host={self._config.host}, db={self._config.database}): {e}" ) from e @contextmanager def _cursor(self) -> Iterator[Tuple[pymysql.connections.Connection, DictCursor]]: """ Yield (connection, cursor) using DictCursor by default. Uses pool connection if configured; otherwise creates a fresh connection. """ if self._pool is not None: # DBUtils pooled connection supports context manager. conn = self._pool.connection() else: conn = self._connect() try: cursor = conn.cursor(DictCursor) try: yield conn, cursor finally: cursor.close() finally: try: conn.close() except Exception: pass def transaction(self): """ Transaction context manager. - Commits on success when autocommit is False. - Rolls back on exception. """ @contextmanager def _tx(): if self._pool is not None: conn = self._pool.connection() else: conn = self._connect() # If autocommit=False, commit/rollback controls are meaningful. tx_active = not self._config.autocommit try: cursor = conn.cursor(DictCursor) try: yield cursor finally: cursor.close() if tx_active: conn.commit() except Exception: if tx_active: try: conn.rollback() except Exception: pass raise finally: try: conn.close() except Exception: pass return _tx() def fetchone(self, sql: str, params: Params = None) -> Optional[Dict[str, Any]]: with self._cursor() as (conn, cursor): try: cursor.execute(sql, _normalize_params(params)) return cursor.fetchone() except Exception as e: raise MySQLQueryError(f"fetchone failed: {e} | sql={sql}") from e def fetchall( self, sql: str, params: Params = None, *, max_rows: Optional[int] = None, ) -> List[Dict[str, Any]]: with self._cursor() as (conn, cursor): try: cursor.execute(sql, _normalize_params(params)) if max_rows is None: return list(cursor.fetchall()) # Cursor fetchall has no max; fallback to fetchmany. out: List[Dict[str, Any]] = [] while True: batch = cursor.fetchmany(size=max_rows - len(out)) if not batch: break out.extend(batch) if len(out) >= max_rows: break return out except Exception as e: raise MySQLQueryError(f"fetchall failed: {e} | sql={sql}") from e def fetchmany( self, sql: str, params: Params = None, *, size: int = 100, ) -> List[Dict[str, Any]]: with self._cursor() as (conn, cursor): try: cursor.execute(sql, _normalize_params(params)) return list(cursor.fetchmany(size=size)) except Exception as e: raise MySQLQueryError(f"fetchmany failed: {e} | sql={sql}") from e def execute( self, sql: str, params: Params = None, *, commit: Optional[bool] = None, ignore_duplicate: bool = False, ignore_deadlock: bool = False, ) -> ExecResult: """ Execute a write statement. Args: commit: - None: follow config.autocommit (当 `autocommit=False` 时默认会 commit) - True/False: force commit/rollback behavior ignore_duplicate: If True, silently ignore MySQL duplicate-key errors (1062). ignore_deadlock: If True, rollback and silently ignore deadlocks (1205). """ # Commit semantics: # - autocommit=True: no explicit commit is required/possible. # - autocommit=False: # - commit is None: commit by default # - commit=True: force commit # - commit=False: skip commit commit_enabled = (not self._config.autocommit) and (True if commit is None else bool(commit)) with self._cursor() as (conn, cursor): try: cursor.execute(sql, _normalize_params(params)) if commit_enabled: conn.commit() return ExecResult( rowcount=int(cursor.rowcount or 0), lastrowid=getattr(cursor, "lastrowid", None), ) except pymysql.err.IntegrityError as e: if ignore_duplicate and getattr(e, "args", None) and e.args and e.args[0] == 1062: if not self._config.autocommit: conn.rollback() return ExecResult(rowcount=0, lastrowid=None) raise MySQLQueryError(f"execute failed (IntegrityError): {e} | sql={sql}") from e except pymysql.err.OperationalError as e: if ignore_deadlock and getattr(e, "args", None) and e.args and e.args[0] == 1205: if not self._config.autocommit: conn.rollback() return ExecResult(rowcount=0, lastrowid=None) raise MySQLQueryError(f"execute failed (OperationalError): {e} | sql={sql}") from e except Exception as e: if not self._config.autocommit: try: conn.rollback() except Exception: pass raise MySQLQueryError(f"execute failed: {e} | sql={sql}") from e def executemany( self, sql: str, params_seq: Sequence[Params], *, commit: Optional[bool] = None, ignore_duplicate: bool = False, ignore_deadlock: bool = False, ) -> ExecResult: """ Execute a statement against multiple parameter sets. Note: If ignore_duplicate/ignore_deadlock are enabled, we fall back to per-row execution to emulate "ignore" semantics without terminating the whole batch. """ commit_enabled = (not self._config.autocommit) and (True if commit is None else bool(commit)) if (ignore_duplicate or ignore_deadlock) and self._config.autocommit is False: # Fallback: execute one-by-one so we can ignore specific errors. total = 0 lastrowid: Optional[int] = None for p in params_seq: res = self.execute( sql, p, commit=commit, ignore_duplicate=ignore_duplicate, ignore_deadlock=ignore_deadlock, ) total += res.rowcount lastrowid = res.lastrowid return ExecResult(rowcount=total, lastrowid=lastrowid) with self._cursor() as (conn, cursor): try: cursor.executemany(sql, [_normalize_params(p) for p in params_seq]) if commit_enabled: conn.commit() return ExecResult( rowcount=int(cursor.rowcount or 0), lastrowid=getattr(cursor, "lastrowid", None), ) except pymysql.err.IntegrityError as e: if ignore_duplicate and getattr(e, "args", None) and e.args and e.args[0] == 1062: if not self._config.autocommit: conn.rollback() return ExecResult(rowcount=0, lastrowid=None) raise MySQLQueryError(f"executemany failed (IntegrityError): {e} | sql={sql}") from e except pymysql.err.OperationalError as e: if ignore_deadlock and getattr(e, "args", None) and e.args and e.args[0] == 1205: if not self._config.autocommit: conn.rollback() return ExecResult(rowcount=0, lastrowid=None) raise MySQLQueryError(f"executemany failed (OperationalError): {e} | sql={sql}") from e except Exception as e: if not self._config.autocommit: try: conn.rollback() except Exception: pass raise MySQLQueryError(f"executemany failed: {e} | sql={sql}") from e