| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- #!/usr/bin/env python3
- """
- 文本相似度计算模块
- 基于 similarities 库(真正的向量模型,不使用 LLM)
- """
- from typing import Dict, Any
- # 支持的模型列表
- SUPPORTED_MODELS = {
- "chinese": "shibing624/text2vec-base-chinese", # 默认,中文通用
- "multilingual": "shibing624/text2vec-base-multilingual", # 多语言(中英韩日德意等)
- "paraphrase": "shibing624/text2vec-base-chinese-paraphrase", # 中文长文本
- "sentence": "shibing624/text2vec-base-chinese-sentence", # 中文短句子
- }
- # 延迟导入 similarities,避免初始化时就加载模型
- _similarity_models = {} # 存储多个模型实例
- def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"):
- """
- 获取或初始化相似度模型(支持多个模型)
- Args:
- model_name: 模型名称
- Returns:
- BertSimilarity 模型实例
- """
- global _similarity_models
- # 如果是简称,转换为完整名称
- if model_name in SUPPORTED_MODELS:
- model_name = SUPPORTED_MODELS[model_name]
- # 如果模型已加载,直接返回
- if model_name in _similarity_models:
- return _similarity_models[model_name]
- # 加载新模型
- try:
- from similarities import BertSimilarity
- print(f"正在加载模型: {model_name}...")
- _similarity_models[model_name] = BertSimilarity(model_name_or_path=model_name)
- print("模型加载完成!")
- return _similarity_models[model_name]
- except ImportError:
- raise ImportError(
- "请先安装 similarities 库: pip install -U similarities torch"
- )
- def compare_phrases(
- phrase_a: str,
- phrase_b: str,
- model_name: str = "chinese"
- ) -> Dict[str, Any]:
- """
- 比较两个短语的语义相似度(兼容 semantic_similarity.py 的接口)
- 返回格式与 semantic_similarity.compare_phrases() 一致:
- {
- "说明": "基于向量模型计算的语义相似度",
- "相似度": 0.85
- }
- Args:
- phrase_a: 第一个短语
- phrase_b: 第二个短语
- model_name: 模型名称,可选:
- 简称:
- - "chinese" (默认) - 中文通用模型
- - "multilingual" - 多语言模型(中英韩日德意等)
- - "paraphrase" - 中文长文本模型
- - "sentence" - 中文短句子模型
- 完整名称:
- - "shibing624/text2vec-base-chinese"
- - "shibing624/text2vec-base-multilingual"
- - "shibing624/text2vec-base-chinese-paraphrase"
- - "shibing624/text2vec-base-chinese-sentence"
- Returns:
- {
- "说明": str, # 相似度说明
- "相似度": float # 0-1之间的相似度分数
- }
- Examples:
- >>> # 使用默认模型
- >>> result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡")
- >>> print(result['相似度']) # 0.855
- >>> # 使用多语言模型
- >>> result = compare_phrases("Hello", "Hi", model_name="multilingual")
- >>> # 使用长文本模型
- >>> result = compare_phrases("长文本1...", "长文本2...", model_name="paraphrase")
- """
- model = _get_similarity_model(model_name)
- score = float(model.similarity(phrase_a, phrase_b))
- # 生成说明
- if score >= 0.9:
- level = "极高"
- elif score >= 0.7:
- level = "高"
- elif score >= 0.5:
- level = "中等"
- elif score >= 0.3:
- level = "较低"
- else:
- level = "低"
- explanation = f"基于向量模型计算的语义相似度为 {level} ({score:.2f})"
- return {
- "说明": explanation,
- "相似度": score
- }
- if __name__ == "__main__":
- print("=" * 60)
- print("text_embedding - 文本相似度计算")
- print("=" * 60)
- print()
- # 示例 1: 默认模型
- print("示例 1: 默认模型(chinese)")
- result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡")
- print(f"相似度: {result['相似度']:.3f}")
- print(f"说明: {result['说明']}")
- print()
- # 示例 2: 短句子
- print("示例 2: 使用默认模型")
- result = compare_phrases("深度学习", "神经网络")
- print(f"相似度: {result['相似度']:.3f}")
- print(f"说明: {result['说明']}")
- print()
- # 示例 3: 不相关
- print("示例 3: 不相关的短语")
- result = compare_phrases("编程", "吃饭")
- print(f"相似度: {result['相似度']:.3f}")
- print(f"说明: {result['说明']}")
- print()
- # 示例 4: 多语言模型
- print("示例 4: 多语言模型(multilingual)")
- result = compare_phrases("Hello", "Hi", model_name="multilingual")
- print(f"相似度: {result['相似度']:.3f}")
- print(f"说明: {result['说明']}")
- print()
- print("=" * 60)
- print("支持的模型:")
- print("-" * 60)
- for key, value in SUPPORTED_MODELS.items():
- print(f" {key:15s} -> {value}")
- print("=" * 60)
|