| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 测试所有文本相似度模型
- 从现有缓存(cache/text_embedding)中提取所有测试用例,
- 使用所有支持的模型并发计算相似度,生成完整的缓存数据。
- """
- import json
- import sys
- from pathlib import Path
- from typing import List, Dict, Tuple
- import time
- import asyncio
- from datetime import datetime
- # 添加项目根目录到路径
- project_root = Path(__file__).parent.parent.parent
- sys.path.insert(0, str(project_root))
- from lib.config import get_cache_dir
- from lib.text_embedding import compare_phrases, SUPPORTED_MODELS
- # 全局并发限制
- MAX_CONCURRENT_REQUESTS = 100
- semaphore = None
- # 进度跟踪
- class ProgressTracker:
- """进度跟踪器"""
- def __init__(self, total: int, description: str = ""):
- self.total = total
- self.completed = 0
- self.start_time = datetime.now()
- self.last_update_time = datetime.now()
- self.description = description
- def update(self, count: int = 1):
- """更新进度"""
- self.completed += count
- current_time = datetime.now()
- # 每0.5秒最多更新一次,或者达到总数时更新
- if (current_time - self.last_update_time).total_seconds() >= 0.5 or self.completed >= self.total:
- self.display()
- self.last_update_time = current_time
- def display(self):
- """显示进度"""
- if self.total == 0:
- return
- percentage = (self.completed / self.total) * 100
- elapsed = (datetime.now() - self.start_time).total_seconds()
- # 计算速度和预估剩余时间
- if elapsed > 0:
- speed = self.completed / elapsed
- if speed > 0:
- remaining = (self.total - self.completed) / speed
- eta_str = f", 预计剩余: {int(remaining)}秒"
- else:
- eta_str = ""
- else:
- eta_str = ""
- bar_length = 40
- filled_length = int(bar_length * self.completed / self.total)
- bar = '█' * filled_length + '░' * (bar_length - filled_length)
- desc = f"{self.description}: " if self.description else ""
- print(f"\r {desc}[{bar}] {self.completed}/{self.total} ({percentage:.1f}%){eta_str}", end='', flush=True)
- # 完成时换行
- if self.completed >= self.total:
- print()
- # 全局进度跟踪器
- progress_tracker = None
- def get_semaphore():
- """获取全局信号量"""
- global semaphore
- if semaphore is None:
- semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
- return semaphore
- def extract_test_cases_from_cache(
- cache_dir: str = None
- ) -> List[Tuple[str, str]]:
- """
- 从现有缓存文件中提取所有测试用例
- Args:
- cache_dir: 缓存目录
- Returns:
- 测试用例列表,每项为 (phrase_a, phrase_b) 元组
- """
- if cache_dir is None:
- cache_dir = get_cache_dir("text_embedding")
- cache_path = Path(cache_dir)
- if not cache_path.exists():
- print(f"缓存目录不存在: {cache_dir}")
- return []
- test_cases = []
- seen_pairs = set() # 用于去重
- # 遍历所有缓存文件
- cache_files = list(cache_path.glob("*.json"))
- print(f"扫描缓存文件: {len(cache_files)} 个")
- for cache_file in cache_files:
- try:
- with open(cache_file, 'r', encoding='utf-8') as f:
- data = json.load(f)
- # 提取短语对
- phrase_a = data.get("input", {}).get("phrase_a")
- phrase_b = data.get("input", {}).get("phrase_b")
- if phrase_a and phrase_b:
- # 使用排序后的元组作为键,避免 (A, B) 和 (B, A) 重复
- pair_key = tuple(sorted([phrase_a, phrase_b]))
- if pair_key not in seen_pairs:
- test_cases.append((phrase_a, phrase_b))
- seen_pairs.add(pair_key)
- except (json.JSONDecodeError, IOError) as e:
- print(f" 读取缓存文件失败: {cache_file.name} - {e}")
- continue
- return test_cases
- async def test_single_case(
- phrase_a: str,
- phrase_b: str,
- model_key: str,
- use_cache: bool = True
- ) -> Dict:
- """
- 测试单个用例(带并发限制)
- Args:
- phrase_a: 第一个短语
- phrase_b: 第二个短语
- model_key: 模型键名
- use_cache: 是否使用缓存
- Returns:
- 测试结果字典
- """
- global progress_tracker
- sem = get_semaphore()
- async with sem:
- try:
- # 使用 asyncio.to_thread 将同步函数转为异步执行
- result = await asyncio.to_thread(
- compare_phrases,
- phrase_a=phrase_a,
- phrase_b=phrase_b,
- model_name=model_key,
- use_cache=use_cache
- )
- # 更新进度
- if progress_tracker:
- progress_tracker.update(1)
- return {
- "phrase_a": phrase_a,
- "phrase_b": phrase_b,
- "model": model_key,
- "相似度": result["相似度"],
- "说明": result["说明"],
- "status": "success"
- }
- except Exception as e:
- # 更新进度
- if progress_tracker:
- progress_tracker.update(1)
- return {
- "phrase_a": phrase_a,
- "phrase_b": phrase_b,
- "model": model_key,
- "相似度": None,
- "说明": f"计算失败: {str(e)}",
- "status": "error"
- }
- async def test_all_models(
- test_cases: List[Tuple[str, str]],
- models: Dict[str, str] = None,
- use_cache: bool = True
- ) -> Dict[str, List[Dict]]:
- """
- 使用所有模型并发测试所有用例
- Args:
- test_cases: 测试用例列表
- models: 模型字典,默认使用所有支持的模型
- use_cache: 是否使用缓存
- Returns:
- 测试结果字典,格式:
- {
- "model_name": [
- {
- "phrase_a": "xxx",
- "phrase_b": "xxx",
- "相似度": 0.85,
- "说明": "xxx"
- },
- ...
- ]
- }
- """
- global progress_tracker
- if models is None:
- models = SUPPORTED_MODELS
- total_tests = len(test_cases) * len(models)
- print(f"\n开始测试 {len(models)} 个模型,共 {len(test_cases)} 个测试用例")
- print(f"总测试数: {total_tests:,}\n")
- # 预加载所有模型(避免多线程加载冲突)
- print("预加载所有模型...")
- for i, model_key in enumerate(models.keys(), 1):
- print(f" [{i}/{len(models)}] 加载模型: {model_key}")
- await asyncio.to_thread(compare_phrases, "测试", "测试", model_name=model_key)
- print("所有模型预加载完成!\n")
- # 初始化进度跟踪器
- progress_tracker = ProgressTracker(total_tests, "测试进度")
- # 创建所有测试任务
- tasks = []
- for model_key in models.keys():
- for phrase_a, phrase_b in test_cases:
- tasks.append(
- test_single_case(phrase_a, phrase_b, model_key, use_cache)
- )
- # 并发执行所有测试
- start_time = time.time()
- all_results = await asyncio.gather(*tasks)
- elapsed = time.time() - start_time
- print(f"\n所有测试完成! 总耗时: {elapsed:.2f}秒")
- print(f"平均速度: {total_tests/elapsed:.2f} 条/秒\n")
- # 按模型分组结果
- results = {model_key: [] for model_key in models.keys()}
- for result in all_results:
- model_key = result["model"]
- results[model_key].append({
- "phrase_a": result["phrase_a"],
- "phrase_b": result["phrase_b"],
- "相似度": result["相似度"],
- "说明": result["说明"],
- "status": result["status"]
- })
- # 统计信息
- print("统计信息:")
- for model_key in models.keys():
- model_results = results[model_key]
- successful = sum(1 for r in model_results if r["status"] == "success")
- failed = len(model_results) - successful
- print(f" {model_key}: {successful} 成功, {failed} 失败")
- return results
- def save_results(
- results: Dict[str, List[Dict]],
- output_file: str = "data/model_comparison_results.json"
- ) -> None:
- """
- 保存测试结果到JSON文件
- Args:
- results: 测试结果
- output_file: 输出文件路径
- """
- output_path = Path(output_file)
- output_path.parent.mkdir(parents=True, exist_ok=True)
- with open(output_path, 'w', encoding='utf-8') as f:
- json.dump(results, f, ensure_ascii=False, indent=2)
- print(f"测试结果已保存到: {output_file}")
- async def main():
- """主函数"""
- # 配置参数(从配置模块获取)
- cache_dir = get_cache_dir("text_embedding")
- output_file = "data/model_comparison_results.json"
- # 步骤 1: 从缓存提取测试用例
- print("=" * 60)
- print("步骤 1: 从缓存提取所有测试用例")
- print("=" * 60)
- test_cases = extract_test_cases_from_cache(cache_dir)
- if not test_cases:
- print("未找到测试用例,请先运行主流程生成缓存数据")
- return
- print(f"提取到 {len(test_cases):,} 个唯一测试用例")
- # 显示前5个测试用例示例
- print("\n前5个测试用例示例:")
- for i, (phrase_a, phrase_b) in enumerate(test_cases[:5], 1):
- print(f" {i}. {phrase_a} vs {phrase_b}")
- # 步骤 2: 测试所有模型
- print("\n" + "=" * 60)
- print("步骤 2: 使用所有模型并发测试")
- print("=" * 60)
- results = await test_all_models(test_cases, use_cache=True)
- # 步骤 3: 保存结果
- print("\n" + "=" * 60)
- print("步骤 3: 保存结果")
- print("=" * 60)
- save_results(results, output_file)
- print("\n" + "=" * 60)
- print("全部完成!")
- print("=" * 60)
- if __name__ == "__main__":
- asyncio.run(main())
|