semantic_similarity.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. #!/usr/bin/env python3
  2. """
  3. 语义相似度分析模块
  4. 使用 AI Agent 判断两个短语之间的语义相似度
  5. """
  6. from agents import Agent, Runner, ModelSettings
  7. from lib.client import get_model
  8. from lib.utils import parse_json_from_text
  9. from typing import Dict, Any, Optional
  10. import hashlib
  11. import json
  12. import os
  13. from datetime import datetime
  14. from pathlib import Path
  15. # 默认提示词模板
  16. DEFAULT_PROMPT_TEMPLATE = """
  17. 从语意角度,判断【{phrase_a}】和【{phrase_b}】的相似度,从0-1打分,输出json格式
  18. ```json
  19. {{
  20. "说明": "简明扼要说明理由",
  21. "相似度": 0.0,
  22. }}
  23. ```
  24. """.strip()
  25. # 默认缓存目录
  26. DEFAULT_CACHE_DIR = "cache/semantic_similarity"
  27. def _generate_cache_key(
  28. phrase_a: str,
  29. phrase_b: str,
  30. model_name: str,
  31. temperature: float,
  32. max_tokens: int,
  33. prompt_template: str,
  34. instructions: str = None,
  35. tools: str = "[]"
  36. ) -> str:
  37. """
  38. 生成缓存键(哈希值)
  39. Args:
  40. phrase_a: 第一个短语
  41. phrase_b: 第二个短语
  42. model_name: 模型名称
  43. temperature: 温度参数
  44. max_tokens: 最大token数
  45. prompt_template: 提示词模板
  46. instructions: Agent 系统指令
  47. tools: 工具列表的 JSON 字符串
  48. Returns:
  49. 32位MD5哈希值
  50. """
  51. # 创建包含所有参数的字符串
  52. cache_string = f"{phrase_a}||{phrase_b}||{model_name}||{temperature}||{max_tokens}||{prompt_template}||{instructions}||{tools}"
  53. # 生成MD5哈希
  54. return hashlib.md5(cache_string.encode('utf-8')).hexdigest()
  55. def _sanitize_for_filename(text: str, max_length: int = 30) -> str:
  56. """
  57. 将文本转换为安全的文件名部分
  58. Args:
  59. text: 原始文本
  60. max_length: 最大长度
  61. Returns:
  62. 安全的文件名字符串
  63. """
  64. import re
  65. # 移除特殊字符,只保留中文、英文、数字、下划线
  66. sanitized = re.sub(r'[^\w\u4e00-\u9fff]', '_', text)
  67. # 移除连续的下划线
  68. sanitized = re.sub(r'_+', '_', sanitized)
  69. # 截断到最大长度
  70. if len(sanitized) > max_length:
  71. sanitized = sanitized[:max_length]
  72. return sanitized.strip('_')
  73. def _get_cache_filepath(
  74. cache_key: str,
  75. phrase_a: str,
  76. phrase_b: str,
  77. model_name: str,
  78. temperature: float,
  79. cache_dir: str = DEFAULT_CACHE_DIR
  80. ) -> Path:
  81. """
  82. 获取缓存文件路径(可读文件名)
  83. Args:
  84. cache_key: 缓存键(哈希值)
  85. phrase_a: 第一个短语
  86. phrase_b: 第二个短语
  87. model_name: 模型名称
  88. temperature: 温度参数
  89. cache_dir: 缓存目录
  90. Returns:
  91. 缓存文件的完整路径
  92. 文件名格式: {phrase_a}_vs_{phrase_b}_{model}_t{temp}_{hash[:8]}.json
  93. 示例: 宿命感_vs_余华的小说_gpt-4.1-mini_t0.0_a7f3e2d9.json
  94. """
  95. # 清理短语和模型名
  96. clean_a = _sanitize_for_filename(phrase_a, max_length=20)
  97. clean_b = _sanitize_for_filename(phrase_b, max_length=20)
  98. # 简化模型名(提取关键部分)
  99. model_short = model_name.split('/')[-1] # 例如: openai/gpt-4.1-mini -> gpt-4.1-mini
  100. model_short = _sanitize_for_filename(model_short, max_length=20)
  101. # 格式化温度参数
  102. temp_str = f"t{temperature:.1f}"
  103. # 使用哈希的前8位
  104. hash_short = cache_key[:8]
  105. # 组合文件名
  106. filename = f"{clean_a}_vs_{clean_b}_{model_short}_{temp_str}_{hash_short}.json"
  107. return Path(cache_dir) / filename
  108. def _load_from_cache(
  109. cache_key: str,
  110. phrase_a: str,
  111. phrase_b: str,
  112. model_name: str,
  113. temperature: float,
  114. cache_dir: str = DEFAULT_CACHE_DIR
  115. ) -> Optional[str]:
  116. """
  117. 从缓存加载数据
  118. Args:
  119. cache_key: 缓存键
  120. phrase_a: 第一个短语
  121. phrase_b: 第二个短语
  122. model_name: 模型名称
  123. temperature: 温度参数
  124. cache_dir: 缓存目录
  125. Returns:
  126. 缓存的结果字符串,如果不存在则返回 None
  127. """
  128. cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, temperature, cache_dir)
  129. # 如果文件不存在,尝试通过哈希匹配查找
  130. if not cache_file.exists():
  131. # 查找所有以该哈希结尾的文件
  132. cache_path = Path(cache_dir)
  133. if cache_path.exists():
  134. hash_short = cache_key[:8]
  135. matching_files = list(cache_path.glob(f"*_{hash_short}.json"))
  136. if matching_files:
  137. cache_file = matching_files[0]
  138. else:
  139. return None
  140. else:
  141. return None
  142. try:
  143. with open(cache_file, 'r', encoding='utf-8') as f:
  144. cached_data = json.load(f)
  145. return cached_data['output']['raw']
  146. except (json.JSONDecodeError, IOError, KeyError):
  147. return None
  148. def _save_to_cache(
  149. cache_key: str,
  150. phrase_a: str,
  151. phrase_b: str,
  152. model_name: str,
  153. temperature: float,
  154. max_tokens: int,
  155. prompt_template: str,
  156. instructions: str,
  157. tools: str,
  158. result: str,
  159. cache_dir: str = DEFAULT_CACHE_DIR
  160. ) -> None:
  161. """
  162. 保存数据到缓存
  163. Args:
  164. cache_key: 缓存键
  165. phrase_a: 第一个短语
  166. phrase_b: 第二个短语
  167. model_name: 模型名称
  168. temperature: 温度参数
  169. max_tokens: 最大token数
  170. prompt_template: 提示词模板
  171. instructions: Agent 系统指令
  172. tools: 工具列表的 JSON 字符串
  173. result: 结果数据(原始字符串)
  174. cache_dir: 缓存目录
  175. """
  176. cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, temperature, cache_dir)
  177. # 确保缓存目录存在
  178. cache_file.parent.mkdir(parents=True, exist_ok=True)
  179. # 尝试解析 result 为 JSON
  180. parsed_result = parse_json_from_text(result)
  181. # 准备缓存数据(包含完整的输入输出信息)
  182. cache_data = {
  183. "input": {
  184. "phrase_a": phrase_a,
  185. "phrase_b": phrase_b,
  186. "model_name": model_name,
  187. "temperature": temperature,
  188. "max_tokens": max_tokens,
  189. "prompt_template": prompt_template,
  190. "instructions": instructions,
  191. "tools": tools
  192. },
  193. "output": {
  194. "raw": result, # 保留原始响应
  195. "parsed": parsed_result # 解析后的JSON对象
  196. },
  197. "metadata": {
  198. "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
  199. "cache_key": cache_key,
  200. "cache_file": str(cache_file.name)
  201. }
  202. }
  203. try:
  204. with open(cache_file, 'w', encoding='utf-8') as f:
  205. json.dump(cache_data, f, ensure_ascii=False, indent=2)
  206. except IOError:
  207. pass # 静默失败,不影响主流程
  208. async def _difference_between_phrases(
  209. phrase_a: str,
  210. phrase_b: str,
  211. model_name: str = 'openai/gpt-4.1-mini',
  212. temperature: float = 0.0,
  213. max_tokens: int = 65536,
  214. prompt_template: str = None,
  215. instructions: str = None,
  216. tools: list = None,
  217. name: str = "Semantic Similarity Analyzer",
  218. use_cache: bool = True,
  219. cache_dir: str = DEFAULT_CACHE_DIR
  220. ) -> str:
  221. """
  222. 从语义角度判断两个短语的相似度
  223. Args:
  224. phrase_a: 第一个短语
  225. phrase_b: 第二个短语
  226. model_name: 使用的模型名称,可选值:
  227. - 'google/gemini-2.5-pro'
  228. - 'anthropic/claude-sonnet-4.5'
  229. - 'google/gemini-2.0-flash-001'
  230. - 'openai/gpt-5-mini'
  231. - 'anthropic/claude-haiku-4.5'
  232. - 'openai/gpt-4.1-mini' (默认)
  233. temperature: 模型温度参数,控制输出随机性,默认 0.0(确定性输出)
  234. max_tokens: 最大生成token数,默认 65536
  235. prompt_template: 自定义提示词模板,使用 {phrase_a} 和 {phrase_b} 作为占位符
  236. 如果为 None,使用默认模板
  237. instructions: Agent 的系统指令,默认为 None
  238. tools: Agent 可用的工具列表,默认为 []
  239. name: Agent 的名称,默认为 "Semantic Similarity Analyzer"(不参与缓存key构建)
  240. use_cache: 是否使用缓存,默认 True
  241. cache_dir: 缓存目录,默认 'cache/semantic_similarity'
  242. Returns:
  243. JSON 格式的相似度分析结果字符串
  244. Examples:
  245. >>> # 使用默认模板和缓存
  246. >>> result = await difference_between_phrases("宿命感", "余华的小说")
  247. >>> print(result)
  248. {
  249. "说明": "简明扼要说明理由",
  250. "相似度": 0.0
  251. }
  252. >>> # 禁用缓存
  253. >>> result = await difference_between_phrases(
  254. ... "宿命感", "余华的小说",
  255. ... use_cache=False
  256. ... )
  257. >>> # 使用自定义模板
  258. >>> custom_template = '''
  259. ... 请分析【{phrase_a}】和【{phrase_b}】的语义关联度
  260. ... 输出格式:{{"score": 0.0, "reason": "..."}}
  261. ... '''
  262. >>> result = await difference_between_phrases(
  263. ... "宿命感", "余华的小说",
  264. ... prompt_template=custom_template
  265. ... )
  266. """
  267. # 使用自定义模板或默认模板
  268. if prompt_template is None:
  269. prompt_template = DEFAULT_PROMPT_TEMPLATE
  270. # 默认tools为空列表
  271. if tools is None:
  272. tools = []
  273. # 生成缓存键(tools转为JSON字符串以便哈希)
  274. tools_str = json.dumps(tools, sort_keys=True) if tools else "[]"
  275. cache_key = _generate_cache_key(
  276. phrase_a, phrase_b, model_name, temperature, max_tokens, prompt_template, instructions, tools_str
  277. )
  278. # 尝试从缓存加载
  279. if use_cache:
  280. cached_result = _load_from_cache(cache_key, phrase_a, phrase_b, model_name, temperature, cache_dir)
  281. if cached_result is not None:
  282. return cached_result
  283. # 缓存未命中,调用 API
  284. agent = Agent(
  285. name=name,
  286. model=get_model(model_name),
  287. model_settings=ModelSettings(
  288. temperature=temperature,
  289. max_tokens=max_tokens,
  290. ),
  291. instructions=instructions,
  292. tools=tools,
  293. )
  294. # 格式化提示词
  295. prompt = prompt_template.format(phrase_a=phrase_a, phrase_b=phrase_b)
  296. result = await Runner.run(agent, input=prompt)
  297. final_output = result.final_output
  298. # 保存到缓存
  299. if use_cache:
  300. _save_to_cache(
  301. cache_key, phrase_a, phrase_b, model_name,
  302. temperature, max_tokens, prompt_template,
  303. instructions, tools_str, final_output, cache_dir
  304. )
  305. return final_output
  306. async def _difference_between_phrases_parsed(
  307. phrase_a: str,
  308. phrase_b: str,
  309. model_name: str = 'openai/gpt-4.1-mini',
  310. temperature: float = 0.0,
  311. max_tokens: int = 65536,
  312. prompt_template: str = None,
  313. instructions: str = None,
  314. tools: list = None,
  315. name: str = "Semantic Similarity Analyzer",
  316. use_cache: bool = True,
  317. cache_dir: str = DEFAULT_CACHE_DIR
  318. ) -> Dict[str, Any]:
  319. """
  320. 从语义角度判断两个短语的相似度,并解析返回结果为字典
  321. Args:
  322. phrase_a: 第一个短语
  323. phrase_b: 第二个短语
  324. model_name: 使用的模型名称
  325. temperature: 模型温度参数,控制输出随机性,默认 0.0(确定性输出)
  326. max_tokens: 最大生成token数,默认 65536
  327. prompt_template: 自定义提示词模板,使用 {phrase_a} 和 {phrase_b} 作为占位符
  328. instructions: Agent 的系统指令,默认为 None
  329. tools: Agent 可用的工具列表,默认为 []
  330. name: Agent 的名称,默认为 "Semantic Similarity Analyzer"
  331. use_cache: 是否使用缓存,默认 True
  332. cache_dir: 缓存目录,默认 'cache/semantic_similarity'
  333. Returns:
  334. 解析后的字典,包含:
  335. - 说明: 相似度判断的理由
  336. - 相似度: 0-1之间的浮点数
  337. Examples:
  338. >>> result = await difference_between_phrases_parsed("宿命感", "余华的小说")
  339. >>> print(result['相似度'])
  340. 0.3
  341. >>> print(result['说明'])
  342. "两个概念有一定关联..."
  343. """
  344. raw_result = await _difference_between_phrases(
  345. phrase_a, phrase_b, model_name, temperature, max_tokens,
  346. prompt_template, instructions, tools, name, use_cache, cache_dir
  347. )
  348. # 使用 utils.parse_json_from_text 解析结果
  349. parsed_result = parse_json_from_text(raw_result)
  350. # 如果解析失败(返回空字典),返回带错误信息的结果
  351. if not parsed_result:
  352. return {
  353. "说明": "解析失败: 无法从响应中提取有效的 JSON",
  354. "相似度": 0.0,
  355. "raw_response": raw_result
  356. }
  357. return parsed_result
  358. # ========== V1 版本(默认版本) ==========
  359. # 对外接口 - V1
  360. async def compare_phrases(
  361. phrase_a: str,
  362. phrase_b: str,
  363. model_name: str = 'openai/gpt-4.1-mini',
  364. temperature: float = 0.0,
  365. max_tokens: int = 65536,
  366. prompt_template: str = None,
  367. instructions: str = None,
  368. tools: list = None,
  369. name: str = "Semantic Similarity Analyzer",
  370. use_cache: bool = True,
  371. cache_dir: str = DEFAULT_CACHE_DIR
  372. ) -> Dict[str, Any]:
  373. """
  374. 比较两个短语的语义相似度(对外唯一接口)
  375. Args:
  376. phrase_a: 第一个短语
  377. phrase_b: 第二个短语
  378. model_name: 使用的模型名称
  379. temperature: 模型温度参数,控制输出随机性,默认 0.0(确定性输出)
  380. max_tokens: 最大生成token数,默认 65536
  381. prompt_template: 自定义提示词模板,使用 {phrase_a} 和 {phrase_b} 作为占位符
  382. instructions: Agent 的系统指令,默认为 None
  383. tools: Agent 可用的工具列表,默认为 []
  384. name: Agent 的名称,默认为 "Semantic Similarity Analyzer"
  385. use_cache: 是否使用缓存,默认 True
  386. cache_dir: 缓存目录,默认 'cache/semantic_similarity'
  387. Returns:
  388. 解析后的字典
  389. """
  390. return await _difference_between_phrases_parsed(
  391. phrase_a, phrase_b, model_name, temperature, max_tokens,
  392. prompt_template, instructions, tools, name, use_cache, cache_dir
  393. )
  394. if __name__ == "__main__":
  395. import asyncio
  396. async def main():
  397. """示例使用"""
  398. # 示例 1: 基本使用(使用缓存)
  399. print("示例 1: 基本使用")
  400. result = await compare_phrases("宿命感", "余华的小说")
  401. print(f"相似度: {result.get('相似度')}")
  402. print(f"说明: {result.get('说明')}")
  403. print()
  404. # 示例 2: 再次调用相同参数(应该从缓存读取)
  405. print("示例 2: 测试缓存")
  406. result = await compare_phrases("宿命感", "余华的小说")
  407. print(f"相似度: {result.get('相似度')}")
  408. print()
  409. # 示例 3: 自定义温度
  410. print("示例 3: 自定义温度(创意性输出)")
  411. result = await compare_phrases(
  412. "创意写作", "AI生成",
  413. temperature=0.7
  414. )
  415. print(f"相似度: {result.get('相似度')}")
  416. print(f"说明: {result.get('说明')}")
  417. print()
  418. # 示例 4: 自定义 Agent 名称
  419. print("示例 4: 自定义 Agent 名称")
  420. result = await compare_phrases(
  421. "人工智能", "机器学习",
  422. name="AI语义分析专家"
  423. )
  424. print(f"相似度: {result.get('相似度')}")
  425. print(f"说明: {result.get('说明')}")
  426. print()
  427. # 示例 5: 使用不同的模型
  428. print("示例 5: 使用 Claude 模型")
  429. result = await compare_phrases(
  430. "深度学习", "神经网络",
  431. model_name='anthropic/claude-haiku-4.5'
  432. )
  433. print(f"相似度: {result.get('相似度')}")
  434. print(f"说明: {result.get('说明')}")
  435. asyncio.run(main())
  436. # ========== V2 版本(示例:详细分析版本) ==========
  437. # V2 默认提示词模板(更详细的分析)
  438. DEFAULT_PROMPT_TEMPLATE_V2 = """
  439. 请深入分析【{phrase_a}】和【{phrase_b}】的语义关系,包括:
  440. 1. 语义相似度(0-1)
  441. 2. 关系类型(如:包含、相关、对立、无关等)
  442. 3. 详细说明
  443. 输出格式:
  444. ```json
  445. {{
  446. "相似度": 0.0,
  447. "关系类型": "相关/包含/对立/无关",
  448. "详细说明": "详细分析两者的语义关系...",
  449. "应用场景": "该关系在实际应用中的意义..."
  450. }}
  451. ```
  452. """.strip()
  453. # 对外接口 - V2
  454. async def compare_phrases_v2(
  455. phrase_a: str,
  456. phrase_b: str,
  457. model_name: str = 'anthropic/claude-sonnet-4.5', # V2 默认使用更强的模型
  458. temperature: float = 0.0,
  459. max_tokens: int = 65536,
  460. prompt_template: str = None,
  461. instructions: str = None,
  462. tools: list = None,
  463. name: str = "Advanced Semantic Analyzer",
  464. use_cache: bool = True,
  465. cache_dir: str = DEFAULT_CACHE_DIR
  466. ) -> Dict[str, Any]:
  467. """
  468. 比较两个短语的语义相似度 - V2 版本(详细分析)
  469. V2 特点:
  470. - 默认使用更强的模型(Claude Sonnet 4.5)
  471. - 更详细的分析输出(包含关系类型和应用场景)
  472. - 适合需要深入分析的场景
  473. Args:
  474. phrase_a: 第一个短语
  475. phrase_b: 第二个短语
  476. model_name: 使用的模型名称,默认 'anthropic/claude-sonnet-4.5'
  477. temperature: 模型温度参数,默认 0.0
  478. max_tokens: 最大生成token数,默认 65536
  479. prompt_template: 自定义提示词模板,默认使用 V2 详细模板
  480. instructions: Agent 的系统指令,默认为 None
  481. tools: Agent 可用的工具列表,默认为 []
  482. name: Agent 的名称,默认 "Advanced Semantic Analyzer"
  483. use_cache: 是否使用缓存,默认 True
  484. cache_dir: 缓存目录,默认 'cache/semantic_similarity'
  485. Returns:
  486. 解析后的字典,包含:
  487. - 相似度: 0-1之间的浮点数
  488. - 关系类型: 关系分类
  489. - 详细说明: 详细分析
  490. - 应用场景: 应用建议
  491. Examples:
  492. >>> result = await compare_phrases_v2("深度学习", "神经网络")
  493. >>> print(result['相似度'])
  494. 0.9
  495. >>> print(result['关系类型'])
  496. "包含"
  497. >>> print(result['详细说明'])
  498. "深度学习是基于人工神经网络的机器学习方法..."
  499. """
  500. # 使用 V2 默认模板(如果未指定)
  501. if prompt_template is None:
  502. prompt_template = DEFAULT_PROMPT_TEMPLATE_V2
  503. return await _difference_between_phrases_parsed(
  504. phrase_a, phrase_b, model_name, temperature, max_tokens,
  505. prompt_template, instructions, tools, name, use_cache, cache_dir
  506. )