ソースを参照

Fix embedding_manager: isolate model in multi-process environment

StrayWarrior 5 ヶ月 前
コミット
ab3a66fd93
2 ファイル変更10 行追加5 行削除
  1. 2 2
      alg_app.py
  2. 8 3
      applications/embedding_manager.py

+ 2 - 2
alg_app.py

@@ -9,7 +9,7 @@ from applications.embedding_manager import EmbeddingManager
 
 app = Quart(__name__)
 AsyncMySQL = AsyncMySQLClient(app)
-
+embedding_manager = EmbeddingManager()
 
 @app.before_serving
 async def init():
@@ -18,8 +18,8 @@ async def init():
     """
     await AsyncMySQL.init_pool()
     model = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5")
-    embedding_manager = EmbeddingManager(model)
     print("模型加载成功")
+    embedding_manager.set_model(model)
     app_routes = AlgRoutes(AsyncMySQL, model, embedding_manager)
     app.register_blueprint(app_routes)
 

+ 8 - 3
applications/embedding_manager.py

@@ -6,8 +6,8 @@ import numpy as np
 
 
 class EmbeddingManager:
-    def __init__(self, model, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
-        self.model = model
+    def __init__(self, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
+        self.model = None
         self.emb_size = emb_size
         self.cache_file = cache_file
         self.cache_key_file = f'{self.cache_file}.keys'
@@ -30,6 +30,10 @@ class EmbeddingManager:
         self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
         self.saving_thread.start()
 
+    def set_model(self, model):
+        """Since each process has its own model, model must be set after process is forked"""
+        self.model = model
+
     def get_embedding(self, text_list):
         """
         Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
@@ -105,7 +109,8 @@ class DummyModel:
 
 if __name__ == "__main__":
     model = DummyModel()
-    manager = EmbeddingManager(model, 1024)
+    manager = EmbeddingManager()
+    manager.set_model(model)
     print(manager.get_embedding(["hello"]))
     print(manager.get_embedding(["world"]))
     print(manager.get_embedding(["hello world"]))