瀏覽代碼

Add embedding_manager

StrayWarrior 5 月之前
父節點
當前提交
3eee2ba72c
共有 1 個文件被更改,包括 102 次插入0 次删除
  1. 102 0
      applications/embedding_manager.py

+ 102 - 0
applications/embedding_manager.py

@@ -0,0 +1,102 @@
+import os
+import threading
+import multiprocessing
+from time import sleep
+import numpy as np
+
+
+class EmbeddingManager:
+    def __init__(self, model, cache_file="embedding_cache", save_interval=600):
+        self.model = model
+        self.cache_file = cache_file
+        self.cache_key_file = f'{self.cache_file}.keys'
+        self.save_interval = save_interval
+        self.cache = multiprocessing.Manager().dict()  # Shared dictionary for multiprocess use
+        self.lock = threading.Lock()  # Thread-safe lock
+
+        npy_filename = f'{self.cache_file}.npy'
+        # Load cache from file if it exists
+        if os.path.exists(npy_filename):
+            embedding_data = np.load(npy_filename)
+            embedding_keys = open(self.cache_key_file, "r").readlines()
+            embedding_keys = [key.strip("\n") for key in embedding_keys]
+            for idx, key in enumerate(embedding_keys):
+                self.cache[key] = embedding_data[idx]
+        print("cache loaded:")
+        print(self.cache)
+
+        # Start the periodic saving thread
+        self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
+        self.saving_thread.start()
+
+    def get_embedding(self, text_list):
+        """
+        Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
+        """
+        if not isinstance(text_list, list):
+            raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
+        embedding_list = np.zeros((len(text_list), 1024), np.float32)
+        if not text_list:
+            return embedding_list
+        new_texts = []
+        new_texts_ori_idx = []
+        for idx, text in enumerate(text_list):
+            if text in self.cache:
+                print(f"find {text} in cache")
+                embedding_list[idx] = self.cache[text]
+            else:
+                new_texts.append(text)
+                new_texts_ori_idx.append(idx)
+
+        new_embeddings = self.model.get_embedding(new_texts)
+
+        # Generate embedding if not found in cache
+        with self.lock:  # Ensure thread-safe access
+            for idx, text in enumerate(new_texts):
+                if text not in self.cache:  # Re-check in case another thread added it
+                    self.cache[text] = new_embeddings[idx]
+                embedding_list[new_texts_ori_idx[idx]] = new_embeddings[idx]
+        return embedding_list
+
+    def _periodic_save(self):
+        """Periodically save the cache to disk."""
+        while True:
+            sleep(self.save_interval)
+            self.save_now()
+
+    def save_now(self):
+        """Manually trigger a save to disk."""
+        # TODO: double-buffer
+        with self.lock:  # Ensure thread-safe access
+            keys = self.cache.keys()
+            cache_to_save = np.zeros((len(keys), 1024), np.float32)
+            for idx, key in enumerate(keys):
+                cache_to_save[idx] = self.cache[key]
+            np.save(self.cache_file, cache_to_save)
+            with open(self.cache_key_file, 'w') as fp:
+                fp.write('\n'.join(keys))
+
+
+# Only for testing
+class DummyModel:
+    def padding_text(self, text):
+        padding_factor = 1024 // len(text)
+        text = text * padding_factor
+        text += text[:1024 - len(text)]
+        return text
+
+    def get_embedding(self, text_list):
+        embeddings = np.zeros((len(text_list), 1024), np.float32)
+        for idx, text in enumerate(text_list):
+            text = self.padding_text(text)
+            embedding = np.array([ord(c) for c in text], np.float32)
+            embeddings[idx] = embedding
+        return embeddings
+
+if __name__ == "__main__":
+    model = DummyModel()
+    manager = EmbeddingManager(model)
+    print(manager.get_embedding(["hello"]))
+    print(manager.get_embedding(["world"]))
+    print(manager.get_embedding(["hello world"]))
+    manager.save_now()