|
@@ -1,38 +1,69 @@
|
|
|
import os
|
|
|
import threading
|
|
|
-import multiprocessing
|
|
|
+from filelock import FileLock
|
|
|
from time import sleep
|
|
|
import numpy as np
|
|
|
+import random
|
|
|
|
|
|
|
|
|
class EmbeddingManager:
|
|
|
- def __init__(self, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
|
|
|
- self.model = None
|
|
|
+ def __init__(self, model, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
|
|
|
+ self.model = model
|
|
|
self.emb_size = emb_size
|
|
|
self.cache_file = cache_file
|
|
|
+ self.cache_file_real = self.cache_file + ".npy"
|
|
|
self.cache_key_file = f'{self.cache_file}.keys'
|
|
|
- self.save_interval = save_interval
|
|
|
- self.cache = multiprocessing.Manager().dict() # Shared dictionary for multiprocess use
|
|
|
+ # avoid multiple process read and write at same time and wait for filelock
|
|
|
+ self.save_interval = save_interval + random.randint(0, save_interval)
|
|
|
+ self.cache = {}
|
|
|
self.lock = threading.Lock() # Thread-safe lock
|
|
|
+ self.filelock = FileLock(self.cache_file + ".lock")
|
|
|
|
|
|
- npy_filename = f'{self.cache_file}.npy'
|
|
|
- # Load cache from file if it exists
|
|
|
- if os.path.exists(npy_filename):
|
|
|
- embedding_data = np.load(npy_filename)
|
|
|
- embedding_keys = open(self.cache_key_file, "r").readlines()
|
|
|
- embedding_keys = [key.strip("\n") for key in embedding_keys]
|
|
|
- for idx, key in enumerate(embedding_keys):
|
|
|
- self.cache[key] = embedding_data[idx]
|
|
|
- print("cache loaded:")
|
|
|
- print(self.cache)
|
|
|
+ self.load_cache()
|
|
|
|
|
|
# Start the periodic saving thread
|
|
|
self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
|
|
|
self.saving_thread.start()
|
|
|
|
|
|
- def set_model(self, model):
|
|
|
- """Since each process has its own model, model must be set after process is forked"""
|
|
|
- self.model = model
|
|
|
+
|
|
|
+ def _load_cache_unsafe(self):
|
|
|
+ """inter-thread and inter-process safety must be guaranteed by caller"""
|
|
|
+ embedding_data = np.load(self.cache_file_real)
|
|
|
+ embedding_keys = open(self.cache_key_file, "r").readlines()
|
|
|
+ embedding_keys = [key.strip("\n") for key in embedding_keys]
|
|
|
+ for idx, key in enumerate(embedding_keys):
|
|
|
+ self.cache[key] = embedding_data[idx]
|
|
|
+
|
|
|
+ def load_cache(self):
|
|
|
+ with self.lock:
|
|
|
+ if os.path.exists(self.cache_file_real):
|
|
|
+ with self.filelock:
|
|
|
+ self._load_cache_unsafe()
|
|
|
+ print("[EmbeddingManager]cache loaded")
|
|
|
+
|
|
|
+ def dump_cache(self):
|
|
|
+ if os.path.dirname(self.cache_file):
|
|
|
+ os.makedirs(os.path.dirname(self.cache_file), 0o755, True)
|
|
|
+ tmp_cache_file = self.cache_file + ".tmp"
|
|
|
+ tmp_cache_key_file = self.cache_key_file + ".tmp"
|
|
|
+ with self.lock: # Ensure thread-safe access firstly
|
|
|
+ with self.filelock: # Ensure inter-process safety secondly
|
|
|
+ if os.path.exists(self.cache_file_real):
|
|
|
+ self._load_cache_unsafe()
|
|
|
+ keys = self.cache.keys()
|
|
|
+ cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
|
|
|
+ for idx, key in enumerate(keys):
|
|
|
+ cache_to_save[idx] = self.cache[key]
|
|
|
+ np.save(tmp_cache_file, cache_to_save)
|
|
|
+ with open(tmp_cache_key_file, 'w') as fp:
|
|
|
+ fp.write('\n'.join(keys))
|
|
|
+ if os.path.exists(self.cache_file + ".npy"):
|
|
|
+ os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
|
|
|
+ if os.path.exists(self.cache_key_file):
|
|
|
+ os.rename(self.cache_key_file, self.cache_key_file + ".bak")
|
|
|
+ os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
|
|
|
+ os.rename(tmp_cache_key_file, self.cache_key_file)
|
|
|
+ print("[EmbeddingManager]cache dumped")
|
|
|
|
|
|
def get_embedding(self, text_list):
|
|
|
"""
|
|
@@ -45,13 +76,14 @@ class EmbeddingManager:
|
|
|
return embedding_list
|
|
|
new_texts = []
|
|
|
new_texts_ori_idx = []
|
|
|
- for idx, text in enumerate(text_list):
|
|
|
- if text in self.cache:
|
|
|
- print(f"find {text} in cache")
|
|
|
- embedding_list[idx] = self.cache[text]
|
|
|
- else:
|
|
|
- new_texts.append(text)
|
|
|
- new_texts_ori_idx.append(idx)
|
|
|
+ with self.lock:
|
|
|
+ for idx, text in enumerate(text_list):
|
|
|
+ if text in self.cache:
|
|
|
+ print(f"find {text} in cache")
|
|
|
+ embedding_list[idx] = self.cache[text]
|
|
|
+ else:
|
|
|
+ new_texts.append(text)
|
|
|
+ new_texts_ori_idx.append(idx)
|
|
|
|
|
|
new_embeddings = self.model.get_embedding(new_texts)
|
|
|
|
|
@@ -67,28 +99,7 @@ class EmbeddingManager:
|
|
|
"""Periodically save the cache to disk."""
|
|
|
while True:
|
|
|
sleep(self.save_interval)
|
|
|
- self.save_now()
|
|
|
-
|
|
|
- def save_now(self):
|
|
|
- """Manually trigger a save to disk."""
|
|
|
- if os.path.dirname(self.cache_file):
|
|
|
- os.makedirs(os.path.dirname(self.cache_file), 0o755, True)
|
|
|
- tmp_cache_file = self.cache_file + ".tmp"
|
|
|
- tmp_cache_key_file = self.cache_key_file + ".tmp"
|
|
|
- with self.lock: # Ensure thread-safe access
|
|
|
- keys = self.cache.keys()
|
|
|
- cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
|
|
|
- for idx, key in enumerate(keys):
|
|
|
- cache_to_save[idx] = self.cache[key]
|
|
|
- np.save(tmp_cache_file, cache_to_save)
|
|
|
- with open(tmp_cache_key_file, 'w') as fp:
|
|
|
- fp.write('\n'.join(keys))
|
|
|
- if os.path.exists(self.cache_file + ".npy"):
|
|
|
- os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
|
|
|
- if os.path.exists(self.cache_key_file):
|
|
|
- os.rename(self.cache_key_file, self.cache_key_file + ".bak")
|
|
|
- os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
|
|
|
- os.rename(tmp_cache_key_file, self.cache_key_file)
|
|
|
+ self.dump_cache()
|
|
|
|
|
|
|
|
|
# Only for testing
|
|
@@ -109,11 +120,10 @@ class DummyModel:
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
model = DummyModel()
|
|
|
- manager = EmbeddingManager()
|
|
|
- manager.set_model(model)
|
|
|
+ manager = EmbeddingManager(model)
|
|
|
print(manager.get_embedding(["hello"]))
|
|
|
print(manager.get_embedding(["world"]))
|
|
|
print(manager.get_embedding(["hello world"]))
|
|
|
- manager.save_now()
|
|
|
+ manager.dump_cache()
|
|
|
print(manager.get_embedding(["new", "word"]))
|
|
|
- manager.save_now()
|
|
|
+ manager.dump_cache()
|