浏览代码

Update textSimilarity: enable embedding manager

StrayWarrior 5 月之前
父节点
当前提交
5ab5ab5593
共有 1 个文件被更改,包括 16 次插入22 次删除
  1. 16 22
      applications/textSimilarity.py

+ 16 - 22
applications/textSimilarity.py

@@ -33,23 +33,25 @@ class NLPFunction(object):
         self.model = model
         self.embedding_manager = embedding_manager
 
+    def direct_similarity(self, a, b):
+        return self.model.similarity(a, b)
+
+    def cached_similarity(self, a, b):
+        text_emb1 = self.embedding_manager.get_embeddings(a)
+        text_emb2 = self.embedding_manager.get_embeddings(b)
+        score_function = self.model.score_functions['cos_sim']
+        score_tensor = score_function(text_emb1, text_emb2)
+        return score_tensor
+
     def base_string_similarity(self, text_dict):
         """
         基础功能,计算两个字符串的相似度
         :param text_dict:
         :return:
         """
-        score_tensor = self.model.similarity(
-            text_dict['text_a'],
-            text_dict['text_b']
-        )
-        # test embedding manager functions
-        text_emb1 = self.embedding_manager.get_embeddings(text_dict['text_a'])
-        text_emb2 = self.embedding_manager.get_embeddings(text_dict['text_b'])
-        score_function = self.model.score_functions['cos_sim']
-        score_tensor_new = score_function(text_emb1, text_emb2)
-        compare_tensor(score_tensor, score_tensor_new)
-
+        text_a = text_dict['text_a']
+        text_b = text_dict['text_b']
+        score_tensor = self.cached_similarity(text_a, text_b)
         response = {
             "score": score_tensor.squeeze().tolist()
         }
@@ -60,17 +62,9 @@ class NLPFunction(object):
         计算两个list的相似度
         :return:
         """
-        score_tensor = self.model.similarity(
-            pair_list_dict['text_list_a'],
-            pair_list_dict['text_list_b']
-        )
-        # test embedding manager functions
-        text_emb1 = self.embedding_manager.get_embeddings(pair_list_dict['text_list_a'])
-        text_emb2 = self.embedding_manager.get_embeddings(pair_list_dict['text_list_b'])
-        score_function = self.model.score_functions['cos_sim']
-        score_tensor_new = score_function(text_emb1, text_emb2)
-        compare_tensor(score_tensor, score_tensor_new)
-
+        text_a = pair_list_dict['text_list_a']
+        text_b = pair_list_dict['text_list_b']
+        score_tensor = self.cached_similarity(text_a, text_b)
         response = {
             "score_list_list": score_tensor.tolist()
         }