Browse Source

Update embedding_manager: fit hypercorn process model

StrayWarrior 5 months ago
parent
commit
f39e4d06cb
6 changed files with 74 additions and 60 deletions
  1. 1 2
      alg_app.py
  2. 61 51
      applications/embedding_manager.py
  3. 9 2
      applications/textSimilarity.py
  4. 2 1
      requirements.txt
  5. 1 2
      routes/__init__.py
  6. 0 2
      routes/nlpServer.py

+ 1 - 2
alg_app.py

@@ -9,7 +9,6 @@ from applications.embedding_manager import EmbeddingManager
 
 app = Quart(__name__)
 AsyncMySQL = AsyncMySQLClient(app)
-embedding_manager = EmbeddingManager()
 
 @app.before_serving
 async def init():
@@ -18,8 +17,8 @@ async def init():
     """
     await AsyncMySQL.init_pool()
     model = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5")
+    embedding_manager = EmbeddingManager(model)
     print("模型加载成功")
-    embedding_manager.set_model(model)
     app_routes = AlgRoutes(AsyncMySQL, model, embedding_manager)
     app.register_blueprint(app_routes)
 

+ 61 - 51
applications/embedding_manager.py

@@ -1,38 +1,69 @@
 import os
 import threading
-import multiprocessing
+from filelock import FileLock
 from time import sleep
 import numpy as np
+import random
 
 
 class EmbeddingManager:
-    def __init__(self, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
-        self.model = None
+    def __init__(self, model, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
+        self.model = model
         self.emb_size = emb_size
         self.cache_file = cache_file
+        self.cache_file_real = self.cache_file + ".npy"
         self.cache_key_file = f'{self.cache_file}.keys'
-        self.save_interval = save_interval
-        self.cache = multiprocessing.Manager().dict()  # Shared dictionary for multiprocess use
+        # avoid multiple process read and write at same time and wait for filelock
+        self.save_interval = save_interval + random.randint(0, save_interval)
+        self.cache = {}
         self.lock = threading.Lock()  # Thread-safe lock
+        self.filelock = FileLock(self.cache_file + ".lock")
 
-        npy_filename = f'{self.cache_file}.npy'
-        # Load cache from file if it exists
-        if os.path.exists(npy_filename):
-            embedding_data = np.load(npy_filename)
-            embedding_keys = open(self.cache_key_file, "r").readlines()
-            embedding_keys = [key.strip("\n") for key in embedding_keys]
-            for idx, key in enumerate(embedding_keys):
-                self.cache[key] = embedding_data[idx]
-        print("cache loaded:")
-        print(self.cache)
+        self.load_cache()
 
         # Start the periodic saving thread
         self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
         self.saving_thread.start()
 
-    def set_model(self, model):
-        """Since each process has its own model, model must be set after process is forked"""
-        self.model = model
+
+    def _load_cache_unsafe(self):
+        """inter-thread and inter-process safety must be guaranteed by caller"""
+        embedding_data = np.load(self.cache_file_real)
+        embedding_keys = open(self.cache_key_file, "r").readlines()
+        embedding_keys = [key.strip("\n") for key in embedding_keys]
+        for idx, key in enumerate(embedding_keys):
+            self.cache[key] = embedding_data[idx]
+
+    def load_cache(self):
+        with self.lock:
+            if os.path.exists(self.cache_file_real):
+                with self.filelock:
+                    self._load_cache_unsafe()
+            print("[EmbeddingManager]cache loaded")
+
+    def dump_cache(self):
+        if os.path.dirname(self.cache_file):
+            os.makedirs(os.path.dirname(self.cache_file), 0o755, True)
+        tmp_cache_file = self.cache_file + ".tmp"
+        tmp_cache_key_file = self.cache_key_file + ".tmp"
+        with self.lock:  # Ensure thread-safe access firstly
+            with self.filelock: # Ensure inter-process safety secondly
+                if os.path.exists(self.cache_file_real):
+                    self._load_cache_unsafe()
+                keys = self.cache.keys()
+                cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
+                for idx, key in enumerate(keys):
+                    cache_to_save[idx] = self.cache[key]
+                np.save(tmp_cache_file, cache_to_save)
+                with open(tmp_cache_key_file, 'w') as fp:
+                    fp.write('\n'.join(keys))
+                if os.path.exists(self.cache_file + ".npy"):
+                    os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
+                if os.path.exists(self.cache_key_file):
+                    os.rename(self.cache_key_file, self.cache_key_file + ".bak")
+                os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
+                os.rename(tmp_cache_key_file, self.cache_key_file)
+        print("[EmbeddingManager]cache dumped")
 
     def get_embedding(self, text_list):
         """
@@ -45,13 +76,14 @@ class EmbeddingManager:
             return embedding_list
         new_texts = []
         new_texts_ori_idx = []
