Procházet zdrojové kódy

Update embedding_manager: double-buffer

StrayWarrior před 6 měsíci
rodič
revize
0d64100f9b
1 změnil soubory, kde provedl 10 přidání a 3 odebrání
  1. 10 3
      applications/embedding_manager.py

+ 10 - 3
applications/embedding_manager.py

@@ -66,15 +66,20 @@ class EmbeddingManager:
 
 
     def save_now(self):
     def save_now(self):
         """Manually trigger a save to disk."""
         """Manually trigger a save to disk."""
-        # TODO: double-buffer
+        tmp_cache_file = self.cache_file + ".tmp"
+        tmp_cache_key_file = self.cache_key_file + ".tmp"
         with self.lock:  # Ensure thread-safe access
         with self.lock:  # Ensure thread-safe access
             keys = self.cache.keys()
             keys = self.cache.keys()
             cache_to_save = np.zeros((len(keys), 1024), np.float32)
             cache_to_save = np.zeros((len(keys), 1024), np.float32)
             for idx, key in enumerate(keys):
             for idx, key in enumerate(keys):
                 cache_to_save[idx] = self.cache[key]
                 cache_to_save[idx] = self.cache[key]
-            np.save(self.cache_file, cache_to_save)
-            with open(self.cache_key_file, 'w') as fp:
+            np.save(tmp_cache_file, cache_to_save)
+            with open(tmp_cache_key_file, 'w') as fp:
                 fp.write('\n'.join(keys))
                 fp.write('\n'.join(keys))
+        os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
+        os.rename(self.cache_key_file, self.cache_key_file + ".bak")
+        os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
+        os.rename(tmp_cache_key_file, self.cache_key_file)
 
 
 
 
 # Only for testing
 # Only for testing
@@ -100,3 +105,5 @@ if __name__ == "__main__":
     print(manager.get_embedding(["world"]))
     print(manager.get_embedding(["world"]))
     print(manager.get_embedding(["hello world"]))
     print(manager.get_embedding(["hello world"]))
     manager.save_now()
     manager.save_now()
+    print(manager.get_embedding(["new", "word"]))
+    manager.save_now()