瀏覽代碼

Update embedding_manager: rename fit to model API convention

StrayWarrior 5 月之前
父節點
當前提交
fa0d4b20f9
共有 2 個文件被更改,包括 11 次插入11 次删除
  1. 7 7
      applications/embedding_manager.py
  2. 4 4
      applications/textSimilarity.py

+ 7 - 7
applications/embedding_manager.py

@@ -65,7 +65,7 @@ class EmbeddingManager:
                 os.rename(tmp_cache_key_file, self.cache_key_file)
         print("[EmbeddingManager]cache dumped")
 
-    def get_embedding(self, text_list):
+    def get_embeddings(self, text_list):
         """
         Search embedding for a given text. If not found, generate using the model, save to cache, and return it.
         """
@@ -85,7 +85,7 @@ class EmbeddingManager:
                     new_texts.append(text)
                     new_texts_ori_idx.append(idx)
 
-        new_embeddings = self.model.get_embedding(new_texts)
+        new_embeddings = self.model.get_embeddings(new_texts)
 
         # Generate embedding if not found in cache
         with self.lock:  # Ensure thread-safe access
@@ -110,7 +110,7 @@ class DummyModel:
         text += text[:1024 - len(text)]
         return text
 
-    def get_embedding(self, text_list):
+    def get_embeddings(self, text_list):
         embeddings = np.zeros((len(text_list), 1024), np.float32)
         for idx, text in enumerate(text_list):
             text = self.padding_text(text)
@@ -121,9 +121,9 @@ class DummyModel:
 if __name__ == "__main__":
     model = DummyModel()
     manager = EmbeddingManager(model)
-    print(manager.get_embedding(["hello"]))
-    print(manager.get_embedding(["world"]))
-    print(manager.get_embedding(["hello world"]))
+    print(manager.get_embeddings(["hello"]))
+    print(manager.get_embeddings(["world"]))
+    print(manager.get_embeddings(["hello world"]))
     manager.dump_cache()
-    print(manager.get_embedding(["new", "word"]))
+    print(manager.get_embeddings(["new", "word"]))
     manager.dump_cache()

+ 4 - 4
applications/textSimilarity.py

@@ -44,8 +44,8 @@ class NLPFunction(object):
             text_dict['text_b']
         )
         # test embedding manager functions
-        text_emb1 = self.embedding_manager.get_embedding(text_dict['text_a'])
-        text_emb2 = self.embedding_manager.get_embedding(text_dict['text_b'])
+        text_emb1 = self.embedding_manager.get_embeddings(text_dict['text_a'])
+        text_emb2 = self.embedding_manager.get_embeddings(text_dict['text_b'])
         score_function = self.model.score_functions['cos_sim']
         score_tensor_new = score_function(text_emb1, text_emb2)
         compare_tensor(score_tensor, score_tensor_new)
@@ -65,8 +65,8 @@ class NLPFunction(object):
             pair_list_dict['text_list_b']
         )
         # test embedding manager functions
-        text_emb1 = self.embedding_manager.get_embedding(pair_list_dict['text_list_a'])
-        text_emb2 = self.embedding_manager.get_embedding(pair_list_dict['text_list_a'])
+        text_emb1 = self.embedding_manager.get_embeddings(pair_list_dict['text_list_a'])
+        text_emb2 = self.embedding_manager.get_embeddings(pair_list_dict['text_list_b'])
         score_function = self.model.score_functions['cos_sim']
         score_tensor_new = score_function(text_emb1, text_emb2)
         compare_tensor(score_tensor, score_tensor_new)