""" @author: luojunhui """ import torch import numpy as np def score_to_attention(score, symbol=1): """ :param score: :param symbol: :return: """ score_pred = torch.FloatTensor(score).unsqueeze(0) score_norm = symbol * torch.nn.functional.normalize(score_pred, p=2) score_attn = torch.nn.functional.softmax(score_norm, dim=1) return score_attn, score_norm, score_pred class NLPFunction(object): """ NLP Task """ def __init__(self, model): self.model = model def base_string_similarity(self, text_dict): """ 基础功能,计算两个字符串的相似度 :param text_dict: :return: """ score_tensor = self.model.similarity( text_dict['text_a'], text_dict['text_b'] ) return score_tensor.squeeze().tolist() def base_list_similarity(self, pair_list_dict): """ 计算两个list的相似度 :return: """ score_tensor = self.model.similarity( pair_list_dict['text_list_a'], pair_list_dict['text_list_b'] ) return score_tensor.tolist() def max_cross_similarity(self, data): """ max :param data: :return: """ score_list_max = [] text_list_max = [] score_array = self.base_list_similarity(data) text_list_a, text_list_b = data['text_list_a'], data['text_list_b'] for i, row in enumerate(score_array): max_index = np.argmax(row) max_value = row[max_index] score_list_max.append(max_value) text_list_max.append(text_list_b[max_index]) return score_list_max, text_list_max, score_array def mean_cross_similarity(self, data): """ :param data: :return: """ score_list_max, text_list_max, score_array = self.max_cross_similarity(data) score_tensor = torch.tensor(score_array) score_res = torch.mean(score_tensor, dim=1) score_list = score_res.tolist() return score_list, text_list_max, score_array def avg_cross_similarity(self, data): """ :param data: :return: """ score_list_b = data['score_list_b'] symbol = data['symbol'] score_list_max, text_list_max, score_array = self.max_cross_similarity(data) score_attn, score_norm, score_pred = score_to_attention(score_list_b, symbol=symbol) score_tensor = torch.tensor(score_array) score_res = torch.matmul(score_tensor, score_attn.transpose(0, 1)) score_list = score_res.squeeze(-1).tolist() return score_list, text_list_max, score_array