|
@@ -9,6 +9,7 @@ import hashlib
|
|
|
import json
|
|
import json
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
|
|
|
+import threading
|
|
|
|
|
|
|
|
# 支持的模型列表
|
|
# 支持的模型列表
|
|
|
SUPPORTED_MODELS = {
|
|
SUPPORTED_MODELS = {
|
|
@@ -20,6 +21,7 @@ SUPPORTED_MODELS = {
|
|
|
|
|
|
|
|
# 延迟导入 similarities,避免初始化时就加载模型
|
|
# 延迟导入 similarities,避免初始化时就加载模型
|
|
|
_similarity_models = {} # 存储多个模型实例
|
|
_similarity_models = {} # 存储多个模型实例
|
|
|
|
|
+_model_lock = threading.Lock() # 线程锁,保护模型加载
|
|
|
|
|
|
|
|
# 默认缓存目录
|
|
# 默认缓存目录
|
|
|
DEFAULT_CACHE_DIR = "cache/text_embedding"
|
|
DEFAULT_CACHE_DIR = "cache/text_embedding"
|
|
@@ -193,7 +195,7 @@ def _save_to_cache(
|
|
|
|
|
|
|
|
def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"):
|
|
def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"):
|
|
|
"""
|
|
"""
|
|
|
- 获取或初始化相似度模型(支持多个模型)
|
|
|
|
|
|
|
+ 获取或初始化相似度模型(支持多个模型,线程安全)
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
model_name: 模型名称
|
|
model_name: 模型名称
|
|
@@ -201,27 +203,33 @@ def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"):
|
|
|
Returns:
|
|
Returns:
|
|
|
BertSimilarity 模型实例
|
|
BertSimilarity 模型实例
|
|
|
"""
|
|
"""
|
|
|
- global _similarity_models
|
|
|
|
|
|
|
+ global _similarity_models, _model_lock
|
|
|
|
|
|
|
|
# 如果是简称,转换为完整名称
|
|
# 如果是简称,转换为完整名称
|
|
|
if model_name in SUPPORTED_MODELS:
|
|
if model_name in SUPPORTED_MODELS:
|
|
|
model_name = SUPPORTED_MODELS[model_name]
|
|
model_name = SUPPORTED_MODELS[model_name]
|
|
|
|
|
|
|
|
- # 如果模型已加载,直接返回
|
|
|
|
|
|
|
+ # 快速路径:如果模型已加载,直接返回(无锁检查)
|
|
|
if model_name in _similarity_models:
|
|
if model_name in _similarity_models:
|
|
|
return _similarity_models[model_name]
|
|
return _similarity_models[model_name]
|
|
|
|
|
|
|
|
- # 加载新模型
|
|
|
|
|
- try:
|
|
|
|
|
- from similarities import BertSimilarity
|
|
|
|
|
- print(f"正在加载模型: {model_name}...")
|
|
|
|
|
- _similarity_models[model_name] = BertSimilarity(model_name_or_path=model_name)
|
|
|
|
|
- print("模型加载完成!")
|
|
|
|
|
- return _similarity_models[model_name]
|
|
|
|
|
- except ImportError:
|
|
|
|
|
- raise ImportError(
|
|
|
|
|
- "请先安装 similarities 库: pip install -U similarities torch"
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # 慢速路径:需要加载模型(使用锁保护)
|
|
|
|
|
+ with _model_lock:
|
|
|
|
|
+ # 双重检查:可能在等待锁时其他线程已经加载了
|
|
|
|
|
+ if model_name in _similarity_models:
|
|
|
|
|
+ return _similarity_models[model_name]
|
|
|
|
|
+
|
|
|
|
|
+ # 加载新模型
|
|
|
|
|
+ try:
|
|
|
|
|
+ from similarities import BertSimilarity
|
|
|
|
|
+ print(f"正在加载模型: {model_name}...")
|
|
|
|
|
+ _similarity_models[model_name] = BertSimilarity(model_name_or_path=model_name)
|
|
|
|
|
+ print("模型加载完成!")
|
|
|
|
|
+ return _similarity_models[model_name]
|
|
|
|
|
+ except ImportError:
|
|
|
|
|
+ raise ImportError(
|
|
|
|
|
+ "请先安装 similarities 库: pip install -U similarities torch"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
|
|
|
def compare_phrases(
|
|
def compare_phrases(
|