from sqlalchemy import create_engine, Column, Integer, String, DateTime from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy.exc import SQLAlchemyError from urllib.parse import quote_plus import configs # 配置日志 from core.config import logger # 创建基础类 Base = declarative_base() def create_sql_engine(config): user = config['user'] passwd = quote_plus(config['password']) host = config['host'] db_name = config['database'] charset = config.get('charset', 'utf8mb4') engine = create_engine(f'mysql+mysqlconnector://{user}:{passwd}@{host}/{db_name}?charset={charset}') return engine def create_rag_db_engine(): config = configs.get()['database']['rag'] return create_sql_engine(config) # 创建数据库引擎 engine = create_rag_db_engine() # 创建会话 Session = sessionmaker(bind=engine) class DBHelper: def __init__(self): """初始化数据库连接""" self.session = Session() def add(self, entity): """插入实体对象""" try: self.session.add(entity) self.session.commit() logger.info(f"添加成功: {entity}") return entity except SQLAlchemyError as e: self.session.rollback() logger.error(f"添加失败: {e}") raise def get(self, model, **filters): """根据过滤条件获取实体对象""" try: entity = self.session.query(model).filter_by(**filters).first() logger.info(f"查询成功: {entity}") return entity except SQLAlchemyError as e: logger.error(f"查询失败: {e}") raise def update(self, model, filters, updates): """更新实体对象""" try: entity = self.session.query(model).filter_by(**filters).first() if entity: for key, value in updates.items(): setattr(entity, key, value) self.session.commit() logger.info(f"更新成功: {entity}") return entity else: logger.warning(f"未找到符合条件的实体: {filters}") return None except SQLAlchemyError as e: self.session.rollback() logger.error(f"更新失败: {e}") raise def delete(self, model, **filters): """删除实体对象""" try: entity = self.session.query(model).filter_by(**filters).first() if entity: self.session.delete(entity) self.session.commit() logger.info(f"删除成功: {entity}") return entity else: logger.warning(f"未找到符合条件的实体: {filters}") return None except SQLAlchemyError as e: self.session.rollback() logger.error(f"删除失败: {e}") raise # def get_all(self, model, **filters): # """获取所有符合条件的实体对象""" # try: # entities = self.session.query(model).filter_by(**filters).all() # logger.info(f"查询成功: {entities}") # return entities # except SQLAlchemyError as e: # logger.error(f"查询失败: {e}") # raise def get_all(self, model, **filters): """获取所有符合条件的实体对象,支持更复杂的查询条件""" try: query = self.session.query(model) # 处理特殊条件如 __in actual_filters = {} for key, value in filters.items(): if key.endswith('__in'): # 处理 IN 查询 field_name = key[:-4] field = getattr(model, field_name) query = query.filter(field.in_(value)) else: actual_filters[key] = value # 应用其他过滤条件 if actual_filters: query = query.filter_by(**actual_filters) entities = query.all() logger.info(f"查询成功: {entities}") return entities except SQLAlchemyError as e: logger.error(f"查询失败: {e}") raise # 创建表 Base.metadata.create_all(engine)