import os import threading import multiprocessing from time import sleep import numpy as np class EmbeddingManager: def __init__(self, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600): self.model = None self.emb_size = emb_size self.cache_file = cache_file self.cache_key_file = f'{self.cache_file}.keys' self.save_interval = save_interval self.cache = multiprocessing.Manager().dict() # Shared dictionary for multiprocess use self.lock = threading.Lock() # Thread-safe lock npy_filename = f'{self.cache_file}.npy' # Load cache from file if it exists if os.path.exists(npy_filename): embedding_data = np.load(npy_filename) embedding_keys = open(self.cache_key_file, "r").readlines() embedding_keys = [key.strip("\n") for key in embedding_keys] for idx, key in enumerate(embedding_keys): self.cache[key] = embedding_data[idx] print("cache loaded:") print(self.cache) # Start the periodic saving thread self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True) self.saving_thread.start() def set_model(self, model): """Since each process has its own model, model must be set after process is forked""" self.model = model def get_embedding(self, text_list): """ Search embedding for a given text. If not found, generate using the model, save to cache, and return it. """ if not isinstance(text_list, list): raise Exception(f"Invalid parameter type: text_list {type(text_list)}") embedding_list = np.zeros((len(text_list), self.emb_size), np.float32) if not text_list: return embedding_list new_texts = [] new_texts_ori_idx = [] for idx, text in enumerate(text_list): if text in self.cache: print(f"find {text} in cache") embedding_list[idx] = self.cache[text] else: new_texts.append(text) new_texts_ori_idx.append(idx) new_embeddings = self.model.get_embedding(new_texts) # Generate embedding if not found in cache with self.lock: # Ensure thread-safe access for idx, text in enumerate(new_texts): if text not in self.cache: # Re-check in case another thread added it self.cache[text] = new_embeddings[idx] embedding_list[new_texts_ori_idx[idx]] = new_embeddings[idx] return embedding_list def _periodic_save(self): """Periodically save the cache to disk.""" while True: sleep(self.save_interval) self.save_now() def save_now(self): """Manually trigger a save to disk.""" if os.path.dirname(self.cache_file): os.makedirs(os.path.dirname(self.cache_file), 0o755, True) tmp_cache_file = self.cache_file + ".tmp" tmp_cache_key_file = self.cache_key_file + ".tmp" with self.lock: # Ensure thread-safe access keys = self.cache.keys() cache_to_save = np.zeros((len(keys), self.emb_size), np.float32) for idx, key in enumerate(keys): cache_to_save[idx] = self.cache[key] np.save(tmp_cache_file, cache_to_save) with open(tmp_cache_key_file, 'w') as fp: fp.write('\n'.join(keys)) if os.path.exists(self.cache_file + ".npy"): os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak") if os.path.exists(self.cache_key_file): os.rename(self.cache_key_file, self.cache_key_file + ".bak") os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy") os.rename(tmp_cache_key_file, self.cache_key_file) # Only for testing class DummyModel: def padding_text(self, text): padding_factor = 1024 // len(text) text = text * padding_factor text += text[:1024 - len(text)] return text def get_embedding(self, text_list): embeddings = np.zeros((len(text_list), 1024), np.float32) for idx, text in enumerate(text_list): text = self.padding_text(text) embedding = np.array([ord(c) for c in text], np.float32) embeddings[idx] = embedding return embeddings if __name__ == "__main__": model = DummyModel() manager = EmbeddingManager() manager.set_model(model) print(manager.get_embedding(["hello"])) print(manager.get_embedding(["world"])) print(manager.get_embedding(["hello world"])) manager.save_now() print(manager.get_embedding(["new", "word"])) manager.save_now()