text_embedding.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. #!/usr/bin/env python3
  2. """
  3. 文本相似度计算模块
  4. 基于 similarities 库(真正的向量模型,不使用 LLM)
  5. """
  6. from typing import Dict, Any, Optional
  7. import hashlib
  8. import json
  9. from pathlib import Path
  10. from datetime import datetime
  11. import threading
  12. # 支持的模型列表
  13. SUPPORTED_MODELS = {
  14. "chinese": "shibing624/text2vec-base-chinese", # 默认,中文通用
  15. "multilingual": "shibing624/text2vec-base-multilingual", # 多语言(中英韩日德意等)
  16. "paraphrase": "shibing624/text2vec-base-chinese-paraphrase", # 中文长文本
  17. "sentence": "shibing624/text2vec-base-chinese-sentence", # 中文短句子
  18. }
  19. # 延迟导入 similarities,避免初始化时就加载模型
  20. _similarity_models = {} # 存储多个模型实例
  21. _model_lock = threading.Lock() # 线程锁,保护模型加载
  22. # 默认缓存目录
  23. DEFAULT_CACHE_DIR = "cache/text_embedding"
  24. def _generate_cache_key(phrase_a: str, phrase_b: str, model_name: str) -> str:
  25. """
  26. 生成缓存键(哈希值)
  27. Args:
  28. phrase_a: 第一个短语
  29. phrase_b: 第二个短语
  30. model_name: 模型名称
  31. Returns:
  32. 32位MD5哈希值
  33. """
  34. cache_string = f"{phrase_a}||{phrase_b}||{model_name}"
  35. return hashlib.md5(cache_string.encode('utf-8')).hexdigest()
  36. def _sanitize_for_filename(text: str, max_length: int = 30) -> str:
  37. """
  38. 将文本转换为安全的文件名部分
  39. Args:
  40. text: 原始文本
  41. max_length: 最大长度
  42. Returns:
  43. 安全的文件名字符串
  44. """
  45. import re
  46. # 移除特殊字符,只保留中文、英文、数字、下划线
  47. sanitized = re.sub(r'[^\w\u4e00-\u9fff]', '_', text)
  48. # 移除连续的下划线
  49. sanitized = re.sub(r'_+', '_', sanitized)
  50. # 截断到最大长度
  51. if len(sanitized) > max_length:
  52. sanitized = sanitized[:max_length]
  53. return sanitized.strip('_')
  54. def _get_cache_filepath(
  55. cache_key: str,
  56. phrase_a: str,
  57. phrase_b: str,
  58. model_name: str,
  59. cache_dir: str = DEFAULT_CACHE_DIR
  60. ) -> Path:
  61. """
  62. 获取缓存文件路径(可读文件名)
  63. Args:
  64. cache_key: 缓存键(哈希值)
  65. phrase_a: 第一个短语
  66. phrase_b: 第二个短语
  67. model_name: 模型名称
  68. cache_dir: 缓存目录
  69. Returns:
  70. 缓存文件的完整路径
  71. 文件名格式: {phrase_a}_vs_{phrase_b}_{model}_{hash[:8]}.json
  72. """
  73. # 清理短语和模型名
  74. clean_a = _sanitize_for_filename(phrase_a, max_length=20)
  75. clean_b = _sanitize_for_filename(phrase_b, max_length=20)
  76. # 简化模型名(提取关键部分)
  77. model_short = model_name.split('/')[-1]
  78. model_short = _sanitize_for_filename(model_short, max_length=20)
  79. # 使用哈希的前8位
  80. hash_short = cache_key[:8]
  81. # 组合文件名
  82. filename = f"{clean_a}_vs_{clean_b}_{model_short}_{hash_short}.json"
  83. return Path(cache_dir) / filename
  84. def _load_from_cache(
  85. cache_key: str,
  86. phrase_a: str,
  87. phrase_b: str,
  88. model_name: str,
  89. cache_dir: str = DEFAULT_CACHE_DIR
  90. ) -> Optional[Dict[str, Any]]:
  91. """
  92. 从缓存加载数据
  93. Args:
  94. cache_key: 缓存键
  95. phrase_a: 第一个短语
  96. phrase_b: 第二个短语
  97. model_name: 模型名称
  98. cache_dir: 缓存目录
  99. Returns:
  100. 缓存的结果字典,如果不存在则返回 None
  101. """
  102. cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, cache_dir)
  103. # 如果文件不存在,尝试通过哈希匹配查找
  104. if not cache_file.exists():
  105. cache_path = Path(cache_dir)
  106. if cache_path.exists():
  107. hash_short = cache_key[:8]
  108. matching_files = list(cache_path.glob(f"*_{hash_short}.json"))
  109. if matching_files:
  110. cache_file = matching_files[0]
  111. else:
  112. return None
  113. else:
  114. return None
  115. try:
  116. with open(cache_file, 'r', encoding='utf-8') as f:
  117. cached_data = json.load(f)
  118. return cached_data['output']
  119. except (json.JSONDecodeError, IOError, KeyError):
  120. return None
  121. def _save_to_cache(
  122. cache_key: str,
  123. phrase_a: str,
  124. phrase_b: str,
  125. model_name: str,
  126. result: Dict[str, Any],
  127. cache_dir: str = DEFAULT_CACHE_DIR
  128. ) -> None:
  129. """
  130. 保存数据到缓存
  131. Args:
  132. cache_key: 缓存键
  133. phrase_a: 第一个短语
  134. phrase_b: 第二个短语
  135. model_name: 模型名称
  136. result: 结果数据(字典格式)
  137. cache_dir: 缓存目录
  138. """
  139. cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, cache_dir)
  140. # 确保缓存目录存在
  141. cache_file.parent.mkdir(parents=True, exist_ok=True)
  142. # 准备缓存数据
  143. cache_data = {
  144. "input": {
  145. "phrase_a": phrase_a,
  146. "phrase_b": phrase_b,
  147. "model_name": model_name,
  148. },
  149. "output": result,
  150. "metadata": {
  151. "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
  152. "cache_key": cache_key,
  153. "cache_file": str(cache_file.name)
  154. }
  155. }
  156. try:
  157. with open(cache_file, 'w', encoding='utf-8') as f:
  158. json.dump(cache_data, f, ensure_ascii=False, indent=2)
  159. except IOError:
  160. pass # 静默失败,不影响主流程
  161. def _get_similarity_model(model_name: str = "shibing624/text2vec-base-chinese"):
  162. """
  163. 获取或初始化相似度模型(支持多个模型,线程安全)
  164. Args:
  165. model_name: 模型名称
  166. Returns:
  167. BertSimilarity 模型实例
  168. """
  169. global _similarity_models, _model_lock
  170. # 如果是简称,转换为完整名称
  171. if model_name in SUPPORTED_MODELS:
  172. model_name = SUPPORTED_MODELS[model_name]
  173. # 快速路径:如果模型已加载,直接返回(无锁检查)
  174. if model_name in _similarity_models:
  175. return _similarity_models[model_name]
  176. # 慢速路径:需要加载模型(使用锁保护)
  177. with _model_lock:
  178. # 双重检查:可能在等待锁时其他线程已经加载了
  179. if model_name in _similarity_models:
  180. return _similarity_models[model_name]
  181. # 加载新模型
  182. try:
  183. from similarities import BertSimilarity
  184. print(f"正在加载模型: {model_name}...")
  185. _similarity_models[model_name] = BertSimilarity(model_name_or_path=model_name)
  186. print("模型加载完成!")
  187. return _similarity_models[model_name]
  188. except ImportError:
  189. raise ImportError(
  190. "请先安装 similarities 库: pip install -U similarities torch"
  191. )
  192. def compare_phrases(
  193. phrase_a: str,
  194. phrase_b: str,
  195. model_name: str = "chinese",
  196. use_cache: bool = True,
  197. cache_dir: str = DEFAULT_CACHE_DIR
  198. ) -> Dict[str, Any]:
  199. """
  200. 比较两个短语的语义相似度(兼容 semantic_similarity.py 的接口)
  201. 返回格式与 semantic_similarity.compare_phrases() 一致:
  202. {
  203. "说明": "基于向量模型计算的语义相似度",
  204. "相似度": 0.85
  205. }
  206. Args:
  207. phrase_a: 第一个短语
  208. phrase_b: 第二个短语
  209. model_name: 模型名称,可选:
  210. 简称:
  211. - "chinese" (默认) - 中文通用模型
  212. - "multilingual" - 多语言模型(中英韩日德意等)
  213. - "paraphrase" - 中文长文本模型
  214. - "sentence" - 中文短句子模型
  215. 完整名称:
  216. - "shibing624/text2vec-base-chinese"
  217. - "shibing624/text2vec-base-multilingual"
  218. - "shibing624/text2vec-base-chinese-paraphrase"
  219. - "shibing624/text2vec-base-chinese-sentence"
  220. use_cache: 是否使用缓存,默认 True
  221. cache_dir: 缓存目录,默认 'cache/text_embedding'
  222. Returns:
  223. {
  224. "说明": str, # 相似度说明
  225. "相似度": float # 0-1之间的相似度分数
  226. }
  227. Examples:
  228. >>> # 使用默认模型
  229. >>> result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡")
  230. >>> print(result['相似度']) # 0.855
  231. >>> # 使用多语言模型
  232. >>> result = compare_phrases("Hello", "Hi", model_name="multilingual")
  233. >>> # 使用长文本模型
  234. >>> result = compare_phrases("长文本1...", "长文本2...", model_name="paraphrase")
  235. >>> # 禁用缓存
  236. >>> result = compare_phrases("测试", "测试", use_cache=False)
  237. """
  238. # 转换简称为完整名称(用于缓存键)
  239. full_model_name = SUPPORTED_MODELS.get(model_name, model_name)
  240. # 生成缓存键
  241. cache_key = _generate_cache_key(phrase_a, phrase_b, full_model_name)
  242. # 尝试从缓存加载
  243. if use_cache:
  244. cached_result = _load_from_cache(cache_key, phrase_a, phrase_b, full_model_name, cache_dir)
  245. if cached_result is not None:
  246. return cached_result
  247. # 缓存未命中,计算相似度
  248. model = _get_similarity_model(model_name)
  249. score = float(model.similarity(phrase_a, phrase_b))
  250. # 生成说明
  251. if score >= 0.9:
  252. level = "极高"
  253. elif score >= 0.7:
  254. level = "高"
  255. elif score >= 0.5:
  256. level = "中等"
  257. elif score >= 0.3:
  258. level = "较低"
  259. else:
  260. level = "低"
  261. explanation = f"基于向量模型计算的语义相似度为 {level} ({score:.2f})"
  262. result = {
  263. "说明": explanation,
  264. "相似度": score
  265. }
  266. # 保存到缓存
  267. if use_cache:
  268. _save_to_cache(cache_key, phrase_a, phrase_b, full_model_name, result, cache_dir)
  269. return result
  270. if __name__ == "__main__":
  271. print("=" * 60)
  272. print("text_embedding - 文本相似度计算(带缓存)")
  273. print("=" * 60)
  274. print()
  275. # 示例 1: 默认模型(首次调用,会保存缓存)
  276. print("示例 1: 默认模型(chinese)- 首次调用")
  277. result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡")
  278. print(f"相似度: {result['相似度']:.3f}")
  279. print(f"说明: {result['说明']}")
  280. print()
  281. # 示例 2: 再次调用相同参数(从缓存读取)
  282. print("示例 2: 测试缓存 - 再次调用相同参数")
  283. result = compare_phrases("如何更换花呗绑定银行卡", "花呗更改绑定银行卡")
  284. print(f"相似度: {result['相似度']:.3f}")
  285. print(f"说明: {result['说明']}")
  286. print("(应该从缓存读取,速度更快)")
  287. print()
  288. # 示例 3: 短句子
  289. print("示例 3: 使用默认模型")
  290. result = compare_phrases("深度学习", "神经网络")
  291. print(f"相似度: {result['相似度']:.3f}")
  292. print(f"说明: {result['说明']}")
  293. print()
  294. # 示例 4: 不相关
  295. print("示例 4: 不相关的短语")
  296. result = compare_phrases("编程", "吃饭")
  297. print(f"相似度: {result['相似度']:.3f}")
  298. print(f"说明: {result['说明']}")
  299. print()
  300. # 示例 5: 多语言模型
  301. print("示例 5: 多语言模型(multilingual)")
  302. result = compare_phrases("Hello", "Hi", model_name="multilingual")
  303. print(f"相似度: {result['相似度']:.3f}")
  304. print(f"说明: {result['说明']}")
  305. print()
  306. # 示例 6: 禁用缓存
  307. print("示例 6: 禁用缓存")
  308. result = compare_phrases("测试", "测试", use_cache=False)
  309. print(f"相似度: {result['相似度']:.3f}")
  310. print(f"说明: {result['说明']}")
  311. print()
  312. print("=" * 60)
  313. print("支持的模型:")
  314. print("-" * 60)
  315. for key, value in SUPPORTED_MODELS.items():
  316. print(f" {key:15s} -> {value}")
  317. print("=" * 60)
  318. print()
  319. print("缓存目录: cache/text_embedding/")
  320. print("缓存文件格式: {phrase_a}_vs_{phrase_b}_{model}_{hash[:8]}.json")
  321. print("=" * 60)