Kaynağa Gözat

Merge branch 'dev-xym-update' of algorithm/RAG into master

xueyiming 2 ay önce
ebeveyn
işleme
e2889d91b9
2 değiştirilmiş dosya ile 85 ekleme ve 41 silme
  1. 82 38
      core/database.py
  2. 3 3
      utils/keywords_utils.py

+ 82 - 38
core/database.py

@@ -17,28 +17,43 @@ def create_sql_engine(config):
     # 配置连接池
     engine = create_engine(
         f'mysql+mysqlconnector://{user}:{passwd}@{host}/{db_name}?charset={charset}',
-        pool_size=30,  # 连接池大小
+        pool_size=50,  # 连接池大小
         max_overflow=10,  # 超过连接池大小后可以创建的最大连接数
         pool_timeout=30,  # 获取连接的超时时间,单位为秒
         pool_recycle=3600,  # 连接最大复用时间,超过这个时间将被关闭并重新创建连接
     )
     return engine
 
+
 def create_rag_db_engine():
     config = configs.get()['database']['rag']
     return create_sql_engine(config)
 
+
 # 创建数据库引擎
 engine = create_rag_db_engine()
 
 # 创建会话
 Session = sessionmaker(bind=engine)
 
+# 使用 scoped_session 来确保每个线程使用独立的 session
+scoped_session_factory = scoped_session(sessionmaker(bind=engine))
+
 
 class DBHelper:
     def __init__(self):
-        """初始化数据库连接"""
-        self.session = Session()
+        """初始化数据库连接,使用 scoped session 管理会话"""
+        self.session = scoped_session_factory()
+
+    def close(self):
+        """显式关闭会话"""
+        self.session.remove()
+
+    def _handle_error(self, error, operation):
+        """处理 SQLAlchemy 错误,回滚事务并记录日志"""
+        self.session.rollback()
+        logger.error(f"{operation}失败: {error}")
+
 
     def add(self, entity):
         """插入实体对象"""
@@ -47,18 +62,45 @@ class DBHelper:
             self.session.commit()
             return entity
         except SQLAlchemyError as e:
-            self.session.rollback()
-            logger.error(f"添加失败: {e}")
-            raise
+            self._handle_error(e, "添加")
 
     def get(self, model, **filters):
-        """根据过滤条件获取实体对象"""
+        """根据过滤条件获取单个实体对象"""
         try:
             entity = self.session.query(model).filter_by(**filters).first()
             return entity
         except SQLAlchemyError as e:
-            logger.error(f"查询失败: {e}")
-            raise
+            self._handle_error(e, "查询")
+
+    def get_all(self, model, limit=None, **filters):
+        """获取所有符合条件的实体对象,支持更复杂的查询条件"""
+        try:
+            query = self.session.query(model)
+
+            # 处理特殊条件如 __in
+            actual_filters = {}
+            for key, value in filters.items():
+                if key.endswith('__in'):
+                    # 处理 IN 查询
+                    field_name = key[:-4]
+                    field = getattr(model, field_name)
+                    query = query.filter(field.in_(value))
+                else:
+                    actual_filters[key] = value
+
+            # 应用其他过滤条件
+            if actual_filters:
+                query = query.filter_by(**actual_filters)
+
+            # 如果传入了 limit 参数,则限制返回的最大条数
+            if limit is not None:
+                query = query.limit(limit)
+
+            # 执行查询
+            entities = query.all()
+            return entities
+        except SQLAlchemyError as e:
+            self._handle_error(e, "查询")
 
     def update(self, model, filters, updates):
         """更新实体对象"""
@@ -68,14 +110,10 @@ class DBHelper:
                 for key, value in updates.items():
                     setattr(entity, key, value)
                 self.session.commit()
-                return
             else:
                 logger.warning(f"未找到符合条件的实体: {filters}")
-                return None
         except SQLAlchemyError as e:
-            self.session.rollback()
-            logger.error(f"更新失败: {e}")
-            raise
+            self._handle_error(e, "更新")
 
     def delete(self, model, **filters):
         """删除实体对象"""
@@ -89,33 +127,39 @@ class DBHelper:
                 logger.warning(f"未找到符合条件的实体: {filters}")
                 return None
         except SQLAlchemyError as e:
-            self.session.rollback()
-            logger.error(f"删除失败: {e}")
-            raise
+            self._handle_error(e, "删除")
 
-    def get_all(self, model, **filters):
-        """获取所有符合条件的实体对象,支持更复杂的查询条件"""
+    def commit(self):
+        """显式提交事务"""
         try:
-            query = self.session.query(model)
+            self.session.commit()
+        except SQLAlchemyError as e:
+            self._handle_error(e, "提交")
 
-            # 处理特殊条件如 __in
-            actual_filters = {}
-            for key, value in filters.items():
-                if key.endswith('__in'):
-                    # 处理 IN 查询
-                    field_name = key[:-4]
-                    field = getattr(model, field_name)
-                    query = query.filter(field.in_(value))
-                else:
-                    actual_filters[key] = value
+    def rollback(self):
+        """显式回滚事务"""
+        self.session.rollback()
 
-            # 应用其他过滤条件
-            if actual_filters:
-                query = query.filter_by(**actual_filters)
 
-            entities = query.all()
-            return entities
-        except SQLAlchemyError as e:
-            logger.error(f"查询失败: {e}")
-            raise
+# 使用示例:
+
+# 创建 DBHelper 实例
+db_helper = DBHelper()
+
+# 添加数据
+# db_helper.add(YourModel(name="example"))
+
+# 查询数据
+# result = db_helper.get(YourModel, id=1)
+
+# 获取所有数据
+# results = db_helper.get_all(YourModel, limit=100, status=1)
+
+# 更新数据
+# db_helper.update(YourModel, {"id": 1}, {"status": 2})
+
+# 删除数据
+# db_helper.delete(YourModel, id=1)
 
+# 关闭会话
+# db_helper.close()

+ 3 - 3
utils/keywords_utils.py

@@ -76,7 +76,7 @@ class KeywordSummaryTask:
         print('process_texts_concurrently start')
         db_helper = DBHelper()
         while True:
-            content_chunks = db_helper.get_all(ContentChunks, chunk_status=2, keywords_status=0)
+            content_chunks = db_helper.get_all(ContentChunks, limit=200, chunk_status=2, keywords_status=0)
             if len(content_chunks) == 0:
                 logger.info('sleep')
                 print('sleep')
@@ -101,5 +101,5 @@ class KeywordSummaryTask:
 
 
 if __name__ == '__main__':
-    keyword_summary_task = KeywordSummaryTask()
-    keyword_summary_task.process_texts_concurrently()
+    db_helper = DBHelper()
+    print(db_helper.get(KeywordData, keyword='短视频'))