textSimilarity.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """
  2. @author: luojunhui
  3. """
  4. import torch
  5. import numpy as np
  6. from similarities import BertSimilarity
  7. model = BertSimilarity(model_name_or_path="BAAI/bge-large-zh-v1.5")
  8. bge_large_zh_v1_5 = 'bge_large_zh_v1_5'
  9. text2vec_base_chinese = "text2vec_base_chinese"
  10. text2vec_bge_large_chinese = "text2vec_bge_large_chinese"
  11. def get_sim_score_by_pair(model, pair):
  12. try:
  13. score_tensor = model.similarity(pair['text_a'], pair['text_b'])
  14. return score_tensor.squeeze().tolist()
  15. except Exception as e:
  16. raise
  17. def get_sim_score_by_pair_list(model, pair_list):
  18. try:
  19. res = [get_sim_score_by_pair(model, pair) for pair in pair_list['text_pair_list']]
  20. return res
  21. except Exception as e:
  22. raise
  23. def get_sim_score_by_list_pair(model, list_pair):
  24. try:
  25. score_tensor = model.similarity(list_pair['text_list_a'], list_pair['text_list_b'])
  26. return score_tensor.tolist()
  27. except Exception as e:
  28. raise
  29. def get_sim_score_max(model, data):
  30. try:
  31. score_list_max = []
  32. text_list_max = []
  33. score_array = get_sim_score_by_list_pair(model, data)
  34. text_list_a, text_list_b = data['text_list_a'], data['text_list_b']
  35. for i, row in enumerate(score_array):
  36. max_index = np.argmax(row)
  37. max_value = row[max_index]
  38. score_list_max.append(max_value)
  39. text_list_max.append(text_list_b[max_index])
  40. return score_list_max, text_list_max, score_array
  41. except Exception as e:
  42. logger.error(f"Error in get_sim_score_max: {e}")
  43. raise
  44. def score_to_attention(score, symbol=1):
  45. try:
  46. score_pred = torch.FloatTensor(score).unsqueeze(0)
  47. score_norm = symbol * torch.nn.functional.normalize(score_pred, p=2, dim=1)
  48. score_attn = torch.nn.functional.softmax(score_norm, dim=1)
  49. return score_attn, score_norm, score_pred
  50. except Exception as e:
  51. logger.error(f"Error in score_to_attention: {e}")
  52. raise
  53. def get_sim_score_avg(model, data):
  54. try:
  55. text_list_a, text_list_b = data['text_list_a'], data['text_list_b']
  56. score_list_b, symbol = data['score_list_b'], data['symbol']
  57. score_list_max, text_list_max, score_array = get_sim_score_max(model, data)
  58. score_attn, score_norm, score_pred = score_to_attention(score_list_b, symbol=symbol)
  59. score_tensor = torch.tensor(score_array)
  60. score_res = torch.matmul(score_tensor, score_attn.transpose(0, 1))
  61. score_list = score_res.squeeze(-1).tolist()
  62. return score_list, text_list_max, score_array
  63. except Exception as e:
  64. logger.error(f"Error in get_sim_score_avg: {e}")
  65. raise
  66. def get_sim_score_mean(model, data):
  67. try:
  68. text_list_a, text_list_b = data['text_list_a'], data['text_list_b']
  69. score_list_max, text_list_max, score_array = get_sim_score_max(model, data)
  70. score_tensor = torch.tensor(score_array)
  71. score_res = torch.mean(score_tensor, dim=1)
  72. score_list = score_res.tolist()
  73. return score_list, text_list_max, score_array
  74. except Exception as e:
  75. raise