textSimilarity.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. """
  2. @author: luojunhui
  3. """
  4. import torch
  5. import numpy as np
  6. def score_to_attention(score, symbol=1):
  7. """
  8. :param score:
  9. :param symbol:
  10. :return:
  11. """
  12. score_pred = torch.FloatTensor(score).unsqueeze(0)
  13. score_norm = symbol * torch.nn.functional.normalize(score_pred, p=2)
  14. score_attn = torch.nn.functional.softmax(score_norm, dim=1)
  15. return score_attn, score_norm, score_pred
  16. class NLPFunction(object):
  17. """
  18. NLP Task
  19. """
  20. def __init__(self, model):
  21. self.model = model
  22. def base_string_similarity(self, text_dict):
  23. """
  24. 基础功能,计算两个字符串的相似度
  25. :param text_dict:
  26. :return:
  27. """
  28. score_tensor = self.model.similarity(
  29. text_dict['text_a'],
  30. text_dict['text_b']
  31. )
  32. response = {
  33. "score": score_tensor.squeeze().tolist()
  34. }
  35. return response
  36. def base_list_similarity(self, pair_list_dict):
  37. """
  38. 计算两个list的相似度
  39. :return:
  40. """
  41. score_tensor = self.model.similarity(
  42. pair_list_dict['text_list_a'],
  43. pair_list_dict['text_list_b']
  44. )
  45. response = {
  46. "score_list_list": score_tensor.tolist()
  47. }
  48. return response
  49. def max_cross_similarity(self, data):
  50. """
  51. max
  52. :param data:
  53. :return:
  54. """
  55. score_list_max = []
  56. text_list_max = []
  57. score_array = self.base_list_similarity(data)['score_list_list']
  58. text_list_a, text_list_b = data['text_list_a'], data['text_list_b']
  59. for i, row in enumerate(score_array):
  60. max_index = np.argmax(row)
  61. max_value = row[max_index]
  62. score_list_max.append(max_value)
  63. text_list_max.append(text_list_b[max_index])
  64. response = {
  65. 'score_list_max': score_list_max,
  66. 'text_list_max': text_list_max,
  67. 'score_list_list': score_array,
  68. }
  69. return response
  70. def mean_cross_similarity(self, data):
  71. """
  72. :param data:
  73. :return:
  74. """
  75. resp = self.max_cross_similarity(data)
  76. score_list_max, text_list_max, score_array = resp['score_list_max'], resp['text_list_max'], resp['score_list_list']
  77. score_tensor = torch.tensor(score_array)
  78. score_res = torch.mean(score_tensor, dim=1)
  79. score_list = score_res.tolist()
  80. response = {
  81. 'score_list_mean': score_list,
  82. 'text_list_max': text_list_max,
  83. 'score_list_list': score_array,
  84. }
  85. return response
  86. def avg_cross_similarity(self, data):
  87. """
  88. :param data:
  89. :return:
  90. """
  91. score_list_b = data['score_list_b']
  92. symbol = data['symbol']
  93. # score_list_max, text_list_max, score_array = self.max_cross_similarity(data)
  94. resp = self.max_cross_similarity(data)
  95. score_list_max, text_list_max, score_array = resp['score_list_max'], resp['text_list_max'], resp[
  96. 'score_list_list']
  97. score_attn, score_norm, score_pred = score_to_attention(score_list_b, symbol=symbol)
  98. score_tensor = torch.tensor(score_array)
  99. score_res = torch.matmul(score_tensor, score_attn.transpose(0, 1))
  100. score_list = score_res.squeeze(-1).tolist()
  101. response = {
  102. 'score_list_avg': score_list,
  103. 'text_list_max': text_list_max,
  104. 'score_list_list': score_array,
  105. }
  106. return response