import pymysql from typing import Dict, List, Any, Optional, Union, Tuple from contextlib import contextmanager from loguru import logger from .mysql_pool import mysql_pool class MySQLHelper: """MySQL数据库操作助手类""" def __init__(self): self.pool = mysql_pool @contextmanager def get_cursor(self, connection: pymysql.Connection = None): """获取游标的上下文管理器""" if connection: cursor = connection.cursor() try: yield cursor finally: cursor.close() else: with self.pool.get_connection_context() as conn: cursor = conn.cursor() try: yield cursor finally: cursor.close() def execute_query(self, sql: str, params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> List[Dict[str, Any]]: """ 执行查询操作 Args: sql: SQL语句 params: 参数 connection: 数据库连接(可选,用于事务) Returns: 查询结果列表 """ try: with self.get_cursor(connection) as cursor: cursor.execute(sql, params) return cursor.fetchall() except Exception as e: logger.error(f"查询执行失败: {sql}, 参数: {params}, 错误: {e}") raise def execute_one(self, sql: str, params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> Optional[Dict[str, Any]]: """ 执行查询操作,返回单条记录 Args: sql: SQL语句 params: 参数 connection: 数据库连接(可选,用于事务) Returns: 单条记录或None """ try: with self.get_cursor(connection) as cursor: cursor.execute(sql, params) return cursor.fetchone() except Exception as e: logger.error(f"查询执行失败: {sql}, 参数: {params}, 错误: {e}") raise def execute_update(self, sql: str, params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> int: """ 执行更新操作(INSERT、UPDATE、DELETE) Args: sql: SQL语句 params: 参数 connection: 数据库连接(可选,用于事务) Returns: 影响的行数 """ try: with self.get_cursor(connection) as cursor: affected_rows = cursor.execute(sql, params) if not connection: # 如果没有传入连接,自动提交 cursor.connection.commit() return affected_rows except Exception as e: if not connection: cursor.connection.rollback() logger.error(f"更新执行失败: {sql}, 参数: {params}, 错误: {e}") raise def execute_many(self, sql: str, params_list: List[Union[tuple, dict]], connection: pymysql.Connection = None) -> int: """ 批量执行操作 Args: sql: SQL语句 params_list: 参数列表 connection: 数据库连接(可选,用于事务) Returns: 影响的总行数 """ try: with self.get_cursor(connection) as cursor: affected_rows = cursor.executemany(sql, params_list) if not connection: cursor.connection.commit() return affected_rows except Exception as e: if not connection: cursor.connection.rollback() logger.error(f"批量执行失败: {sql}, 参数: {params_list}, 错误: {e}") raise def insert(self, table: str, data: Dict[str, Any], connection: pymysql.Connection = None) -> int: """ 插入数据 Args: table: 表名 data: 数据字典 connection: 数据库连接(可选,用于事务) Returns: 插入的记录ID """ if not data: raise ValueError("插入数据不能为空") columns = list(data.keys()) placeholders = ', '.join(['%s'] * len(columns)) sql = f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders})" params = tuple(data.values()) try: with self.get_cursor(connection) as cursor: cursor.execute(sql, params) if not connection: cursor.connection.commit() return cursor.lastrowid except Exception as e: if not connection: cursor.connection.rollback() logger.error(f"插入失败: {sql}, 参数: {params}, 错误: {e}") raise def insert_many(self, table: str, data_list: List[Dict[str, Any]], connection: pymysql.Connection = None) -> int: """ 批量插入数据 Args: table: 表名 data_list: 数据列表 connection: 数据库连接(可选,用于事务) Returns: 影响的行数 """ if not data_list: raise ValueError("插入数据不能为空") # 使用第一条记录的键作为列名 columns = list(data_list[0].keys()) placeholders = ', '.join(['%s'] * len(columns)) sql = f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders})" # 构建参数列表 params_list = [tuple(data[col] for col in columns) for data in data_list] return self.execute_many(sql, params_list, connection) def update(self, table: str, data: Dict[str, Any], where: str, where_params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> int: """ 更新数据 Args: table: 表名 data: 更新数据字典 where: WHERE条件 where_params: WHERE条件参数 connection: 数据库连接(可选,用于事务) Returns: 影响的行数 """ if not data: raise ValueError("更新数据不能为空") set_clause = ', '.join([f"{col} = %s" for col in data.keys()]) sql = f"UPDATE {table} SET {set_clause} WHERE {where}" # 合并参数 params = list(data.values()) if where_params: if isinstance(where_params, (tuple, list)): params.extend(where_params) elif isinstance(where_params, dict): params.extend(where_params.values()) return self.execute_update(sql, tuple(params), connection) def delete(self, table: str, where: str, where_params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> int: """ 删除数据 Args: table: 表名 where: WHERE条件 where_params: WHERE条件参数 connection: 数据库连接(可选,用于事务) Returns: 影响的行数 """ sql = f"DELETE FROM {table} WHERE {where}" return self.execute_update(sql, where_params, connection) def select(self, table: str, columns: str = "*", where: str = "", where_params: Optional[Union[tuple, dict]] = None, order_by: str = "", limit: Optional[int] = None, connection: pymysql.Connection = None) -> List[Dict[str, Any]]: """ 查询数据 Args: table: 表名 columns: 查询列,默认为* where: WHERE条件 where_params: WHERE条件参数 order_by: 排序条件 limit: 限制数量 connection: 数据库连接(可选,用于事务) Returns: 查询结果列表 """ sql = f"SELECT {columns} FROM {table}" if where: sql += f" WHERE {where}" if order_by: sql += f" ORDER BY {order_by}" if limit: sql += f" LIMIT {limit}" return self.execute_query(sql, where_params, connection) def select_one(self, table: str, columns: str = "*", where: str = "", where_params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> Optional[Dict[str, Any]]: """ 查询单条数据 Args: table: 表名 columns: 查询列,默认为* where: WHERE条件 where_params: WHERE条件参数 connection: 数据库连接(可选,用于事务) Returns: 单条记录或None """ sql = f"SELECT {columns} FROM {table}" if where: sql += f" WHERE {where}" sql += " LIMIT 1" return self.execute_one(sql, where_params, connection) def count(self, table: str, where: str = "", where_params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> int: """ 统计记录数 Args: table: 表名 where: WHERE条件 where_params: WHERE条件参数 connection: 数据库连接(可选,用于事务) Returns: 记录数 """ sql = f"SELECT COUNT(*) as count FROM {table}" if where: sql += f" WHERE {where}" result = self.execute_one(sql, where_params, connection) return result['count'] if result else 0 def exists(self, table: str, where: str, where_params: Optional[Union[tuple, dict]] = None, connection: pymysql.Connection = None) -> bool: """ 检查记录是否存在 Args: table: 表名 where: WHERE条件 where_params: WHERE条件参数 connection: 数据库连接(可选,用于事务) Returns: 是否存在 """ return self.count(table, where, where_params, connection) > 0 # 全局实例 mysql_helper = MySQLHelper()