|
@@ -1,15 +1,21 @@
|
|
|
"""
|
|
|
@author: luojunhui
|
|
|
"""
|
|
|
-import time
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
-from similarities import BertSimilarity
|
|
|
|
|
|
|
|
|
-# bge_large_zh_v1_5 = 'bge_large_zh_v1_5'
|
|
|
-# text2vec_base_chinese = "text2vec_base_chinese"
|
|
|
-# text2vec_bge_large_chinese = "text2vec_bge_large_chinese"
|
|
|
+def score_to_attention(score, symbol=1):
|
|
|
+ """
|
|
|
+
|
|
|
+ :param score:
|
|
|
+ :param symbol:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ score_pred = torch.FloatTensor(score).unsqueeze(0)
|
|
|
+ score_norm = symbol * torch.nn.functional.normalize(score_pred, p=2)
|
|
|
+ score_attn = torch.nn.functional.softmax(score_norm, dim=1)
|
|
|
+ return score_attn, score_norm, score_pred
|
|
|
|
|
|
|
|
|
class NLPFunction(object):
|
|
@@ -35,7 +41,7 @@ class NLPFunction(object):
|
|
|
def base_list_similarity(self, pair_list_dict):
|
|
|
"""
|
|
|
计算两个list的相似度
|
|
|
- :return: "score_list_b": [100, 1000, 500, 40],
|
|
|
+ :return:
|
|
|
"""
|
|
|
score_tensor = self.model.similarity(
|
|
|
pair_list_dict['text_list_a'],
|
|
@@ -43,24 +49,44 @@ class NLPFunction(object):
|
|
|
)
|
|
|
return score_tensor.tolist()
|
|
|
|
|
|
+ def max_cross_similarity(self, data):
|
|
|
+ """
|
|
|
+ max
|
|
|
+ :param data:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ score_list_max = []
|
|
|
+ text_list_max = []
|
|
|
+ score_array = self.base_list_similarity(data)
|
|
|
+ text_list_a, text_list_b = data['text_list_a'], data['text_list_b']
|
|
|
+ for i, row in enumerate(score_array):
|
|
|
+ max_index = np.argmax(row)
|
|
|
+ max_value = row[max_index]
|
|
|
+ score_list_max.append(max_value)
|
|
|
+ text_list_max.append(text_list_b[max_index])
|
|
|
+ return score_list_max, text_list_max, score_array
|
|
|
+
|
|
|
+ def mean_cross_similarity(self, data):
|
|
|
+ """
|
|
|
+ :param data:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ score_list_max, text_list_max, score_array = self.max_cross_similarity(data)
|
|
|
+ score_tensor = torch.tensor(score_array)
|
|
|
+ score_res = torch.mean(score_tensor, dim=1)
|
|
|
+ score_list = score_res.tolist()
|
|
|
+ return score_list, text_list_max, score_array
|
|
|
|
|
|
-if __name__ == '__main__':
|
|
|
- a = time.time()
|
|
|
- m = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5")
|
|
|
- b = time.time()
|
|
|
- print("模型加载时间:\t", b - a)
|
|
|
- NF = NLPFunction(m)
|
|
|
- td = {
|
|
|
- "text_a": "王者荣耀",
|
|
|
- "text_b": "斗罗大陆"
|
|
|
- }
|
|
|
- tld = {
|
|
|
- "text_list_a": ["凯旋", "圣洁", "篮球"],
|
|
|
- "text_list_b": ["胜利", "纯洁", "足球"]
|
|
|
- }
|
|
|
- # res = NF.base_string_similarity(text_dict=td)
|
|
|
- res = NF.base_list_similarity(pair_list_dict=tld)
|
|
|
- c = time.time()
|
|
|
- print("计算时间:\t", c - b)
|
|
|
- for i in res:
|
|
|
- print(i)
|
|
|
+ def avg_cross_similarity(self, data):
|
|
|
+ """
|
|
|
+ :param data:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ score_list_b = data['score_list_b']
|
|
|
+ symbol = data['symbol']
|
|
|
+ score_list_max, text_list_max, score_array = self.max_cross_similarity(data)
|
|
|
+ score_attn, score_norm, score_pred = score_to_attention(score_list_b, symbol=symbol)
|
|
|
+ score_tensor = torch.tensor(score_array)
|
|
|
+ score_res = torch.matmul(score_tensor, score_attn.transpose(0, 1))
|
|
|
+ score_list = score_res.squeeze(-1).tolist()
|
|
|
+ return score_list, text_list_max, score_array
|