|  | @@ -0,0 +1,131 @@
 | 
											
												
													
														|  | 
 |  | +import os
 | 
											
												
													
														|  | 
 |  | +import threading
 | 
											
												
													
														|  | 
 |  | +from filelock import FileLock
 | 
											
												
													
														|  | 
 |  | +from time import sleep
 | 
											
												
													
														|  | 
 |  | +import numpy as np
 | 
											
												
													
														|  | 
 |  | +import random
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class EmbeddingManager:
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, model, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
 | 
											
												
													
														|  | 
 |  | +        self.model = model
 | 
											
												
													
														|  | 
 |  | +        self.emb_size = emb_size
 | 
											
												
													
														|  | 
 |  | +        self.cache_file = cache_file
 | 
											
												
													
														|  | 
 |  | +        self.cache_file_real = self.cache_file + ".npy"
 | 
											
												
													
														|  | 
 |  | +        self.cache_key_file = f'{self.cache_file}.keys'
 | 
											
												
													
														|  | 
 |  | +        # avoid multiple process read and write at same time and wait for filelock
 | 
											
												
													
														|  | 
 |  | +        self.save_interval = save_interval + random.randint(0, save_interval)
 | 
											
												
													
														|  | 
 |  | +        self.cache = {}
 | 
											
												
													
														|  | 
 |  | +        self.lock = threading.Lock()  # Thread-safe lock
 | 
											
												
													
														|  | 
 |  | +        self.filelock = FileLock(self.cache_file + ".lock")
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        self.load_cache()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Start the periodic saving thread
 | 
											
												
													
														|  | 
 |  | +        self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
 | 
											
												
													
														|  | 
 |  | +        self.saving_thread.start()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _load_cache_unsafe(self):
 | 
											
												
													
														|  | 
 |  | +        """inter-thread and inter-process safety must be guaranteed by caller"""
 | 
											
												
													
														|  | 
 |  | +        embedding_data = np.load(self.cache_file_real)
 | 
											
												
													
														|  | 
 |  | +        embedding_keys = open(self.cache_key_file, "r").readlines()
 | 
											
												
													
														|  | 
 |  | +        embedding_keys = [key.strip("\n") for key in embedding_keys]
 | 
											
												
													
														|  | 
 |  | +        for idx, key in enumerate(embedding_keys):
 | 
											
												
													
														|  | 
 |  | +            self.cache[key] = embedding_data[idx]
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def load_cache(self):
 | 
											
												
													
														|  | 
 |  | +        with self.lock:
 | 
											
												
													
														|  | 
 |  | +            if os.path.exists(self.cache_file_real):
 | 
											
												
													
														|  | 
 |  | +                with self.filelock:
 | 
											
												
													
														|  | 
 |  | +                    self._load_cache_unsafe()
 | 
											
												
													
														|  | 
 |  | +            print("[EmbeddingManager]cache loaded")
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def dump_cache(self):
 | 
											
												
													
														|  | 
 |  | +        if os.path.dirname(self.cache_file):
 | 
											
												
													
														|  | 
 |  | +            os.makedirs(os.path.dirname(self.cache_file), 0o755, True)
 | 
											
												
													
														|  | 
 |  | +        tmp_cache_file = self.cache_file + ".tmp"
 | 
											
												
													
														|  | 
 |  | +        tmp_cache_key_file = self.cache_key_file + ".tmp"
 | 
											
												
													
														|  | 
 |  | +        with self.lock:  # Ensure thread-safe access firstly
 | 
											
												
													
														|  | 
 |  | +            with self.filelock: # Ensure inter-process safety secondly
 | 
											
												
													
														|  | 
 |  | +                if os.path.exists(self.cache_file_real):
 | 
											
												
													
														|  | 
 |  | +                    self._load_cache_unsafe()
 | 
											
												
													
														|  | 
 |  | +                keys = self.cache.keys()
 | 
											
												
													
														|  | 
 |  | +                cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
 | 
											
												
													
														|  | 
 |  | +                for idx, key in enumerate(keys):
 | 
											
												
													
														|  | 
 |  | +                    cache_to_save[idx] = self.cache[key]
 | 
											
												
													
														|  | 
 |  | +                np.save(tmp_cache_file, cache_to_save)
 | 
											
												
													
														|  | 
 |  | +                with open(tmp_cache_key_file, 'w') as fp:
 | 
											
												
													
														|  | 
 |  | +                    fp.write('\n'.join(keys))
 | 
											
												
													
														|  | 
 |  | +                if os.path.exists(self.cache_file + ".npy"):
 | 
											
												
													
														|  | 
 |  | +                    os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
 | 
											
												
													
														|  | 
 |  | +                if os.path.exists(self.cache_key_file):
 | 
											
												
													
														|  | 
 |  | +                    os.rename(self.cache_key_file, self.cache_key_file + ".bak")
 | 
											
												
													
														|  | 
 |  | +                os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
 | 
											
												
													
														|  | 
 |  | +                os.rename(tmp_cache_key_file, self.cache_key_file)
 | 
											
												
													
														|  | 
 |  | +        print("[EmbeddingManager]cache dumped")
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def get_embeddings(self, text_list):
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        if not isinstance(text_list, list):
 | 
											
												
													
														|  | 
 |  | +            raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
 | 
											
												
													
														|  | 
 |  | +        embedding_list = np.zeros((len(text_list), self.emb_size), np.float32)
 | 
											
												
													
														|  | 
 |  | +        if not text_list:
 | 
											
												
													
														|  | 
 |  | +            return embedding_list
 | 
											
												
													
														|  | 
 |  | +        new_texts = []
 | 
											
												
													
														|  | 
 |  | +        new_texts_ori_idx = []
 | 
											
												
													
														|  | 
 |  | +        with self.lock:
 | 
											
												
													
														|  | 
 |  | +            for idx, text in enumerate(text_list):
 | 
											
												
													
														|  | 
 |  | +                if text in self.cache:
 | 
											
												
													
														|  | 
 |  | +                    # print(f"find {text} in cache")
 | 
											
												
													
														|  | 
 |  | +                    embedding_list[idx] = self.cache[text]
 | 
											
												
													
														|  | 
 |  | +                else:
 | 
											
												
													
														|  | 
 |  | +                    new_texts.append(text)
 | 
											
												
													
														|  | 
 |  | +                    new_texts_ori_idx.append(idx)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        new_embeddings = self.model.get_embeddings(new_texts)
 | 
											
												
													
														|  | 
 |  | +        if new_embeddings.shape[0] > 0 and new_embeddings.shape[1] != self.emb_size:
 | 
											
												
													
														|  | 
 |  | +            raise Exception("Embedding size mismatch")
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Generate embedding if not found in cache
 | 
											
												
													
														|  | 
 |  | +        with self.lock:  # Ensure thread-safe access
 | 
											
												
													
														|  | 
 |  | +            for idx, text in enumerate(new_texts):
 | 
											
												
													
														|  | 
 |  | +                if text not in self.cache:  # Re-check in case another thread added it
 | 
											
												
													
														|  | 
 |  | +                    self.cache[text] = new_embeddings[idx]
 | 
											
												
													
														|  | 
 |  | +                embedding_list[new_texts_ori_idx[idx]] = new_embeddings[idx]
 | 
											
												
													
														|  | 
 |  | +        return embedding_list
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _periodic_save(self):
 | 
											
												
													
														|  | 
 |  | +        """Periodically save the cache to disk."""
 | 
											
												
													
														|  | 
 |  | +        while True:
 | 
											
												
													
														|  | 
 |  | +            sleep(self.save_interval)
 | 
											
												
													
														|  | 
 |  | +            self.dump_cache()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +# Only for testing
 | 
											
												
													
														|  | 
 |  | +class DummyModel:
 | 
											
												
													
														|  | 
 |  | +    def padding_text(self, text):
 | 
											
												
													
														|  | 
 |  | +        padding_factor = 1024 // len(text)
 | 
											
												
													
														|  | 
 |  | +        text = text * padding_factor
 | 
											
												
													
														|  | 
 |  | +        text += text[:1024 - len(text)]
 | 
											
												
													
														|  | 
 |  | +        return text
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def get_embeddings(self, text_list):
 | 
											
												
													
														|  | 
 |  | +        embeddings = np.zeros((len(text_list), 1024), np.float32)
 | 
											
												
													
														|  | 
 |  | +        for idx, text in enumerate(text_list):
 | 
											
												
													
														|  | 
 |  | +            text = self.padding_text(text)
 | 
											
												
													
														|  | 
 |  | +            embedding = np.array([ord(c) for c in text], np.float32)
 | 
											
												
													
														|  | 
 |  | +            embeddings[idx] = embedding
 | 
											
												
													
														|  | 
 |  | +        return embeddings
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +if __name__ == "__main__":
 | 
											
												
													
														|  | 
 |  | +    model = DummyModel()
 | 
											
												
													
														|  | 
 |  | +    manager = EmbeddingManager(model)
 | 
											
												
													
														|  | 
 |  | +    print(manager.get_embeddings(["hello"]))
 | 
											
												
													
														|  | 
 |  | +    print(manager.get_embeddings(["world"]))
 | 
											
												
													
														|  | 
 |  | +    print(manager.get_embeddings(["hello world"]))
 | 
											
												
													
														|  | 
 |  | +    manager.dump_cache()
 | 
											
												
													
														|  | 
 |  | +    print(manager.get_embeddings(["new", "word"]))
 | 
											
												
													
														|  | 
 |  | +    manager.dump_cache()
 |