浏览代码

Update embedding_manager: configurable emb size

StrayWarrior 5 月之前
父节点
当前提交
c9cad1dbbc
共有 1 个文件被更改,包括 9 次插入6 次删除
  1. 9 6
      applications/embedding_manager.py

+ 9 - 6
applications/embedding_manager.py

@@ -6,8 +6,9 @@ import numpy as np
 
 
 
 
 class EmbeddingManager:
 class EmbeddingManager:
-    def __init__(self, model, cache_file="embedding_cache", save_interval=600):
+    def __init__(self, model, emb_size = 1024, cache_file="embedding_cache", save_interval=600):
         self.model = model
         self.model = model
+        self.emb_size = emb_size
         self.cache_file = cache_file
         self.cache_file = cache_file
         self.cache_key_file = f'{self.cache_file}.keys'
         self.cache_key_file = f'{self.cache_file}.keys'
         self.save_interval = save_interval
         self.save_interval = save_interval
@@ -35,7 +36,7 @@ class EmbeddingManager:
         """
         """
         if not isinstance(text_list, list):
         if not isinstance(text_list, list):
             raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
             raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
-        embedding_list = np.zeros((len(text_list), 1024), np.float32)
+        embedding_list = np.zeros((len(text_list), self.emb_size), np.float32)
         if not text_list:
         if not text_list:
             return embedding_list
             return embedding_list
         new_texts = []
         new_texts = []
@@ -70,14 +71,16 @@ class EmbeddingManager:
         tmp_cache_key_file = self.cache_key_file + ".tmp"
         tmp_cache_key_file = self.cache_key_file + ".tmp"
         with self.lock:  # Ensure thread-safe access
         with self.lock:  # Ensure thread-safe access
             keys = self.cache.keys()
             keys = self.cache.keys()
-            cache_to_save = np.zeros((len(keys), 1024), np.float32)
+            cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
             for idx, key in enumerate(keys):
             for idx, key in enumerate(keys):
                 cache_to_save[idx] = self.cache[key]
                 cache_to_save[idx] = self.cache[key]
             np.save(tmp_cache_file, cache_to_save)
             np.save(tmp_cache_file, cache_to_save)
             with open(tmp_cache_key_file, 'w') as fp:
             with open(tmp_cache_key_file, 'w') as fp:
                 fp.write('\n'.join(keys))
                 fp.write('\n'.join(keys))
-        os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
-        os.rename(self.cache_key_file, self.cache_key_file + ".bak")
+        if os.path.exists(self.cache_file + ".npy"):
+            os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
+        if os.path.exists(self.cache_key_file):
+            os.rename(self.cache_key_file, self.cache_key_file + ".bak")
         os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
         os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
         os.rename(tmp_cache_key_file, self.cache_key_file)
         os.rename(tmp_cache_key_file, self.cache_key_file)
 
 
@@ -100,7 +103,7 @@ class DummyModel:
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     model = DummyModel()
     model = DummyModel()
-    manager = EmbeddingManager(model)
+    manager = EmbeddingManager(model, 1024)
     print(manager.get_embedding(["hello"]))
     print(manager.get_embedding(["hello"]))
     print(manager.get_embedding(["world"]))
     print(manager.get_embedding(["world"]))
     print(manager.get_embedding(["hello world"]))
     print(manager.get_embedding(["hello world"]))