|
@@ -36,11 +36,24 @@ engine = create_rag_db_engine()
|
|
|
# 创建会话
|
|
# 创建会话
|
|
|
Session = sessionmaker(bind=engine)
|
|
Session = sessionmaker(bind=engine)
|
|
|
|
|
|
|
|
|
|
+# 使用 scoped_session 来确保每个线程使用独立的 session
|
|
|
|
|
+scoped_session_factory = scoped_session(sessionmaker(bind=engine))
|
|
|
|
|
+
|
|
|
|
|
|
|
|
class DBHelper:
|
|
class DBHelper:
|
|
|
def __init__(self):
|
|
def __init__(self):
|
|
|
- """初始化数据库连接"""
|
|
|
|
|
- self.session = Session()
|
|
|
|
|
|
|
+ """初始化数据库连接,使用 scoped session 管理会话"""
|
|
|
|
|
+ self.session = scoped_session_factory()
|
|
|
|
|
+
|
|
|
|
|
+ def close(self):
|
|
|
|
|
+ """显式关闭会话"""
|
|
|
|
|
+ self.session.remove()
|
|
|
|
|
+
|
|
|
|
|
+ def _handle_error(self, error, operation):
|
|
|
|
|
+ """处理 SQLAlchemy 错误,回滚事务并记录日志"""
|
|
|
|
|
+ self.session.rollback()
|
|
|
|
|
+ logger.error(f"{operation}失败: {error}")
|
|
|
|
|
+
|
|
|
|
|
|
|
|
def add(self, entity):
|
|
def add(self, entity):
|
|
|
"""插入实体对象"""
|
|
"""插入实体对象"""
|
|
@@ -49,18 +62,45 @@ class DBHelper:
|
|
|
self.session.commit()
|
|
self.session.commit()
|
|
|
return entity
|
|
return entity
|
|
|
except SQLAlchemyError as e:
|
|
except SQLAlchemyError as e:
|
|
|
- self.session.rollback()
|
|
|
|
|
- logger.error(f"添加失败: {e}")
|
|
|
|
|
- raise
|
|
|
|
|
|
|
+ self._handle_error(e, "添加")
|
|
|
|
|
|
|
|
def get(self, model, **filters):
|
|
def get(self, model, **filters):
|
|
|
- """根据过滤条件获取实体对象"""
|
|
|
|
|
|
|
+ """根据过滤条件获取单个实体对象"""
|
|
|
try:
|
|
try:
|
|
|
entity = self.session.query(model).filter_by(**filters).first()
|
|
entity = self.session.query(model).filter_by(**filters).first()
|
|
|
return entity
|
|
return entity
|
|
|
except SQLAlchemyError as e:
|
|
except SQLAlchemyError as e:
|
|
|
- logger.error(f"查询失败: {e}")
|
|
|
|
|
- raise
|
|
|
|
|
|
|
+ self._handle_error(e, "查询")
|
|
|
|
|
+
|
|
|
|
|
+ def get_all(self, model, limit=None, **filters):
|
|
|
|
|
+ """获取所有符合条件的实体对象,支持更复杂的查询条件"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ query = self.session.query(model)
|
|
|
|
|
+
|
|
|
|
|
+ # 处理特殊条件如 __in
|
|
|
|
|
+ actual_filters = {}
|
|
|
|
|
+ for key, value in filters.items():
|
|
|
|
|
+ if key.endswith('__in'):
|
|
|
|
|
+ # 处理 IN 查询
|
|
|
|
|
+ field_name = key[:-4]
|
|
|
|
|
+ field = getattr(model, field_name)
|
|
|
|
|
+ query = query.filter(field.in_(value))
|
|
|
|
|
+ else:
|
|
|
|
|
+ actual_filters[key] = value
|
|
|
|
|
+
|
|
|
|
|
+ # 应用其他过滤条件
|
|
|
|
|
+ if actual_filters:
|
|
|
|
|
+ query = query.filter_by(**actual_filters)
|
|
|
|
|
+
|
|
|
|
|
+ # 如果传入了 limit 参数,则限制返回的最大条数
|
|
|
|
|
+ if limit is not None:
|
|
|
|
|
+ query = query.limit(limit)
|
|
|
|
|
+
|
|
|
|
|
+ # 执行查询
|
|
|
|
|
+ entities = query.all()
|
|
|
|
|
+ return entities
|
|
|
|
|
+ except SQLAlchemyError as e:
|
|
|
|
|
+ self._handle_error(e, "查询")
|
|
|
|
|
|
|
|
def update(self, model, filters, updates):
|
|
def update(self, model, filters, updates):
|
|
|
"""更新实体对象"""
|
|
"""更新实体对象"""
|
|
@@ -70,14 +110,10 @@ class DBHelper:
|
|
|
for key, value in updates.items():
|
|
for key, value in updates.items():
|
|
|
setattr(entity, key, value)
|
|
setattr(entity, key, value)
|
|
|
self.session.commit()
|
|
self.session.commit()
|
|
|
- return
|
|
|
|
|
else:
|
|
else:
|
|
|
logger.warning(f"未找到符合条件的实体: {filters}")
|
|
logger.warning(f"未找到符合条件的实体: {filters}")
|
|
|
- return None
|
|
|
|
|
except SQLAlchemyError as e:
|
|
except SQLAlchemyError as e:
|
|
|
- self.session.rollback()
|
|
|
|
|
- logger.error(f"更新失败: {e}")
|
|
|
|
|
- raise
|
|
|
|
|
|
|
+ self._handle_error(e, "更新")
|
|
|
|
|
|
|
|
def delete(self, model, **filters):
|
|
def delete(self, model, **filters):
|
|
|
"""删除实体对象"""
|
|
"""删除实体对象"""
|
|
@@ -91,39 +127,39 @@ class DBHelper:
|
|
|
logger.warning(f"未找到符合条件的实体: {filters}")
|
|
logger.warning(f"未找到符合条件的实体: {filters}")
|
|
|
return None
|
|
return None
|
|
|
except SQLAlchemyError as e:
|
|
except SQLAlchemyError as e:
|
|
|
- self.session.rollback()
|
|
|
|
|
- logger.error(f"删除失败: {e}")
|
|
|
|
|
- raise
|
|
|
|
|
|
|
+ self._handle_error(e, "删除")
|
|
|
|
|
|
|
|
- def get_all(self, model, limit=None, **filters):
|
|
|
|
|
- """获取所有符合条件的实体对象,支持更复杂的查询条件,并可限制最大返回条数"""
|
|
|
|
|
|
|
+ def commit(self):
|
|
|
|
|
+ """显式提交事务"""
|
|
|
try:
|
|
try:
|
|
|
- query = self.session.query(model)
|
|
|
|
|
|
|
+ self.session.commit()
|
|
|
|
|
+ except SQLAlchemyError as e:
|
|
|
|
|
+ self._handle_error(e, "提交")
|
|
|
|
|
|
|
|
- # 处理特殊条件如 __in
|
|
|
|
|
- actual_filters = {}
|
|
|
|
|
- for key, value in filters.items():
|
|
|
|
|
- if key.endswith('__in'):
|
|
|
|
|
- # 处理 IN 查询
|
|
|
|
|
- field_name = key[:-4]
|
|
|
|
|
- field = getattr(model, field_name)
|
|
|
|
|
- query = query.filter(field.in_(value))
|
|
|
|
|
- else:
|
|
|
|
|
- actual_filters[key] = value
|
|
|
|
|
|
|
+ def rollback(self):
|
|
|
|
|
+ """显式回滚事务"""
|
|
|
|
|
+ self.session.rollback()
|
|
|
|
|
|
|
|
- # 应用其他过滤条件
|
|
|
|
|
- if actual_filters:
|
|
|
|
|
- query = query.filter_by(**actual_filters)
|
|
|
|
|
|
|
|
|
|
- # 如果传入了 limit 参数,则限制返回的最大条数
|
|
|
|
|
- if limit is not None:
|
|
|
|
|
- query = query.limit(limit)
|
|
|
|
|
|
|
+# 使用示例:
|
|
|
|
|
|
|
|
- # 执行查询
|
|
|
|
|
- entities = query.all()
|
|
|
|
|
- return entities
|
|
|
|
|
|
|
+# 创建 DBHelper 实例
|
|
|
|
|
+db_helper = DBHelper()
|
|
|
|
|
|
|
|
- except SQLAlchemyError as e:
|
|
|
|
|
- logger.error(f"查询失败: {e}")
|
|
|
|
|
- raise
|
|
|
|
|
|
|
+# 添加数据
|
|
|
|
|
+# db_helper.add(YourModel(name="example"))
|
|
|
|
|
+
|
|
|
|
|
+# 查询数据
|
|
|
|
|
+# result = db_helper.get(YourModel, id=1)
|
|
|
|
|
+
|
|
|
|
|
+# 获取所有数据
|
|
|
|
|
+# results = db_helper.get_all(YourModel, limit=100, status=1)
|
|
|
|
|
+
|
|
|
|
|
+# 更新数据
|
|
|
|
|
+# db_helper.update(YourModel, {"id": 1}, {"status": 2})
|
|
|
|
|
+
|
|
|
|
|
+# 删除数据
|
|
|
|
|
+# db_helper.delete(YourModel, id=1)
|
|
|
|
|
|
|
|
|
|
+# 关闭会话
|
|
|
|
|
+# db_helper.close()
|