text_embedding.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. #!/usr/bin/env python3
  2. """
  3. 文本相似度计算模块
  4. 基于 similarities 库(真正的向量模型,不使用 LLM)
  5. """
  6. from typing import Dict, Any
  7. # 支持的模型列表
  8. SUPPORTED_MODELS = {
  9. "chinese": "shibing624/text2vec-base-chinese", # 默认,中文通用
  10. "multilingual": "shibing624/text2vec-base-multilingual", # 多语言(中英韩日德意等)
  11. "paraphrase": "shibing624/text2vec-base-chinese-paraphrase", # 中文长文本
  12. "sentence": "shibing624/text2vec-base-chinese-sentence", # 中文短句子
  13. }
  14. # 延迟导入 similarities,避免初始化时就加载模型
  15. _similarity_models = {} # 存储多个模型实例
  16. def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"):
  17. """
  18. 获取或初始化相似度模型(支持多个模型)
  19. Args:
  20. model_name: 模型名称
  21. Returns:
  22. BertSimilarity 模型实例
  23. """
  24. global _similarity_models
  25. # 如果是简称,转换为完整名称
  26. if model_name in SUPPORTED_MODELS:
  27. model_name = SUPPORTED_MODELS[model_name]
  28. # 如果模型已加载,直接返回
  29. if model_name in _similarity_models:
  30. return _similarity_models[model_name]
  31. # 加载新模型
  32. try:
  33. from similarities import BertSimilarity
  34. print(f"正在加载模型: {model_name}...")
  35. _similarity_models[model_name] = BertSimilarity(model_name_or_path=model_name)
  36. print("模型加载完成!")
  37. return _similarity_models[model_name]
  38. except ImportError:
  39. raise ImportError(
  40. "请先安装 similarities 库: pip install -U similarities torch"
  41. )
  42. def compare_phrases(
  43. phrase_a: str,
  44. phrase_b: str,
  45. model_name: str = "chinese"
  46. ) -> Dict[str, Any]:
  47. """
  48. 比较两个短语的语义相似度(兼容 semantic_similarity.py 的接口)
  49. 返回格式与 semantic_similarity.compare_phrases() 一致:
  50. {
  51. "说明": "基于向量模型计算的语义相似度",
  52. "相似度": 0.85
  53. }
  54. Args:
  55. phrase_a: 第一个短语
  56. phrase_b: 第二个短语
  57. model_name: 模型名称,可选:
  58. 简称:
  59. - "chinese" (默认) - 中文通用模型
  60. - "multilingual" - 多语言模型(中英韩日德意等)
  61. - "paraphrase" - 中文长文本模型
  62. - "sentence" - 中文短句子模型
  63. 完整名称:
  64. - "shibing624/text2vec-base-chinese"
  65. - "shibing624/text2vec-base-multilingual"
  66. - "shibing624/text2vec-base-chinese-paraphrase"
  67. - "shibing624/text2vec-base-chinese-sentence"
  68. Returns:
  69. {
  70. "说明": str, # 相似度说明
  71. "相似度": float # 0-1之间的相似度分数
  72. }
  73. Examples:
  74. >>> # 使用默认模型
  75. >>> result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡")
  76. >>> print(result['相似度']) # 0.855
  77. >>> # 使用多语言模型
  78. >>> result = compare_phrases("Hello", "Hi", model_name="multilingual")
  79. >>> # 使用长文本模型
  80. >>> result = compare_phrases("长文本1...", "长文本2...", model_name="paraphrase")
  81. """
  82. model = _get_similarity_model(model_name)
  83. score = float(model.similarity(phrase_a, phrase_b))
  84. # 生成说明
  85. if score >= 0.9:
  86. level = "极高"
  87. elif score >= 0.7:
  88. level = "高"
  89. elif score >= 0.5:
  90. level = "中等"
  91. elif score >= 0.3:
  92. level = "较低"
  93. else:
  94. level = "低"
  95. explanation = f"基于向量模型计算的语义相似度为 {level} ({score:.2f})"
  96. return {
  97. "说明": explanation,
  98. "相似度": score
  99. }
  100. if __name__ == "__main__":
  101. print("=" * 60)
  102. print("text_embedding - 文本相似度计算")
  103. print("=" * 60)
  104. print()
  105. # 示例 1: 默认模型
  106. print("示例 1: 默认模型(chinese)")
  107. result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡")
  108. print(f"相似度: {result['相似度']:.3f}")
  109. print(f"说明: {result['说明']}")
  110. print()
  111. # 示例 2: 短句子
  112. print("示例 2: 使用默认模型")
  113. result = compare_phrases("深度学习", "神经网络")
  114. print(f"相似度: {result['相似度']:.3f}")
  115. print(f"说明: {result['说明']}")
  116. print()
  117. # 示例 3: 不相关
  118. print("示例 3: 不相关的短语")
  119. result = compare_phrases("编程", "吃饭")
  120. print(f"相似度: {result['相似度']:.3f}")
  121. print(f"说明: {result['说明']}")
  122. print()
  123. # 示例 4: 多语言模型
  124. print("示例 4: 多语言模型(multilingual)")
  125. result = compare_phrases("Hello", "Hi", model_name="multilingual")
  126. print(f"相似度: {result['相似度']:.3f}")
  127. print(f"说明: {result['说明']}")
  128. print()
  129. print("=" * 60)
  130. print("支持的模型:")
  131. print("-" * 60)
  132. for key, value in SUPPORTED_MODELS.items():
  133. print(f" {key:15s} -> {value}")
  134. print("=" * 60)