|
@@ -6,8 +6,8 @@ import numpy as np
|
|
|
|
|
|
|
|
|
class EmbeddingManager:
|
|
|
- def __init__(self, model, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
|
|
|
- self.model = model
|
|
|
+ def __init__(self, emb_size=1024, cache_file="cache/embedding_cache", save_interval=600):
|
|
|
+ self.model = None
|
|
|
self.emb_size = emb_size
|
|
|
self.cache_file = cache_file
|
|
|
self.cache_key_file = f'{self.cache_file}.keys'
|
|
@@ -30,6 +30,10 @@ class EmbeddingManager:
|
|
|
self.saving_thread = threading.Thread(target=self._periodic_save, daemon=True)
|
|
|
self.saving_thread.start()
|
|
|
|
|
|
+ def set_model(self, model):
|
|
|
+ """Since each process has its own model, model must be set after process is forked"""
|
|
|
+ self.model = model
|
|
|
+
|
|
|
def get_embedding(self, text_list):
|
|
|
"""
|
|
|
Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
|
|
@@ -105,7 +109,8 @@ class DummyModel:
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
model = DummyModel()
|
|
|
- manager = EmbeddingManager(model, 1024)
|
|
|
+ manager = EmbeddingManager()
|
|
|
+ manager.set_model(model)
|
|
|
print(manager.get_embedding(["hello"]))
|
|
|
print(manager.get_embedding(["world"]))
|
|
|
print(manager.get_embedding(["hello world"]))
|