فهرست منبع

Merge branch 'feature/20241208-cache-embedding' of Server/LongArticleAlgServer into master-GPU

fengzhoutian 7 ماه پیش
والد
کامیت
149347555e
7فایلهای تغییر یافته به همراه165 افزوده شده و 21 حذف شده
  1. 3 2
      alg_app.py
  2. 131 0
      applications/embedding_manager.py
  3. 24 9
      applications/textSimilarity.py
  4. 2 1
      requirements.txt
  5. 2 3
      routes/__init__.py
  6. 1 3
      routes/accountServer.py
  7. 2 3
      routes/nlpServer.py

+ 3 - 2
alg_app.py

@@ -5,11 +5,11 @@ from quart import Quart
 from similarities import BertSimilarity
 from routes import AlgRoutes
 from applications import AsyncMySQLClient
+from applications.embedding_manager import EmbeddingManager
 
 app = Quart(__name__)
 AsyncMySQL = AsyncMySQLClient(app)
 
-
 @app.before_serving
 async def init():
     """
@@ -17,8 +17,9 @@ async def init():
     """
     await AsyncMySQL.init_pool()
     model = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5")
+    embedding_manager = EmbeddingManager(model)
     print("模型加载成功")
-    app_routes = AlgRoutes(AsyncMySQL, model)
+    app_routes = AlgRoutes(AsyncMySQL, model, embedding_manager)
     app.register_blueprint(app_routes)
 
 

+ 131 - 0
applications/embedding_manager.py

