mysql_transaction.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import logging
  2. import pymysql
  3. from contextlib import contextmanager
  4. from typing import Any, Callable, Optional
  5. from loguru import logger
  6. from .mysql_pool import mysql_pool
  7. from .mysql_advanced import MySQLAdvanced
  8. class MySQLTransaction(MySQLAdvanced):
  9. """MySQL事务管理类"""
  10. def __init__(self):
  11. super().__init__()
  12. @contextmanager
  13. def transaction(self, isolation_level: Optional[str] = None):
  14. """
  15. 事务上下文管理器
  16. Args:
  17. isolation_level: 事务隔离级别
  18. - 'READ UNCOMMITTED'
  19. - 'READ COMMITTED'
  20. - 'REPEATABLE READ'
  21. - 'SERIALIZABLE'
  22. Usage:
  23. with mysql_transaction.transaction():
  24. # 执行数据库操作
  25. pass
  26. """
  27. connection = None
  28. try:
  29. connection = self.pool.get_connection()
  30. # 设置事务隔离级别
  31. if isolation_level:
  32. connection.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {isolation_level}")
  33. # 开始事务
  34. connection.begin()
  35. yield connection
  36. # 提交事务
  37. connection.commit()
  38. except Exception as e:
  39. # 回滚事务
  40. if connection:
  41. connection.rollback()
  42. logger.error(f"事务执行失败,已回滚: {e}")
  43. raise
  44. finally:
  45. # 返回连接到连接池
  46. if connection:
  47. self.pool.return_connection(connection)
  48. def execute_in_transaction(self, func: Callable, *args, isolation_level: Optional[str] = None, **kwargs) -> Any:
  49. """
  50. 在事务中执行函数
  51. Args:
  52. func: 要执行的函数,第一个参数必须是connection
  53. isolation_level: 事务隔离级别
  54. *args: 函数参数
  55. **kwargs: 函数关键字参数
  56. Returns:
  57. 函数执行结果
  58. Usage:
  59. def my_operations(connection, param1, param2):
  60. # 执行数据库操作
  61. return result
  62. result = mysql_transaction.execute_in_transaction(my_operations, param1, param2)
  63. """
  64. with self.transaction(isolation_level) as connection:
  65. return func(connection, *args, **kwargs)
  66. def batch_operations(self, operations: list, isolation_level: Optional[str] = None) -> list:
  67. """
  68. 批量执行操作(在同一事务中)
  69. Args:
  70. operations: 操作列表,每个操作为 (method_name, args, kwargs) 的元组
  71. isolation_level: 事务隔离级别
  72. Returns:
  73. 所有操作的结果列表
  74. Usage:
  75. operations = [
  76. ('insert', ('table1', {'col1': 'value1'}), {}),
  77. ('update', ('table2', {'col2': 'value2'}, 'id = %s', (1,)), {}),
  78. ('delete', ('table3', 'id = %s', (2,)), {})
  79. ]
  80. results = mysql_transaction.batch_operations(operations)
  81. """
  82. results = []
  83. with self.transaction(isolation_level) as connection:
  84. for operation in operations:
  85. method_name, args, kwargs = operation
  86. kwargs['connection'] = connection # 将连接传递给方法
  87. # 获取方法并执行
  88. method = getattr(self, method_name)
  89. result = method(*args, **kwargs)
  90. results.append(result)
  91. return results
  92. def savepoint_transaction(self, savepoint_name: str = "sp1"):
  93. """
  94. 保存点事务管理器
  95. Args:
  96. savepoint_name: 保存点名称
  97. Usage:
  98. with mysql_transaction.transaction():
  99. # 一些操作
  100. with mysql_transaction.savepoint_transaction("sp1"):
  101. # 需要保存点的操作
  102. pass
  103. """
  104. return SavepointManager(savepoint_name)
  105. class SavepointManager:
  106. """保存点管理器"""
  107. def __init__(self, savepoint_name: str):
  108. self.savepoint_name = savepoint_name
  109. self.connection = None
  110. self.logger = logging.getLogger(__name__)
  111. def __enter__(self):
  112. # 这里需要获取当前事务的连接,但由于上下文限制,暂时抛出异常
  113. # 实际使用中,需要将connection传入
  114. raise NotImplementedError("保存点功能需要在事务上下文中使用,请直接使用connection.execute")
  115. def __exit__(self, exc_type, exc_val, exc_tb):
  116. if exc_type is not None:
  117. try:
  118. self.connection.execute(f"ROLLBACK TO SAVEPOINT {self.savepoint_name}")
  119. self.logger.info(f"回滚到保存点: {self.savepoint_name}")
  120. except Exception as e:
  121. self.logger.error(f"回滚保存点失败: {e}")
  122. else:
  123. try:
  124. self.connection.execute(f"RELEASE SAVEPOINT {self.savepoint_name}")
  125. except Exception as e:
  126. self.logger.error(f"释放保存点失败: {e}")
  127. class TransactionHelper:
  128. """事务辅助工具"""
  129. @staticmethod
  130. def create_savepoint(connection: pymysql.Connection, savepoint_name: str):
  131. """创建保存点"""
  132. connection.execute(f"SAVEPOINT {savepoint_name}")
  133. @staticmethod
  134. def rollback_to_savepoint(connection: pymysql.Connection, savepoint_name: str):
  135. """回滚到保存点"""
  136. connection.execute(f"ROLLBACK TO SAVEPOINT {savepoint_name}")
  137. @staticmethod
  138. def release_savepoint(connection: pymysql.Connection, savepoint_name: str):
  139. """释放保存点"""
  140. connection.execute(f"RELEASE SAVEPOINT {savepoint_name}")
  141. @staticmethod
  142. def set_isolation_level(connection: pymysql.Connection, level: str):
  143. """设置事务隔离级别"""
  144. valid_levels = ['READ UNCOMMITTED', 'READ COMMITTED', 'REPEATABLE READ', 'SERIALIZABLE']
  145. if level not in valid_levels:
  146. raise ValueError(f"无效的隔离级别: {level}")
  147. connection.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {level}")
  148. @staticmethod
  149. def get_isolation_level(connection: pymysql.Connection) -> str:
  150. """获取当前事务隔离级别"""
  151. cursor = connection.cursor()
  152. cursor.execute("SELECT @@SESSION.transaction_isolation")
  153. result = cursor.fetchone()
  154. cursor.close()
  155. return result[0] if result else None
  156. # 全局实例
  157. mysql_transaction = MySQLTransaction()
  158. transaction_helper = TransactionHelper()