from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.exc import SQLAlchemyError from urllib.parse import quote_plus import configs from core.config import logger # 配置数据库连接池 def create_sql_engine(config): user = config['user'] passwd = quote_plus(config['password']) host = config['host'] db_name = config['database'] charset = config.get('charset', 'utf8mb4') # 配置连接池 engine = create_engine( f'mysql+mysqlconnector://{user}:{passwd}@{host}/{db_name}?charset={charset}', pool_size=50, # 连接池大小 max_overflow=10, # 超过连接池大小后可以创建的最大连接数 pool_timeout=30, # 获取连接的超时时间,单位为秒 pool_recycle=3600, # 连接最大复用时间,超过这个时间将被关闭并重新创建连接 ) return engine def create_rag_db_engine(): config = configs.get()['database']['rag'] return create_sql_engine(config) # 创建数据库引擎 engine = create_rag_db_engine() # 创建会话 Session = sessionmaker(bind=engine) # 使用 scoped_session 来确保每个线程使用独立的 session scoped_session_factory = scoped_session(sessionmaker(bind=engine)) class DBHelper: def __init__(self): """初始化数据库连接,使用 scoped session 管理会话""" self.session = scoped_session_factory() def close(self): """显式关闭会话""" self.session.remove() def _handle_error(self, error, operation): """处理 SQLAlchemy 错误,回滚事务并记录日志""" self.session.rollback() logger.error(f"{operation}失败: {error}") def add(self, entity): """插入实体对象""" try: self.session.add(entity) self.session.commit() return entity except SQLAlchemyError as e: self._handle_error(e, "添加") def get(self, model, **filters): """根据过滤条件获取单个实体对象""" try: entity = self.session.query(model).filter_by(**filters).first() return entity except SQLAlchemyError as e: self._handle_error(e, "查询") def get_all(self, model, limit=None, **filters): """获取所有符合条件的实体对象,支持更复杂的查询条件""" try: query = self.session.query(model) # 处理特殊条件如 __in actual_filters = {} for key, value in filters.items(): if key.endswith('__in'): # 处理 IN 查询 field_name = key[:-4] field = getattr(model, field_name) query = query.filter(field.in_(value)) else: actual_filters[key] = value # 应用其他过滤条件 if actual_filters: query = query.filter_by(**actual_filters) # 如果传入了 limit 参数,则限制返回的最大条数 if limit is not None: query = query.limit(limit) # 执行查询 entities = query.all() return entities except SQLAlchemyError as e: self._handle_error(e, "查询") def update(self, model, filters, updates): """更新实体对象""" try: entity = self.session.query(model).filter_by(**filters).first() if entity: for key, value in updates.items(): setattr(entity, key, value) self.session.commit() else: logger.warning(f"未找到符合条件的实体: {filters}") except SQLAlchemyError as e: self._handle_error(e, "更新") def delete(self, model, **filters): """删除实体对象""" try: entity = self.session.query(model).filter_by(**filters).first() if entity: self.session.delete(entity) self.session.commit() return entity else: logger.warning(f"未找到符合条件的实体: {filters}") return None except SQLAlchemyError as e: self._handle_error(e, "删除") def commit(self): """显式提交事务""" try: self.session.commit() except SQLAlchemyError as e: self._handle_error(e, "提交") def rollback(self): """显式回滚事务""" self.session.rollback() # 使用示例: # 创建 DBHelper 实例 db_helper = DBHelper() # 添加数据 # db_helper.add(YourModel(name="example")) # 查询数据 # result = db_helper.get(YourModel, id=1) # 获取所有数据 # results = db_helper.get_all(YourModel, limit=100, status=1) # 更新数据 # db_helper.update(YourModel, {"id": 1}, {"status": 2}) # 删除数据 # db_helper.delete(YourModel, id=1) # 关闭会话 # db_helper.close()