@@ -0,0 +1,131 @@
+import os
+import threading
+from filelock import FileLock
+from time import sleep
+import numpy as np
+import random
+
+
+class EmbeddingManager:
+    def __init__(self, model, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
+        self.model = model
+        self.emb_size = emb_size
+        self.cache_file = cache_file
+        self.cache_file_real = self.cache_file + ".npy"
+        self.cache_key_file = f'{self.cache_file}.keys'
+        # avoid multiple process read and write at same time and wait for filelock
+        self.save_interval = save_interval + random.randint(0, save_interval)
+        self.cache = {}
+        self.lock = threading.Lock()  # Thread-safe lock
+        self.filelock = FileLock(self.cache_file + ".lock")
+
+        self.load_cache()
+
+        # Start the periodic saving thread
+        self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
+        self.saving_thread.start()
+
+
+    def _load_cache_unsafe(self):
+        """inter-thread and inter-process safety must be guaranteed by caller"""
+        embedding_data = np.load(self.cache_file_real)
+        embedding_keys = open(self.cache_key_file, "r").readlines()
+        embedding_keys = [key.strip("\n") for key in embedding_keys]
+        for idx, key in enumerate(embedding_keys):
+            self.cache[key] = embedding_data[idx]
+
+    def load_cache(self):
+        with self.lock:
+            if os.path.exists(self.cache_file_real):
+                with self.filelock:
+                    self._load_cache_unsafe()
+            print("[EmbeddingManager]cache loaded")
+
+    def dump_cache(self):
+        if os.path.dirname(self.cache_file):
+            os.makedirs(os.path.dirname(self.cache_file), 0o755, True)
+        tmp_cache_file = self.cache_file + ".tmp"
+        tmp_cache_key_file = self.cache_key_file + ".tmp"
+        with self.lock:  # Ensure thread-safe access firstly
+            with self.filelock: # Ensure inter-process safety secondly
+                if os.path.exists(self.cache_file_real):
+                    self._load_cache_unsafe()
+                keys = self.cache.keys()
+                cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
+                for idx, key in enumerate(keys):
+                    cache_to_save[idx] = self.cache[key]
+                np.save(tmp_cache_file, cache_to_save)
+                with open(tmp_cache_key_file, 'w') as fp:
+                    fp.write('\n'.join(keys))
+                if os.path.exists(self.cache_file + ".npy"):
+                    os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
+                if os.path.exists(self.cache_key_file):
+                    os.rename(self.cache_key_file, self.cache_key_file + ".bak")
+                os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
+                os.rename(tmp_cache_key_file, self.cache_key_file)
+        print("[EmbeddingManager]cache dumped")
+
+    def get_embeddings(self, text_list):
+        """
+        Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
+        """
+        if not isinstance(text_list, list):
+            raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
+        embedding_list = np.zeros((len(text_list), self.emb_size), np.float32)
+        if not text_list:
+            return embedding_list
+        new_texts = []
+        new_texts_ori_idx = []
+        with self.lock:
+            for idx, text in enumerate(text_list):
+                if text in self.cache:
+                    # print(f"find {text} in cache")
+                    embedding_list[idx] = self.cache[text]
+                else:
+                    new_texts.append(text)
+                    new_texts_ori_idx.append(idx)
+
+        new_embeddings = self.model.get_embeddings(new_texts)
+        if new_embeddings.shape[0] > 0 and new_embeddings.shape[1] != self.emb_size:
+            raise Exception("Embedding size mismatch")
+
+        # Generate embedding if not found in cache
+        with self.lock:  # Ensure thread-safe access
+            for idx, text in enumerate(new_texts):
+                if text not in self.cache:  # Re-check in case another thread added it
+                    self.cache[text] = new_embeddings[idx]
+                embedding_list[new_texts_ori_idx[idx]] = new_embeddings[idx]
+        return embedding_list
+
+    def _periodic_save(self):
+        """Periodically save the cache to disk."""
+        while True:
+            sleep(self.save_interval)
+            self.dump_cache()
+
+
+# Only for testing
+class DummyModel:
+    def padding_text(self, text):
+        padding_factor = 1024 // len(text)
+        text = text * padding_factor
+        text += text[:1024 - len(text)]
+        return text
+
+    def get_embeddings(self, text_list):
+        embeddings = np.zeros((len(text_list), 1024), np.float32)
+        for idx, text in enumerate(text_list):
+            text = self.padding_text(text)
+            embedding = np.array([ord(c) for c in text], np.float32)
+            embeddings[idx] = embedding
+        return embeddings
+
+if __name__ == "__main__":
+    model = DummyModel()
+    manager = EmbeddingManager(model)
+    print(manager.get_embeddings(["hello"]))
+    print(manager.get_embeddings(["world"]))
+    print(manager.get_embeddings(["hello world"]))
+    manager.dump_cache()
+    print(manager.get_embeddings(["new", "word"]))
+    manager.dump_cache()

+ 24 - 9
applications/textSimilarity.py

@@ -17,14 +17,31 @@ def score_to_attention(score, symbol=1):
     score_attn = torch.nn.functional.softmax(score_norm, dim=1)
     return score_attn, score_norm, score_pred
 
+def compare_tensor(tensor1, tensor2):
+    if tensor1.shape != tensor2.shape:
+        print(f"[compare_tensor]shape error: {tensor1.shape} vs {tensor2.shape}")
+        return
+    if not torch.allclose(tensor1, tensor2):
+        print("[compare_tensor]value error: tensor1 not close to tensor2")
 
 class NLPFunction(object):
     """
     NLP Task
     """
 
-    def __init__(self, model):
+    def __init__(self, model, embedding_manager):
         self.model = model
+        self.embedding_manager = embedding_manager
+
+    def direct_similarity(self, a, b):
+        return self.model.similarity(a, b)
+
+    def cached_similarity(self, a, b):
+        text_emb1 = self.embedding_manager.get_embeddings(a)
+        text_emb2 = self.embedding_manager.get_embeddings(b)
+        score_function = self.model.score_functions['cos_sim']
+        score_tensor = score_function(text_emb1, text_emb2)
+        return score_tensor
 
     def base_string_similarity(self, text_dict):
         """
@@ -32,10 +49,9 @@ class NLPFunction(object):
         :param text_dict:
         :return:
         """
-        score_tensor = self.model.similarity(
-            text_dict['text_a'],
-            text_dict['text_b']
-        )
+        text_a = text_dict['text_a']
+        text_b = text_dict['text_b']
+        score_tensor = self.cached_similarity(text_a, text_b)
         response = {
             "score": score_tensor.squeeze().tolist()
         }
@@ -46,10 +62,9 @@ class NLPFunction(object):
         计算两个list的相似度
         :return:
         """
-        score_tensor = self.model.similarity(
-            pair_list_dict['text_list_a'],
-            pair_list_dict['text_list_b']
-        )
+        text_a = pair_list_dict['text_list_a']
+        text_b = pair_list_dict['text_list_b']
+        score_tensor = self.cached_similarity(text_a, text_b)
         response = {
             "score_list_list": score_tensor.tolist()
         }

+ 2 - 1
requirements.txt

@@ -26,4 +26,5 @@ torch~=2.3.1
 tqdm~=4.66.4
 transformers
 pydantic~=2.6.4
-similarities~=1.1.7
+similarities~=1.1.7
+filelock

+ 2 - 3
routes/__init__.py

@@ -10,8 +10,7 @@ from .articleDBServer import ArticleSpider
 from .accountServer import AccountServer
 from applications.articleTools import ArticleDBTools
 
-
-def AlgRoutes(mysql_client, model):
+def AlgRoutes(mysql_client, model, embedding_manager):
     """
     ALG ROUTES
     :return:
@@ -46,7 +45,7 @@ def AlgRoutes(mysql_client, model):
         :return:
         """
         params = await request.get_json()
-        nlpS = NLPServer(params=params, model=model)
+        nlpS = NLPServer(params=params, model=model, embedding_manager=embedding_manager)
         response = nlpS.deal()
         return jsonify(response)
 

+ 1 - 3
routes/accountServer.py

@@ -44,7 +44,7 @@ class AccountServer(object):
         async with aiohttp.ClientSession() as session:
             async with session.post(url, headers=headers, json=body) as response:
                 response_text = await response.text()
-                print("结果:\t", response_text)
+                # print("结果:\t", response_text)
                 if response_text:
                     return await response.json()
                 else:
@@ -112,8 +112,6 @@ class AccountServer(object):
                 (good_df["show_view_count"] / good_df["view_count_avg"]).values.tolist()
 
         account_interest = good_df["title"].values.tolist()
-        print(account_interest)
-        print(extend_dicts)
         return account_interest, extend_dicts
 
     async def get_each_account_score_list(self, gh_id):

+ 2 - 3
routes/nlpServer.py

@@ -3,19 +3,18 @@
 """
 from applications.textSimilarity import NLPFunction
 
-
 class NLPServer(object):
     """
     nlp_server
     """
-    def __init__(self, params, model):
+    def __init__(self, params, model, embedding_manager):
         """
         :param params:
         """
         self.data = None
         self.function = None
         self.params = params
-        self.nlp = NLPFunction(model=model)
+        self.nlp = NLPFunction(model=model, embedding_manager=embedding_manager)
 
     def check_params(self):
         """