|
@@ -65,7 +65,7 @@ class EmbeddingManager:
|
|
|
os.rename(tmp_cache_key_file, self.cache_key_file)
|
|
|
print("[EmbeddingManager]cache dumped")
|
|
|
|
|
|
- def get_embedding(self, text_list):
|
|
|
+ def get_embeddings(self, text_list):
|
|
|
"""
|
|
|
Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
|
|
|
"""
|
|
@@ -85,7 +85,7 @@ class EmbeddingManager:
|
|
|
new_texts.append(text)
|
|
|
new_texts_ori_idx.append(idx)
|
|
|
|
|
|
- new_embeddings = self.model.get_embedding(new_texts)
|
|
|
+ new_embeddings = self.model.get_embeddings(new_texts)
|
|
|
|
|
|
# Generate embedding if not found in cache
|
|
|
with self.lock: # Ensure thread-safe access
|
|
@@ -110,7 +110,7 @@ class DummyModel:
|
|
|
text += text[:1024 - len(text)]
|
|
|
return text
|
|
|
|
|
|
- def get_embedding(self, text_list):
|
|
|
+ def get_embeddings(self, text_list):
|
|
|
embeddings = np.zeros((len(text_list), 1024), np.float32)
|
|
|
for idx, text in enumerate(text_list):
|
|
|
text = self.padding_text(text)
|
|
@@ -121,9 +121,9 @@ class DummyModel:
|
|
|
if __name__ == "__main__":
|
|
|
model = DummyModel()
|
|
|
manager = EmbeddingManager(model)
|
|
|
- print(manager.get_embedding(["hello"]))
|
|
|
- print(manager.get_embedding(["world"]))
|
|
|
- print(manager.get_embedding(["hello world"]))
|
|
|
+ print(manager.get_embeddings(["hello"]))
|
|
|
+ print(manager.get_embeddings(["world"]))
|
|
|
+ print(manager.get_embeddings(["hello world"]))
|
|
|
manager.dump_cache()
|
|
|
- print(manager.get_embedding(["new", "word"]))
|
|
|
+ print(manager.get_embeddings(["new", "word"]))
|
|
|
manager.dump_cache()
|