|
@@ -66,15 +66,20 @@ class EmbeddingManager:
|
|
|
|
|
|
def save_now(self):
|
|
|
"""Manually trigger a save to disk."""
|
|
|
- # TODO: double-buffer
|
|
|
+ tmp_cache_file = self.cache_file + ".tmp"
|
|
|
+ tmp_cache_key_file = self.cache_key_file + ".tmp"
|
|
|
with self.lock: # Ensure thread-safe access
|
|
|
keys = self.cache.keys()
|
|
|
cache_to_save = np.zeros((len(keys), 1024), np.float32)
|
|
|
for idx, key in enumerate(keys):
|
|
|
cache_to_save[idx] = self.cache[key]
|
|
|
- np.save(self.cache_file, cache_to_save)
|
|
|
- with open(self.cache_key_file, 'w') as fp:
|
|
|
+ np.save(tmp_cache_file, cache_to_save)
|
|
|
+ with open(tmp_cache_key_file, 'w') as fp:
|
|
|
fp.write('\n'.join(keys))
|
|
|
+ os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
|
|
|
+ os.rename(self.cache_key_file, self.cache_key_file + ".bak")
|
|
|
+ os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
|
|
|
+ os.rename(tmp_cache_key_file, self.cache_key_file)
|
|
|
|
|
|
|
|
|
# Only for testing
|
|
@@ -100,3 +105,5 @@ if __name__ == "__main__":
|
|
|
print(manager.get_embedding(["world"]))
|
|
|
print(manager.get_embedding(["hello world"]))
|
|
|
manager.save_now()
|
|
|
+ print(manager.get_embedding(["new", "word"]))
|
|
|
+ manager.save_now()
|