test_all_models.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 测试所有文本相似度模型
  5. 从现有缓存(cache/text_embedding)中提取所有测试用例,
  6. 使用所有支持的模型并发计算相似度,生成完整的缓存数据。
  7. """
  8. import json
  9. import sys
  10. from pathlib import Path
  11. from typing import List, Dict, Tuple
  12. import time
  13. import asyncio
  14. from datetime import datetime
  15. # 添加项目根目录到路径
  16. project_root = Path(__file__).parent.parent.parent
  17. sys.path.insert(0, str(project_root))
  18. from lib.text_embedding import compare_phrases, SUPPORTED_MODELS
  19. # 全局并发限制
  20. MAX_CONCURRENT_REQUESTS = 100
  21. semaphore = None
  22. # 进度跟踪
  23. class ProgressTracker:
  24. """进度跟踪器"""
  25. def __init__(self, total: int, description: str = ""):
  26. self.total = total
  27. self.completed = 0
  28. self.start_time = datetime.now()
  29. self.last_update_time = datetime.now()
  30. self.description = description
  31. def update(self, count: int = 1):
  32. """更新进度"""
  33. self.completed += count
  34. current_time = datetime.now()
  35. # 每0.5秒最多更新一次,或者达到总数时更新
  36. if (current_time - self.last_update_time).total_seconds() >= 0.5 or self.completed >= self.total:
  37. self.display()
  38. self.last_update_time = current_time
  39. def display(self):
  40. """显示进度"""
  41. if self.total == 0:
  42. return
  43. percentage = (self.completed / self.total) * 100
  44. elapsed = (datetime.now() - self.start_time).total_seconds()
  45. # 计算速度和预估剩余时间
  46. if elapsed > 0:
  47. speed = self.completed / elapsed
  48. if speed > 0:
  49. remaining = (self.total - self.completed) / speed
  50. eta_str = f", 预计剩余: {int(remaining)}秒"
  51. else:
  52. eta_str = ""
  53. else:
  54. eta_str = ""
  55. bar_length = 40
  56. filled_length = int(bar_length * self.completed / self.total)
  57. bar = '█' * filled_length + '░' * (bar_length - filled_length)
  58. desc = f"{self.description}: " if self.description else ""
  59. print(f"\r {desc}[{bar}] {self.completed}/{self.total} ({percentage:.1f}%){eta_str}", end='', flush=True)
  60. # 完成时换行
  61. if self.completed >= self.total:
  62. print()
  63. # 全局进度跟踪器
  64. progress_tracker = None
  65. def get_semaphore():
  66. """获取全局信号量"""
  67. global semaphore
  68. if semaphore is None:
  69. semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
  70. return semaphore
  71. def extract_test_cases_from_cache(
  72. cache_dir: str = "cache/text_embedding"
  73. ) -> List[Tuple[str, str]]:
  74. """
  75. 从现有缓存文件中提取所有测试用例
  76. Args:
  77. cache_dir: 缓存目录
  78. Returns:
  79. 测试用例列表,每项为 (phrase_a, phrase_b) 元组
  80. """
  81. cache_path = Path(cache_dir)
  82. if not cache_path.exists():
  83. print(f"缓存目录不存在: {cache_dir}")
  84. return []
  85. test_cases = []
  86. seen_pairs = set() # 用于去重
  87. # 遍历所有缓存文件
  88. cache_files = list(cache_path.glob("*.json"))
  89. print(f"扫描缓存文件: {len(cache_files)} 个")
  90. for cache_file in cache_files:
  91. try:
  92. with open(cache_file, 'r', encoding='utf-8') as f:
  93. data = json.load(f)
  94. # 提取短语对
  95. phrase_a = data.get("input", {}).get("phrase_a")
  96. phrase_b = data.get("input", {}).get("phrase_b")
  97. if phrase_a and phrase_b:
  98. # 使用排序后的元组作为键,避免 (A, B) 和 (B, A) 重复
  99. pair_key = tuple(sorted([phrase_a, phrase_b]))
  100. if pair_key not in seen_pairs:
  101. test_cases.append((phrase_a, phrase_b))
  102. seen_pairs.add(pair_key)
  103. except (json.JSONDecodeError, IOError) as e:
  104. print(f" 读取缓存文件失败: {cache_file.name} - {e}")
  105. continue
  106. return test_cases
  107. async def test_single_case(
  108. phrase_a: str,
  109. phrase_b: str,
  110. model_key: str,
  111. use_cache: bool = True
  112. ) -> Dict:
  113. """
  114. 测试单个用例(带并发限制)
  115. Args:
  116. phrase_a: 第一个短语
  117. phrase_b: 第二个短语
  118. model_key: 模型键名
  119. use_cache: 是否使用缓存
  120. Returns:
  121. 测试结果字典
  122. """
  123. global progress_tracker
  124. sem = get_semaphore()
  125. async with sem:
  126. try:
  127. # 使用 asyncio.to_thread 将同步函数转为异步执行
  128. result = await asyncio.to_thread(
  129. compare_phrases,
  130. phrase_a=phrase_a,
  131. phrase_b=phrase_b,
  132. model_name=model_key,
  133. use_cache=use_cache
  134. )
  135. # 更新进度
  136. if progress_tracker:
  137. progress_tracker.update(1)
  138. return {
  139. "phrase_a": phrase_a,
  140. "phrase_b": phrase_b,
  141. "model": model_key,
  142. "相似度": result["相似度"],
  143. "说明": result["说明"],
  144. "status": "success"
  145. }
  146. except Exception as e:
  147. # 更新进度
  148. if progress_tracker:
  149. progress_tracker.update(1)
  150. return {
  151. "phrase_a": phrase_a,
  152. "phrase_b": phrase_b,
  153. "model": model_key,
  154. "相似度": None,
  155. "说明": f"计算失败: {str(e)}",
  156. "status": "error"
  157. }
  158. async def test_all_models(
  159. test_cases: List[Tuple[str, str]],
  160. models: Dict[str, str] = None,
  161. use_cache: bool = True
  162. ) -> Dict[str, List[Dict]]:
  163. """
  164. 使用所有模型并发测试所有用例
  165. Args:
  166. test_cases: 测试用例列表
  167. models: 模型字典,默认使用所有支持的模型
  168. use_cache: 是否使用缓存
  169. Returns:
  170. 测试结果字典,格式:
  171. {
  172. "model_name": [
  173. {
  174. "phrase_a": "xxx",
  175. "phrase_b": "xxx",
  176. "相似度": 0.85,
  177. "说明": "xxx"
  178. },
  179. ...
  180. ]
  181. }
  182. """
  183. global progress_tracker
  184. if models is None:
  185. models = SUPPORTED_MODELS
  186. total_tests = len(test_cases) * len(models)
  187. print(f"\n开始测试 {len(models)} 个模型,共 {len(test_cases)} 个测试用例")
  188. print(f"总测试数: {total_tests:,}\n")
  189. # 预加载所有模型(避免多线程加载冲突)
  190. print("预加载所有模型...")
  191. for i, model_key in enumerate(models.keys(), 1):
  192. print(f" [{i}/{len(models)}] 加载模型: {model_key}")
  193. await asyncio.to_thread(compare_phrases, "测试", "测试", model_name=model_key)
  194. print("所有模型预加载完成!\n")
  195. # 初始化进度跟踪器
  196. progress_tracker = ProgressTracker(total_tests, "测试进度")
  197. # 创建所有测试任务
  198. tasks = []
  199. for model_key in models.keys():
  200. for phrase_a, phrase_b in test_cases:
  201. tasks.append(
  202. test_single_case(phrase_a, phrase_b, model_key, use_cache)
  203. )
  204. # 并发执行所有测试
  205. start_time = time.time()
  206. all_results = await asyncio.gather(*tasks)
  207. elapsed = time.time() - start_time
  208. print(f"\n所有测试完成! 总耗时: {elapsed:.2f}秒")
  209. print(f"平均速度: {total_tests/elapsed:.2f} 条/秒\n")
  210. # 按模型分组结果
  211. results = {model_key: [] for model_key in models.keys()}
  212. for result in all_results:
  213. model_key = result["model"]
  214. results[model_key].append({
  215. "phrase_a": result["phrase_a"],
  216. "phrase_b": result["phrase_b"],
  217. "相似度": result["相似度"],
  218. "说明": result["说明"],
  219. "status": result["status"]
  220. })
  221. # 统计信息
  222. print("统计信息:")
  223. for model_key in models.keys():
  224. model_results = results[model_key]
  225. successful = sum(1 for r in model_results if r["status"] == "success")
  226. failed = len(model_results) - successful
  227. print(f" {model_key}: {successful} 成功, {failed} 失败")
  228. return results
  229. def save_results(
  230. results: Dict[str, List[Dict]],
  231. output_file: str = "data/model_comparison_results.json"
  232. ) -> None:
  233. """
  234. 保存测试结果到JSON文件
  235. Args:
  236. results: 测试结果
  237. output_file: 输出文件路径
  238. """
  239. output_path = Path(output_file)
  240. output_path.parent.mkdir(parents=True, exist_ok=True)
  241. with open(output_path, 'w', encoding='utf-8') as f:
  242. json.dump(results, f, ensure_ascii=False, indent=2)
  243. print(f"测试结果已保存到: {output_file}")
  244. async def main():
  245. """主函数"""
  246. # 配置参数
  247. cache_dir = "cache/text_embedding"
  248. output_file = "data/model_comparison_results.json"
  249. # 步骤 1: 从缓存提取测试用例
  250. print("=" * 60)
  251. print("步骤 1: 从缓存提取所有测试用例")
  252. print("=" * 60)
  253. test_cases = extract_test_cases_from_cache(cache_dir)
  254. if not test_cases:
  255. print("未找到测试用例,请先运行主流程生成缓存数据")
  256. return
  257. print(f"提取到 {len(test_cases):,} 个唯一测试用例")
  258. # 显示前5个测试用例示例
  259. print("\n前5个测试用例示例:")
  260. for i, (phrase_a, phrase_b) in enumerate(test_cases[:5], 1):
  261. print(f" {i}. {phrase_a} vs {phrase_b}")
  262. # 步骤 2: 测试所有模型
  263. print("\n" + "=" * 60)
  264. print("步骤 2: 使用所有模型并发测试")
  265. print("=" * 60)
  266. results = await test_all_models(test_cases, use_cache=True)
  267. # 步骤 3: 保存结果
  268. print("\n" + "=" * 60)
  269. print("步骤 3: 保存结果")
  270. print("=" * 60)
  271. save_results(results, output_file)
  272. print("\n" + "=" * 60)
  273. print("全部完成!")
  274. print("=" * 60)
  275. if __name__ == "__main__":
  276. asyncio.run(main())