#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 测试所有文本相似度模型 从现有缓存(cache/text_embedding)中提取所有测试用例, 使用所有支持的模型并发计算相似度,生成完整的缓存数据。 """ import json import sys from pathlib import Path from typing import List, Dict, Tuple import time import asyncio from datetime import datetime # 添加项目根目录到路径 project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from lib.text_embedding import compare_phrases, SUPPORTED_MODELS # 全局并发限制 MAX_CONCURRENT_REQUESTS = 100 semaphore = None # 进度跟踪 class ProgressTracker: """进度跟踪器""" def __init__(self, total: int, description: str = ""): self.total = total self.completed = 0 self.start_time = datetime.now() self.last_update_time = datetime.now() self.description = description def update(self, count: int = 1): """更新进度""" self.completed += count current_time = datetime.now() # 每0.5秒最多更新一次,或者达到总数时更新 if (current_time - self.last_update_time).total_seconds() >= 0.5 or self.completed >= self.total: self.display() self.last_update_time = current_time def display(self): """显示进度""" if self.total == 0: return percentage = (self.completed / self.total) * 100 elapsed = (datetime.now() - self.start_time).total_seconds() # 计算速度和预估剩余时间 if elapsed > 0: speed = self.completed / elapsed if speed > 0: remaining = (self.total - self.completed) / speed eta_str = f", 预计剩余: {int(remaining)}秒" else: eta_str = "" else: eta_str = "" bar_length = 40 filled_length = int(bar_length * self.completed / self.total) bar = '█' * filled_length + '░' * (bar_length - filled_length) desc = f"{self.description}: " if self.description else "" print(f"\r {desc}[{bar}] {self.completed}/{self.total} ({percentage:.1f}%){eta_str}", end='', flush=True) # 完成时换行 if self.completed >= self.total: print() # 全局进度跟踪器 progress_tracker = None def get_semaphore(): """获取全局信号量""" global semaphore if semaphore is None: semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) return semaphore def extract_test_cases_from_cache( cache_dir: str = "cache/text_embedding" ) -> List[Tuple[str, str]]: """ 从现有缓存文件中提取所有测试用例 Args: cache_dir: 缓存目录 Returns: 测试用例列表,每项为 (phrase_a, phrase_b) 元组 """ cache_path = Path(cache_dir) if not cache_path.exists(): print(f"缓存目录不存在: {cache_dir}") return [] test_cases = [] seen_pairs = set() # 用于去重 # 遍历所有缓存文件 cache_files = list(cache_path.glob("*.json")) print(f"扫描缓存文件: {len(cache_files)} 个") for cache_file in cache_files: try: with open(cache_file, 'r', encoding='utf-8') as f: data = json.load(f) # 提取短语对 phrase_a = data.get("input", {}).get("phrase_a") phrase_b = data.get("input", {}).get("phrase_b") if phrase_a and phrase_b: # 使用排序后的元组作为键,避免 (A, B) 和 (B, A) 重复 pair_key = tuple(sorted([phrase_a, phrase_b])) if pair_key not in seen_pairs: test_cases.append((phrase_a, phrase_b)) seen_pairs.add(pair_key) except (json.JSONDecodeError, IOError) as e: print(f" 读取缓存文件失败: {cache_file.name} - {e}") continue return test_cases async def test_single_case( phrase_a: str, phrase_b: str, model_key: str, use_cache: bool = True ) -> Dict: """ 测试单个用例(带并发限制) Args: phrase_a: 第一个短语 phrase_b: 第二个短语 model_key: 模型键名 use_cache: 是否使用缓存 Returns: 测试结果字典 """ global progress_tracker sem = get_semaphore() async with sem: try: # 使用 asyncio.to_thread 将同步函数转为异步执行 result = await asyncio.to_thread( compare_phrases, phrase_a=phrase_a, phrase_b=phrase_b, model_name=model_key, use_cache=use_cache ) # 更新进度 if progress_tracker: progress_tracker.update(1) return { "phrase_a": phrase_a, "phrase_b": phrase_b, "model": model_key, "相似度": result["相似度"], "说明": result["说明"], "status": "success" } except Exception as e: # 更新进度 if progress_tracker: progress_tracker.update(1) return { "phrase_a": phrase_a, "phrase_b": phrase_b, "model": model_key, "相似度": None, "说明": f"计算失败: {str(e)}", "status": "error" } async def test_all_models( test_cases: List[Tuple[str, str]], models: Dict[str, str] = None, use_cache: bool = True ) -> Dict[str, List[Dict]]: """ 使用所有模型并发测试所有用例 Args: test_cases: 测试用例列表 models: 模型字典,默认使用所有支持的模型 use_cache: 是否使用缓存 Returns: 测试结果字典,格式: { "model_name": [ { "phrase_a": "xxx", "phrase_b": "xxx", "相似度": 0.85, "说明": "xxx" }, ... ] } """ global progress_tracker if models is None: models = SUPPORTED_MODELS total_tests = len(test_cases) * len(models) print(f"\n开始测试 {len(models)} 个模型,共 {len(test_cases)} 个测试用例") print(f"总测试数: {total_tests:,}\n") # 预加载第一个模型(避免多线程加载冲突) print("预加载模型...") first_model = list(models.keys())[0] await asyncio.to_thread(compare_phrases, "测试", "测试", model_name=first_model) print("预加载完成!\n") # 初始化进度跟踪器 progress_tracker = ProgressTracker(total_tests, "测试进度") # 创建所有测试任务 tasks = [] for model_key in models.keys(): for phrase_a, phrase_b in test_cases: tasks.append( test_single_case(phrase_a, phrase_b, model_key, use_cache) ) # 并发执行所有测试 start_time = time.time() all_results = await asyncio.gather(*tasks) elapsed = time.time() - start_time print(f"\n所有测试完成! 总耗时: {elapsed:.2f}秒") print(f"平均速度: {total_tests/elapsed:.2f} 条/秒\n") # 按模型分组结果 results = {model_key: [] for model_key in models.keys()} for result in all_results: model_key = result["model"] results[model_key].append({ "phrase_a": result["phrase_a"], "phrase_b": result["phrase_b"], "相似度": result["相似度"], "说明": result["说明"], "status": result["status"] }) # 统计信息 print("统计信息:") for model_key in models.keys(): model_results = results[model_key] successful = sum(1 for r in model_results if r["status"] == "success") failed = len(model_results) - successful print(f" {model_key}: {successful} 成功, {failed} 失败") return results def save_results( results: Dict[str, List[Dict]], output_file: str = "data/model_comparison_results.json" ) -> None: """ 保存测试结果到JSON文件 Args: results: 测试结果 output_file: 输出文件路径 """ output_path = Path(output_file) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) print(f"测试结果已保存到: {output_file}") async def main(): """主函数""" # 配置参数 cache_dir = "cache/text_embedding" output_file = "data/model_comparison_results.json" # 步骤 1: 从缓存提取测试用例 print("=" * 60) print("步骤 1: 从缓存提取所有测试用例") print("=" * 60) test_cases = extract_test_cases_from_cache(cache_dir) if not test_cases: print("未找到测试用例,请先运行主流程生成缓存数据") return print(f"提取到 {len(test_cases):,} 个唯一测试用例") # 显示前5个测试用例示例 print("\n前5个测试用例示例:") for i, (phrase_a, phrase_b) in enumerate(test_cases[:5], 1): print(f" {i}. {phrase_a} vs {phrase_b}") # 步骤 2: 测试所有模型 print("\n" + "=" * 60) print("步骤 2: 使用所有模型并发测试") print("=" * 60) results = await test_all_models(test_cases, use_cache=True) # 步骤 3: 保存结果 print("\n" + "=" * 60) print("步骤 3: 保存结果") print("=" * 60) save_results(results, output_file) print("\n" + "=" * 60) print("全部完成!") print("=" * 60) if __name__ == "__main__": asyncio.run(main())