database.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. from sqlalchemy import create_engine
  2. from sqlalchemy.orm import sessionmaker, scoped_session
  3. from sqlalchemy.exc import SQLAlchemyError
  4. from urllib.parse import quote_plus
  5. import configs
  6. from core.config import logger
  7. # 配置数据库连接池
  8. def create_sql_engine(config):
  9. user = config['user']
  10. passwd = quote_plus(config['password'])
  11. host = config['host']
  12. db_name = config['database']
  13. charset = config.get('charset', 'utf8mb4')
  14. # 配置连接池
  15. engine = create_engine(
  16. f'mysql+mysqlconnector://{user}:{passwd}@{host}/{db_name}?charset={charset}',
  17. pool_size=50, # 连接池大小
  18. max_overflow=10, # 超过连接池大小后可以创建的最大连接数
  19. pool_timeout=30, # 获取连接的超时时间,单位为秒
  20. pool_recycle=3600, # 连接最大复用时间,超过这个时间将被关闭并重新创建连接
  21. )
  22. return engine
  23. def create_rag_db_engine():
  24. config = configs.get()['database']['rag']
  25. return create_sql_engine(config)
  26. # 创建数据库引擎
  27. engine = create_rag_db_engine()
  28. # 创建会话
  29. Session = sessionmaker(bind=engine)
  30. # 使用 scoped_session 来确保每个线程使用独立的 session
  31. scoped_session_factory = scoped_session(sessionmaker(bind=engine))
  32. class DBHelper:
  33. def __init__(self):
  34. """初始化数据库连接,使用 scoped session 管理会话"""
  35. self.session = scoped_session_factory()
  36. def close(self):
  37. """显式关闭会话"""
  38. self.session.remove()
  39. def _handle_error(self, error, operation):
  40. """处理 SQLAlchemy 错误,回滚事务并记录日志"""
  41. self.session.rollback()
  42. logger.error(f"{operation}失败: {error}")
  43. def add(self, entity):
  44. """插入实体对象"""
  45. try:
  46. self.session.add(entity)
  47. self.session.commit()
  48. return entity
  49. except SQLAlchemyError as e:
  50. self._handle_error(e, "添加")
  51. def get(self, model, **filters):
  52. """根据过滤条件获取单个实体对象"""
  53. try:
  54. entity = self.session.query(model).filter_by(**filters).first()
  55. return entity
  56. except SQLAlchemyError as e:
  57. self._handle_error(e, "查询")
  58. def get_all(self, model, limit=None, order_by=None, **filters):
  59. """获取所有符合条件的实体对象,支持更复杂的查询条件"""
  60. try:
  61. query = self.session.query(model)
  62. # 处理特殊条件如 __in
  63. actual_filters = {}
  64. for key, value in filters.items():
  65. if key.endswith('__in'):
  66. # 处理 IN 查询
  67. field_name = key[:-4]
  68. field = getattr(model, field_name)
  69. query = query.filter(field.in_(value))
  70. else:
  71. actual_filters[key] = value
  72. # 应用其他过滤条件
  73. if actual_filters:
  74. query = query.filter_by(**actual_filters)
  75. # 添加排序条件
  76. if order_by:
  77. # order_by 是一个字典,形如 {'field_name': 'asc' 或 'desc'}
  78. for field_name, direction in order_by.items():
  79. field = getattr(model, field_name)
  80. if direction == 'desc':
  81. query = query.order_by(field.desc())
  82. else:
  83. query = query.order_by(field.asc())
  84. # 如果传入了 limit 参数,则限制返回的最大条数
  85. if limit is not None:
  86. query = query.limit(limit)
  87. # 执行查询
  88. entities = query.all()
  89. return entities
  90. except SQLAlchemyError as e:
  91. self._handle_error(e, "查询")
  92. def get_paginated(self, model, page=1, page_size=10, order_by=None, **filters):
  93. """分页查询符合条件的实体对象,支持排序"""
  94. try:
  95. query = self.session.query(model)
  96. # 处理特殊条件如 __in
  97. actual_filters = {}
  98. for key, value in filters.items():
  99. if key.endswith('__in'):
  100. # 处理 IN 查询
  101. field_name = key[:-4]
  102. field = getattr(model, field_name)
  103. query = query.filter(field.in_(value))
  104. else:
  105. actual_filters[key] = value
  106. # 应用其他过滤条件
  107. if actual_filters:
  108. query = query.filter_by(**actual_filters)
  109. # 添加排序条件
  110. if order_by:
  111. # order_by 是一个字典,形如 {'field_name': 'asc' 或 'desc'}
  112. for field_name, direction in order_by.items():
  113. field = getattr(model, field_name)
  114. if direction == 'desc':
  115. query = query.order_by(field.desc())
  116. else:
  117. query = query.order_by(field.asc())
  118. # 计算总记录数
  119. total_count = query.count()
  120. # 分页查询,计算偏移量
  121. offset = (page - 1) * page_size
  122. query = query.offset(offset).limit(page_size)
  123. # 执行查询
  124. entities = query.all()
  125. # 返回分页结果:当前页数据和总记录数
  126. return {
  127. "entities": entities,
  128. "total_count": total_count,
  129. "page": page,
  130. "page_size": page_size,
  131. "total_pages": (total_count + page_size - 1) // page_size # 向上取整计算总页数
  132. }
  133. except SQLAlchemyError as e:
  134. self._handle_error(e, "查询")
  135. def count(self, model, **filters):
  136. """查询符合条件的记录条数"""
  137. try:
  138. query = self.session.query(model)
  139. # 处理特殊条件如 __in
  140. actual_filters = {}
  141. for key, value in filters.items():
  142. if key.endswith('__in'):
  143. # 处理 IN 查询
  144. field_name = key[:-4]
  145. field = getattr(model, field_name)
  146. query = query.filter(field.in_(value))
  147. else:
  148. actual_filters[key] = value
  149. # 应用其他过滤条件
  150. if actual_filters:
  151. query = query.filter_by(**actual_filters)
  152. # 执行查询并获取总记录数
  153. count = query.count()
  154. return count
  155. except SQLAlchemyError as e:
  156. self._handle_error(e, "查询条数")
  157. def update(self, model, filters, updates):
  158. """更新实体对象"""
  159. try:
  160. entity = self.session.query(model).filter_by(**filters).first()
  161. if entity:
  162. for key, value in updates.items():
  163. setattr(entity, key, value)
  164. self.session.commit()
  165. else:
  166. logger.warning(f"未找到符合条件的实体: {filters}")
  167. except SQLAlchemyError as e:
  168. self._handle_error(e, "更新")
  169. def delete(self, model, **filters):
  170. """删除实体对象"""
  171. try:
  172. entity = self.session.query(model).filter_by(**filters).first()
  173. if entity:
  174. self.session.delete(entity)
  175. self.session.commit()
  176. return entity
  177. else:
  178. logger.warning(f"未找到符合条件的实体: {filters}")
  179. return None
  180. except SQLAlchemyError as e:
  181. self._handle_error(e, "删除")
  182. def commit(self):
  183. """显式提交事务"""
  184. try:
  185. self.session.commit()
  186. except SQLAlchemyError as e:
  187. self._handle_error(e, "提交")
  188. def rollback(self):
  189. """显式回滚事务"""
  190. self.session.rollback()
  191. # 使用示例:
  192. # 创建 DBHelper 实例
  193. db_helper = DBHelper()
  194. # 添加数据
  195. # db_helper.add(YourModel(name="example"))
  196. # 查询数据
  197. # result = db_helper.get(YourModel, id=1)
  198. # 获取所有数据
  199. # results = db_helper.get_all(YourModel, limit=100, status=1)
  200. # 更新数据
  201. # db_helper.update(YourModel, {"id": 1}, {"status": 2})
  202. # 删除数据
  203. # db_helper.delete(YourModel, id=1)
  204. # 关闭会话
  205. # db_helper.close()