Bladeren bron

处理连接数问题

xueyiming 2 maanden geleden
bovenliggende
commit
da07d88dcb
2 gewijzigde bestanden met toevoegingen van 79 en 43 verwijderingen
  1. 77 41
      core/database.py
  2. 2 2
      utils/keywords_utils.py

+ 77 - 41
core/database.py

@@ -36,11 +36,24 @@ engine = create_rag_db_engine()
 # 创建会话
 Session = sessionmaker(bind=engine)
 
+# 使用 scoped_session 来确保每个线程使用独立的 session
+scoped_session_factory = scoped_session(sessionmaker(bind=engine))
+
 
 class DBHelper:
     def __init__(self):
-        """初始化数据库连接"""
-        self.session = Session()
+        """初始化数据库连接,使用 scoped session 管理会话"""
+        self.session = scoped_session_factory()
+
+    def close(self):
+        """显式关闭会话"""
+        self.session.remove()
+
+    def _handle_error(self, error, operation):
+        """处理 SQLAlchemy 错误,回滚事务并记录日志"""
+        self.session.rollback()
+        logger.error(f"{operation}失败: {error}")
+
 
     def add(self, entity):
         """插入实体对象"""
@@ -49,18 +62,45 @@ class DBHelper:
             self.session.commit()
             return entity
         except SQLAlchemyError as e:
-            self.session.rollback()
-            logger.error(f"添加失败: {e}")
-            raise
+            self._handle_error(e, "添加")
 
     def get(self, model, **filters):
-        """根据过滤条件获取实体对象"""
+        """根据过滤条件获取单个实体对象"""
         try:
             entity = self.session.query(model).filter_by(**filters).first()
             return entity
         except SQLAlchemyError as e:
-            logger.error(f"查询失败: {e}")
-            raise
+            self._handle_error(e, "查询")
+
+    def get_all(self, model, limit=None, **filters):
+        """获取所有符合条件的实体对象,支持更复杂的查询条件"""
+        try:
+            query = self.session.query(model)
+
+            # 处理特殊条件如 __in
+            actual_filters = {}
+            for key, value in filters.items():
+                if key.endswith('__in'):
+                    # 处理 IN 查询
+                    field_name = key[:-4]
+                    field = getattr(model, field_name)
+                    query = query.filter(field.in_(value))
+                else:
+                    actual_filters[key] = value
+
+            # 应用其他过滤条件
+            if actual_filters:
+                query = query.filter_by(**actual_filters)
+
+            # 如果传入了 limit 参数,则限制返回的最大条数
+            if limit is not None:
+                query = query.limit(limit)
+
+            # 执行查询
+            entities = query.all()
+            return entities
+        except SQLAlchemyError as e:
+            self._handle_error(e, "查询")
 
     def update(self, model, filters, updates):
         """更新实体对象"""
@@ -70,14 +110,10 @@ class DBHelper:
                 for key, value in updates.items():
                     setattr(entity, key, value)
                 self.session.commit()
-                return
             else:
                 logger.warning(f"未找到符合条件的实体: {filters}")
-                return None
         except SQLAlchemyError as e:
-            self.session.rollback()
-            logger.error(f"更新失败: {e}")
-            raise
+            self._handle_error(e, "更新")
 
     def delete(self, model, **filters):
         """删除实体对象"""
@@ -91,39 +127,39 @@ class DBHelper:
                 logger.warning(f"未找到符合条件的实体: {filters}")
                 return None
         except SQLAlchemyError as e:
-            self.session.rollback()
-            logger.error(f"删除失败: {e}")
-            raise
+            self._handle_error(e, "删除")
 
-    def get_all(self, model, limit=None, **filters):
-        """获取所有符合条件的实体对象,支持更复杂的查询条件,并可限制最大返回条数"""
+    def commit(self):
+        """显式提交事务"""
         try:
-            query = self.session.query(model)
+            self.session.commit()
+        except SQLAlchemyError as e:
+            self._handle_error(e, "提交")
 
-            # 处理特殊条件如 __in
-            actual_filters = {}
-            for key, value in filters.items():
-                if key.endswith('__in'):
-                    # 处理 IN 查询
-                    field_name = key[:-4]
-                    field = getattr(model, field_name)
-                    query = query.filter(field.in_(value))
-                else:
-                    actual_filters[key] = value
+    def rollback(self):
+        """显式回滚事务"""
+        self.session.rollback()
 
-            # 应用其他过滤条件
-            if actual_filters:
-                query = query.filter_by(**actual_filters)
 
-            # 如果传入了 limit 参数,则限制返回的最大条数
-            if limit is not None:
-                query = query.limit(limit)
+# 使用示例:
 
-            # 执行查询
-            entities = query.all()
-            return entities
+# 创建 DBHelper 实例
+db_helper = DBHelper()
 
-        except SQLAlchemyError as e:
-            logger.error(f"查询失败: {e}")
-            raise
+# 添加数据
+# db_helper.add(YourModel(name="example"))
+
+# 查询数据
+# result = db_helper.get(YourModel, id=1)
+
+# 获取所有数据
+# results = db_helper.get_all(YourModel, limit=100, status=1)
+
+# 更新数据
+# db_helper.update(YourModel, {"id": 1}, {"status": 2})
+
+# 删除数据
+# db_helper.delete(YourModel, id=1)
 
+# 关闭会话
+# db_helper.close()

+ 2 - 2
utils/keywords_utils.py

@@ -101,5 +101,5 @@ class KeywordSummaryTask:
 
 
 if __name__ == '__main__':
-    keyword_summary_task = KeywordSummaryTask()
-    keyword_summary_task.process_texts_concurrently()
+    db_helper = DBHelper()
+    print(db_helper.get(KeywordData, keyword='短视频'))