|
|
@@ -4,7 +4,11 @@
|
|
|
基于 similarities 库(真正的向量模型,不使用 LLM)
|
|
|
"""
|
|
|
|
|
|
-from typing import Dict, Any
|
|
|
+from typing import Dict, Any, Optional
|
|
|
+import hashlib
|
|
|
+import json
|
|
|
+from pathlib import Path
|
|
|
+from datetime import datetime
|
|
|
|
|
|
# 支持的模型列表
|
|
|
SUPPORTED_MODELS = {
|
|
|
@@ -17,6 +21,175 @@ SUPPORTED_MODELS = {
|
|
|
# 延迟导入 similarities,避免初始化时就加载模型
|
|
|
_similarity_models = {} # 存储多个模型实例
|
|
|
|
|
|
+# 默认缓存目录
|
|
|
+DEFAULT_CACHE_DIR = "cache/text_embedding"
|
|
|
+
|
|
|
+
|
|
|
+def _generate_cache_key(phrase_a: str, phrase_b: str, model_name: str) -> str:
|
|
|
+ """
|
|
|
+ 生成缓存键(哈希值)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ phrase_a: 第一个短语
|
|
|
+ phrase_b: 第二个短语
|
|
|
+ model_name: 模型名称
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 32位MD5哈希值
|
|
|
+ """
|
|
|
+ cache_string = f"{phrase_a}||{phrase_b}||{model_name}"
|
|
|
+ return hashlib.md5(cache_string.encode('utf-8')).hexdigest()
|
|
|
+
|
|
|
+
|
|
|
+def _sanitize_for_filename(text: str, max_length: int = 30) -> str:
|
|
|
+ """
|
|
|
+ 将文本转换为安全的文件名部分
|
|
|
+
|
|
|
+ Args:
|
|
|
+ text: 原始文本
|
|
|
+ max_length: 最大长度
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 安全的文件名字符串
|
|
|
+ """
|
|
|
+ import re
|
|
|
+ # 移除特殊字符,只保留中文、英文、数字、下划线
|
|
|
+ sanitized = re.sub(r'[^\w\u4e00-\u9fff]', '_', text)
|
|
|
+ # 移除连续的下划线
|
|
|
+ sanitized = re.sub(r'_+', '_', sanitized)
|
|
|
+ # 截断到最大长度
|
|
|
+ if len(sanitized) > max_length:
|
|
|
+ sanitized = sanitized[:max_length]
|
|
|
+ return sanitized.strip('_')
|
|
|
+
|
|
|
+
|
|
|
+def _get_cache_filepath(
|
|
|
+ cache_key: str,
|
|
|
+ phrase_a: str,
|
|
|
+ phrase_b: str,
|
|
|
+ model_name: str,
|
|
|
+ cache_dir: str = DEFAULT_CACHE_DIR
|
|
|
+) -> Path:
|
|
|
+ """
|
|
|
+ 获取缓存文件路径(可读文件名)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cache_key: 缓存键(哈希值)
|
|
|
+ phrase_a: 第一个短语
|
|
|
+ phrase_b: 第二个短语
|
|
|
+ model_name: 模型名称
|
|
|
+ cache_dir: 缓存目录
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 缓存文件的完整路径
|
|
|
+
|
|
|
+ 文件名格式: {phrase_a}_vs_{phrase_b}_{model}_{hash[:8]}.json
|
|
|
+ """
|
|
|
+ # 清理短语和模型名
|
|
|
+ clean_a = _sanitize_for_filename(phrase_a, max_length=20)
|
|
|
+ clean_b = _sanitize_for_filename(phrase_b, max_length=20)
|
|
|
+
|
|
|
+ # 简化模型名(提取关键部分)
|
|
|
+ model_short = model_name.split('/')[-1]
|
|
|
+ model_short = _sanitize_for_filename(model_short, max_length=20)
|
|
|
+
|
|
|
+ # 使用哈希的前8位
|
|
|
+ hash_short = cache_key[:8]
|
|
|
+
|
|
|
+ # 组合文件名
|
|
|
+ filename = f"{clean_a}_vs_{clean_b}_{model_short}_{hash_short}.json"
|
|
|
+
|
|
|
+ return Path(cache_dir) / filename
|
|
|
+
|
|
|
+
|
|
|
+def _load_from_cache(
|
|
|
+ cache_key: str,
|
|
|
+ phrase_a: str,
|
|
|
+ phrase_b: str,
|
|
|
+ model_name: str,
|
|
|
+ cache_dir: str = DEFAULT_CACHE_DIR
|
|
|
+) -> Optional[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 从缓存加载数据
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cache_key: 缓存键
|
|
|
+ phrase_a: 第一个短语
|
|
|
+ phrase_b: 第二个短语
|
|
|
+ model_name: 模型名称
|
|
|
+ cache_dir: 缓存目录
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 缓存的结果字典,如果不存在则返回 None
|
|
|
+ """
|
|
|
+ cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, cache_dir)
|
|
|
+
|
|
|
+ # 如果文件不存在,尝试通过哈希匹配查找
|
|
|
+ if not cache_file.exists():
|
|
|
+ cache_path = Path(cache_dir)
|
|
|
+ if cache_path.exists():
|
|
|
+ hash_short = cache_key[:8]
|
|
|
+ matching_files = list(cache_path.glob(f"*_{hash_short}.json"))
|
|
|
+ if matching_files:
|
|
|
+ cache_file = matching_files[0]
|
|
|
+ else:
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ return None
|
|
|
+
|
|
|
+ try:
|
|
|
+ with open(cache_file, 'r', encoding='utf-8') as f:
|
|
|
+ cached_data = json.load(f)
|
|
|
+ return cached_data['output']
|
|
|
+ except (json.JSONDecodeError, IOError, KeyError):
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def _save_to_cache(
|
|
|
+ cache_key: str,
|
|
|
+ phrase_a: str,
|
|
|
+ phrase_b: str,
|
|
|
+ model_name: str,
|
|
|
+ result: Dict[str, Any],
|
|
|
+ cache_dir: str = DEFAULT_CACHE_DIR
|
|
|
+) -> None:
|
|
|
+ """
|
|
|
+ 保存数据到缓存
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cache_key: 缓存键
|
|
|
+ phrase_a: 第一个短语
|
|
|
+ phrase_b: 第二个短语
|
|
|
+ model_name: 模型名称
|
|
|
+ result: 结果数据(字典格式)
|
|
|
+ cache_dir: 缓存目录
|
|
|
+ """
|
|
|
+ cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, cache_dir)
|
|
|
+
|
|
|
+ # 确保缓存目录存在
|
|
|
+ cache_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ # 准备缓存数据
|
|
|
+ cache_data = {
|
|
|
+ "input": {
|
|
|
+ "phrase_a": phrase_a,
|
|
|
+ "phrase_b": phrase_b,
|
|
|
+ "model_name": model_name,
|
|
|
+ },
|
|
|
+ "output": result,
|
|
|
+ "metadata": {
|
|
|
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
+ "cache_key": cache_key,
|
|
|
+ "cache_file": str(cache_file.name)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ try:
|
|
|
+ with open(cache_file, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
|
|
+ except IOError:
|
|
|
+ pass # 静默失败,不影响主流程
|
|
|
+
|
|
|
|
|
|
def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"):
|
|
|
"""
|
|
|
@@ -54,7 +227,9 @@ def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"):
|
|
|
def compare_phrases(
|
|
|
phrase_a: str,
|
|
|
phrase_b: str,
|
|
|
- model_name: str = "chinese"
|
|
|
+ model_name: str = "chinese",
|
|
|
+ use_cache: bool = True,
|
|
|
+ cache_dir: str = DEFAULT_CACHE_DIR
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
比较两个短语的语义相似度(兼容 semantic_similarity.py 的接口)
|
|
|
@@ -80,6 +255,8 @@ def compare_phrases(
|
|
|
- "shibing624/text2vec-base-multilingual"
|
|
|
- "shibing624/text2vec-base-chinese-paraphrase"
|
|
|
- "shibing624/text2vec-base-chinese-sentence"
|
|
|
+ use_cache: 是否使用缓存,默认 True
|
|
|
+ cache_dir: 缓存目录,默认 'cache/text_embedding'
|
|
|
|
|
|
Returns:
|
|
|
{
|
|
|
@@ -97,7 +274,23 @@ def compare_phrases(
|
|
|
|
|
|
>>> # 使用长文本模型
|
|
|
>>> result = compare_phrases("长文本1...", "长文本2...", model_name="paraphrase")
|
|
|
+
|
|
|
+ >>> # 禁用缓存
|
|
|
+ >>> result = compare_phrases("测试", "测试", use_cache=False)
|
|
|
"""
|
|
|
+ # 转换简称为完整名称(用于缓存键)
|
|
|
+ full_model_name = SUPPORTED_MODELS.get(model_name, model_name)
|
|
|
+
|
|
|
+ # 生成缓存键
|
|
|
+ cache_key = _generate_cache_key(phrase_a, phrase_b, full_model_name)
|
|
|
+
|
|
|
+ # 尝试从缓存加载
|
|
|
+ if use_cache:
|
|
|
+ cached_result = _load_from_cache(cache_key, phrase_a, phrase_b, full_model_name, cache_dir)
|
|
|
+ if cached_result is not None:
|
|
|
+ return cached_result
|
|
|
+
|
|
|
+ # 缓存未命中,计算相似度
|
|
|
model = _get_similarity_model(model_name)
|
|
|
score = float(model.similarity(phrase_a, phrase_b))
|
|
|
|
|
|
@@ -115,49 +308,74 @@ def compare_phrases(
|
|
|
|
|
|
explanation = f"基于向量模型计算的语义相似度为 {level} ({score:.2f})"
|
|
|
|
|
|
- return {
|
|
|
+ result = {
|
|
|
"说明": explanation,
|
|
|
"相似度": score
|
|
|
}
|
|
|
|
|
|
+ # 保存到缓存
|
|
|
+ if use_cache:
|
|
|
+ _save_to_cache(cache_key, phrase_a, phrase_b, full_model_name, result, cache_dir)
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print("=" * 60)
|
|
|
- print("text_embedding - 文本相似度计算")
|
|
|
+ print("text_embedding - 文本相似度计算(带缓存)")
|
|
|
print("=" * 60)
|
|
|
print()
|
|
|
|
|
|
- # 示例 1: 默认模型
|
|
|
- print("示例 1: 默认模型(chinese)")
|
|
|
+ # 示例 1: 默认模型(首次调用,会保存缓存)
|
|
|
+ print("示例 1: 默认模型(chinese)- 首次调用")
|
|
|
result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡")
|
|
|
print(f"相似度: {result['相似度']:.3f}")
|
|
|
print(f"说明: {result['说明']}")
|
|
|
print()
|
|
|
|
|
|
- # 示例 2: 短句子
|
|
|
- print("示例 2: 使用默认模型")
|
|
|
+ # 示例 2: 再次调用相同参数(从缓存读取)
|
|
|
+ print("示例 2: 测试缓存 - 再次调用相同参数")
|
|
|
+ result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡")
|
|
|
+ print(f"相似度: {result['相似度']:.3f}")
|
|
|
+ print(f"说明: {result['说明']}")
|
|
|
+ print("(应该从缓存读取,速度更快)")
|
|
|
+ print()
|
|
|
+
|
|
|
+ # 示例 3: 短句子
|
|
|
+ print("示例 3: 使用默认模型")
|
|
|
result = compare_phrases("深度学习", "神经网络")
|
|
|
print(f"相似度: {result['相似度']:.3f}")
|
|
|
print(f"说明: {result['说明']}")
|
|
|
print()
|
|
|
|
|
|
- # 示例 3: 不相关
|
|
|
- print("示例 3: 不相关的短语")
|
|
|
+ # 示例 4: 不相关
|
|
|
+ print("示例 4: 不相关的短语")
|
|
|
result = compare_phrases("编程", "吃饭")
|
|
|
print(f"相似度: {result['相似度']:.3f}")
|
|
|
print(f"说明: {result['说明']}")
|
|
|
print()
|
|
|
|
|
|
- # 示例 4: 多语言模型
|
|
|
- print("示例 4: 多语言模型(multilingual)")
|
|
|
+ # 示例 5: 多语言模型
|
|
|
+ print("示例 5: 多语言模型(multilingual)")
|
|
|
result = compare_phrases("Hello", "Hi", model_name="multilingual")
|
|
|
print(f"相似度: {result['相似度']:.3f}")
|
|
|
print(f"说明: {result['说明']}")
|
|
|
print()
|
|
|
|
|
|
+ # 示例 6: 禁用缓存
|
|
|
+ print("示例 6: 禁用缓存")
|
|
|
+ result = compare_phrases("测试", "测试", use_cache=False)
|
|
|
+ print(f"相似度: {result['相似度']:.3f}")
|
|
|
+ print(f"说明: {result['说明']}")
|
|
|
+ print()
|
|
|
+
|
|
|
print("=" * 60)
|
|
|
print("支持的模型:")
|
|
|
print("-" * 60)
|
|
|
for key, value in SUPPORTED_MODELS.items():
|
|
|
print(f" {key:15s} -> {value}")
|
|
|
print("=" * 60)
|
|
|
+ print()
|
|
|
+ print("缓存目录: cache/text_embedding/")
|
|
|
+ print("缓存文件格式: {phrase_a}_vs_{phrase_b}_{model}_{hash[:8]}.json")
|
|
|
+ print("=" * 60)
|