test_all_models.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 测试所有文本相似度模型
  5. 从现有缓存(cache/text_embedding)中提取所有测试用例,
  6. 使用所有支持的模型并发计算相似度,生成完整的缓存数据。
  7. """
  8. import json
  9. import sys
  10. from pathlib import Path
  11. from typing import List, Dict, Tuple
  12. import time
  13. import asyncio
  14. from datetime import datetime
  15. # 添加项目根目录到路径
  16. project_root = Path(__file__).parent.parent.parent
  17. sys.path.insert(0, str(project_root))
  18. from lib.config import get_cache_dir
  19. from lib.text_embedding import compare_phrases, SUPPORTED_MODELS
  20. # 全局并发限制
  21. MAX_CONCURRENT_REQUESTS = 100
  22. semaphore = None
  23. # 进度跟踪
  24. class ProgressTracker:
  25. """进度跟踪器"""
  26. def __init__(self, total: int, description: str = ""):
  27. self.total = total
  28. self.completed = 0
  29. self.start_time = datetime.now()
  30. self.last_update_time = datetime.now()
  31. self.description = description
  32. def update(self, count: int = 1):
  33. """更新进度"""
  34. self.completed += count
  35. current_time = datetime.now()
  36. # 每0.5秒最多更新一次,或者达到总数时更新
  37. if (current_time - self.last_update_time).total_seconds() >= 0.5 or self.completed >= self.total:
  38. self.display()
  39. self.last_update_time = current_time
  40. def display(self):
  41. """显示进度"""
  42. if self.total == 0:
  43. return
  44. percentage = (self.completed / self.total) * 100
  45. elapsed = (datetime.now() - self.start_time).total_seconds()
  46. # 计算速度和预估剩余时间
  47. if elapsed > 0:
  48. speed = self.completed / elapsed
  49. if speed > 0:
  50. remaining = (self.total - self.completed) / speed
  51. eta_str = f", 预计剩余: {int(remaining)}秒"
  52. else:
  53. eta_str = ""
  54. else:
  55. eta_str = ""
  56. bar_length = 40
  57. filled_length = int(bar_length * self.completed / self.total)
  58. bar = '█' * filled_length + '░' * (bar_length - filled_length)
  59. desc = f"{self.description}: " if self.description else ""
  60. print(f"\r {desc}[{bar}] {self.completed}/{self.total} ({percentage:.1f}%){eta_str}", end='', flush=True)
  61. # 完成时换行
  62. if self.completed >= self.total:
  63. print()
  64. # 全局进度跟踪器
  65. progress_tracker = None
  66. def get_semaphore():
  67. """获取全局信号量"""
  68. global semaphore
  69. if semaphore is None:
  70. semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
  71. return semaphore
  72. def extract_test_cases_from_cache(
  73. cache_dir: str = None
  74. ) -> List[Tuple[str, str]]:
  75. """
  76. 从现有缓存文件中提取所有测试用例
  77. Args:
  78. cache_dir: 缓存目录
  79. Returns:
  80. 测试用例列表,每项为 (phrase_a, phrase_b) 元组
  81. """
  82. if cache_dir is None:
  83. cache_dir = get_cache_dir("text_embedding")
  84. cache_path = Path(cache_dir)
  85. if not cache_path.exists():
  86. print(f"缓存目录不存在: {cache_dir}")
  87. return []
  88. test_cases = []
  89. seen_pairs = set() # 用于去重
  90. # 遍历所有缓存文件
  91. cache_files = list(cache_path.glob("*.json"))
  92. print(f"扫描缓存文件: {len(cache_files)} 个")
  93. for cache_file in cache_files:
  94. try:
  95. with open(cache_file, 'r', encoding='utf-8') as f:
  96. data = json.load(f)
  97. # 提取短语对
  98. phrase_a = data.get("input", {}).get("phrase_a")
  99. phrase_b = data.get("input", {}).get("phrase_b")
  100. if phrase_a and phrase_b:
  101. # 使用排序后的元组作为键,避免 (A, B) 和 (B, A) 重复
  102. pair_key = tuple(sorted([phrase_a, phrase_b]))
  103. if pair_key not in seen_pairs:
  104. test_cases.append((phrase_a, phrase_b))
  105. seen_pairs.add(pair_key)
  106. except (json.JSONDecodeError, IOError) as e:
  107. print(f" 读取缓存文件失败: {cache_file.name} - {e}")
  108. continue
  109. return test_cases
  110. async def test_single_case(
  111. phrase_a: str,
  112. phrase_b: str,
  113. model_key: str,
  114. use_cache: bool = True
  115. ) -> Dict:
  116. """
  117. 测试单个用例(带并发限制)
  118. Args:
  119. phrase_a: 第一个短语
  120. phrase_b: 第二个短语
  121. model_key: 模型键名
  122. use_cache: 是否使用缓存
  123. Returns:
  124. 测试结果字典
  125. """
  126. global progress_tracker
  127. sem = get_semaphore()
  128. async with sem:
  129. try:
  130. # 使用 asyncio.to_thread 将同步函数转为异步执行
  131. result = await asyncio.to_thread(
  132. compare_phrases,
  133. phrase_a=phrase_a,
  134. phrase_b=phrase_b,
  135. model_name=model_key,
  136. use_cache=use_cache
  137. )
  138. # 更新进度
  139. if progress_tracker:
  140. progress_tracker.update(1)
  141. return {
  142. "phrase_a": phrase_a,
  143. "phrase_b": phrase_b,
  144. "model": model_key,
  145. "相似度": result["相似度"],
  146. "说明": result["说明"],
  147. "status": "success"
  148. }
  149. except Exception as e:
  150. # 更新进度
  151. if progress_tracker:
  152. progress_tracker.update(1)
  153. return {
  154. "phrase_a": phrase_a,
  155. "phrase_b": phrase_b,
  156. "model": model_key,
  157. "相似度": None,
  158. "说明": f"计算失败: {str(e)}",
  159. "status": "error"
  160. }
  161. async def test_all_models(
  162. test_cases: List[Tuple[str, str]],
  163. models: Dict[str, str] = None,
  164. use_cache: bool = True
  165. ) -> Dict[str, List[Dict]]:
  166. """
  167. 使用所有模型并发测试所有用例
  168. Args:
  169. test_cases: 测试用例列表
  170. models: 模型字典,默认使用所有支持的模型
  171. use_cache: 是否使用缓存
  172. Returns:
  173. 测试结果字典,格式:
  174. {
  175. "model_name": [
  176. {
  177. "phrase_a": "xxx",
  178. "phrase_b": "xxx",
  179. "相似度": 0.85,
  180. "说明": "xxx"
  181. },
  182. ...
  183. ]
  184. }
  185. """
  186. global progress_tracker
  187. if models is None:
  188. models = SUPPORTED_MODELS
  189. total_tests = len(test_cases) * len(models)
  190. print(f"\n开始测试 {len(models)} 个模型,共 {len(test_cases)} 个测试用例")
  191. print(f"总测试数: {total_tests:,}\n")
  192. # 预加载所有模型(避免多线程加载冲突)
  193. print("预加载所有模型...")
  194. for i, model_key in enumerate(models.keys(), 1):
  195. print(f" [{i}/{len(models)}] 加载模型: {model_key}")
  196. await asyncio.to_thread(compare_phrases, "测试", "测试", model_name=model_key)
  197. print("所有模型预加载完成!\n")
  198. # 初始化进度跟踪器
  199. progress_tracker = ProgressTracker(total_tests, "测试进度")
  200. # 创建所有测试任务
  201. tasks = []
  202. for model_key in models.keys():
  203. for phrase_a, phrase_b in test_cases:
  204. tasks.append(
  205. test_single_case(phrase_a, phrase_b, model_key, use_cache)
  206. )
  207. # 并发执行所有测试
  208. start_time = time.time()
  209. all_results = await asyncio.gather(*tasks)
  210. elapsed = time.time() - start_time
  211. print(f"\n所有测试完成! 总耗时: {elapsed:.2f}秒")
  212. print(f"平均速度: {total_tests/elapsed:.2f} 条/秒\n")
  213. # 按模型分组结果
  214. results = {model_key: [] for model_key in models.keys()}
  215. for result in all_results:
  216. model_key = result["model"]
  217. results[model_key].append({
  218. "phrase_a": result["phrase_a"],
  219. "phrase_b": result["phrase_b"],
  220. "相似度": result["相似度"],
  221. "说明": result["说明"],
  222. "status": result["status"]
  223. })
  224. # 统计信息
  225. print("统计信息:")
  226. for model_key in models.keys():
  227. model_results = results[model_key]
  228. successful = sum(1 for r in model_results if r["status"] == "success")
  229. failed = len(model_results) - successful
  230. print(f" {model_key}: {successful} 成功, {failed} 失败")
  231. return results
  232. def save_results(
  233. results: Dict[str, List[Dict]],
  234. output_file: str = "data/model_comparison_results.json"
  235. ) -> None:
  236. """
  237. 保存测试结果到JSON文件
  238. Args:
  239. results: 测试结果
  240. output_file: 输出文件路径
  241. """
  242. output_path = Path(output_file)
  243. output_path.parent.mkdir(parents=True, exist_ok=True)
  244. with open(output_path, 'w', encoding='utf-8') as f:
  245. json.dump(results, f, ensure_ascii=False, indent=2)
  246. print(f"测试结果已保存到: {output_file}")
  247. async def main():
  248. """主函数"""
  249. # 配置参数(从配置模块获取)
  250. cache_dir = get_cache_dir("text_embedding")
  251. output_file = "data/model_comparison_results.json"
  252. # 步骤 1: 从缓存提取测试用例
  253. print("=" * 60)
  254. print("步骤 1: 从缓存提取所有测试用例")
  255. print("=" * 60)
  256. test_cases = extract_test_cases_from_cache(cache_dir)
  257. if not test_cases:
  258. print("未找到测试用例,请先运行主流程生成缓存数据")
  259. return
  260. print(f"提取到 {len(test_cases):,} 个唯一测试用例")
  261. # 显示前5个测试用例示例
  262. print("\n前5个测试用例示例:")
  263. for i, (phrase_a, phrase_b) in enumerate(test_cases[:5], 1):
  264. print(f" {i}. {phrase_a} vs {phrase_b}")
  265. # 步骤 2: 测试所有模型
  266. print("\n" + "=" * 60)
  267. print("步骤 2: 使用所有模型并发测试")
  268. print("=" * 60)
  269. results = await test_all_models(test_cases, use_cache=True)
  270. # 步骤 3: 保存结果
  271. print("\n" + "=" * 60)
  272. print("步骤 3: 保存结果")
  273. print("=" * 60)
  274. save_results(results, output_file)
  275. print("\n" + "=" * 60)
  276. print("全部完成!")
  277. print("=" * 60)
  278. if __name__ == "__main__":
  279. asyncio.run(main())