hybrid_similarity.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. #!/usr/bin/env python3
  2. """
  3. 混合相似度计算模块
  4. 结合向量模型(text_embedding)和LLM模型(semantic_similarity)的结果
  5. """
  6. from typing import Dict, Any, Optional
  7. import asyncio
  8. from lib.text_embedding import compare_phrases as compare_phrases_embedding
  9. from lib.semantic_similarity import compare_phrases as compare_phrases_semantic
  10. from lib.config import get_cache_dir
  11. async def compare_phrases(
  12. phrase_a: str,
  13. phrase_b: str,
  14. weight_embedding: float = 0.5,
  15. weight_semantic: float = 0.5,
  16. embedding_model: str = "chinese",
  17. semantic_model: str = 'openai/gpt-4.1-mini',
  18. use_cache: bool = True,
  19. cache_dir_embedding: Optional[str] = None,
  20. cache_dir_semantic: Optional[str] = None,
  21. **semantic_kwargs
  22. ) -> Dict[str, Any]:
  23. """
  24. 混合相似度计算:同时使用向量模型和LLM模型,按权重组合结果
  25. Args:
  26. phrase_a: 第一个短语
  27. phrase_b: 第二个短语
  28. weight_embedding: 向量模型权重,默认 0.5
  29. weight_semantic: LLM模型权重,默认 0.5
  30. embedding_model: 向量模型名称,默认 "chinese"
  31. semantic_model: LLM模型名称,默认 'openai/gpt-4.1-mini'
  32. use_cache: 是否使用缓存,默认 True
  33. cache_dir_embedding: 向量模型缓存目录,默认从配置读取
  34. cache_dir_semantic: LLM模型缓存目录,默认从配置读取
  35. **semantic_kwargs: 其他传递给semantic_similarity的参数
  36. - temperature: 温度参数,默认 0.0
  37. - max_tokens: 最大token数,默认 65536
  38. - prompt_template: 自定义提示词模板
  39. - instructions: Agent系统指令
  40. - tools: Agent工具列表
  41. - name: Agent名称
  42. Returns:
  43. {
  44. "相似度": float, # 加权平均后的相似度 (0-1)
  45. "说明": str # 综合说明(包含各模型的分数和说明)
  46. }
  47. Examples:
  48. >>> # 使用默认权重 (0.5:0.5)
  49. >>> result = await compare_phrases("深度学习", "神经网络")
  50. >>> print(result['相似度']) # 加权平均后的相似度
  51. 0.82
  52. >>> # 自定义权重,更倾向向量模型
  53. >>> result = await compare_phrases(
  54. ... "深度学习", "神经网络",
  55. ... weight_embedding=0.7,
  56. ... weight_semantic=0.3
  57. ... )
  58. >>> # 使用不同的模型
  59. >>> result = await compare_phrases(
  60. ... "深度学习", "神经网络",
  61. ... embedding_model="multilingual",
  62. ... semantic_model="anthropic/claude-sonnet-4.5"
  63. ... )
  64. """
  65. # 验证权重
  66. total_weight = weight_embedding + weight_semantic
  67. if abs(total_weight - 1.0) > 0.001:
  68. raise ValueError(f"权重之和必须为1.0,当前为: {total_weight}")
  69. # 使用配置的缓存目录(如果未指定)
  70. if cache_dir_embedding is None:
  71. cache_dir_embedding = get_cache_dir("text_embedding")
  72. if cache_dir_semantic is None:
  73. cache_dir_semantic = get_cache_dir("semantic_similarity")
  74. # 并发调用两个模型
  75. embedding_task = asyncio.to_thread(
  76. compare_phrases_embedding,
  77. phrase_a=phrase_a,
  78. phrase_b=phrase_b,
  79. model_name=embedding_model,
  80. use_cache=use_cache,
  81. cache_dir=cache_dir_embedding
  82. )
  83. semantic_task = compare_phrases_semantic(
  84. phrase_a=phrase_a,
  85. phrase_b=phrase_b,
  86. model_name=semantic_model,
  87. use_cache=use_cache,
  88. cache_dir=cache_dir_semantic,
  89. **semantic_kwargs
  90. )
  91. # 等待两个任务完成
  92. embedding_result, semantic_result = await asyncio.gather(
  93. embedding_task,
  94. semantic_task
  95. )
  96. # 提取相似度分数
  97. score_embedding = embedding_result.get("相似度", 0.0)
  98. score_semantic = semantic_result.get("相似度", 0.0)
  99. # 计算加权平均
  100. final_score = (
  101. score_embedding * weight_embedding +
  102. score_semantic * weight_semantic
  103. )
  104. # 生成综合说明(格式化为清晰的结构)
  105. explanation = (
  106. f"【混合相似度】{final_score:.3f}(向量模型权重{weight_embedding},LLM模型权重{weight_semantic})\n\n"
  107. f"【向量模型】相似度={score_embedding:.3f}\n"
  108. f"{embedding_result.get('说明', 'N/A')}\n\n"
  109. f"【LLM模型】相似度={score_semantic:.3f}\n"
  110. f"{semantic_result.get('说明', 'N/A')}"
  111. )
  112. # 构建返回结果(与原接口完全一致)
  113. return {
  114. "相似度": final_score,
  115. "说明": explanation
  116. }
  117. def compare_phrases_sync(
  118. phrase_a: str,
  119. phrase_b: str,
  120. weight_embedding: float = 0.5,
  121. weight_semantic: float = 0.5,
  122. **kwargs
  123. ) -> Dict[str, Any]:
  124. """
  125. 混合相似度计算的同步版本(内部创建事件循环)
  126. Args:
  127. phrase_a: 第一个短语
  128. phrase_b: 第二个短语
  129. weight_embedding: 向量模型权重,默认 0.5
  130. weight_semantic: LLM模型权重,默认 0.5
  131. **kwargs: 其他参数(同 compare_phrases)
  132. Returns:
  133. 同 compare_phrases
  134. Examples:
  135. >>> result = compare_phrases_sync("深度学习", "神经网络")
  136. >>> print(result['相似度'])
  137. """
  138. return asyncio.run(
  139. compare_phrases(
  140. phrase_a=phrase_a,
  141. phrase_b=phrase_b,
  142. weight_embedding=weight_embedding,
  143. weight_semantic=weight_semantic,
  144. **kwargs
  145. )
  146. )
  147. if __name__ == "__main__":
  148. async def main():
  149. print("=" * 80)
  150. print("混合相似度计算示例")
  151. print("=" * 80)
  152. print()
  153. # 示例 1: 默认权重 (0.5:0.5)
  154. print("示例 1: 默认权重 (0.5:0.5)")
  155. print("-" * 80)
  156. result = await compare_phrases("深度学习", "神经网络")
  157. print(f"相似度: {result['相似度']:.3f}")
  158. print(f"说明:\n{result['说明']}")
  159. print()
  160. # 示例 2: 不相关的短语
  161. print("示例 2: 不相关的短语")
  162. print("-" * 80)
  163. result = await compare_phrases("编程", "吃饭")
  164. print(f"相似度: {result['相似度']:.3f}")
  165. print(f"说明:\n{result['说明']}")
  166. print()
  167. # 示例 3: 自定义权重,更倾向向量模型
  168. print("示例 3: 自定义权重 (向量:0.7, LLM:0.3)")
  169. print("-" * 80)
  170. result = await compare_phrases(
  171. "人工智能", "机器学习",
  172. weight_embedding=0.7,
  173. weight_semantic=0.3
  174. )
  175. print(f"相似度: {result['相似度']:.3f}")
  176. print(f"说明:\n{result['说明']}")
  177. print()
  178. # 示例 4: 完整输出示例
  179. print("示例 4: 完整输出示例")
  180. print("-" * 80)
  181. result = await compare_phrases("宿命感", "余华的小说")
  182. print(f"相似度: {result['相似度']:.3f}")
  183. print(f"说明:\n{result['说明']}")
  184. print()
  185. # 示例 5: 同步版本
  186. print("示例 5: 同步版本调用")
  187. print("-" * 80)
  188. result = compare_phrases_sync("Python", "编程语言")
  189. print(f"相似度: {result['相似度']:.3f}")
  190. print(f"说明:\n{result['说明']}")
  191. print()
  192. print("=" * 80)
  193. asyncio.run(main())