database.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from sqlalchemy import create_engine
  2. from sqlalchemy.orm import sessionmaker, scoped_session
  3. from sqlalchemy.exc import SQLAlchemyError
  4. from urllib.parse import quote_plus
  5. import configs
  6. from core.config import logger
  7. # 配置数据库连接池
  8. def create_sql_engine(config):
  9. user = config['user']
  10. passwd = quote_plus(config['password'])
  11. host = config['host']
  12. db_name = config['database']
  13. charset = config.get('charset', 'utf8mb4')
  14. # 配置连接池
  15. engine = create_engine(
  16. f'mysql+mysqlconnector://{user}:{passwd}@{host}/{db_name}?charset={charset}',
  17. pool_size=50, # 连接池大小
  18. max_overflow=10, # 超过连接池大小后可以创建的最大连接数
  19. pool_timeout=30, # 获取连接的超时时间,单位为秒
  20. pool_recycle=3600, # 连接最大复用时间,超过这个时间将被关闭并重新创建连接
  21. )
  22. return engine
  23. def create_rag_db_engine():
  24. config = configs.get()['database']['rag']
  25. return create_sql_engine(config)
  26. # 创建数据库引擎
  27. engine = create_rag_db_engine()
  28. # 创建会话
  29. Session = sessionmaker(bind=engine)
  30. # 使用 scoped_session 来确保每个线程使用独立的 session
  31. scoped_session_factory = scoped_session(sessionmaker(bind=engine))
  32. class DBHelper:
  33. def __init__(self):
  34. """初始化数据库连接,使用 scoped session 管理会话"""
  35. self.session = scoped_session_factory()
  36. def close(self):
  37. """显式关闭会话"""
  38. self.session.remove()
  39. def _handle_error(self, error, operation):
  40. """处理 SQLAlchemy 错误,回滚事务并记录日志"""
  41. self.session.rollback()
  42. logger.error(f"{operation}失败: {error}")
  43. def add(self, entity):
  44. """插入实体对象"""
  45. try:
  46. self.session.add(entity)
  47. self.session.commit()
  48. return entity
  49. except SQLAlchemyError as e:
  50. self._handle_error(e, "添加")
  51. def get(self, model, **filters):
  52. """根据过滤条件获取单个实体对象"""
  53. try:
  54. entity = self.session.query(model).filter_by(**filters).first()
  55. return entity
  56. except SQLAlchemyError as e:
  57. self._handle_error(e, "查询")
  58. def get_all(self, model, limit=None, **filters):
  59. """获取所有符合条件的实体对象,支持更复杂的查询条件"""
  60. try:
  61. query = self.session.query(model)
  62. # 处理特殊条件如 __in
  63. actual_filters = {}
  64. for key, value in filters.items():
  65. if key.endswith('__in'):
  66. # 处理 IN 查询
  67. field_name = key[:-4]
  68. field = getattr(model, field_name)
  69. query = query.filter(field.in_(value))
  70. else:
  71. actual_filters[key] = value
  72. # 应用其他过滤条件
  73. if actual_filters:
  74. query = query.filter_by(**actual_filters)
  75. # 如果传入了 limit 参数,则限制返回的最大条数
  76. if limit is not None:
  77. query = query.limit(limit)
  78. # 执行查询
  79. entities = query.all()
  80. return entities
  81. except SQLAlchemyError as e:
  82. self._handle_error(e, "查询")
  83. def update(self, model, filters, updates):
  84. """更新实体对象"""
  85. try:
  86. entity = self.session.query(model).filter_by(**filters).first()
  87. if entity:
  88. for key, value in updates.items():
  89. setattr(entity, key, value)
  90. self.session.commit()
  91. else:
  92. logger.warning(f"未找到符合条件的实体: {filters}")
  93. except SQLAlchemyError as e:
  94. self._handle_error(e, "更新")
  95. def delete(self, model, **filters):
  96. """删除实体对象"""
  97. try:
  98. entity = self.session.query(model).filter_by(**filters).first()
  99. if entity:
  100. self.session.delete(entity)
  101. self.session.commit()
  102. return entity
  103. else:
  104. logger.warning(f"未找到符合条件的实体: {filters}")
  105. return None
  106. except SQLAlchemyError as e:
  107. self._handle_error(e, "删除")
  108. def commit(self):
  109. """显式提交事务"""
  110. try:
  111. self.session.commit()
  112. except SQLAlchemyError as e:
  113. self._handle_error(e, "提交")
  114. def rollback(self):
  115. """显式回滚事务"""
  116. self.session.rollback()
  117. # 使用示例:
  118. # 创建 DBHelper 实例
  119. db_helper = DBHelper()
  120. # 添加数据
  121. # db_helper.add(YourModel(name="example"))
  122. # 查询数据
  123. # result = db_helper.get(YourModel, id=1)
  124. # 获取所有数据
  125. # results = db_helper.get_all(YourModel, limit=100, status=1)
  126. # 更新数据
  127. # db_helper.update(YourModel, {"id": 1}, {"status": 2})
  128. # 删除数据
  129. # db_helper.delete(YourModel, id=1)
  130. # 关闭会话
  131. # db_helper.close()