semantic_similarity.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. #!/usr/bin/env python3
  2. """
  3. 语义相似度分析模块
  4. 使用 AI Agent 判断两个短语之间的语义相似度
  5. """
  6. from agents import Agent, Runner, ModelSettings
  7. from lib.client import get_model
  8. from lib.utils import parse_json_from_text
  9. from lib.config import get_cache_dir
  10. from typing import Dict, Any, Optional
  11. import hashlib
  12. import json
  13. import os
  14. from datetime import datetime
  15. from pathlib import Path
  16. # 默认提示词模板
  17. DEFAULT_PROMPT_TEMPLATE = """
  18. 从语意角度,判断【{phrase_a}】和【{phrase_b}】的相似度,从0-1打分,输出json格式
  19. ```json
  20. {{
  21. "说明": "简明扼要说明理由",
  22. "相似度": 0.0,
  23. }}
  24. ```
  25. """.strip()
  26. def _get_default_cache_dir() -> str:
  27. """获取默认缓存目录(从配置中读取)"""
  28. return get_cache_dir("semantic_similarity")
  29. def _generate_cache_key(
  30. phrase_a: str,
  31. phrase_b: str,
  32. model_name: str,
  33. temperature: float,
  34. max_tokens: int,
  35. prompt_template: str,
  36. instructions: str = None,
  37. tools: str = "[]"
  38. ) -> str:
  39. """
  40. 生成缓存键(哈希值)
  41. Args:
  42. phrase_a: 第一个短语
  43. phrase_b: 第二个短语
  44. model_name: 模型名称
  45. temperature: 温度参数
  46. max_tokens: 最大token数
  47. prompt_template: 提示词模板
  48. instructions: Agent 系统指令
  49. tools: 工具列表的 JSON 字符串
  50. Returns:
  51. 32位MD5哈希值
  52. """
  53. # 创建包含所有参数的字符串
  54. cache_string = f"{phrase_a}||{phrase_b}||{model_name}||{temperature}||{max_tokens}||{prompt_template}||{instructions}||{tools}"
  55. # 生成MD5哈希
  56. return hashlib.md5(cache_string.encode('utf-8')).hexdigest()
  57. def _sanitize_for_filename(text: str, max_length: int = 30) -> str:
  58. """
  59. 将文本转换为安全的文件名部分
  60. Args:
  61. text: 原始文本
  62. max_length: 最大长度
  63. Returns:
  64. 安全的文件名字符串
  65. """
  66. import re
  67. # 移除特殊字符,只保留中文、英文、数字、下划线
  68. sanitized = re.sub(r'[^\w\u4e00-\u9fff]', '_', text)
  69. # 移除连续的下划线
  70. sanitized = re.sub(r'_+', '_', sanitized)
  71. # 截断到最大长度
  72. if len(sanitized) > max_length:
  73. sanitized = sanitized[:max_length]
  74. return sanitized.strip('_')
  75. def _get_cache_filepath(
  76. cache_key: str,
  77. phrase_a: str,
  78. phrase_b: str,
  79. model_name: str,
  80. temperature: float,
  81. cache_dir: Optional[str] = None
  82. ) -> Path:
  83. """
  84. 获取缓存文件路径(可读文件名)
  85. Args:
  86. cache_key: 缓存键(哈希值)
  87. phrase_a: 第一个短语
  88. phrase_b: 第二个短语
  89. model_name: 模型名称
  90. temperature: 温度参数
  91. cache_dir: 缓存目录
  92. Returns:
  93. 缓存文件的完整路径
  94. 文件名格式: {phrase_a}_vs_{phrase_b}_{model}_t{temp}_{hash[:8]}.json
  95. 示例: 宿命感_vs_余华的小说_gpt-4.1-mini_t0.0_a7f3e2d9.json
  96. """
  97. if cache_dir is None:
  98. cache_dir = _get_default_cache_dir()
  99. # 清理短语和模型名
  100. clean_a = _sanitize_for_filename(phrase_a, max_length=20)
  101. clean_b = _sanitize_for_filename(phrase_b, max_length=20)
  102. # 简化模型名(提取关键部分)
  103. model_short = model_name.split('/')[-1] # 例如: openai/gpt-4.1-mini -> gpt-4.1-mini
  104. model_short = _sanitize_for_filename(model_short, max_length=20)
  105. # 格式化温度参数
  106. temp_str = f"t{temperature:.1f}"
  107. # 使用哈希的前8位
  108. hash_short = cache_key[:8]
  109. # 组合文件名
  110. filename = f"{clean_a}_vs_{clean_b}_{model_short}_{temp_str}_{hash_short}.json"
  111. return Path(cache_dir) / filename
  112. def _load_from_cache(
  113. cache_key: str,
  114. phrase_a: str,
  115. phrase_b: str,
  116. model_name: str,
  117. temperature: float,
  118. cache_dir: Optional[str] = None
  119. ) -> Optional[str]:
  120. """
  121. 从缓存加载数据
  122. Args:
  123. cache_key: 缓存键
  124. phrase_a: 第一个短语
  125. phrase_b: 第二个短语
  126. model_name: 模型名称
  127. temperature: 温度参数
  128. cache_dir: 缓存目录
  129. Returns:
  130. 缓存的结果字符串,如果不存在则返回 None
  131. """
  132. if cache_dir is None:
  133. cache_dir = _get_default_cache_dir()
  134. cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, temperature, cache_dir)
  135. # 如果文件不存在,尝试通过哈希匹配查找
  136. if not cache_file.exists():
  137. # 查找所有以该哈希结尾的文件
  138. cache_path = Path(cache_dir)
  139. if cache_path.exists():
  140. hash_short = cache_key[:8]
  141. matching_files = list(cache_path.glob(f"*_{hash_short}.json"))
  142. if matching_files:
  143. cache_file = matching_files[0]
  144. else:
  145. return None
  146. else:
  147. return None
  148. try:
  149. with open(cache_file, 'r', encoding='utf-8') as f:
  150. cached_data = json.load(f)
  151. return cached_data['output']['raw']
  152. except (json.JSONDecodeError, IOError, KeyError):
  153. return None
  154. def _save_to_cache(
  155. cache_key: str,
  156. phrase_a: str,
  157. phrase_b: str,
  158. model_name: str,
  159. temperature: float,
  160. max_tokens: int,
  161. prompt_template: str,
  162. instructions: str,
  163. tools: str,
  164. result: str,
  165. cache_dir: Optional[str] = None
  166. ) -> None:
  167. """
  168. 保存数据到缓存
  169. Args:
  170. cache_key: 缓存键
  171. phrase_a: 第一个短语
  172. phrase_b: 第二个短语
  173. model_name: 模型名称
  174. temperature: 温度参数
  175. max_tokens: 最大token数
  176. prompt_template: 提示词模板
  177. instructions: Agent 系统指令
  178. tools: 工具列表的 JSON 字符串
  179. result: 结果数据(原始字符串)
  180. cache_dir: 缓存目录
  181. """
  182. if cache_dir is None:
  183. cache_dir = _get_default_cache_dir()
  184. cache_file = _get_cache_filepath(cache_key, phrase_a, phrase_b, model_name, temperature, cache_dir)
  185. # 确保缓存目录存在
  186. cache_file.parent.mkdir(parents=True, exist_ok=True)
  187. # 尝试解析 result 为 JSON
  188. parsed_result = parse_json_from_text(result)
  189. # 准备缓存数据(包含完整的输入输出信息)
  190. cache_data = {
  191. "input": {
  192. "phrase_a": phrase_a,
  193. "phrase_b": phrase_b,
  194. "model_name": model_name,
  195. "temperature": temperature,
  196. "max_tokens": max_tokens,
  197. "prompt_template": prompt_template,
  198. "instructions": instructions,
  199. "tools": tools
  200. },
  201. "output": {
  202. "raw": result, # 保留原始响应
  203. "parsed": parsed_result # 解析后的JSON对象
  204. },
  205. "metadata": {
  206. "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
  207. "cache_key": cache_key,
  208. "cache_file": str(cache_file.name)
  209. }
  210. }
  211. try:
  212. with open(cache_file, 'w', encoding='utf-8') as f:
  213. json.dump(cache_data, f, ensure_ascii=False, indent=2)
  214. except IOError:
  215. pass # 静默失败,不影响主流程
  216. async def _difference_between_phrases(
  217. phrase_a: str,
  218. phrase_b: str,
  219. model_name: str = 'openai/gpt-4.1-mini',
  220. temperature: float = 0.0,
  221. max_tokens: int = 65536,
  222. prompt_template: str = None,
  223. instructions: str = None,
  224. tools: list = None,
  225. name: str = "Semantic Similarity Analyzer",
  226. use_cache: bool = True,
  227. cache_dir: Optional[str] = None
  228. ) -> str:
  229. """
  230. 从语义角度判断两个短语的相似度
  231. Args:
  232. phrase_a: 第一个短语
  233. phrase_b: 第二个短语
  234. model_name: 使用的模型名称,可选值:
  235. - 'google/gemini-2.5-pro'
  236. - 'anthropic/claude-sonnet-4.5'
  237. - 'google/gemini-2.0-flash-001'
  238. - 'openai/gpt-5-mini'
  239. - 'anthropic/claude-haiku-4.5'
  240. - 'openai/gpt-4.1-mini' (默认)
  241. temperature: 模型温度参数,控制输出随机性,默认 0.0(确定性输出)
  242. max_tokens: 最大生成token数,默认 65536
  243. prompt_template: 自定义提示词模板,使用 {phrase_a} 和 {phrase_b} 作为占位符
  244. 如果为 None,使用默认模板
  245. instructions: Agent 的系统指令,默认为 None
  246. tools: Agent 可用的工具列表,默认为 []
  247. name: Agent 的名称,默认为 "Semantic Similarity Analyzer"(不参与缓存key构建)
  248. use_cache: 是否使用缓存,默认 True
  249. cache_dir: 缓存目录,默认从配置读取(可通过 lib.config 设置)
  250. Returns:
  251. JSON 格式的相似度分析结果字符串
  252. Examples:
  253. >>> # 使用默认模板和缓存
  254. >>> result = await difference_between_phrases("宿命感", "余华的小说")
  255. >>> print(result)
  256. {
  257. "说明": "简明扼要说明理由",
  258. "相似度": 0.0
  259. }
  260. >>> # 禁用缓存
  261. >>> result = await difference_between_phrases(
  262. ... "宿命感", "余华的小说",
  263. ... use_cache=False
  264. ... )
  265. >>> # 使用自定义模板
  266. >>> custom_template = '''
  267. ... 请分析【{phrase_a}】和【{phrase_b}】的语义关联度
  268. ... 输出格式:{{"score": 0.0, "reason": "..."}}
  269. ... '''
  270. >>> result = await difference_between_phrases(
  271. ... "宿命感", "余华的小说",
  272. ... prompt_template=custom_template
  273. ... )
  274. """
  275. # 使用自定义模板或默认模板
  276. if prompt_template is None:
  277. prompt_template = DEFAULT_PROMPT_TEMPLATE
  278. # 默认tools为空列表
  279. if tools is None:
  280. tools = []
  281. # 生成缓存键(tools转为JSON字符串以便哈希)
  282. tools_str = json.dumps(tools, sort_keys=True) if tools else "[]"
  283. cache_key = _generate_cache_key(
  284. phrase_a, phrase_b, model_name, temperature, max_tokens, prompt_template, instructions, tools_str
  285. )
  286. # 尝试从缓存加载
  287. if use_cache:
  288. cached_result = _load_from_cache(cache_key, phrase_a, phrase_b, model_name, temperature, cache_dir)
  289. if cached_result is not None:
  290. return cached_result
  291. # 缓存未命中,调用 API
  292. agent = Agent(
  293. name=name,
  294. model=get_model(model_name),
  295. model_settings=ModelSettings(
  296. temperature=temperature,
  297. max_tokens=max_tokens,
  298. ),
  299. instructions=instructions,
  300. tools=tools,
  301. )
  302. # 格式化提示词
  303. prompt = prompt_template.format(phrase_a=phrase_a, phrase_b=phrase_b)
  304. result = await Runner.run(agent, input=prompt)
  305. final_output = result.final_output
  306. # 注意:不在这里缓存,而是在解析成功后缓存
  307. # 这样可以避免缓存解析失败的响应
  308. return final_output
  309. async def _difference_between_phrases_parsed(
  310. phrase_a: str,
  311. phrase_b: str,
  312. model_name: str = 'openai/gpt-4.1-mini',
  313. temperature: float = 0.0,
  314. max_tokens: int = 65536,
  315. prompt_template: str = None,
  316. instructions: str = None,
  317. tools: list = None,
  318. name: str = "Semantic Similarity Analyzer",
  319. use_cache: bool = True,
  320. cache_dir: Optional[str] = None
  321. ) -> Dict[str, Any]:
  322. """
  323. 从语义角度判断两个短语的相似度,并解析返回结果为字典
  324. Args:
  325. phrase_a: 第一个短语
  326. phrase_b: 第二个短语
  327. model_name: 使用的模型名称
  328. temperature: 模型温度参数,控制输出随机性,默认 0.0(确定性输出)
  329. max_tokens: 最大生成token数,默认 65536
  330. prompt_template: 自定义提示词模板,使用 {phrase_a} 和 {phrase_b} 作为占位符
  331. instructions: Agent 的系统指令,默认为 None
  332. tools: Agent 可用的工具列表,默认为 []
  333. name: Agent 的名称,默认为 "Semantic Similarity Analyzer"
  334. use_cache: 是否使用缓存,默认 True
  335. cache_dir: 缓存目录,默认从配置读取(可通过 lib.config 设置)
  336. Returns:
  337. 解析后的字典,包含:
  338. - 说明: 相似度判断的理由
  339. - 相似度: 0-1之间的浮点数
  340. Raises:
  341. ValueError: 当无法解析AI响应为有效JSON时抛出
  342. Examples:
  343. >>> result = await difference_between_phrases_parsed("宿命感", "余华的小说")
  344. >>> print(result['相似度'])
  345. 0.3
  346. >>> print(result['说明'])
  347. "两个概念有一定关联..."
  348. """
  349. # 使用默认模板或自定义模板
  350. if prompt_template is None:
  351. prompt_template = DEFAULT_PROMPT_TEMPLATE
  352. # 默认tools为空列表
  353. if tools is None:
  354. tools = []
  355. # 生成缓存键
  356. tools_str = json.dumps(tools, sort_keys=True) if tools else "[]"
  357. cache_key = _generate_cache_key(
  358. phrase_a, phrase_b, model_name, temperature, max_tokens, prompt_template, instructions, tools_str
  359. )
  360. # 尝试从缓存加载
  361. if use_cache:
  362. cached_result = _load_from_cache(cache_key, phrase_a, phrase_b, model_name, temperature, cache_dir)
  363. if cached_result is not None:
  364. # 缓存命中,直接解析并返回
  365. parsed_result = parse_json_from_text(cached_result)
  366. if parsed_result:
  367. return parsed_result
  368. # 如果缓存的内容也无法解析,继续执行API调用(可能之前缓存了错误响应)
  369. # 调用AI获取原始响应(不传use_cache,因为我们在这里手动处理缓存)
  370. raw_result = await _difference_between_phrases(
  371. phrase_a, phrase_b, model_name, temperature, max_tokens,
  372. prompt_template, instructions, tools, name, use_cache=False, cache_dir=cache_dir
  373. )
  374. # 使用 utils.parse_json_from_text 解析结果
  375. parsed_result = parse_json_from_text(raw_result)
  376. # 如果解析失败(返回空字典),抛出异常并包含详细信息
  377. if not parsed_result:
  378. # 格式化prompt用于错误信息
  379. formatted_prompt = prompt_template.format(phrase_a=phrase_a, phrase_b=phrase_b)
  380. error_msg = f"""
  381. JSON解析失败!
  382. ================================================================================
  383. 短语A: {phrase_a}
  384. 短语B: {phrase_b}
  385. 模型: {model_name}
  386. 温度: {temperature}
  387. ================================================================================
  388. Prompt:
  389. {formatted_prompt}
  390. ================================================================================
  391. AI响应 (长度: {len(raw_result)}):
  392. {raw_result}
  393. ================================================================================
  394. """
  395. raise ValueError(error_msg)
  396. # 只有解析成功后才缓存
  397. if use_cache:
  398. _save_to_cache(
  399. cache_key, phrase_a, phrase_b, model_name,
  400. temperature, max_tokens, prompt_template,
  401. instructions, tools_str, raw_result, cache_dir
  402. )
  403. return parsed_result
  404. # ========== V1 版本(默认版本) ==========
  405. # 对外接口 - V1
  406. async def compare_phrases(
  407. phrase_a: str,
  408. phrase_b: str,
  409. model_name: str = 'openai/gpt-4.1-mini',
  410. temperature: float = 0.0,
  411. max_tokens: int = 65536,
  412. prompt_template: str = None,
  413. instructions: str = None,
  414. tools: list = None,
  415. name: str = "Semantic Similarity Analyzer",
  416. use_cache: bool = True,
  417. cache_dir: Optional[str] = None
  418. ) -> Dict[str, Any]:
  419. """
  420. 比较两个短语的语义相似度(对外唯一接口)
  421. Args:
  422. phrase_a: 第一个短语
  423. phrase_b: 第二个短语
  424. model_name: 使用的模型名称
  425. temperature: 模型温度参数,控制输出随机性,默认 0.0(确定性输出)
  426. max_tokens: 最大生成token数,默认 65536
  427. prompt_template: 自定义提示词模板,使用 {phrase_a} 和 {phrase_b} 作为占位符
  428. instructions: Agent 的系统指令,默认为 None
  429. tools: Agent 可用的工具列表,默认为 []
  430. name: Agent 的名称,默认为 "Semantic Similarity Analyzer"
  431. use_cache: 是否使用缓存,默认 True
  432. cache_dir: 缓存目录,默认从配置读取(可通过 lib.config 设置)
  433. Returns:
  434. 解析后的字典
  435. """
  436. return await _difference_between_phrases_parsed(
  437. phrase_a, phrase_b, model_name, temperature, max_tokens,
  438. prompt_template, instructions, tools, name, use_cache, cache_dir
  439. )
  440. if __name__ == "__main__":
  441. import asyncio
  442. async def main():
  443. """示例使用"""
  444. # 示例 1: 基本使用(使用缓存)
  445. print("示例 1: 基本使用")
  446. result = await compare_phrases("宿命感", "余华的小说")
  447. print(f"相似度: {result.get('相似度')}")
  448. print(f"说明: {result.get('说明')}")
  449. print()
  450. # 示例 2: 再次调用相同参数(应该从缓存读取)
  451. print("示例 2: 测试缓存")
  452. result = await compare_phrases("宿命感", "余华的小说")
  453. print(f"相似度: {result.get('相似度')}")
  454. print()
  455. # 示例 3: 自定义温度
  456. print("示例 3: 自定义温度(创意性输出)")
  457. result = await compare_phrases(
  458. "创意写作", "AI生成",
  459. temperature=0.7
  460. )
  461. print(f"相似度: {result.get('相似度')}")
  462. print(f"说明: {result.get('说明')}")
  463. print()
  464. # 示例 4: 自定义 Agent 名称
  465. print("示例 4: 自定义 Agent 名称")
  466. result = await compare_phrases(
  467. "人工智能", "机器学习",
  468. name="AI语义分析专家"
  469. )
  470. print(f"相似度: {result.get('相似度')}")
  471. print(f"说明: {result.get('说明')}")
  472. print()
  473. # 示例 5: 使用不同的模型
  474. print("示例 5: 使用 Claude 模型")
  475. result = await compare_phrases(
  476. "深度学习", "神经网络",
  477. model_name='anthropic/claude-haiku-4.5'
  478. )
  479. print(f"相似度: {result.get('相似度')}")
  480. print(f"说明: {result.get('说明')}")
  481. asyncio.run(main())
  482. # ========== V2 版本(示例:详细分析版本) ==========
  483. # V2 默认提示词模板(更详细的分析)
  484. DEFAULT_PROMPT_TEMPLATE_V2 = """
  485. 请深入分析【{phrase_a}】和【{phrase_b}】的语义关系,包括:
  486. 1. 语义相似度(0-1)
  487. 2. 关系类型(如:包含、相关、对立、无关等)
  488. 3. 详细说明
  489. 输出格式:
  490. ```json
  491. {{
  492. "相似度": 0.0,
  493. "关系类型": "相关/包含/对立/无关",
  494. "详细说明": "详细分析两者的语义关系...",
  495. "应用场景": "该关系在实际应用中的意义..."
  496. }}
  497. ```
  498. """.strip()
  499. # 对外接口 - V2
  500. async def compare_phrases_v2(
  501. phrase_a: str,
  502. phrase_b: str,
  503. model_name: str = 'anthropic/claude-sonnet-4.5', # V2 默认使用更强的模型
  504. temperature: float = 0.0,
  505. max_tokens: int = 65536,
  506. prompt_template: str = None,
  507. instructions: str = None,
  508. tools: list = None,
  509. name: str = "Advanced Semantic Analyzer",
  510. use_cache: bool = True,
  511. cache_dir: Optional[str] = None
  512. ) -> Dict[str, Any]:
  513. """
  514. 比较两个短语的语义相似度 - V2 版本(详细分析)
  515. V2 特点:
  516. - 默认使用更强的模型(Claude Sonnet 4.5)
  517. - 更详细的分析输出(包含关系类型和应用场景)
  518. - 适合需要深入分析的场景
  519. Args:
  520. phrase_a: 第一个短语
  521. phrase_b: 第二个短语
  522. model_name: 使用的模型名称,默认 'anthropic/claude-sonnet-4.5'
  523. temperature: 模型温度参数,默认 0.0
  524. max_tokens: 最大生成token数,默认 65536
  525. prompt_template: 自定义提示词模板,默认使用 V2 详细模板
  526. instructions: Agent 的系统指令,默认为 None
  527. tools: Agent 可用的工具列表,默认为 []
  528. name: Agent 的名称,默认 "Advanced Semantic Analyzer"
  529. use_cache: 是否使用缓存,默认 True
  530. cache_dir: 缓存目录,默认从配置读取(可通过 lib.config 设置)
  531. Returns:
  532. 解析后的字典,包含:
  533. - 相似度: 0-1之间的浮点数
  534. - 关系类型: 关系分类
  535. - 详细说明: 详细分析
  536. - 应用场景: 应用建议
  537. Examples:
  538. >>> result = await compare_phrases_v2("深度学习", "神经网络")
  539. >>> print(result['相似度'])
  540. 0.9
  541. >>> print(result['关系类型'])
  542. "包含"
  543. >>> print(result['详细说明'])
  544. "深度学习是基于人工神经网络的机器学习方法..."
  545. """
  546. # 使用 V2 默认模板(如果未指定)
  547. if prompt_template is None:
  548. prompt_template = DEFAULT_PROMPT_TEMPLATE_V2
  549. return await _difference_between_phrases_parsed(
  550. phrase_a, phrase_b, model_name, temperature, max_tokens,
  551. prompt_template, instructions, tools, name, use_cache, cache_dir
  552. )