Explorar el Código

feat: 添加模型加载线程安全机制并优化预加载策略

- 为text_embedding模块添加线程锁,确保多线程环境下模型加载的安全性
- 使用双重检查锁定模式优化性能
- 更新test_all_models脚本,预加载所有模型避免并发加载冲突

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
yangxiaohui hace 1 semana
padre
commit
ffad7c210a
Se han modificado 2 ficheros con 28 adiciones y 19 borrados
  1. 22 14
      lib/text_embedding.py
  2. 6 5
      script/analysis/test_all_models.py

+ 22 - 14
lib/text_embedding.py

@@ -9,6 +9,7 @@ import hashlib
 import json
 from pathlib import Path
 from datetime import datetime
+import threading
 
 # 支持的模型列表
 SUPPORTED_MODELS = {
@@ -20,6 +21,7 @@ SUPPORTED_MODELS = {
 
 # 延迟导入 similarities,避免初始化时就加载模型
 _similarity_models = {}  # 存储多个模型实例
+_model_lock = threading.Lock()  # 线程锁,保护模型加载
 
 # 默认缓存目录
 DEFAULT_CACHE_DIR = "cache/text_embedding"
@@ -193,7 +195,7 @@ def _save_to_cache(
 
 def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"):
     """
-    获取或初始化相似度模型(支持多个模型)
+    获取或初始化相似度模型(支持多个模型,线程安全
 
     Args:
         model_name: 模型名称
@@ -201,27 +203,33 @@ def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"):
     Returns:
         BertSimilarity 模型实例
     """
-    global _similarity_models
+    global _similarity_models, _model_lock
 
     # 如果是简称,转换为完整名称
     if model_name in SUPPORTED_MODELS:
         model_name = SUPPORTED_MODELS[model_name]
 
-    # 如果模型已加载,直接返回
+    # 快速路径:如果模型已加载,直接返回(无锁检查)
     if model_name in _similarity_models:
         return _similarity_models[model_name]
 
-    # 加载新模型
-    try:
-        from similarities import BertSimilarity
-        print(f"正在加载模型: {model_name}...")
-        _similarity_models[model_name] = BertSimilarity(model_name_or_path=model_name)
-        print("模型加载完成!")
-        return _similarity_models[model_name]
-    except ImportError:
-        raise ImportError(
-            "请先安装 similarities 库: pip install -U similarities torch"
-        )
+    # 慢速路径:需要加载模型(使用锁保护)
+    with _model_lock:
+        # 双重检查:可能在等待锁时其他线程已经加载了
+        if model_name in _similarity_models:
+            return _similarity_models[model_name]
+
+        # 加载新模型
+        try:
+            from similarities import BertSimilarity
+            print(f"正在加载模型: {model_name}...")
+            _similarity_models[model_name] = BertSimilarity(model_name_or_path=model_name)
+            print("模型加载完成!")
+            return _similarity_models[model_name]
+        except ImportError:
+            raise ImportError(
+                "请先安装 similarities 库: pip install -U similarities torch"
+            )
 
 
 def compare_phrases(

+ 6 - 5
script/analysis/test_all_models.py

@@ -233,11 +233,12 @@ async def test_all_models(
     print(f"\n开始测试 {len(models)} 个模型,共 {len(test_cases)} 个测试用例")
     print(f"总测试数: {total_tests:,}\n")
 
-    # 预加载第一个模型(避免多线程加载冲突)
-    print("预加载模型...")
-    first_model = list(models.keys())[0]
-    await asyncio.to_thread(compare_phrases, "测试", "测试", model_name=first_model)
-    print("预加载完成!\n")
+    # 预加载所有模型(避免多线程加载冲突)
+    print("预加载所有模型...")
+    for i, model_key in enumerate(models.keys(), 1):
+        print(f"  [{i}/{len(models)}] 加载模型: {model_key}")
+        await asyncio.to_thread(compare_phrases, "测试", "测试", model_name=model_key)
+    print("所有模型预加载完成!\n")
 
     # 初始化进度跟踪器
     progress_tracker = ProgressTracker(total_tests, "测试进度")