#!/usr/bin/env python3 """ 文本相似度计算模块 基于 similarities 库(真正的向量模型,不使用 LLM) """ from typing import Dict, Any, Optional import hashlib import json from pathlib import Path from datetime import datetime import threading # 支持的模型列表 SUPPORTED_MODELS = { "chinese": "shibing624/text2vec-base-chinese", # 默认,中文通用 "multilingual": "shibing624/text2vec-base-multilingual", # 多语言(中英韩日德意等) "paraphrase": "shibing624/text2vec-base-chinese-paraphrase", # 中文长文本 "sentence": "shibing624/text2vec-base-chinese-sentence", # 中文短句子 } # 延迟导入 similarities,避免初始化时就加载模型 _similarity_models = {} # 存储多个模型实例 _model_lock = threading.Lock() # 线程锁,保护模型加载 # 默认缓存目录 DEFAULT_CACHE_DIR = "cache/text_embedding" def _generate_cache_key(phrase_a: str, phrase_b: str, model_name: str) -> str: """ 生成缓存键(哈希值) Args: phrase_a: 第一个短语 phrase_b: 第二个短语 model_name: 模型名称 Returns: 32位MD5哈希值 """ cache_string = f"{phrase_a}||{phrase_b}||{model_name}" return hashlib.md5(cache_string.encode('utf-8')).hexdigest() def _sanitize_for_filename(text: str, max_length: int = 30) -> str: """ 将文本转换为安全的文件名部分 Args: text: 原始文本 max_length: 最大长度 Returns: 安全的文件名字符串 """ import re # 移除特殊字符,只保留中文、英文、数字、下划线 sanitized = re.sub(r'[^\w\u4e00-\u9fff]', '_', text) # 移除连续的下划线 sanitized = re.sub(r'_+', '_', sanitized) # 截断到最大长度 if len(sanitized) > max_length: sanitized = sanitized[:max_length] return sanitized.strip('_') def _get_cache_filepath( cache_key: str, phrase_a: str, phrase_b: str, model_name: str, cache_dir: str = DEFAULT_CACHE_DIR ) -> Path: """ 获取缓存文件路径(可读文件名) Args: cache_key: 缓存键(哈希值) phrase_a: 第一个短语 phrase_b: 第二个短语 model_name: 模型名称 cache_dir: 缓存目录 Returns: 缓存文件的完整路径 文件名格式: {phrase_a}_vs_{phrase_b}_{model}_{hash[:8]}.json """ # 清理短语和模型名 clean_a = _sanitize_for_filename(phrase_a, max_length=20) clean_b = _sanitize_for_filename(phrase_b, max_length=20) # 简化模型名(提取关键部分) model_short = model_name.split('/')[-1] model_short = _sanitize_for_filename(model_short, max_length=20) # 使用哈希的前8位 hash_short = cache_key[:8] # 组合文件名 filename = f"{clean_a}_vs_{clean_b}_{model_short}_{hash_short}.json" return Path(cache_dir) / filename def _load_from_cache( cache_key: str, phrase_a: str, phrase_b: str, model_name: str, cache_dir: str = DEFAULT_CACHE_DIR ) -> Optional[Dict[str, Any]]: """ 从缓存加载数据 Args: cache_key: 缓存键 phrase_a: 第一个短语 phrase_b: 第二个短语 model_name: 模型名称 cache_dir: 缓存目录 Returns: 缓存的结果字典,如果不存在则返回 None """ cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, cache_dir) # 如果文件不存在,尝试通过哈希匹配查找 if not cache_file.exists(): cache_path = Path(cache_dir) if cache_path.exists(): hash_short = cache_key[:8] matching_files = list(cache_path.glob(f"*_{hash_short}.json")) if matching_files: cache_file = matching_files[0] else: return None else: return None try: with open(cache_file, 'r', encoding='utf-8') as f: cached_data = json.load(f) return cached_data['output'] except (json.JSONDecodeError, IOError, KeyError): return None def _save_to_cache( cache_key: str, phrase_a: str, phrase_b: str, model_name: str, result: Dict[str, Any], cache_dir: str = DEFAULT_CACHE_DIR ) -> None: """ 保存数据到缓存 Args: cache_key: 缓存键 phrase_a: 第一个短语 phrase_b: 第二个短语 model_name: 模型名称 result: 结果数据(字典格式) cache_dir: 缓存目录 """ cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, cache_dir) # 确保缓存目录存在 cache_file.parent.mkdir(parents=True, exist_ok=True) # 准备缓存数据 cache_data = { "input": { "phrase_a": phrase_a, "phrase_b": phrase_b, "model_name": model_name, }, "output": result, "metadata": { "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "cache_key": cache_key, "cache_file": str(cache_file.name) } } try: with open(cache_file, 'w', encoding='utf-8') as f: json.dump(cache_data, f, ensure_ascii=False, indent=2) except IOError: pass # 静默失败,不影响主流程 def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"): """ 获取或初始化相似度模型(支持多个模型,线程安全) Args: model_name: 模型名称 Returns: BertSimilarity 模型实例 """ global _similarity_models, _model_lock # 如果是简称,转换为完整名称 if model_name in SUPPORTED_MODELS: model_name = SUPPORTED_MODELS[model_name] # 快速路径:如果模型已加载,直接返回(无锁检查) if model_name in _similarity_models: return _similarity_models[model_name] # 慢速路径:需要加载模型(使用锁保护) with _model_lock: # 双重检查:可能在等待锁时其他线程已经加载了 if model_name in _similarity_models: return _similarity_models[model_name] # 加载新模型 try: from similarities import BertSimilarity print(f"正在加载模型: {model_name}...") _similarity_models[model_name] = BertSimilarity(model_name_or_path=model_name) print("模型加载完成!") return _similarity_models[model_name] except ImportError: raise ImportError( "请先安装 similarities 库: pip install -U similarities torch" ) def compare_phrases( phrase_a: str, phrase_b: str, model_name: str = "chinese", use_cache: bool = True, cache_dir: str = DEFAULT_CACHE_DIR ) -> Dict[str, Any]: """ 比较两个短语的语义相似度(兼容 semantic_similarity.py 的接口) 返回格式与 semantic_similarity.compare_phrases() 一致: { "说明": "基于向量模型计算的语义相似度", "相似度": 0.85 } Args: phrase_a: 第一个短语 phrase_b: 第二个短语 model_name: 模型名称,可选: 简称: - "chinese" (默认) - 中文通用模型 - "multilingual" - 多语言模型(中英韩日德意等) - "paraphrase" - 中文长文本模型 - "sentence" - 中文短句子模型 完整名称: - "shibing624/text2vec-base-chinese" - "shibing624/text2vec-base-multilingual" - "shibing624/text2vec-base-chinese-paraphrase" - "shibing624/text2vec-base-chinese-sentence" use_cache: 是否使用缓存,默认 True cache_dir: 缓存目录,默认 'cache/text_embedding' Returns: { "说明": str, # 相似度说明 "相似度": float # 0-1之间的相似度分数 } Examples: >>> # 使用默认模型 >>> result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡") >>> print(result['相似度']) # 0.855 >>> # 使用多语言模型 >>> result = compare_phrases("Hello", "Hi", model_name="multilingual") >>> # 使用长文本模型 >>> result = compare_phrases("长文本1...", "长文本2...", model_name="paraphrase") >>> # 禁用缓存 >>> result = compare_phrases("测试", "测试", use_cache=False) """ # 转换简称为完整名称(用于缓存键) full_model_name = SUPPORTED_MODELS.get(model_name, model_name) # 生成缓存键 cache_key = _generate_cache_key(phrase_a, phrase_b, full_model_name) # 尝试从缓存加载 if use_cache: cached_result = _load_from_cache(cache_key, phrase_a, phrase_b, full_model_name, cache_dir) if cached_result is not None: return cached_result # 缓存未命中,计算相似度 model = _get_similarity_model(model_name) score = float(model.similarity(phrase_a, phrase_b)) # 生成说明 if score >= 0.9: level = "极高" elif score >= 0.7: level = "高" elif score >= 0.5: level = "中等" elif score >= 0.3: level = "较低" else: level = "低" explanation = f"基于向量模型计算的语义相似度为 {level} ({score:.2f})" result = { "说明": explanation, "相似度": score } # 保存到缓存 if use_cache: _save_to_cache(cache_key, phrase_a, phrase_b, full_model_name, result, cache_dir) return result if __name__ == "__main__": print("=" * 60) print("text_embedding - 文本相似度计算(带缓存)") print("=" * 60) print() # 示例 1: 默认模型(首次调用,会保存缓存) print("示例 1: 默认模型(chinese)- 首次调用") result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡") print(f"相似度: {result['相似度']:.3f}") print(f"说明: {result['说明']}") print() # 示例 2: 再次调用相同参数(从缓存读取) print("示例 2: 测试缓存 - 再次调用相同参数") result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡") print(f"相似度: {result['相似度']:.3f}") print(f"说明: {result['说明']}") print("(应该从缓存读取,速度更快)") print() # 示例 3: 短句子 print("示例 3: 使用默认模型") result = compare_phrases("深度学习", "神经网络") print(f"相似度: {result['相似度']:.3f}") print(f"说明: {result['说明']}") print() # 示例 4: 不相关 print("示例 4: 不相关的短语") result = compare_phrases("编程", "吃饭") print(f"相似度: {result['相似度']:.3f}") print(f"说明: {result['说明']}") print() # 示例 5: 多语言模型 print("示例 5: 多语言模型(multilingual)") result = compare_phrases("Hello", "Hi", model_name="multilingual") print(f"相似度: {result['相似度']:.3f}") print(f"说明: {result['说明']}") print() # 示例 6: 禁用缓存 print("示例 6: 禁用缓存") result = compare_phrases("测试", "测试", use_cache=False) print(f"相似度: {result['相似度']:.3f}") print(f"说明: {result['说明']}") print() print("=" * 60) print("支持的模型:") print("-" * 60) for key, value in SUPPORTED_MODELS.items(): print(f" {key:15s} -> {value}") print("=" * 60) print() print("缓存目录: cache/text_embedding/") print("缓存文件格式: {phrase_a}_vs_{phrase_b}_{model}_{hash[:8]}.json") print("=" * 60)