|  | @@ -6,8 +6,9 @@ import numpy as np
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  class EmbeddingManager:
 | 
	
		
			
				|  |  | -    def __init__(self, model, cache_file="embedding_cache", save_interval=600):
 | 
	
		
			
				|  |  | +    def __init__(self, model, emb_size = 1024, cache_file="embedding_cache", save_interval=600):
 | 
	
		
			
				|  |  |          self.model = model
 | 
	
		
			
				|  |  | +        self.emb_size = emb_size
 | 
	
		
			
				|  |  |          self.cache_file = cache_file
 | 
	
		
			
				|  |  |          self.cache_key_file = f'{self.cache_file}.keys'
 | 
	
		
			
				|  |  |          self.save_interval = save_interval
 | 
	
	
		
			
				|  | @@ -35,7 +36,7 @@ class EmbeddingManager:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          if not isinstance(text_list, list):
 | 
	
		
			
				|  |  |              raise Exception(f"Invalid parameter type: text_list {type(text_list)}")
 | 
	
		
			
				|  |  | -        embedding_list = np.zeros((len(text_list), 1024), np.float32)
 | 
	
		
			
				|  |  | +        embedding_list = np.zeros((len(text_list), self.emb_size), np.float32)
 | 
	
		
			
				|  |  |          if not text_list:
 | 
	
		
			
				|  |  |              return embedding_list
 | 
	
		
			
				|  |  |          new_texts = []
 | 
	
	
		
			
				|  | @@ -70,14 +71,16 @@ class EmbeddingManager:
 | 
	
		
			
				|  |  |          tmp_cache_key_file = self.cache_key_file + ".tmp"
 | 
	
		
			
				|  |  |          with self.lock:  # Ensure thread-safe access
 | 
	
		
			
				|  |  |              keys = self.cache.keys()
 | 
	
		
			
				|  |  | -            cache_to_save = np.zeros((len(keys), 1024), np.float32)
 | 
	
		
			
				|  |  | +            cache_to_save = np.zeros((len(keys), self.emb_size), np.float32)
 | 
	
		
			
				|  |  |              for idx, key in enumerate(keys):
 | 
	
		
			
				|  |  |                  cache_to_save[idx] = self.cache[key]
 | 
	
		
			
				|  |  |              np.save(tmp_cache_file, cache_to_save)
 | 
	
		
			
				|  |  |              with open(tmp_cache_key_file, 'w') as fp:
 | 
	
		
			
				|  |  |                  fp.write('\n'.join(keys))
 | 
	
		
			
				|  |  | -        os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
 | 
	
		
			
				|  |  | -        os.rename(self.cache_key_file, self.cache_key_file + ".bak")
 | 
	
		
			
				|  |  | +        if os.path.exists(self.cache_file + ".npy"):
 | 
	
		
			
				|  |  | +            os.rename(self.cache_file + ".npy", self.cache_file + ".npy.bak")
 | 
	
		
			
				|  |  | +        if os.path.exists(self.cache_key_file):
 | 
	
		
			
				|  |  | +            os.rename(self.cache_key_file, self.cache_key_file + ".bak")
 | 
	
		
			
				|  |  |          os.rename(tmp_cache_file + ".npy", self.cache_file + ".npy")
 | 
	
		
			
				|  |  |          os.rename(tmp_cache_key_file, self.cache_key_file)
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -100,7 +103,7 @@ class DummyModel:
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  if __name__ == "__main__":
 | 
	
		
			
				|  |  |      model = DummyModel()
 | 
	
		
			
				|  |  | -    manager = EmbeddingManager(model)
 | 
	
		
			
				|  |  | +    manager = EmbeddingManager(model, 1024)
 | 
	
		
			
				|  |  |      print(manager.get_embedding(["hello"]))
 | 
	
		
			
				|  |  |      print(manager.get_embedding(["world"]))
 | 
	
		
			
				|  |  |      print(manager.get_embedding(["hello world"]))
 |