| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389 |
- #!/usr/bin/env python3
- """
- 文本相似度计算模块
- 基于 similarities 库(真正的向量模型,不使用 LLM)
- """
- from typing import Dict, Any, Optional
- import hashlib
- import json
- from pathlib import Path
- from datetime import datetime
- import threading
- # 支持的模型列表
- SUPPORTED_MODELS = {
- "chinese": "shibing624/text2vec-base-chinese", # 默认,中文通用
- "multilingual": "shibing624/text2vec-base-multilingual", # 多语言(中英韩日德意等)
- "paraphrase": "shibing624/text2vec-base-chinese-paraphrase", # 中文长文本
- "sentence": "shibing624/text2vec-base-chinese-sentence", # 中文短句子
- }
- # 延迟导入 similarities,避免初始化时就加载模型
- _similarity_models = {} # 存储多个模型实例
- _model_lock = threading.Lock() # 线程锁,保护模型加载
- # 默认缓存目录
- DEFAULT_CACHE_DIR = "cache/text_embedding"
- def _generate_cache_key(phrase_a: str, phrase_b: str, model_name: str) -> str:
- """
- 生成缓存键(哈希值)
- Args:
- phrase_a: 第一个短语
- phrase_b: 第二个短语
- model_name: 模型名称
- Returns:
- 32位MD5哈希值
- """
- cache_string = f"{phrase_a}||{phrase_b}||{model_name}"
- return hashlib.md5(cache_string.encode('utf-8')).hexdigest()
- def _sanitize_for_filename(text: str, max_length: int = 30) -> str:
- """
- 将文本转换为安全的文件名部分
- Args:
- text: 原始文本
- max_length: 最大长度
- Returns:
- 安全的文件名字符串
- """
- import re
- # 移除特殊字符,只保留中文、英文、数字、下划线
- sanitized = re.sub(r'[^\w\u4e00-\u9fff]', '_', text)
- # 移除连续的下划线
- sanitized = re.sub(r'_+', '_', sanitized)
- # 截断到最大长度
- if len(sanitized) > max_length:
- sanitized = sanitized[:max_length]
- return sanitized.strip('_')
- def _get_cache_filepath(
- cache_key: str,
- phrase_a: str,
- phrase_b: str,
- model_name: str,
- cache_dir: str = DEFAULT_CACHE_DIR
- ) -> Path:
- """
- 获取缓存文件路径(可读文件名)
- Args:
- cache_key: 缓存键(哈希值)
- phrase_a: 第一个短语
- phrase_b: 第二个短语
- model_name: 模型名称
- cache_dir: 缓存目录
- Returns:
- 缓存文件的完整路径
- 文件名格式: {phrase_a}_vs_{phrase_b}_{model}_{hash[:8]}.json
- """
- # 清理短语和模型名
- clean_a = _sanitize_for_filename(phrase_a, max_length=20)
- clean_b = _sanitize_for_filename(phrase_b, max_length=20)
- # 简化模型名(提取关键部分)
- model_short = model_name.split('/')[-1]
- model_short = _sanitize_for_filename(model_short, max_length=20)
- # 使用哈希的前8位
- hash_short = cache_key[:8]
- # 组合文件名
- filename = f"{clean_a}_vs_{clean_b}_{model_short}_{hash_short}.json"
- return Path(cache_dir) / filename
- def _load_from_cache(
- cache_key: str,
- phrase_a: str,
- phrase_b: str,
- model_name: str,
- cache_dir: str = DEFAULT_CACHE_DIR
- ) -> Optional[Dict[str, Any]]:
- """
- 从缓存加载数据
- Args:
- cache_key: 缓存键
- phrase_a: 第一个短语
- phrase_b: 第二个短语
- model_name: 模型名称
- cache_dir: 缓存目录
- Returns:
- 缓存的结果字典,如果不存在则返回 None
- """
- cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, cache_dir)
- # 如果文件不存在,尝试通过哈希匹配查找
- if not cache_file.exists():
- cache_path = Path(cache_dir)
- if cache_path.exists():
- hash_short = cache_key[:8]
- matching_files = list(cache_path.glob(f"*_{hash_short}.json"))
- if matching_files:
- cache_file = matching_files[0]
- else:
- return None
- else:
- return None
- try:
- with open(cache_file, 'r', encoding='utf-8') as f:
- cached_data = json.load(f)
- return cached_data['output']
- except (json.JSONDecodeError, IOError, KeyError):
- return None
- def _save_to_cache(
- cache_key: str,
- phrase_a: str,
- phrase_b: str,
- model_name: str,
- result: Dict[str, Any],
- cache_dir: str = DEFAULT_CACHE_DIR
- ) -> None:
- """
- 保存数据到缓存
- Args:
- cache_key: 缓存键
- phrase_a: 第一个短语
- phrase_b: 第二个短语
- model_name: 模型名称
- result: 结果数据(字典格式)
- cache_dir: 缓存目录
- """
- cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, cache_dir)
- # 确保缓存目录存在
- cache_file.parent.mkdir(parents=True, exist_ok=True)
- # 准备缓存数据
- cache_data = {
- "input": {
- "phrase_a": phrase_a,
- "phrase_b": phrase_b,
- "model_name": model_name,
- },
- "output": result,
- "metadata": {
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- "cache_key": cache_key,
- "cache_file": str(cache_file.name)
- }
- }
- try:
- with open(cache_file, 'w', encoding='utf-8') as f:
- json.dump(cache_data, f, ensure_ascii=False, indent=2)
- except IOError:
- pass # 静默失败,不影响主流程
- def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"):
- """
- 获取或初始化相似度模型(支持多个模型,线程安全)
- Args:
- model_name: 模型名称
- Returns:
- BertSimilarity 模型实例
- """
- global _similarity_models, _model_lock
- # 如果是简称,转换为完整名称
- if model_name in SUPPORTED_MODELS:
- model_name = SUPPORTED_MODELS[model_name]
- # 快速路径:如果模型已加载,直接返回(无锁检查)
- if model_name in _similarity_models:
- return _similarity_models[model_name]
- # 慢速路径:需要加载模型(使用锁保护)
- with _model_lock:
- # 双重检查:可能在等待锁时其他线程已经加载了
- if model_name in _similarity_models:
- return _similarity_models[model_name]
- # 加载新模型
- try:
- from similarities import BertSimilarity
- print(f"正在加载模型: {model_name}...")
- _similarity_models[model_name] = BertSimilarity(model_name_or_path=model_name)
- print("模型加载完成!")
- return _similarity_models[model_name]
- except ImportError:
- raise ImportError(
- "请先安装 similarities 库: pip install -U similarities torch"
- )
- def compare_phrases(
- phrase_a: str,
- phrase_b: str,
- model_name: str = "chinese",
- use_cache: bool = True,
- cache_dir: str = DEFAULT_CACHE_DIR
- ) -> Dict[str, Any]:
- """
- 比较两个短语的语义相似度(兼容 semantic_similarity.py 的接口)
- 返回格式与 semantic_similarity.compare_phrases() 一致:
- {
- "说明": "基于向量模型计算的语义相似度",
- "相似度": 0.85
- }
- Args:
- phrase_a: 第一个短语
- phrase_b: 第二个短语
- model_name: 模型名称,可选:
- 简称:
- - "chinese" (默认) - 中文通用模型
- - "multilingual" - 多语言模型(中英韩日德意等)
- - "paraphrase" - 中文长文本模型
- - "sentence" - 中文短句子模型
- 完整名称:
- - "shibing624/text2vec-base-chinese"
- - "shibing624/text2vec-base-multilingual"
- - "shibing624/text2vec-base-chinese-paraphrase"
- - "shibing624/text2vec-base-chinese-sentence"
- use_cache: 是否使用缓存,默认 True
- cache_dir: 缓存目录,默认 'cache/text_embedding'
- Returns:
- {
- "说明": str, # 相似度说明
- "相似度": float # 0-1之间的相似度分数
- }
- Examples:
- >>> # 使用默认模型
- >>> result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡")
- >>> print(result['相似度']) # 0.855
- >>> # 使用多语言模型
- >>> result = compare_phrases("Hello", "Hi", model_name="multilingual")
- >>> # 使用长文本模型
- >>> result = compare_phrases("长文本1...", "长文本2...", model_name="paraphrase")
- >>> # 禁用缓存
- >>> result = compare_phrases("测试", "测试", use_cache=False)
- """
- # 转换简称为完整名称(用于缓存键)
- full_model_name = SUPPORTED_MODELS.get(model_name, model_name)
- # 生成缓存键
- cache_key = _generate_cache_key(phrase_a, phrase_b, full_model_name)
- # 尝试从缓存加载
- if use_cache:
- cached_result = _load_from_cache(cache_key, phrase_a, phrase_b, full_model_name, cache_dir)
- if cached_result is not None:
- return cached_result
- # 缓存未命中,计算相似度
- model = _get_similarity_model(model_name)
- score = float(model.similarity(phrase_a, phrase_b))
- # 生成说明
- if score >= 0.9:
- level = "极高"
- elif score >= 0.7:
- level = "高"
- elif score >= 0.5:
- level = "中等"
- elif score >= 0.3:
- level = "较低"
- else:
- level = "低"
- explanation = f"基于向量模型计算的语义相似度为 {level} ({score:.2f})"
- result = {
- "说明": explanation,
- "相似度": score
- }
- # 保存到缓存
- if use_cache:
- _save_to_cache(cache_key, phrase_a, phrase_b, full_model_name, result, cache_dir)
- return result
- if __name__ == "__main__":
- print("=" * 60)
- print("text_embedding - 文本相似度计算(带缓存)")
- print("=" * 60)
- print()
- # 示例 1: 默认模型(首次调用,会保存缓存)
- print("示例 1: 默认模型(chinese)- 首次调用")
- result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡")
- print(f"相似度: {result['相似度']:.3f}")
- print(f"说明: {result['说明']}")
- print()
- # 示例 2: 再次调用相同参数(从缓存读取)
- print("示例 2: 测试缓存 - 再次调用相同参数")
- result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡")
- print(f"相似度: {result['相似度']:.3f}")
- print(f"说明: {result['说明']}")
- print("(应该从缓存读取,速度更快)")
- print()
- # 示例 3: 短句子
- print("示例 3: 使用默认模型")
- result = compare_phrases("深度学习", "神经网络")
- print(f"相似度: {result['相似度']:.3f}")
- print(f"说明: {result['说明']}")
- print()
- # 示例 4: 不相关
- print("示例 4: 不相关的短语")
- result = compare_phrases("编程", "吃饭")
- print(f"相似度: {result['相似度']:.3f}")
- print(f"说明: {result['说明']}")
- print()
- # 示例 5: 多语言模型
- print("示例 5: 多语言模型(multilingual)")
- result = compare_phrases("Hello", "Hi", model_name="multilingual")
- print(f"相似度: {result['相似度']:.3f}")
- print(f"说明: {result['说明']}")
- print()
- # 示例 6: 禁用缓存
- print("示例 6: 禁用缓存")
- result = compare_phrases("测试", "测试", use_cache=False)
- print(f"相似度: {result['相似度']:.3f}")
- print(f"说明: {result['说明']}")
- print()
- print("=" * 60)
- print("支持的模型:")
- print("-" * 60)
- for key, value in SUPPORTED_MODELS.items():
- print(f" {key:15s} -> {value}")
- print("=" * 60)
- print()
- print("缓存目录: cache/text_embedding/")
- print("缓存文件格式: {phrase_a}_vs_{phrase_b}_{model}_{hash[:8]}.json")
- print("=" * 60)
|