123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import os
- import threading
- import multiprocessing
- from time import sleep
- import numpy as np
- class EmbeddingManager:
- def __init__(self, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
- self.model = None
- self.emb_size = emb_size
- self.cache_file = cache_file
- self.cache_key_file = f'{self.cache_file}.keys'
- self.save_interval = save_interval
- self.cache = multiprocessing.Manager().dict() # Shared dictionary for multiprocess use
- self.lock = threading.Lock() # Thread-safe lock
- npy_filename = f'{self.cache_file}.npy'
- # Load cache from file if it exists
- if os.path.exists(npy_filename):
- embedding_data = np.load(npy_filename)
- embedding_keys = open(self.cache_key_file, "r").readlines()
- embedding_keys = [key.strip("\n") for key in embedding_keys]
- for idx, key in enumerate(embedding_keys):
- self.cache[key] = embedding_data[idx]
- print("cache loaded:")
- print(self.cache)
- # Start the periodic saving thread
- self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
- self.saving_thread.start()
- def set_model(self, model):
- """Since each process has its own model, model must be set after process is forked"""
- self.model = model
- def get_embedding(self, text_list):
- """
- Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
- """
- if not isinstance(text_list, list):
- raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
- embedding_list = np.zeros((len(text_list), self.emb_size), np.float32)
- if not text_list:
- return embedding_list
- new_texts = []
- new_texts_ori_idx = []
- for idx, text in enumerate(text_list):
- if text in self.cache:
- print(f"find {text} in cache")
- embedding_list[idx] = self.cache[text]
- else:
- new_texts.append(text)
- new_texts_ori_idx.append(idx)
- new_embeddings = self.model.get_embedding(new_texts)
- # Generate embedding if not found in cache
- with self.lock: # Ensure thread-safe access
- for idx, text in enumerate(new_texts):
- if text not in self.cache: # Re-check in case another thread added it
- self.cache[text] = new_embeddings[idx]
- embedding_list[new_texts_ori_idx[idx]] = new_embeddings[idx]
- return embedding_list
- def _periodic_save(self):
- """Periodically save the cache to disk."""
- while True:
- sleep(self.save_interval)
- self.save_now()
- def save_now(self):
- """Manually trigger a save to disk."""
- if os.path.dirname(self.cache_file):
- os.makedirs(os.path.dirname(self.cache_file), 0o755, True)
- tmp_cache_file = self.cache_file + ".tmp"
- tmp_cache_key_file = self.cache_key_file + ".tmp"
- with self.lock: # Ensure thread-safe access
- keys = self.cache.keys()
- cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
- for idx, key in enumerate(keys):
- cache_to_save[idx] = self.cache[key]
- np.save(tmp_cache_file, cache_to_save)
- with open(tmp_cache_key_file, 'w') as fp:
- fp.write('\n'.join(keys))
- if os.path.exists(self.cache_file + ".npy"):
- os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
- if os.path.exists(self.cache_key_file):
- os.rename(self.cache_key_file, self.cache_key_file + ".bak")
- os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
- os.rename(tmp_cache_key_file, self.cache_key_file)
- # Only for testing
- class DummyModel:
- def padding_text(self, text):
- padding_factor = 1024 // len(text)
- text = text * padding_factor
- text += text[:1024 - len(text)]
- return text
- def get_embedding(self, text_list):
- embeddings = np.zeros((len(text_list), 1024), np.float32)
- for idx, text in enumerate(text_list):
- text = self.padding_text(text)
- embedding = np.array([ord(c) for c in text], np.float32)
- embeddings[idx] = embedding
- return embeddings
- if __name__ == "__main__":
- model = DummyModel()
- manager = EmbeddingManager()
- manager.set_model(model)
- print(manager.get_embedding(["hello"]))
- print(manager.get_embedding(["world"]))
- print(manager.get_embedding(["hello world"]))
- manager.save_now()
- print(manager.get_embedding(["new", "word"]))
- manager.save_now()
|