|
@@ -0,0 +1,151 @@
|
|
|
+"""
|
|
|
+@author: luojunhui
|
|
|
+"""
|
|
|
+
|
|
|
+import pymysql
|
|
|
+from contextlib import contextmanager
|
|
|
+from .exceptions import QueryError, TransactionError
|
|
|
+
|
|
|
+
|
|
|
+class DatabaseConnector:
|
|
|
+ """
|
|
|
+ 数据库连接器,使用 pymysql 进行 MySQL 数据库操作。
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, db_config):
|
|
|
+ """
|
|
|
+ 初始化数据库连接配置。
|
|
|
+
|
|
|
+ :param db_config:
|
|
|
+ """
|
|
|
+ self.db_config = db_config
|
|
|
+ self.connection = None
|
|
|
+
|
|
|
+ def connect(self):
|
|
|
+ """
|
|
|
+ 建立数据库连接。
|
|
|
+
|
|
|
+ :raises ConnectionError: 如果无法连接到数据库。
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ self.connection = pymysql.connect(
|
|
|
+ host=self.db_config.get('host', 'localhost'),
|
|
|
+ user=self.db_config['user'],
|
|
|
+ password=self.db_config['password'],
|
|
|
+ db=self.db_config['db'],
|
|
|
+ port=self.db_config.get('port', 3306),
|
|
|
+ charset=self.db_config.get('charset', 'utf8mb4')
|
|
|
+ )
|
|
|
+ except pymysql.MySQLError as e:
|
|
|
+ raise ConnectionError(f"无法连接到数据库: {e}")
|
|
|
+
|
|
|
+ def close(self):
|
|
|
+ """
|
|
|
+ 关闭数据库连接。
|
|
|
+ """
|
|
|
+ if self.connection:
|
|
|
+ self.connection.close()
|
|
|
+ self.connection = None
|
|
|
+
|
|
|
+ def execute_query(self, query, params=None):
|
|
|
+ """
|
|
|
+ 执行单条查询语句,并返回结果。
|
|
|
+
|
|
|
+ :param query: SQL 查询语句。
|
|
|
+ :param params: 可选的参数,用于参数化查询。
|
|
|
+ :return: 查询结果列表。
|
|
|
+ :raises QueryError: 如果执行查询时出错。
|
|
|
+ """
|
|
|
+ if not self.connection:
|
|
|
+ self.connect()
|
|
|
+
|
|
|
+ try:
|
|
|
+ with self.connection.cursor() as cursor:
|
|
|
+ cursor.execute(query, params)
|
|
|
+ result = cursor.fetchall()
|
|
|
+ return result
|
|
|
+ except pymysql.MySQLError as e:
|
|
|
+ self.connection.rollback()
|
|
|
+ raise QueryError(f"查询执行失败: {e}")
|
|
|
+
|
|
|
+ def execute_many(self, query, params_list):
|
|
|
+ """
|
|
|
+ 执行多条查询语句。
|
|
|
+
|
|
|
+ :param query: SQL 查询语句。
|
|
|
+ :param params_list: 包含多个参数的列表。
|
|
|
+ :raises QueryError: 如果执行查询时出错。
|
|
|
+ """
|
|
|
+ if not self.connection:
|
|
|
+ self.connect()
|
|
|
+
|
|
|
+ try:
|
|
|
+ with self.connection.cursor() as cursor:
|
|
|
+ cursor.executemany(query, params_list)
|
|
|
+ except pymysql.MySQLError as e:
|
|
|
+ self.connection.rollback()
|
|
|
+ raise QueryError(f"批量查询执行失败: {e}")
|
|
|
+
|
|
|
+ def commit(self):
|
|
|
+ """
|
|
|
+ 提交当前事务。
|
|
|
+
|
|
|
+ :raises TransactionError: 如果提交事务时出错。
|
|
|
+ """
|
|
|
+ if not self.connection:
|
|
|
+ raise TransactionError("没有活动的数据库连接。")
|
|
|
+ try:
|
|
|
+ self.connection.commit()
|
|
|
+ print("事务提交成功。")
|
|
|
+ except pymysql.MySQLError as e:
|
|
|
+ self.connection.rollback()
|
|
|
+ raise TransactionError(f"提交事务失败: {e}")
|
|
|
+
|
|
|
+ def rollback(self):
|
|
|
+ """
|
|
|
+ 回滚当前事务。
|
|
|
+
|
|
|
+ :raises TransactionError: 如果回滚事务时出错。
|
|
|
+ """
|
|
|
+ if not self.connection:
|
|
|
+ raise TransactionError("没有活动的数据库连接。")
|
|
|
+ try:
|
|
|
+ self.connection.rollback()
|
|
|
+ print("事务已回滚。")
|
|
|
+ except pymysql.MySQLError as e:
|
|
|
+ raise TransactionError(f"回滚事务失败: {e}")
|
|
|
+
|
|
|
+ @contextmanager
|
|
|
+ def transaction(self):
|
|
|
+ """
|
|
|
+ 上下文管理器,用于处理事务。
|
|
|
+
|
|
|
+ 使用示例:
|
|
|
+ with db.transaction():
|
|
|
+ db.execute_query("INSERT INTO ...")
|
|
|
+ db.execute_query("UPDATE ...")
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ yield
|
|
|
+ self.commit()
|
|
|
+ except Exception as e:
|
|
|
+ self.rollback()
|
|
|
+ raise e
|
|
|
+
|
|
|
+ def __enter__(self):
|
|
|
+ """
|
|
|
+ 支持 with 语句,进入上下文时建立连接。
|
|
|
+ """
|
|
|
+ self.connect()
|
|
|
+ return self
|
|
|
+
|
|
|
+ def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
+ """
|
|
|
+ 支持 with 语句,退出上下文时关闭连接。
|
|
|
+ """
|
|
|
+ if exc_type:
|
|
|
+ self.rollback()
|
|
|
+ self.close()
|
|
|
+
|
|
|
+
|
|
|
+
|