| 
														
															@@ -1,92 +1,66 @@ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 """ 
														 | 
														
														 | 
														
															 """ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 @author: luojunhui 
														 | 
														
														 | 
														
															 @author: luojunhui 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 """ 
														 | 
														
														 | 
														
															 """ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import time 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 import torch 
														 | 
														
														 | 
														
															 import torch 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 import numpy as np 
														 | 
														
														 | 
														
															 import numpy as np 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from similarities import BertSimilarity 
														 | 
														
														 | 
														
															 from similarities import BertSimilarity 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-model = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# bge_large_zh_v1_5 = 'bge_large_zh_v1_5' 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# text2vec_base_chinese = "text2vec_base_chinese" 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+# text2vec_bge_large_chinese = "text2vec_bge_large_chinese" 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-bge_large_zh_v1_5 = 'bge_large_zh_v1_5' 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-text2vec_base_chinese = "text2vec_base_chinese" 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-text2vec_bge_large_chinese = "text2vec_bge_large_chinese" 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+class NLPFunction(object): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    """ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    NLP Task 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    """ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-def get_sim_score_by_pair(model, pair): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    try: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_tensor = model.similarity(pair['text_a'], pair['text_b']) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        return score_tensor.squeeze().tolist() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    except Exception as e: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        raise 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-def get_sim_score_by_pair_list(model, pair_list): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    try: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        res = [get_sim_score_by_pair(model, pair) for pair in pair_list['text_pair_list']] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        return res 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    except Exception as e: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        raise 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    def __init__(self, model): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        self.model = model 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    def base_string_similarity(self, text_dict): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        """ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        基础功能,计算两个字符串的相似度 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        :param text_dict: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        :return: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        """ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        score_tensor = self.model.similarity( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            text_dict['text_a'], 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            text_dict['text_b'] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        ) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        return score_tensor.squeeze().tolist() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-def get_sim_score_by_list_pair(model, list_pair): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    try: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_tensor = model.similarity(list_pair['text_list_a'], list_pair['text_list_b']) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    def base_list_similarity(self, pair_list_dict): 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        """ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        计算两个list的相似度 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        :return: "score_list_b": [100, 1000, 500, 40], 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        """ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        score_tensor = self.model.similarity( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            pair_list_dict['text_list_a'], 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            pair_list_dict['text_list_b'] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        ) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         return score_tensor.tolist() 
														 | 
														
														 | 
														
															         return score_tensor.tolist() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    except Exception as e: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        raise 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-def get_sim_score_max(model, data): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    try: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_list_max = [] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        text_list_max = [] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_array = get_sim_score_by_list_pair(model, data) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        text_list_a, text_list_b = data['text_list_a'], data['text_list_b'] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        for i, row in enumerate(score_array): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            max_index = np.argmax(row) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            max_value = row[max_index] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            score_list_max.append(max_value) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            text_list_max.append(text_list_b[max_index]) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        return score_list_max, text_list_max, score_array 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    except Exception as e: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        logger.error(f"Error in get_sim_score_max: {e}") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        raise 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-def score_to_attention(score, symbol=1): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    try: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_pred = torch.FloatTensor(score).unsqueeze(0) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_norm = symbol * torch.nn.functional.normalize(score_pred, p=2, dim=1) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_attn = torch.nn.functional.softmax(score_norm, dim=1) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        return score_attn, score_norm, score_pred 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    except Exception as e: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        logger.error(f"Error in score_to_attention: {e}") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        raise 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-def get_sim_score_avg(model, data): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    try: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        text_list_a, text_list_b = data['text_list_a'], data['text_list_b'] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_list_b, symbol = data['score_list_b'], data['symbol'] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_list_max, text_list_max, score_array = get_sim_score_max(model, data) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_attn, score_norm, score_pred = score_to_attention(score_list_b, symbol=symbol) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_tensor = torch.tensor(score_array) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_res = torch.matmul(score_tensor, score_attn.transpose(0, 1)) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_list = score_res.squeeze(-1).tolist() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        return score_list, text_list_max, score_array 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    except Exception as e: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        logger.error(f"Error in get_sim_score_avg: {e}") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        raise 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-def get_sim_score_mean(model, data): 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    try: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        text_list_a, text_list_b = data['text_list_a'], data['text_list_b'] 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_list_max, text_list_max, score_array = get_sim_score_max(model, data) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_tensor = torch.tensor(score_array) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_res = torch.mean(score_tensor, dim=1) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        score_list = score_res.tolist() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        return score_list, text_list_max, score_array 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    except Exception as e: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        raise 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+if __name__ == '__main__': 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    a = time.time() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    m = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    b = time.time() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    print("模型加载时间:\t", b - a) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    NF = NLPFunction(m) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    td = { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        "text_a": "王者荣耀", 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        "text_b": "斗罗大陆" 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    tld = { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        "text_list_a": ["凯旋", "圣洁", "篮球"], 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        "text_list_b": ["胜利", "纯洁", "足球"] 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    # res = NF.base_string_similarity(text_dict=td) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    res = NF.base_list_similarity(pair_list_dict=tld) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    c = time.time() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    print("计算时间:\t", c - b) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+    for i in res: 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        print(i) 
														 |