|  | @@ -0,0 +1,102 @@
 | 
											
												
													
														|  | 
 |  | +import os
 | 
											
												
													
														|  | 
 |  | +import threading
 | 
											
												
													
														|  | 
 |  | +import multiprocessing
 | 
											
												
													
														|  | 
 |  | +from time import sleep
 | 
											
												
													
														|  | 
 |  | +import numpy as np
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +class EmbeddingManager:
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, model, cache_file="embedding_cache", save_interval=600):
 | 
											
												
													
														|  | 
 |  | +        self.model = model
 | 
											
												
													
														|  | 
 |  | +        self.cache_file = cache_file
 | 
											
												
													
														|  | 
 |  | +        self.cache_key_file = f'{self.cache_file}.keys'
 | 
											
												
													
														|  | 
 |  | +        self.save_interval = save_interval
 | 
											
												
													
														|  | 
 |  | +        self.cache = multiprocessing.Manager().dict()  # Shared dictionary for multiprocess use
 | 
											
												
													
														|  | 
 |  | +        self.lock = threading.Lock()  # Thread-safe lock
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        npy_filename = f'{self.cache_file}.npy'
 | 
											
												
													
														|  | 
 |  | +        # Load cache from file if it exists
 | 
											
												
													
														|  | 
 |  | +        if os.path.exists(npy_filename):
 | 
											
												
													
														|  | 
 |  | +            embedding_data = np.load(npy_filename)
 | 
											
												
													
														|  | 
 |  | +            embedding_keys = open(self.cache_key_file, "r").readlines()
 | 
											
												
													
														|  | 
 |  | +            embedding_keys = [key.strip("\n") for key in embedding_keys]
 | 
											
												
													
														|  | 
 |  | +            for idx, key in enumerate(embedding_keys):
 | 
											
												
													
														|  | 
 |  | +                self.cache[key] = embedding_data[idx]
 | 
											
												
													
														|  | 
 |  | +        print("cache loaded:")
 | 
											
												
													
														|  | 
 |  | +        print(self.cache)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Start the periodic saving thread
 | 
											
												
													
														|  | 
 |  | +        self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
 | 
											
												
													
														|  | 
 |  | +        self.saving_thread.start()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def get_embedding(self, text_list):
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
 | 
											
												
													
														|  | 
 |  | +        """
 | 
											
												
													
														|  | 
 |  | +        if not isinstance(text_list, list):
 | 
											
												
													
														|  | 
 |  | +            raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
 | 
											
												
													
														|  | 
 |  | +        embedding_list = np.zeros((len(text_list), 1024), np.float32)
 | 
											
												
													
														|  | 
 |  | +        if not text_list:
 | 
											
												
													
														|  | 
 |  | +            return embedding_list
 | 
											
												
													
														|  | 
 |  | +        new_texts = []
 | 
											
												
													
														|  | 
 |  | +        new_texts_ori_idx = []
 | 
											
												
													
														|  | 
 |  | +        for idx, text in enumerate(text_list):
 | 
											
												
													
														|  | 
 |  | +            if text in self.cache:
 | 
											
												
													
														|  | 
 |  | +                print(f"find {text} in cache")
 | 
											
												
													
														|  | 
 |  | +                embedding_list[idx] = self.cache[text]
 | 
											
												
													
														|  | 
 |  | +            else:
 | 
											
												
													
														|  | 
 |  | +                new_texts.append(text)
 | 
											
												
													
														|  | 
 |  | +                new_texts_ori_idx.append(idx)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        new_embeddings = self.model.get_embedding(new_texts)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +        # Generate embedding if not found in cache
 | 
											
												
													
														|  | 
 |  | +        with self.lock:  # Ensure thread-safe access
 | 
											
												
													
														|  | 
 |  | +            for idx, text in enumerate(new_texts):
 | 
											
												
													
														|  | 
 |  | +                if text not in self.cache:  # Re-check in case another thread added it
 | 
											
												
													
														|  | 
 |  | +                    self.cache[text] = new_embeddings[idx]
 | 
											
												
													
														|  | 
 |  | +                embedding_list[new_texts_ori_idx[idx]] = new_embeddings[idx]
 | 
											
												
													
														|  | 
 |  | +        return embedding_list
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def _periodic_save(self):
 | 
											
												
													
														|  | 
 |  | +        """Periodically save the cache to disk."""
 | 
											
												
													
														|  | 
 |  | +        while True:
 | 
											
												
													
														|  | 
 |  | +            sleep(self.save_interval)
 | 
											
												
													
														|  | 
 |  | +            self.save_now()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def save_now(self):
 | 
											
												
													
														|  | 
 |  | +        """Manually trigger a save to disk."""
 | 
											
												
													
														|  | 
 |  | +        # TODO: double-buffer
 | 
											
												
													
														|  | 
 |  | +        with self.lock:  # Ensure thread-safe access
 | 
											
												
													
														|  | 
 |  | +            keys = self.cache.keys()
 | 
											
												
													
														|  | 
 |  | +            cache_to_save = np.zeros((len(keys), 1024), np.float32)
 | 
											
												
													
														|  | 
 |  | +            for idx, key in enumerate(keys):
 | 
											
												
													
														|  | 
 |  | +                cache_to_save[idx] = self.cache[key]
 | 
											
												
													
														|  | 
 |  | +            np.save(self.cache_file, cache_to_save)
 | 
											
												
													
														|  | 
 |  | +            with open(self.cache_key_file, 'w') as fp:
 | 
											
												
													
														|  | 
 |  | +                fp.write('\n'.join(keys))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +# Only for testing
 | 
											
												
													
														|  | 
 |  | +class DummyModel:
 | 
											
												
													
														|  | 
 |  | +    def padding_text(self, text):
 | 
											
												
													
														|  | 
 |  | +        padding_factor = 1024 // len(text)
 | 
											
												
													
														|  | 
 |  | +        text = text * padding_factor
 | 
											
												
													
														|  | 
 |  | +        text += text[:1024 - len(text)]
 | 
											
												
													
														|  | 
 |  | +        return text
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +    def get_embedding(self, text_list):
 | 
											
												
													
														|  | 
 |  | +        embeddings = np.zeros((len(text_list), 1024), np.float32)
 | 
											
												
													
														|  | 
 |  | +        for idx, text in enumerate(text_list):
 | 
											
												
													
														|  | 
 |  | +            text = self.padding_text(text)
 | 
											
												
													
														|  | 
 |  | +            embedding = np.array([ord(c) for c in text], np.float32)
 | 
											
												
													
														|  | 
 |  | +            embeddings[idx] = embedding
 | 
											
												
													
														|  | 
 |  | +        return embeddings
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +if __name__ == "__main__":
 | 
											
												
													
														|  | 
 |  | +    model = DummyModel()
 | 
											
												
													
														|  | 
 |  | +    manager = EmbeddingManager(model)
 | 
											
												
													
														|  | 
 |  | +    print(manager.get_embedding(["hello"]))
 | 
											
												
													
														|  | 
 |  | +    print(manager.get_embedding(["world"]))
 | 
											
												
													
														|  | 
 |  | +    print(manager.get_embedding(["hello world"]))
 | 
											
												
													
														|  | 
 |  | +    manager.save_now()
 |