|  | @@ -6,8 +6,8 @@ import numpy as np
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  class EmbeddingManager:
 |  |  class EmbeddingManager:
 | 
											
												
													
														|  | -    def __init__(self, model, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
 |  | 
 | 
											
												
													
														|  | -        self.model = model
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
 | 
											
												
													
														|  | 
 |  | +        self.model = None
 | 
											
												
													
														|  |          self.emb_size = emb_size
 |  |          self.emb_size = emb_size
 | 
											
												
													
														|  |          self.cache_file = cache_file
 |  |          self.cache_file = cache_file
 | 
											
												
													
														|  |          self.cache_key_file = f'{self.cache_file}.keys'
 |  |          self.cache_key_file = f'{self.cache_file}.keys'
 | 
											
										
											
												
													
														|  | @@ -30,6 +30,10 @@ class EmbeddingManager:
 | 
											
												
													
														|  |          self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
 |  |          self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
 | 
											
												
													
														|  |          self.saving_thread.start()
 |  |          self.saving_thread.start()
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +    def set_model(self, model):
 | 
											
												
													
														|  | 
 |  | +        """Since each process has its own model, model must be set after process is forked"""
 | 
											
												
													
														|  | 
 |  | +        self.model = model
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |      def get_embedding(self, text_list):
 |  |      def get_embedding(self, text_list):
 | 
											
												
													
														|  |          """
 |  |          """
 | 
											
												
													
														|  |          Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
 |  |          Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
 | 
											
										
											
												
													
														|  | @@ -105,7 +109,8 @@ class DummyModel:
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  if __name__ == "__main__":
 |  |  if __name__ == "__main__":
 | 
											
												
													
														|  |      model = DummyModel()
 |  |      model = DummyModel()
 | 
											
												
													
														|  | -    manager = EmbeddingManager(model, 1024)
 |  | 
 | 
											
												
													
														|  | 
 |  | +    manager = EmbeddingManager()
 | 
											
												
													
														|  | 
 |  | +    manager.set_model(model)
 | 
											
												
													
														|  |      print(manager.get_embedding(["hello"]))
 |  |      print(manager.get_embedding(["hello"]))
 | 
											
												
													
														|  |      print(manager.get_embedding(["world"]))
 |  |      print(manager.get_embedding(["world"]))
 | 
											
												
													
														|  |      print(manager.get_embedding(["hello world"]))
 |  |      print(manager.get_embedding(["hello world"]))
 |