| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- import logging
- import pymysql
- from contextlib import contextmanager
- from typing import Any, Callable, Optional
- from loguru import logger
- from .mysql_pool import mysql_pool
- from .mysql_advanced import MySQLAdvanced
- class MySQLTransaction(MySQLAdvanced):
- """MySQL事务管理类"""
- def __init__(self):
- super().__init__()
- @contextmanager
- def transaction(self, isolation_level: Optional[str] = None):
- """
- 事务上下文管理器
- Args:
- isolation_level: 事务隔离级别
- - 'READ UNCOMMITTED'
- - 'READ COMMITTED'
- - 'REPEATABLE READ'
- - 'SERIALIZABLE'
- Usage:
- with mysql_transaction.transaction():
- # 执行数据库操作
- pass
- """
- connection = None
- try:
- connection = self.pool.get_connection()
- # 设置事务隔离级别
- if isolation_level:
- connection.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {isolation_level}")
- # 开始事务
- connection.begin()
- yield connection
- # 提交事务
- connection.commit()
- except Exception as e:
- # 回滚事务
- if connection:
- connection.rollback()
- logger.error(f"事务执行失败,已回滚: {e}")
- raise
- finally:
- # 返回连接到连接池
- if connection:
- self.pool.return_connection(connection)
- def execute_in_transaction(self, func: Callable, *args, isolation_level: Optional[str] = None, **kwargs) -> Any:
- """
- 在事务中执行函数
- Args:
- func: 要执行的函数,第一个参数必须是connection
- isolation_level: 事务隔离级别
- *args: 函数参数
- **kwargs: 函数关键字参数
- Returns:
- 函数执行结果
- Usage:
- def my_operations(connection, param1, param2):
- # 执行数据库操作
- return result
- result = mysql_transaction.execute_in_transaction(my_operations, param1, param2)
- """
- with self.transaction(isolation_level) as connection:
- return func(connection, *args, **kwargs)
- def batch_operations(self, operations: list, isolation_level: Optional[str] = None) -> list:
- """
- 批量执行操作(在同一事务中)
- Args:
- operations: 操作列表,每个操作为 (method_name, args, kwargs) 的元组
- isolation_level: 事务隔离级别
- Returns:
- 所有操作的结果列表
- Usage:
- operations = [
- ('insert', ('table1', {'col1': 'value1'}), {}),
- ('update', ('table2', {'col2': 'value2'}, 'id = %s', (1,)), {}),
- ('delete', ('table3', 'id = %s', (2,)), {})
- ]
- results = mysql_transaction.batch_operations(operations)
- """
- results = []
- with self.transaction(isolation_level) as connection:
- for operation in operations:
- method_name, args, kwargs = operation
- kwargs['connection'] = connection # 将连接传递给方法
- # 获取方法并执行
- method = getattr(self, method_name)
- result = method(*args, **kwargs)
- results.append(result)
- return results
- def savepoint_transaction(self, savepoint_name: str = "sp1"):
- """
- 保存点事务管理器
- Args:
- savepoint_name: 保存点名称
- Usage:
- with mysql_transaction.transaction():
- # 一些操作
- with mysql_transaction.savepoint_transaction("sp1"):
- # 需要保存点的操作
- pass
- """
- return SavepointManager(savepoint_name)
- class SavepointManager:
- """保存点管理器"""
- def __init__(self, savepoint_name: str):
- self.savepoint_name = savepoint_name
- self.connection = None
- self.logger = logging.getLogger(__name__)
- def __enter__(self):
- # 这里需要获取当前事务的连接,但由于上下文限制,暂时抛出异常
- # 实际使用中,需要将connection传入
- raise NotImplementedError("保存点功能需要在事务上下文中使用,请直接使用connection.execute")
- def __exit__(self, exc_type, exc_val, exc_tb):
- if exc_type is not None:
- try:
- self.connection.execute(f"ROLLBACK TO SAVEPOINT {self.savepoint_name}")
- self.logger.info(f"回滚到保存点: {self.savepoint_name}")
- except Exception as e:
- self.logger.error(f"回滚保存点失败: {e}")
- else:
- try:
- self.connection.execute(f"RELEASE SAVEPOINT {self.savepoint_name}")
- except Exception as e:
- self.logger.error(f"释放保存点失败: {e}")
- class TransactionHelper:
- """事务辅助工具"""
- @staticmethod
- def create_savepoint(connection: pymysql.Connection, savepoint_name: str):
- """创建保存点"""
- connection.execute(f"SAVEPOINT {savepoint_name}")
- @staticmethod
- def rollback_to_savepoint(connection: pymysql.Connection, savepoint_name: str):
- """回滚到保存点"""
- connection.execute(f"ROLLBACK TO SAVEPOINT {savepoint_name}")
- @staticmethod
- def release_savepoint(connection: pymysql.Connection, savepoint_name: str):
- """释放保存点"""
- connection.execute(f"RELEASE SAVEPOINT {savepoint_name}")
- @staticmethod
- def set_isolation_level(connection: pymysql.Connection, level: str):
- """设置事务隔离级别"""
- valid_levels = ['READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ', 'SERIALIZABLE']
- if level not in valid_levels:
- raise ValueError(f"无效的隔离级别: {level}")
- connection.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {level}")
- @staticmethod
- def get_isolation_level(connection: pymysql.Connection) -> str:
- """获取当前事务隔离级别"""
- cursor = connection.cursor()
- cursor.execute("SELECT @@SESSION.transaction_isolation")
- result = cursor.fetchone()
- cursor.close()
- return result[0] if result else None
- # 全局实例
- mysql_transaction = MySQLTransaction()
- transaction_helper = TransactionHelper()
|