-        for idx, text in enumerate(text_list):
-            if text in self.cache:
-                print(f"find {text} in cache")
-                embedding_list[idx] = self.cache[text]
-            else:
-                new_texts.append(text)
-                new_texts_ori_idx.append(idx)
+        with self.lock:
+            for idx, text in enumerate(text_list):
+                if text in self.cache:
+                    print(f"find {text} in cache")
+                    embedding_list[idx] = self.cache[text]
+                else:
+                    new_texts.append(text)
+                    new_texts_ori_idx.append(idx)
 
         new_embeddings = self.model.get_embedding(new_texts)
 
@@ -67,28 +99,7 @@ class EmbeddingManager:
         """Periodically save the cache to disk."""
         while True:
             sleep(self.save_interval)
-            self.save_now()
-
-    def save_now(self):
-        """Manually trigger a save to disk."""
-        if os.path.dirname(self.cache_file):
-            os.makedirs(os.path.dirname(self.cache_file), 0o755, True)
-        tmp_cache_file = self.cache_file + ".tmp"
-        tmp_cache_key_file = self.cache_key_file + ".tmp"
-        with self.lock:  # Ensure thread-safe access
-            keys = self.cache.keys()
-            cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
-            for idx, key in enumerate(keys):
-                cache_to_save[idx] = self.cache[key]
-            np.save(tmp_cache_file, cache_to_save)
-            with open(tmp_cache_key_file, 'w') as fp:
-                fp.write('\n'.join(keys))
-        if os.path.exists(self.cache_file + ".npy"):
-            os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
-        if os.path.exists(self.cache_key_file):
-            os.rename(self.cache_key_file, self.cache_key_file + ".bak")
-        os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
-        os.rename(tmp_cache_key_file, self.cache_key_file)
+            self.dump_cache()
 
 
 # Only for testing
@@ -109,11 +120,10 @@ class DummyModel:
 
 if __name__ == "__main__":
     model = DummyModel()
-    manager = EmbeddingManager()
-    manager.set_model(model)
+    manager = EmbeddingManager(model)
     print(manager.get_embedding(["hello"]))
     print(manager.get_embedding(["world"]))
     print(manager.get_embedding(["hello world"]))
-    manager.save_now()
+    manager.dump_cache()
     print(manager.get_embedding(["new", "word"]))
-    manager.save_now()
+    manager.dump_cache()

+ 9 - 2
applications/textSimilarity.py

@@ -19,10 +19,10 @@ def score_to_attention(score, symbol=1):
 
 def compare_tensor(tensor1, tensor2):
     if tensor1.shape != tensor2.shape:
-        print(f"[EmbeddingManager]shape error: {tensor1.shape} vs {tensor2.shape}")
+        print(f"[compare_tensor]shape error: {tensor1.shape} vs {tensor2.shape}")
         return
     if not torch.allclose(tensor1, tensor2):
-        print("[EmbeddingManager]value error: tensor1 not close to tensor2")
+        print("[compare_tensor]value error: tensor1 not close to tensor2")
 
 class NLPFunction(object):
     """
@@ -64,6 +64,13 @@ class NLPFunction(object):
             pair_list_dict['text_list_a'],
             pair_list_dict['text_list_b']
         )
+        # test embedding manager functions
+        text_emb1 = self.embedding_manager.get_embedding(pair_list_dict['text_list_a'])
+        text_emb2 = self.embedding_manager.get_embedding(pair_list_dict['text_list_a'])
+        score_function = self.model.score_functions['cos_sim']
+        score_tensor_new = score_function(text_emb1, text_emb2)
+        compare_tensor(score_tensor, score_tensor_new)
+
         response = {
             "score_list_list": score_tensor.tolist()
         }

+ 2 - 1
requirements.txt

@@ -26,4 +26,5 @@ torch~=2.3.1
 tqdm~=4.66.4
 transformers
 pydantic~=2.6.4
-similarities~=1.1.7
+similarities~=1.1.7
+filelock

+ 1 - 2
routes/__init__.py

@@ -10,7 +10,6 @@ from .articleDBServer import ArticleSpider
 from .accountServer import AccountServer
 from applications.articleTools import ArticleDBTools
 
-
 def AlgRoutes(mysql_client, model, embedding_manager):
     """
     ALG ROUTES
@@ -46,7 +45,7 @@ def AlgRoutes(mysql_client, model, embedding_manager):
         :return:
         """
         params = await request.get_json()
-        nlpS = NLPServer(params=params, model=model)
+        nlpS = NLPServer(params=params, model=model, embedding_manager=embedding_manager)
         response = nlpS.deal()
         return jsonify(response)
 

+ 0 - 2
routes/nlpServer.py

@@ -3,7 +3,6 @@
 """
 from applications.textSimilarity import NLPFunction
 
-
 class NLPServer(object):
     """
     nlp_server
@@ -15,7 +14,6 @@ class NLPServer(object):
         self.data = None
         self.function = None
         self.params = params
-        self.embedding_manager = embedding_manager
         self.nlp = NLPFunction(model=model, embedding_manager=embedding_manager)
 
     def check_params(self):