llm_cached.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. #!/usr/bin/env python3
  2. """
  3. 通用 LLM 缓存分析模块
  4. 根据 prompt + 模型 + 参数 进行缓存,相同输入直接返回缓存结果。
  5. """
  6. import asyncio
  7. import hashlib
  8. import json
  9. import re
  10. from dataclasses import dataclass, field
  11. from datetime import datetime
  12. from pathlib import Path
  13. from typing import Any, Callable, Dict, List, Optional
  14. from agents import Agent, Runner, ModelSettings
  15. from lib.client import get_model
  16. from lib.config import get_cache_dir
  17. from lib.utils import parse_json_from_text
  18. from lib.my_trace import set_trace_smith as set_trace
  19. # ===== 配置 =====
  20. @dataclass
  21. class LLMConfig:
  22. """LLM 配置"""
  23. model_name: str = "google/gemini-3-pro-preview"
  24. temperature: float = 0.0
  25. max_tokens: int = 65536
  26. def to_dict(self) -> Dict:
  27. return {
  28. "model_name": self.model_name,
  29. "temperature": self.temperature,
  30. "max_tokens": self.max_tokens,
  31. }
  32. @dataclass
  33. class AnalyzeResult:
  34. """分析结果(包含元数据)"""
  35. data: Any # 解析后的数据或原始字符串
  36. cache_hit: bool # 是否命中缓存
  37. model_name: str # 使用的模型
  38. cache_key: str # 缓存键
  39. log_url: Optional[str] = None # trace URL(仅当实际调用 LLM 时)
  40. retries: int = 0 # 重试次数
  41. def to_dict(self) -> Dict:
  42. return {
  43. "data": self.data,
  44. "cache_hit": self.cache_hit,
  45. "model_name": self.model_name,
  46. "cache_key": self.cache_key,
  47. "log_url": self.log_url,
  48. "retries": self.retries,
  49. }
  50. # 默认模型
  51. DEFAULT_MODEL = "google/gemini-3-pro-preview"
  52. # 预设配置
  53. PRESETS = {
  54. "default": LLMConfig(
  55. model_name=DEFAULT_MODEL,
  56. temperature=0.0,
  57. max_tokens=65536,
  58. ),
  59. "fast": LLMConfig(
  60. model_name="openai/gpt-4.1-mini",
  61. temperature=0.0,
  62. max_tokens=65536,
  63. ),
  64. "balanced": LLMConfig(
  65. model_name="google/gemini-2.5-flash-preview-05-20",
  66. temperature=0.0,
  67. max_tokens=65536,
  68. ),
  69. "quality": LLMConfig(
  70. model_name="anthropic/claude-sonnet-4",
  71. temperature=0.0,
  72. max_tokens=65536,
  73. ),
  74. "best": LLMConfig(
  75. model_name="google/gemini-2.5-pro-preview-05-06",
  76. temperature=0.0,
  77. max_tokens=65536,
  78. ),
  79. }
  80. # ===== 缓存工具函数 =====
  81. def _get_cache_dir(task_name: str) -> Path:
  82. """获取缓存目录"""
  83. return Path(get_cache_dir(f"llm_cached/{task_name}"))
  84. def _generate_cache_key(
  85. prompt: str,
  86. config: LLMConfig,
  87. ) -> str:
  88. """生成缓存键(MD5 哈希)"""
  89. cache_string = f"{prompt}||{config.model_name}||{config.temperature}||{config.max_tokens}"
  90. return hashlib.md5(cache_string.encode('utf-8')).hexdigest()
  91. def _sanitize_filename(text: str, max_length: int = 30) -> str:
  92. """将文本转换为安全的文件名"""
  93. sanitized = re.sub(r'[^\w\u4e00-\u9fff]', '_', text)
  94. sanitized = re.sub(r'_+', '_', sanitized)
  95. if len(sanitized) > max_length:
  96. sanitized = sanitized[:max_length]
  97. return sanitized.strip('_')
  98. def _get_cache_filepath(
  99. task_name: str,
  100. cache_key: str,
  101. prompt_preview: str,
  102. config: LLMConfig,
  103. ) -> Path:
  104. """
  105. 获取缓存文件路径
  106. 文件名格式: {prompt_preview}_{model}_{hash[:8]}.json
  107. """
  108. cache_dir = _get_cache_dir(task_name)
  109. # 清理 prompt 预览
  110. clean_preview = _sanitize_filename(prompt_preview, max_length=40)
  111. # 简化模型名
  112. model_short = config.model_name.split('/')[-1]
  113. model_short = _sanitize_filename(model_short, max_length=20)
  114. # 哈希前8位
  115. hash_short = cache_key[:8]
  116. filename = f"{clean_preview}_{model_short}_{hash_short}.json"
  117. return cache_dir / filename
  118. def _load_from_cache(
  119. task_name: str,
  120. cache_key: str,
  121. prompt_preview: str,
  122. config: LLMConfig,
  123. ) -> Optional[Dict]:
  124. """从缓存加载,返回 {raw: str, log_url: str}"""
  125. cache_file = _get_cache_filepath(task_name, cache_key, prompt_preview, config)
  126. # 如果文件不存在,尝试通过哈希匹配
  127. if not cache_file.exists():
  128. cache_dir = _get_cache_dir(task_name)
  129. if cache_dir.exists():
  130. hash_short = cache_key[:8]
  131. matching_files = list(cache_dir.glob(f"*_{hash_short}.json"))
  132. if matching_files:
  133. cache_file = matching_files[0]
  134. else:
  135. return None
  136. else:
  137. return None
  138. try:
  139. with open(cache_file, 'r', encoding='utf-8') as f:
  140. cached_data = json.load(f)
  141. return {
  142. "raw": cached_data['output']['raw'],
  143. "log_url": cached_data.get('metadata', {}).get('log_url'),
  144. }
  145. except (json.JSONDecodeError, IOError, KeyError):
  146. return None
  147. def _save_to_cache(
  148. task_name: str,
  149. cache_key: str,
  150. prompt_preview: str,
  151. prompt: str,
  152. config: LLMConfig,
  153. result: str,
  154. log_url: Optional[str] = None,
  155. ) -> None:
  156. """保存到缓存(包含 log_url)"""
  157. cache_file = _get_cache_filepath(task_name, cache_key, prompt_preview, config)
  158. cache_file.parent.mkdir(parents=True, exist_ok=True)
  159. # 尝试解析 JSON
  160. parsed_result = parse_json_from_text(result)
  161. cache_data = {
  162. "input": {
  163. "prompt": prompt,
  164. "prompt_preview": prompt_preview,
  165. **config.to_dict(),
  166. },
  167. "output": {
  168. "raw": result,
  169. "parsed": parsed_result,
  170. },
  171. "metadata": {
  172. "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
  173. "cache_key": cache_key,
  174. "cache_file": str(cache_file.name),
  175. "log_url": log_url,
  176. }
  177. }
  178. try:
  179. with open(cache_file, 'w', encoding='utf-8') as f:
  180. json.dump(cache_data, f, ensure_ascii=False, indent=2)
  181. except IOError:
  182. pass
  183. # ===== 核心 API =====
  184. async def analyze(
  185. prompt: str,
  186. task_name: str = "default",
  187. config: Optional[LLMConfig] = None,
  188. preset: Optional[str] = None,
  189. force: bool = False,
  190. parse_json: bool = True,
  191. max_retries: int = 3,
  192. log_url: Optional[str] = None,
  193. ) -> AnalyzeResult:
  194. """
  195. 通用 LLM 分析(带缓存)
  196. Args:
  197. prompt: 完整的 prompt
  198. task_name: 任务名称(用于缓存目录分类)
  199. config: LLM 配置,如果为 None 则使用 preset 或默认配置
  200. preset: 预设配置名称 ("default", "fast", "balanced", "quality", "best")
  201. force: 强制重新分析(跳过缓存),默认 False
  202. parse_json: 是否解析为 JSON
  203. max_retries: 最大重试次数(默认3次)
  204. log_url: 外部传入的 trace URL(如果为 None 且缓存未命中,则自动生成)
  205. Returns:
  206. AnalyzeResult 对象,包含:
  207. - data: 解析后的数据或原始字符串
  208. - cache_hit: 是否命中缓存
  209. - model_name: 使用的模型
  210. - cache_key: 缓存键
  211. - log_url: trace URL(仅当实际调用 LLM 时)
  212. - retries: 实际重试次数
  213. Examples:
  214. >>> # 使用缓存(默认)
  215. >>> result = await analyze("分析...", task_name="origin")
  216. >>> # 强制重新分析
  217. >>> result = await analyze("分析...", task_name="origin", force=True)
  218. >>> # 外部控制 trace(多个分析共享同一个 trace)
  219. >>> _, log_url = set_trace()
  220. >>> result1 = await analyze("分析1...", log_url=log_url)
  221. >>> result2 = await analyze("分析2...", log_url=log_url)
  222. """
  223. # 确定配置
  224. if config is None:
  225. if preset and preset in PRESETS:
  226. config = PRESETS[preset]
  227. else:
  228. config = PRESETS["default"]
  229. # 生成缓存键
  230. cache_key = _generate_cache_key(prompt, config)
  231. # prompt 预览(用于文件名)
  232. prompt_preview = prompt[:50].replace('\n', ' ')
  233. # 尝试从缓存加载(除非 force=True)
  234. if not force:
  235. cached_data = _load_from_cache(task_name, cache_key, prompt_preview, config)
  236. if cached_data is not None:
  237. cached_raw = cached_data["raw"]
  238. cached_log_url = cached_data.get("log_url")
  239. if parse_json:
  240. parsed = parse_json_from_text(cached_raw)
  241. if parsed:
  242. return AnalyzeResult(
  243. data=parsed,
  244. cache_hit=True,
  245. model_name=config.model_name,
  246. cache_key=cache_key,
  247. log_url=cached_log_url, # 返回缓存时的 log_url
  248. retries=0,
  249. )
  250. else:
  251. return AnalyzeResult(
  252. data=cached_raw,
  253. cache_hit=True,
  254. model_name=config.model_name,
  255. cache_key=cache_key,
  256. log_url=cached_log_url, # 返回缓存时的 log_url
  257. retries=0,
  258. )
  259. # 设置 trace(仅当实际调用 LLM 且未传入 log_url 时)
  260. if log_url is None:
  261. _, log_url = set_trace()
  262. # 创建 Agent
  263. agent = Agent(
  264. name=f"LLM-{task_name}",
  265. model=get_model(config.model_name),
  266. model_settings=ModelSettings(
  267. temperature=config.temperature,
  268. max_tokens=config.max_tokens,
  269. ),
  270. tools=[],
  271. )
  272. last_error = None
  273. retries = 0
  274. for attempt in range(max_retries):
  275. try:
  276. result = await Runner.run(agent, input=prompt)
  277. raw_output = result.final_output
  278. if parse_json:
  279. parsed = parse_json_from_text(raw_output)
  280. if parsed:
  281. # 解析成功,缓存并返回
  282. _save_to_cache(task_name, cache_key, prompt_preview, prompt, config, raw_output, log_url)
  283. return AnalyzeResult(
  284. data=parsed,
  285. cache_hit=False,
  286. model_name=config.model_name,
  287. cache_key=cache_key,
  288. log_url=log_url,
  289. retries=retries,
  290. )
  291. else:
  292. # 解析失败,重试
  293. retries += 1
  294. last_error = f"JSON 解析失败 (尝试 {attempt + 1}/{max_retries})\n响应: {raw_output[:500]}..."
  295. print(f" ⚠️ {last_error}")
  296. if attempt < max_retries - 1:
  297. await asyncio.sleep(1)
  298. else:
  299. # 不需要解析 JSON
  300. _save_to_cache(task_name, cache_key, prompt_preview, prompt, config, raw_output, log_url)
  301. return AnalyzeResult(
  302. data=raw_output,
  303. cache_hit=False,
  304. model_name=config.model_name,
  305. cache_key=cache_key,
  306. log_url=log_url,
  307. retries=retries,
  308. )
  309. except Exception as e:
  310. retries += 1
  311. last_error = f"API 调用失败 (尝试 {attempt + 1}/{max_retries}): {str(e)}"
  312. print(f" ⚠️ {last_error}")
  313. if attempt < max_retries - 1:
  314. await asyncio.sleep(1)
  315. raise ValueError(f"所有 {max_retries} 次重试均失败: {last_error}")
  316. async def analyze_batch(
  317. prompts: list[str],
  318. task_name: str = "default",
  319. config: Optional[LLMConfig] = None,
  320. preset: Optional[str] = None,
  321. force: bool = False,
  322. parse_json: bool = True,
  323. max_concurrent: int = 10,
  324. log_url: Optional[str] = None,
  325. progress_callback: Optional[Callable] = None,
  326. ) -> list[AnalyzeResult]:
  327. """
  328. 批量 LLM 分析(带并发控制)
  329. Args:
  330. prompts: prompt 列表
  331. task_name: 任务名称
  332. config: LLM 配置
  333. preset: 预设配置名称
  334. force: 强制重新分析(跳过缓存),默认 False
  335. parse_json: 是否解析为 JSON
  336. max_concurrent: 最大并发数
  337. log_url: 外部传入的 trace URL(所有分析共享同一个 trace)
  338. progress_callback: 进度回调
  339. Returns:
  340. AnalyzeResult 列表
  341. """
  342. # 如果没有传入 log_url,生成一个共享的
  343. if log_url is None:
  344. _, log_url = set_trace()
  345. semaphore = asyncio.Semaphore(max_concurrent)
  346. async def limited_analyze(prompt: str):
  347. async with semaphore:
  348. result = await analyze(
  349. prompt=prompt,
  350. task_name=task_name,
  351. config=config,
  352. preset=preset,
  353. force=force,
  354. parse_json=parse_json,
  355. log_url=log_url,
  356. )
  357. if progress_callback:
  358. progress_callback(1)
  359. return result
  360. tasks = [limited_analyze(p) for p in prompts]
  361. return await asyncio.gather(*tasks)
  362. # ===== 便捷函数 =====
  363. async def analyze_fast(prompt: str, task_name: str = "default", **kwargs) -> AnalyzeResult:
  364. """快速分析(使用 fast 预设)"""
  365. return await analyze(prompt, task_name=task_name, preset="fast", **kwargs)
  366. async def analyze_quality(prompt: str, task_name: str = "default", **kwargs) -> AnalyzeResult:
  367. """高质量分析(使用 quality 预设)"""
  368. return await analyze(prompt, task_name=task_name, preset="quality", **kwargs)
  369. # ===== 测试 =====
  370. if __name__ == "__main__":
  371. async def main():
  372. prompt = """
  373. 分析"猫咪"和"宠物"的关系,输出 JSON:
  374. ```json
  375. {
  376. "关系": "...",
  377. "说明": "..."
  378. }
  379. ```
  380. """
  381. print("测试 1: 基本用法(自动生成 trace)")
  382. result = await analyze(prompt, task_name="test", preset="fast")
  383. print(f" cache_hit: {result.cache_hit}")
  384. print(f" model: {result.model_name}")
  385. print(f" log_url: {result.log_url}")
  386. print(f" data: {result.data}")
  387. print("\n测试 2: 缓存命中")
  388. result = await analyze(prompt, task_name="test", preset="fast")
  389. print(f" cache_hit: {result.cache_hit}")
  390. print(f" log_url: {result.log_url}") # 应该是 None
  391. print("\n测试 3: 强制重新分析 (force=True)")
  392. _, shared_log_url = set_trace()
  393. print(f" 共享 trace: {shared_log_url}")
  394. result = await analyze(
  395. "输出 JSON: {\"test\": 123}",
  396. task_name="test",
  397. log_url=shared_log_url,
  398. force=True,
  399. )
  400. print(f" cache_hit: {result.cache_hit}")
  401. print(f" log_url: {result.log_url}")
  402. asyncio.run(main())