test_relation_analyzer.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. """
  2. 测试 relation_analyzer 模块
  3. """
  4. import asyncio
  5. from lib.relation_analyzer import analyze_relation
  6. async def test_all_relations():
  7. """测试所有7种关系类型"""
  8. # 测试用例:每种关系类型的典型例子
  9. test_cases = [
  10. # 1. same(同义)
  11. {
  12. "phrase_a": "医生",
  13. "phrase_b": "大夫",
  14. "expected_relation": "same",
  15. "description": "完全同义"
  16. },
  17. {
  18. "phrase_a": "计算机",
  19. "phrase_b": "电脑",
  20. "expected_relation": "same",
  21. "description": "完全同义"
  22. },
  23. # 2. coordinate(同级)
  24. {
  25. "phrase_a": "轿车",
  26. "phrase_b": "SUV",
  27. "expected_relation": "coordinate",
  28. "description": "都是汽车的子类"
  29. },
  30. {
  31. "phrase_a": "苹果",
  32. "phrase_b": "香蕉",
  33. "expected_relation": "coordinate",
  34. "description": "都是水果"
  35. },
  36. # 3. contains(包含)
  37. {
  38. "phrase_a": "水果",
  39. "phrase_b": "苹果",
  40. "expected_relation": "contains",
  41. "description": "水果包含苹果"
  42. },
  43. {
  44. "phrase_a": "汽车",
  45. "phrase_b": "轿车",
  46. "expected_relation": "contains",
  47. "description": "汽车包含轿车"
  48. },
  49. # 4. contained_by(被包含)
  50. {
  51. "phrase_a": "苹果",
  52. "phrase_b": "水果",
  53. "expected_relation": "contained_by",
  54. "description": "苹果被水果包含"
  55. },
  56. {
  57. "phrase_a": "轿车",
  58. "phrase_b": "交通工具",
  59. "expected_relation": "contained_by",
  60. "description": "轿车被交通工具包含"
  61. },
  62. # 5. overlap(部分重叠)
  63. {
  64. "phrase_a": "红苹果",
  65. "phrase_b": "大苹果",
  66. "expected_relation": "overlap",
  67. "description": "有交集(又红又大的苹果)"
  68. },
  69. {
  70. "phrase_a": "学生",
  71. "phrase_b": "运动员",
  72. "expected_relation": "overlap",
  73. "description": "有交集(学生运动员)"
  74. },
  75. # 6. related(相关)
  76. {
  77. "phrase_a": "医生",
  78. "phrase_b": "医院",
  79. "expected_relation": "related",
  80. "description": "工作场所关系"
  81. },
  82. {
  83. "phrase_a": "阅读",
  84. "phrase_b": "书籍",
  85. "expected_relation": "related",
  86. "description": "动作-对象关系"
  87. },
  88. # 7. unrelated(无关)
  89. {
  90. "phrase_a": "医生",
  91. "phrase_b": "石头",
  92. "expected_relation": "unrelated",
  93. "description": "完全无关"
  94. },
  95. {
  96. "phrase_a": "苹果",
  97. "phrase_b": "数学",
  98. "expected_relation": "unrelated",
  99. "description": "完全无关"
  100. },
  101. ]
  102. # 模型选择(根据你的配置调整)
  103. model_name = "google/gemini-2.5-flash" # 默认模型
  104. print(f"=" * 80)
  105. print(f"开始测试 relation_analyzer 模块")
  106. print(f"使用模型: {model_name}")
  107. print(f"测试用例数量: {len(test_cases)}")
  108. print(f"=" * 80)
  109. print()
  110. results = []
  111. for i, test_case in enumerate(test_cases, 1):
  112. phrase_a = test_case["phrase_a"]
  113. phrase_b = test_case["phrase_b"]
  114. expected = test_case["expected_relation"]
  115. description = test_case["description"]
  116. print(f"[{i}/{len(test_cases)}] 测试: \"{phrase_a}\" <-> \"{phrase_b}\"")
  117. print(f" 说明: {description}")
  118. print(f" 期望关系: {expected}")
  119. # 调用分析函数
  120. result = await analyze_relation(
  121. phrase_a=phrase_a,
  122. phrase_b=phrase_b,
  123. model_name=model_name
  124. )
  125. relation = result.get("relation", "unknown")
  126. score = result.get("score", 0.0)
  127. explanation = result.get("explanation", "")
  128. # 判断是否符合预期
  129. is_correct = (relation == expected)
  130. status = "✓" if is_correct else "✗"
  131. print(f" 实际关系: {relation} (score: {score:.2f}) {status}")
  132. print(f" 解释: {explanation}")
  133. print()
  134. results.append({
  135. "test_case": test_case,
  136. "result": result,
  137. "is_correct": is_correct
  138. })
  139. # 统计结果
  140. correct_count = sum(1 for r in results if r["is_correct"])
  141. total_count = len(results)
  142. accuracy = correct_count / total_count * 100
  143. print(f"=" * 80)
  144. print(f"测试完成")
  145. print(f"正确: {correct_count}/{total_count} ({accuracy:.1f}%)")
  146. print(f"=" * 80)
  147. # 显示错误的测试用例
  148. errors = [r for r in results if not r["is_correct"]]
  149. if errors:
  150. print()
  151. print("错误的测试用例:")
  152. for error in errors:
  153. tc = error["test_case"]
  154. result = error["result"]
  155. print(f" - \"{tc['phrase_a']}\" <-> \"{tc['phrase_b']}\"")
  156. print(f" 期望: {tc['expected_relation']}, 实际: {result['relation']}")
  157. return results
  158. async def test_single_example():
  159. """测试单个例子"""
  160. print("测试单个例子:")
  161. print()
  162. result = await analyze_relation(
  163. phrase_a="水果",
  164. phrase_b="苹果",
  165. model_name="google/gemini-2.5-flash" # 默认模型
  166. )
  167. print(f"短语A: 水果")
  168. print(f"短语B: 苹果")
  169. print(f"关系: {result['relation']}")
  170. print(f"分数: {result['score']}")
  171. print(f"解释: {result['explanation']}")
  172. if __name__ == "__main__":
  173. # 选择测试方式:
  174. # 方式1:测试单个例子(快速验证)
  175. # asyncio.run(test_single_example())
  176. # 方式2:测试所有关系类型(完整测试)
  177. asyncio.run(test_all_relations())