import logging import pymysql from contextlib import contextmanager from typing import Any, Callable, Optional from loguru import logger from .mysql_pool import mysql_pool from .mysql_advanced import MySQLAdvanced class MySQLTransaction(MySQLAdvanced): """MySQL事务管理类""" def __init__(self): super().__init__() @contextmanager def transaction(self, isolation_level: Optional[str] = None): """ 事务上下文管理器 Args: isolation_level: 事务隔离级别 - 'READ UNCOMMITTED' - 'READ COMMITTED' - 'REPEATABLE READ' - 'SERIALIZABLE' Usage: with mysql_transaction.transaction(): # 执行数据库操作 pass """ connection = None try: connection = self.pool.get_connection() # 设置事务隔离级别 if isolation_level: connection.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {isolation_level}") # 开始事务 connection.begin() yield connection # 提交事务 connection.commit() except Exception as e: # 回滚事务 if connection: connection.rollback() logger.error(f"事务执行失败,已回滚: {e}") raise finally: # 返回连接到连接池 if connection: self.pool.return_connection(connection) def execute_in_transaction(self, func: Callable, *args, isolation_level: Optional[str] = None, **kwargs) -> Any: """ 在事务中执行函数 Args: func: 要执行的函数,第一个参数必须是connection isolation_level: 事务隔离级别 *args: 函数参数 **kwargs: 函数关键字参数 Returns: 函数执行结果 Usage: def my_operations(connection, param1, param2): # 执行数据库操作 return result result = mysql_transaction.execute_in_transaction(my_operations, param1, param2) """ with self.transaction(isolation_level) as connection: return func(connection, *args, **kwargs) def batch_operations(self, operations: list, isolation_level: Optional[str] = None) -> list: """ 批量执行操作(在同一事务中) Args: operations: 操作列表,每个操作为 (method_name, args, kwargs) 的元组 isolation_level: 事务隔离级别 Returns: 所有操作的结果列表 Usage: operations = [ ('insert', ('table1', {'col1': 'value1'}), {}), ('update', ('table2', {'col2': 'value2'}, 'id = %s', (1,)), {}), ('delete', ('table3', 'id = %s', (2,)), {}) ] results = mysql_transaction.batch_operations(operations) """ results = [] with self.transaction(isolation_level) as connection: for operation in operations: method_name, args, kwargs = operation kwargs['connection'] = connection # 将连接传递给方法 # 获取方法并执行 method = getattr(self, method_name) result = method(*args, **kwargs) results.append(result) return results def savepoint_transaction(self, savepoint_name: str = "sp1"): """ 保存点事务管理器 Args: savepoint_name: 保存点名称 Usage: with mysql_transaction.transaction(): # 一些操作 with mysql_transaction.savepoint_transaction("sp1"): # 需要保存点的操作 pass """ return SavepointManager(savepoint_name) class SavepointManager: """保存点管理器""" def __init__(self, savepoint_name: str): self.savepoint_name = savepoint_name self.connection = None self.logger = logging.getLogger(__name__) def __enter__(self): # 这里需要获取当前事务的连接,但由于上下文限制,暂时抛出异常 # 实际使用中,需要将connection传入 raise NotImplementedError("保存点功能需要在事务上下文中使用,请直接使用connection.execute") def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None: try: self.connection.execute(f"ROLLBACK TO SAVEPOINT {self.savepoint_name}") self.logger.info(f"回滚到保存点: {self.savepoint_name}") except Exception as e: self.logger.error(f"回滚保存点失败: {e}") else: try: self.connection.execute(f"RELEASE SAVEPOINT {self.savepoint_name}") except Exception as e: self.logger.error(f"释放保存点失败: {e}") class TransactionHelper: """事务辅助工具""" @staticmethod def create_savepoint(connection: pymysql.Connection, savepoint_name: str): """创建保存点""" connection.execute(f"SAVEPOINT {savepoint_name}") @staticmethod def rollback_to_savepoint(connection: pymysql.Connection, savepoint_name: str): """回滚到保存点""" connection.execute(f"ROLLBACK TO SAVEPOINT {savepoint_name}") @staticmethod def release_savepoint(connection: pymysql.Connection, savepoint_name: str): """释放保存点""" connection.execute(f"RELEASE SAVEPOINT {savepoint_name}") @staticmethod def set_isolation_level(connection: pymysql.Connection, level: str): """设置事务隔离级别""" valid_levels = ['READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ', 'SERIALIZABLE'] if level not in valid_levels: raise ValueError(f"无效的隔离级别: {level}") connection.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {level}") @staticmethod def get_isolation_level(connection: pymysql.Connection) -> str: """获取当前事务隔离级别""" cursor = connection.cursor() cursor.execute("SELECT @@SESSION.transaction_isolation") result = cursor.fetchone() cursor.close() return result[0] if result else None # 全局实例 mysql_transaction = MySQLTransaction() transaction_helper = TransactionHelper()