from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.exc import SQLAlchemyError from urllib.parse import quote_plus import configs from core.config import logger # 配置数据库连接池 def create_sql_engine(config): user = config['user'] passwd = quote_plus(config['password']) host = config['host'] db_name = config['database'] charset = config.get('charset', 'utf8mb4') # 配置连接池 engine = create_engine( f'mysql+mysqlconnector://{user}:{passwd}@{host}/{db_name}?charset={charset}', pool_size=30, # 连接池大小 max_overflow=10, # 超过连接池大小后可以创建的最大连接数 pool_timeout=30, # 获取连接的超时时间,单位为秒 pool_recycle=3600, # 连接最大复用时间,超过这个时间将被关闭并重新创建连接 ) return engine def create_rag_db_engine(): config = configs.get()['database']['rag'] return create_sql_engine(config) # 创建数据库引擎 engine = create_rag_db_engine() # 创建会话 Session = sessionmaker(bind=engine) class DBHelper: def __init__(self): """初始化数据库连接""" self.session = Session() def add(self, entity): """插入实体对象""" try: self.session.add(entity) self.session.commit() return entity except SQLAlchemyError as e: self.session.rollback() logger.error(f"添加失败: {e}") raise def get(self, model, **filters): """根据过滤条件获取实体对象""" try: entity = self.session.query(model).filter_by(**filters).first() return entity except SQLAlchemyError as e: logger.error(f"查询失败: {e}") raise def update(self, model, filters, updates): """更新实体对象""" try: entity = self.session.query(model).filter_by(**filters).first() if entity: for key, value in updates.items(): setattr(entity, key, value) self.session.commit() return else: logger.warning(f"未找到符合条件的实体: {filters}") return None except SQLAlchemyError as e: self.session.rollback() logger.error(f"更新失败: {e}") raise def delete(self, model, **filters): """删除实体对象""" try: entity = self.session.query(model).filter_by(**filters).first() if entity: self.session.delete(entity) self.session.commit() return entity else: logger.warning(f"未找到符合条件的实体: {filters}") return None except SQLAlchemyError as e: self.session.rollback() logger.error(f"删除失败: {e}") raise def get_all(self, model, **filters): """获取所有符合条件的实体对象,支持更复杂的查询条件""" try: query = self.session.query(model) # 处理特殊条件如 __in actual_filters = {} for key, value in filters.items(): if key.endswith('__in'): # 处理 IN 查询 field_name = key[:-4] field = getattr(model, field_name) query = query.filter(field.in_(value)) else: actual_filters[key] = value # 应用其他过滤条件 if actual_filters: query = query.filter_by(**actual_filters) entities = query.all() return entities except SQLAlchemyError as e: logger.error(f"查询失败: {e}") raise