textSimilarity.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. """
  2. @author: luojunhui
  3. """
  4. import torch
  5. import numpy as np
  6. def score_to_attention(score, symbol=1):
  7. """
  8. :param score:
  9. :param symbol:
  10. :return:
  11. """
  12. score_pred = torch.FloatTensor(score).unsqueeze(0)
  13. score_norm = symbol * torch.nn.functional.normalize(score_pred, p=2)
  14. score_attn = torch.nn.functional.softmax(score_norm, dim=1)
  15. return score_attn, score_norm, score_pred
  16. def compare_tensor(tensor1, tensor2):
  17. if tensor1.shape != tensor2.shape:
  18. print(f"[compare_tensor]shape error: {tensor1.shape} vs {tensor2.shape}")
  19. return
  20. if not torch.allclose(tensor1, tensor2):
  21. print("[compare_tensor]value error: tensor1 not close to tensor2")
  22. class NLPFunction(object):
  23. """
  24. NLP Task
  25. """
  26. def __init__(self, model, embedding_manager):
  27. self.model = model
  28. self.embedding_manager = embedding_manager
  29. def base_string_similarity(self, text_dict):
  30. """
  31. 基础功能,计算两个字符串的相似度
  32. :param text_dict:
  33. :return:
  34. """
  35. score_tensor = self.model.similarity(
  36. text_dict['text_a'],
  37. text_dict['text_b']
  38. )
  39. # test embedding manager functions
  40. text_emb1 = self.embedding_manager.get_embeddings(text_dict['text_a'])
  41. text_emb2 = self.embedding_manager.get_embeddings(text_dict['text_b'])
  42. score_function = self.model.score_functions['cos_sim']
  43. score_tensor_new = score_function(text_emb1, text_emb2)
  44. compare_tensor(score_tensor, score_tensor_new)
  45. response = {
  46. "score": score_tensor.squeeze().tolist()
  47. }
  48. return response
  49. def base_list_similarity(self, pair_list_dict):
  50. """
  51. 计算两个list的相似度
  52. :return:
  53. """
  54. score_tensor = self.model.similarity(
  55. pair_list_dict['text_list_a'],
  56. pair_list_dict['text_list_b']
  57. )
  58. # test embedding manager functions
  59. text_emb1 = self.embedding_manager.get_embeddings(pair_list_dict['text_list_a'])
  60. text_emb2 = self.embedding_manager.get_embeddings(pair_list_dict['text_list_b'])
  61. score_function = self.model.score_functions['cos_sim']
  62. score_tensor_new = score_function(text_emb1, text_emb2)
  63. compare_tensor(score_tensor, score_tensor_new)
  64. response = {
  65. "score_list_list": score_tensor.tolist()
  66. }
  67. return response
  68. def max_cross_similarity(self, data):
  69. """
  70. max
  71. :param data:
  72. :return:
  73. """
  74. score_list_max = []
  75. text_list_max = []
  76. score_array = self.base_list_similarity(data)['score_list_list']
  77. text_list_a, text_list_b = data['text_list_a'], data['text_list_b']
  78. for i, row in enumerate(score_array):
  79. max_index = np.argmax(row)
  80. max_value = row[max_index]
  81. score_list_max.append(max_value)
  82. text_list_max.append(text_list_b[max_index])
  83. response = {
  84. 'score_list_max': score_list_max,
  85. 'text_list_max': text_list_max,
  86. 'score_list_list': score_array,
  87. }
  88. return response
  89. def mean_cross_similarity(self, data):
  90. """
  91. :param data:
  92. :return:
  93. """
  94. resp = self.max_cross_similarity(data)
  95. score_list_max, text_list_max, score_array = resp['score_list_max'], resp['text_list_max'], resp['score_list_list']
  96. score_tensor = torch.tensor(score_array)
  97. score_res = torch.mean(score_tensor, dim=1)
  98. score_list = score_res.tolist()
  99. response = {
  100. 'score_list_mean': score_list,
  101. 'text_list_max': text_list_max,
  102. 'score_list_list': score_array,
  103. }
  104. return response
  105. def avg_cross_similarity(self, data):
  106. """
  107. :param data:
  108. :return:
  109. """
  110. score_list_b = data['score_list_b']
  111. symbol = data['symbol']
  112. # score_list_max, text_list_max, score_array = self.max_cross_similarity(data)
  113. resp = self.max_cross_similarity(data)
  114. score_list_max, text_list_max, score_array = resp['score_list_max'], resp['text_list_max'], resp[
  115. 'score_list_list']
  116. score_attn, score_norm, score_pred = score_to_attention(score_list_b, symbol=symbol)
  117. score_tensor = torch.tensor(score_array)
  118. score_res = torch.matmul(score_tensor, score_attn.transpose(0, 1))
  119. score_list = score_res.squeeze(-1).tolist()
  120. response = {
  121. 'score_list_avg': score_list,
  122. 'text_list_max': text_list_max,
  123. 'score_list_list': score_array,
  124. }
  125. return response