#!/usr/bin/env python3 """ 文本相似度计算模块 基于 similarities 库(真正的向量模型,不使用 LLM) """ from typing import Dict, Any, Optional import hashlib import json from pathlib import Path from datetime import datetime import threading from .config import get_cache_dir # 支持的模型列表 SUPPORTED_MODELS = { "chinese": "shibing624/text2vec-base-chinese", # 默认,中文通用 "multilingual": "shibing624/text2vec-base-multilingual", # 多语言(中英韩日德意等) "paraphrase": "shibing624/text2vec-base-chinese-paraphrase", # 中文长文本 "sentence": "shibing624/text2vec-base-chinese-sentence", # 中文短句子 } # 延迟导入 similarities,避免初始化时就加载模型 _similarity_models = {} # 存储多个模型实例 _model_lock = threading.Lock() # 线程锁,保护模型加载 def _get_default_cache_dir() -> str: """获取默认缓存目录(从配置中读取)""" return get_cache_dir("text_embedding") def _generate_cache_key(phrase_a: str, phrase_b: str, model_name: str) -> str: """ 生成缓存键(哈希值) Args: phrase_a: 第一个短语 phrase_b: 第二个短语 model_name: 模型名称 Returns: 32位MD5哈希值 """ cache_string = f"{phrase_a}||{phrase_b}||{model_name}" return hashlib.md5(cache_string.encode('utf-8')).hexdigest() def _sanitize_for_filename(text: str, max_length: int = 30) -> str: """ 将文本转换为安全的文件名部分 Args: text: 原始文本 max_length: 最大长度 Returns: 安全的文件名字符串 """ import re # 移除特殊字符,只保留中文、英文、数字、下划线 sanitized = re.sub(r'[^\w\u4e00-\u9fff]', '_', text) # 移除连续的下划线 sanitized = re.sub(r'_+', '_', sanitized) # 截断到最大长度 if len(sanitized) > max_length: sanitized = sanitized[:max_length] return sanitized.strip('_') def _get_cache_filepath( cache_key: str, phrase_a: str, phrase_b: str, model_name: str, cache_dir: Optional[str] = None ) -> Path: """ 获取缓存文件路径(可读文件名) Args: cache_key: 缓存键(哈希值) phrase_a: 第一个短语 phrase_b: 第二个短语 model_name: 模型名称 cache_dir: 缓存目录 Returns: 缓存文件的完整路径 文件名格式: {phrase_a}_vs_{phrase_b}_{model}_{hash[:8]}.json """ if cache_dir is None: cache_dir = _get_default_cache_dir() # 清理短语和模型名 clean_a = _sanitize_for_filename(phrase_a, max_length=20) clean_b = _sanitize_for_filename(phrase_b, max_length=20) # 简化模型名(提取关键部分) model_short = model_name.split('/')[-1] model_short = _sanitize_for_filename(model_short, max_length=20) # 使用哈希的前8位 hash_short = cache_key[:8] # 组合文件名 filename = f"{clean_a}_vs_{clean_b}_{model_short}_{hash_short}.json" return Path(cache_dir) / filename def _load_from_cache( cache_key: str, phrase_a: str, phrase_b: str, model_name: str, cache_dir: Optional[str] = None ) -> Optional[Dict[str, Any]]: """ 从缓存加载数据 Args: cache_key: 缓存键 phrase_a: 第一个短语 phrase_b: 第二个短语 model_name: 模型名称 cache_dir: 缓存目录 Returns: 缓存的结果字典,如果不存在则返回 None """ if cache_dir is None: cache_dir = _get_default_cache_dir() cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, cache_dir) # 如果文件不存在,尝试通过哈希匹配查找 if not cache_file.exists(): cache_path = Path(cache_dir) if cache_path.exists(): hash_short = cache_key[:8] matching_files = list(cache_path.glob(f"*_{hash_short}.json")) if matching_files: cache_file = matching_files[0] else: return None else: return None try: with open(cache_file, 'r', encoding='utf-8') as f: cached_data = json.load(f) return cached_data['output'] except (json.JSONDecodeError, IOError, KeyError): return None def _save_to_cache( cache_key: str, phrase_a: str, phrase_b: str, model_name: str, result: Dict[str, Any], cache_dir: Optional[str] = None ) -> None: """ 保存数据到缓存 Args: cache_key: 缓存键 phrase_a: 第一个短语 phrase_b: 第二个短语 model_name: 模型名称 result: 结果数据(字典格式) cache_dir: 缓存目录 """ if cache_dir is None: cache_dir = _get_default_cache_dir() cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, cache_dir) # 确保缓存目录存在 cache_file.parent.mkdir(parents=True, exist_ok=True) # 准备缓存数据 cache_data = { "input": { "phrase_a": phrase_a, "phrase_b": phrase_b, "model_name": model_name, }, "output": result, "metadata": { "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "cache_key": cache_key, "cache_file": str(cache_file.name) } } try: with open(cache_file, 'w', encoding='utf-8') as f: json.dump(cache_data, f, ensure_ascii=False, indent=2) except IOError: pass # 静默失败,不影响主流程 def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"): """ 获取或初始化相似度模型(支持多个模型,线程安全) Args: model_name: 模型名称 Returns: BertSimilarity 模型实例 """ global _similarity_models, _model_lock # 如果是简称,转换为完整名称 if model_name in SUPPORTED_MODELS: model_name = SUPPORTED_MODELS[model_name] # 快速路径:如果模型已加载,直接返回(无锁检查) if model_name in _similarity_models: return _similarity_models[model_name] # 慢速路径:需要加载模型(使用锁保护) with _model_lock: # 双重检查:可能在等待锁时其他线程已经加载了 if model_name in _similarity_models: return _similarity_models[model_name] # 加载新模型 try: from similarities import BertSimilarity print(f"正在加载模型: {model_name}...") _similarity_models[model_name] = BertSimilarity(model_name_or_path=model_name) print("模型加载完成!") return _similarity_models[model_name] except ImportError: raise ImportError( "请先安装 similarities 库: pip install -U similarities torch" ) def compare_phrases( phrase_a: str, phrase_b: str, model_name: str = "chinese", use_cache: bool = True, cache_dir: Optional[str] = None ) -> Dict[str, Any]: """ 比较两个短语的语义相似度(兼容 semantic_similarity.py 的接口) 返回格式与 semantic_similarity.compare_phrases() 一致: { "说明": "基于向量模型计算的语义相似度", "相似度": 0.85 } Args: phrase_a: 第一个短语 phrase_b: 第二个短语 model_name: 模型名称,可选: 简称: - "chinese" (默认) - 中文通用模型 - "multilingual" - 多语言模型(中英韩日德意等) - "paraphrase" - 中文长文本模型 - "sentence" - 中文短句子模型 完整名称: - "shibing624/text2vec-base-chinese" - "shibing624/text2vec-base-multilingual" - "shibing624/text2vec-base-chinese-paraphrase" - "shibing624/text2vec-base-chinese-sentence" use_cache: 是否使用缓存,默认 True cache_dir: 缓存目录,默认从配置读取(可通过 lib.config 设置) Returns: { "说明": str, # 相似度说明 "相似度": float # 0-1之间的相似度分数 } Examples: >>> # 使用默认模型 >>> result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡") >>> print(result['相似度']) # 0.855 >>> # 使用多语言模型 >>> result = compare_phrases("Hello", "Hi", model_name="multilingual") >>> # 使用长文本模型 >>> result = compare_phrases("长文本1...", "长文本2...", model_name="paraphrase") >>> # 禁用缓存 >>> result = compare_phrases("测试", "测试", use_cache=False) >>> # 自定义缓存目录 >>> result = compare_phrases("测试1", "测试2", cache_dir="/tmp/my_cache") """ if cache_dir is None: cache_dir = _get_default_cache_dir() # 转换简称为完整名称(用于缓存键) full_model_name = SUPPORTED_MODELS.get(model_name, model_name) # 生成缓存键 cache_key = _generate_cache_key(phrase_a, phrase_b, full_model_name) # 尝试从缓存加载 if use_cache: cached_result = _load_from_cache(cache_key, phrase_a, phrase_b, full_model_name, cache_dir) if cached_result is not None: return cached_result # 缓存未命中,计算相似度 model = _get_similarity_model(model_name) score = float(model.similarity(phrase_a, phrase_b)) # 生成说明 if score >= 0.9: level = "极高" elif score >= 0.7: level = "高" elif score >= 0.5: level = "中等" elif score >= 0.3: level = "较低" else: level = "低" explanation = f"基于向量模型计算的语义相似度为 {level} ({score:.2f})" result = { "说明": explanation, "相似度": score } # 保存到缓存 if use_cache: _save_to_cache(cache_key, phrase_a, phrase_b, full_model_name, result, cache_dir) return result if __name__ == "__main__": print("=" * 60) print("text_embedding - 文本相似度计算(带缓存)") print("=" * 60) print() # 示例 1: 默认模型(首次调用,会保存缓存) print("示例 1: 默认模型(chinese)- 首次调用") result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡") print(f"相似度: {result['相似度']:.3f}") print(f"说明: {result['说明']}") print() # 示例 2: 再次调用相同参数(从缓存读取) print("示例 2: 测试缓存 - 再次调用相同参数") result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡") print(f"相似度: {result['相似度']:.3f}") print(f"说明: {result['说明']}") print("(应该从缓存读取,速度更快)") print() # 示例 3: 短句子 print("示例 3: 使用默认模型") result = compare_phrases("深度学习", "神经网络") print(f"相似度: {result['相似度']:.3f}") print(f"说明: {result['说明']}") print() # 示例 4: 不相关 print("示例 4: 不相关的短语") result = compare_phrases("编程", "吃饭") print(f"相似度: {result['相似度']:.3f}") print(f"说明: {result['说明']}") print() # 示例 5: 多语言模型 print("示例 5: 多语言模型(multilingual)") result = compare_phrases("Hello", "Hi", model_name="multilingual") print(f"相似度: {result['相似度']:.3f}") print(f"说明: {result['说明']}") print() # 示例 6: 禁用缓存 print("示例 6: 禁用缓存") result = compare_phrases("测试", "测试", use_cache=False) print(f"相似度: {result['相似度']:.3f}") print(f"说明: {result['说明']}") print() print("=" * 60) print("支持的模型:") print("-" * 60) for key, value in SUPPORTED_MODELS.items(): print(f" {key:15s} -> {value}") print("=" * 60) print() print("缓存目录: cache/text_embedding/") print("缓存文件格式: {phrase_a}_vs_{phrase_b}_{model}_{hash[:8]}.json") print("=" * 60)