database.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. from sqlalchemy import create_engine
  2. from sqlalchemy.orm import sessionmaker, scoped_session
  3. from sqlalchemy.exc import SQLAlchemyError
  4. from urllib.parse import quote_plus
  5. import configs
  6. from core.config import logger
  7. # 配置数据库连接池
  8. def create_sql_engine(config):
  9. user = config['user']
  10. passwd = quote_plus(config['password'])
  11. host = config['host']
  12. db_name = config['database']
  13. charset = config.get('charset', 'utf8mb4')
  14. # 配置连接池
  15. engine = create_engine(
  16. f'mysql+mysqlconnector://{user}:{passwd}@{host}/{db_name}?charset={charset}',
  17. pool_size=50, # 连接池大小
  18. max_overflow=10, # 超过连接池大小后可以创建的最大连接数
  19. pool_timeout=30, # 获取连接的超时时间,单位为秒
  20. pool_recycle=3600, # 连接最大复用时间,超过这个时间将被关闭并重新创建连接
  21. )
  22. return engine
  23. def create_rag_db_engine():
  24. config = configs.get()['database']['rag']
  25. return create_sql_engine(config)
  26. # 创建数据库引擎
  27. engine = create_rag_db_engine()
  28. # 创建会话
  29. Session = sessionmaker(bind=engine)
  30. class DBHelper:
  31. def __init__(self):
  32. """初始化数据库连接"""
  33. self.session = Session()
  34. def add(self, entity):
  35. """插入实体对象"""
  36. try:
  37. self.session.add(entity)
  38. self.session.commit()
  39. return entity
  40. except SQLAlchemyError as e:
  41. self.session.rollback()
  42. logger.error(f"添加失败: {e}")
  43. raise
  44. def get(self, model, **filters):
  45. """根据过滤条件获取实体对象"""
  46. try:
  47. entity = self.session.query(model).filter_by(**filters).first()
  48. return entity
  49. except SQLAlchemyError as e:
  50. logger.error(f"查询失败: {e}")
  51. raise
  52. def update(self, model, filters, updates):
  53. """更新实体对象"""
  54. try:
  55. entity = self.session.query(model).filter_by(**filters).first()
  56. if entity:
  57. for key, value in updates.items():
  58. setattr(entity, key, value)
  59. self.session.commit()
  60. return
  61. else:
  62. logger.warning(f"未找到符合条件的实体: {filters}")
  63. return None
  64. except SQLAlchemyError as e:
  65. self.session.rollback()
  66. logger.error(f"更新失败: {e}")
  67. raise
  68. def delete(self, model, **filters):
  69. """删除实体对象"""
  70. try:
  71. entity = self.session.query(model).filter_by(**filters).first()
  72. if entity:
  73. self.session.delete(entity)
  74. self.session.commit()
  75. return entity
  76. else:
  77. logger.warning(f"未找到符合条件的实体: {filters}")
  78. return None
  79. except SQLAlchemyError as e:
  80. self.session.rollback()
  81. logger.error(f"删除失败: {e}")
  82. raise
  83. def get_all(self, model, limit=None, **filters):
  84. """获取所有符合条件的实体对象,支持更复杂的查询条件,并可限制最大返回条数"""
  85. try:
  86. query = self.session.query(model)
  87. # 处理特殊条件如 __in
  88. actual_filters = {}
  89. for key, value in filters.items():
  90. if key.endswith('__in'):
  91. # 处理 IN 查询
  92. field_name = key[:-4]
  93. field = getattr(model, field_name)
  94. query = query.filter(field.in_(value))
  95. else:
  96. actual_filters[key] = value
  97. # 应用其他过滤条件
  98. if actual_filters:
  99. query = query.filter_by(**actual_filters)
  100. # 如果传入了 limit 参数,则限制返回的最大条数
  101. if limit is not None:
  102. query = query.limit(limit)
  103. # 执行查询
  104. entities = query.all()
  105. return entities
  106. except SQLAlchemyError as e:
  107. logger.error(f"查询失败: {e}")
  108. raise