|
@@ -0,0 +1,102 @@
|
|
|
|
+import os
|
|
|
|
+import threading
|
|
|
|
+import multiprocessing
|
|
|
|
+from time import sleep
|
|
|
|
+import numpy as np
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class EmbeddingManager:
|
|
|
|
+ def __init__(self, model, cache_file="embedding_cache", save_interval=600):
|
|
|
|
+ self.model = model
|
|
|
|
+ self.cache_file = cache_file
|
|
|
|
+ self.cache_key_file = f'{self.cache_file}.keys'
|
|
|
|
+ self.save_interval = save_interval
|
|
|
|
+ self.cache = multiprocessing.Manager().dict() # Shared dictionary for multiprocess use
|
|
|
|
+ self.lock = threading.Lock() # Thread-safe lock
|
|
|
|
+
|
|
|
|
+ npy_filename = f'{self.cache_file}.npy'
|
|
|
|
+ # Load cache from file if it exists
|
|
|
|
+ if os.path.exists(npy_filename):
|
|
|
|
+ embedding_data = np.load(npy_filename)
|
|
|
|
+ embedding_keys = open(self.cache_key_file, "r").readlines()
|
|
|
|
+ embedding_keys = [key.strip("\n") for key in embedding_keys]
|
|
|
|
+ for idx, key in enumerate(embedding_keys):
|
|
|
|
+ self.cache[key] = embedding_data[idx]
|
|
|
|
+ print("cache loaded:")
|
|
|
|
+ print(self.cache)
|
|
|
|
+
|
|
|
|
+ # Start the periodic saving thread
|
|
|
|
+ self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
|
|
|
|
+ self.saving_thread.start()
|
|
|
|
+
|
|
|
|
+ def get_embedding(self, text_list):
|
|
|
|
+ """
|
|
|
|
+ Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
|
|
|
|
+ """
|
|
|
|
+ if not isinstance(text_list, list):
|
|
|
|
+ raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
|
|
|
|
+ embedding_list = np.zeros((len(text_list), 1024), np.float32)
|
|
|
|
+ if not text_list:
|
|
|
|
+ return embedding_list
|
|
|
|
+ new_texts = []
|
|
|
|
+ new_texts_ori_idx = []
|
|
|
|
+ for idx, text in enumerate(text_list):
|
|
|
|
+ if text in self.cache:
|
|
|
|
+ print(f"find {text} in cache")
|
|
|
|
+ embedding_list[idx] = self.cache[text]
|
|
|
|
+ else:
|
|
|
|
+ new_texts.append(text)
|
|
|
|
+ new_texts_ori_idx.append(idx)
|
|
|
|
+
|
|
|
|
+ new_embeddings = self.model.get_embedding(new_texts)
|
|
|
|
+
|
|
|
|
+ # Generate embedding if not found in cache
|
|
|
|
+ with self.lock: # Ensure thread-safe access
|
|
|
|
+ for idx, text in enumerate(new_texts):
|
|
|
|
+ if text not in self.cache: # Re-check in case another thread added it
|
|
|
|
+ self.cache[text] = new_embeddings[idx]
|
|
|
|
+ embedding_list[new_texts_ori_idx[idx]] = new_embeddings[idx]
|
|
|
|
+ return embedding_list
|
|
|
|
+
|
|
|
|
+ def _periodic_save(self):
|
|
|
|
+ """Periodically save the cache to disk."""
|
|
|
|
+ while True:
|
|
|
|
+ sleep(self.save_interval)
|
|
|
|
+ self.save_now()
|
|
|
|
+
|
|
|
|
+ def save_now(self):
|
|
|
|
+ """Manually trigger a save to disk."""
|
|
|
|
+ # TODO: double-buffer
|
|
|
|
+ with self.lock: # Ensure thread-safe access
|
|
|
|
+ keys = self.cache.keys()
|
|
|
|
+ cache_to_save = np.zeros((len(keys), 1024), np.float32)
|
|
|
|
+ for idx, key in enumerate(keys):
|
|
|
|
+ cache_to_save[idx] = self.cache[key]
|
|
|
|
+ np.save(self.cache_file, cache_to_save)
|
|
|
|
+ with open(self.cache_key_file, 'w') as fp:
|
|
|
|
+ fp.write('\n'.join(keys))
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# Only for testing
|
|
|
|
+class DummyModel:
|
|
|
|
+ def padding_text(self, text):
|
|
|
|
+ padding_factor = 1024 // len(text)
|
|
|
|
+ text = text * padding_factor
|
|
|
|
+ text += text[:1024 - len(text)]
|
|
|
|
+ return text
|
|
|
|
+
|
|
|
|
+ def get_embedding(self, text_list):
|
|
|
|
+ embeddings = np.zeros((len(text_list), 1024), np.float32)
|
|
|
|
+ for idx, text in enumerate(text_list):
|
|
|
|
+ text = self.padding_text(text)
|
|
|
|
+ embedding = np.array([ord(c) for c in text], np.float32)
|
|
|
|
+ embeddings[idx] = embedding
|
|
|
|
+ return embeddings
|
|
|
|
+
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
+ model = DummyModel()
|
|
|
|
+ manager = EmbeddingManager(model)
|
|
|
|
+ print(manager.get_embedding(["hello"]))
|
|
|
|
+ print(manager.get_embedding(["world"]))
|
|
|
|
+ print(manager.get_embedding(["hello world"]))
|
|
|
|
+ manager.save_now()
|