#!/usr/bin/env python3 """ 通用 LLM 缓存分析模块 根据 prompt + 模型 + 参数 进行缓存,相同输入直接返回缓存结果。 """ import asyncio import hashlib import json import re from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, List, Optional from agents import Agent, Runner, ModelSettings from lib.client import get_model from lib.config import get_cache_dir from lib.utils import parse_json_from_text from lib.my_trace import set_trace_smith as set_trace # ===== 配置 ===== @dataclass class LLMConfig: """LLM 配置""" model_name: str = "google/gemini-3-pro-preview" temperature: float = 0.0 max_tokens: int = 65536 def to_dict(self) -> Dict: return { "model_name": self.model_name, "temperature": self.temperature, "max_tokens": self.max_tokens, } @dataclass class AnalyzeResult: """分析结果(包含元数据)""" data: Any # 解析后的数据或原始字符串 cache_hit: bool # 是否命中缓存 model_name: str # 使用的模型 cache_key: str # 缓存键 log_url: Optional[str] = None # trace URL(仅当实际调用 LLM 时) retries: int = 0 # 重试次数 def to_dict(self) -> Dict: return { "data": self.data, "cache_hit": self.cache_hit, "model_name": self.model_name, "cache_key": self.cache_key, "log_url": self.log_url, "retries": self.retries, } # 默认模型 DEFAULT_MODEL = "google/gemini-3-pro-preview" # 预设配置 PRESETS = { "default": LLMConfig( model_name=DEFAULT_MODEL, temperature=0.0, max_tokens=65536, ), "fast": LLMConfig( model_name="openai/gpt-4.1-mini", temperature=0.0, max_tokens=65536, ), "balanced": LLMConfig( model_name="google/gemini-2.5-flash-preview-05-20", temperature=0.0, max_tokens=65536, ), "quality": LLMConfig( model_name="anthropic/claude-sonnet-4", temperature=0.0, max_tokens=65536, ), "best": LLMConfig( model_name="google/gemini-2.5-pro-preview-05-06", temperature=0.0, max_tokens=65536, ), } # ===== 缓存工具函数 ===== def _get_cache_dir(task_name: str) -> Path: """获取缓存目录""" return Path(get_cache_dir(f"llm_cached/{task_name}")) def _generate_cache_key( prompt: str, config: LLMConfig, ) -> str: """生成缓存键(MD5 哈希)""" cache_string = f"{prompt}||{config.model_name}||{config.temperature}||{config.max_tokens}" return hashlib.md5(cache_string.encode('utf-8')).hexdigest() def _sanitize_filename(text: str, max_length: int = 30) -> str: """将文本转换为安全的文件名""" sanitized = re.sub(r'[^\w\u4e00-\u9fff]', '_', text) sanitized = re.sub(r'_+', '_', sanitized) if len(sanitized) > max_length: sanitized = sanitized[:max_length] return sanitized.strip('_') def _get_cache_filepath( task_name: str, cache_key: str, prompt_preview: str, config: LLMConfig, ) -> Path: """ 获取缓存文件路径 文件名格式: {prompt_preview}_{model}_{hash[:8]}.json """ cache_dir = _get_cache_dir(task_name) # 清理 prompt 预览 clean_preview = _sanitize_filename(prompt_preview, max_length=40) # 简化模型名 model_short = config.model_name.split('/')[-1] model_short = _sanitize_filename(model_short, max_length=20) # 哈希前8位 hash_short = cache_key[:8] filename = f"{clean_preview}_{model_short}_{hash_short}.json" return cache_dir / filename def _load_from_cache( task_name: str, cache_key: str, prompt_preview: str, config: LLMConfig, ) -> Optional[Dict]: """从缓存加载,返回 {raw: str, log_url: str}""" cache_file = _get_cache_filepath(task_name, cache_key, prompt_preview, config) # 如果文件不存在,尝试通过哈希匹配 if not cache_file.exists(): cache_dir = _get_cache_dir(task_name) if cache_dir.exists(): hash_short = cache_key[:8] matching_files = list(cache_dir.glob(f"*_{hash_short}.json")) if matching_files: cache_file = matching_files[0] else: return None else: return None try: with open(cache_file, 'r', encoding='utf-8') as f: cached_data = json.load(f) return { "raw": cached_data['output']['raw'], "log_url": cached_data.get('metadata', {}).get('log_url'), } except (json.JSONDecodeError, IOError, KeyError): return None def _save_to_cache( task_name: str, cache_key: str, prompt_preview: str, prompt: str, config: LLMConfig, result: str, log_url: Optional[str] = None, ) -> None: """保存到缓存(包含 log_url)""" cache_file = _get_cache_filepath(task_name, cache_key, prompt_preview, config) cache_file.parent.mkdir(parents=True, exist_ok=True) # 尝试解析 JSON parsed_result = parse_json_from_text(result) cache_data = { "input": { "prompt": prompt, "prompt_preview": prompt_preview, **config.to_dict(), }, "output": { "raw": result, "parsed": parsed_result, }, "metadata": { "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "cache_key": cache_key, "cache_file": str(cache_file.name), "log_url": log_url, } } try: with open(cache_file, 'w', encoding='utf-8') as f: json.dump(cache_data, f, ensure_ascii=False, indent=2) except IOError: pass # ===== 核心 API ===== async def analyze( prompt: str, task_name: str = "default", config: Optional[LLMConfig] = None, preset: Optional[str] = None, force: bool = False, parse_json: bool = True, max_retries: int = 3, log_url: Optional[str] = None, ) -> AnalyzeResult: """ 通用 LLM 分析(带缓存) Args: prompt: 完整的 prompt task_name: 任务名称(用于缓存目录分类) config: LLM 配置,如果为 None 则使用 preset 或默认配置 preset: 预设配置名称 ("default", "fast", "balanced", "quality", "best") force: 强制重新分析(跳过缓存),默认 False parse_json: 是否解析为 JSON max_retries: 最大重试次数(默认3次) log_url: 外部传入的 trace URL(如果为 None 且缓存未命中,则自动生成) Returns: AnalyzeResult 对象,包含: - data: 解析后的数据或原始字符串 - cache_hit: 是否命中缓存 - model_name: 使用的模型 - cache_key: 缓存键 - log_url: trace URL(仅当实际调用 LLM 时) - retries: 实际重试次数 Examples: >>> # 使用缓存(默认) >>> result = await analyze("分析...", task_name="origin") >>> # 强制重新分析 >>> result = await analyze("分析...", task_name="origin", force=True) >>> # 外部控制 trace(多个分析共享同一个 trace) >>> _, log_url = set_trace() >>> result1 = await analyze("分析1...", log_url=log_url) >>> result2 = await analyze("分析2...", log_url=log_url) """ # 确定配置 if config is None: if preset and preset in PRESETS: config = PRESETS[preset] else: config = PRESETS["default"] # 生成缓存键 cache_key = _generate_cache_key(prompt, config) # prompt 预览(用于文件名) prompt_preview = prompt[:50].replace('\n', ' ') # 尝试从缓存加载(除非 force=True) if not force: cached_data = _load_from_cache(task_name, cache_key, prompt_preview, config) if cached_data is not None: cached_raw = cached_data["raw"] cached_log_url = cached_data.get("log_url") if parse_json: parsed = parse_json_from_text(cached_raw) if parsed: return AnalyzeResult( data=parsed, cache_hit=True, model_name=config.model_name, cache_key=cache_key, log_url=cached_log_url, # 返回缓存时的 log_url retries=0, ) else: return AnalyzeResult( data=cached_raw, cache_hit=True, model_name=config.model_name, cache_key=cache_key, log_url=cached_log_url, # 返回缓存时的 log_url retries=0, ) # 设置 trace(仅当实际调用 LLM 且未传入 log_url 时) if log_url is None: _, log_url = set_trace() # 创建 Agent agent = Agent( name=f"LLM-{task_name}", model=get_model(config.model_name), model_settings=ModelSettings( temperature=config.temperature, max_tokens=config.max_tokens, ), tools=[], ) last_error = None retries = 0 for attempt in range(max_retries): try: result = await Runner.run(agent, input=prompt) raw_output = result.final_output if parse_json: parsed = parse_json_from_text(raw_output) if parsed: # 解析成功,缓存并返回 _save_to_cache(task_name, cache_key, prompt_preview, prompt, config, raw_output, log_url) return AnalyzeResult( data=parsed, cache_hit=False, model_name=config.model_name, cache_key=cache_key, log_url=log_url, retries=retries, ) else: # 解析失败,重试 retries += 1 last_error = f"JSON 解析失败 (尝试 {attempt + 1}/{max_retries})\n响应: {raw_output[:500]}..." print(f" ⚠️ {last_error}") if attempt < max_retries - 1: await asyncio.sleep(1) else: # 不需要解析 JSON _save_to_cache(task_name, cache_key, prompt_preview, prompt, config, raw_output, log_url) return AnalyzeResult( data=raw_output, cache_hit=False, model_name=config.model_name, cache_key=cache_key, log_url=log_url, retries=retries, ) except Exception as e: retries += 1 last_error = f"API 调用失败 (尝试 {attempt + 1}/{max_retries}): {str(e)}" print(f" ⚠️ {last_error}") if attempt < max_retries - 1: await asyncio.sleep(1) raise ValueError(f"所有 {max_retries} 次重试均失败: {last_error}") async def analyze_batch( prompts: list[str], task_name: str = "default", config: Optional[LLMConfig] = None, preset: Optional[str] = None, force: bool = False, parse_json: bool = True, max_concurrent: int = 10, log_url: Optional[str] = None, progress_callback: Optional[Callable] = None, ) -> list[AnalyzeResult]: """ 批量 LLM 分析(带并发控制) Args: prompts: prompt 列表 task_name: 任务名称 config: LLM 配置 preset: 预设配置名称 force: 强制重新分析(跳过缓存),默认 False parse_json: 是否解析为 JSON max_concurrent: 最大并发数 log_url: 外部传入的 trace URL(所有分析共享同一个 trace) progress_callback: 进度回调 Returns: AnalyzeResult 列表 """ # 如果没有传入 log_url,生成一个共享的 if log_url is None: _, log_url = set_trace() semaphore = asyncio.Semaphore(max_concurrent) async def limited_analyze(prompt: str): async with semaphore: result = await analyze( prompt=prompt, task_name=task_name, config=config, preset=preset, force=force, parse_json=parse_json, log_url=log_url, ) if progress_callback: progress_callback(1) return result tasks = [limited_analyze(p) for p in prompts] return await asyncio.gather(*tasks) # ===== 便捷函数 ===== async def analyze_fast(prompt: str, task_name: str = "default", **kwargs) -> AnalyzeResult: """快速分析(使用 fast 预设)""" return await analyze(prompt, task_name=task_name, preset="fast", **kwargs) async def analyze_quality(prompt: str, task_name: str = "default", **kwargs) -> AnalyzeResult: """高质量分析(使用 quality 预设)""" return await analyze(prompt, task_name=task_name, preset="quality", **kwargs) # ===== 测试 ===== if __name__ == "__main__": async def main(): prompt = """ 分析"猫咪"和"宠物"的关系,输出 JSON: ```json { "关系": "...", "说明": "..." } ``` """ print("测试 1: 基本用法(自动生成 trace)") result = await analyze(prompt, task_name="test", preset="fast") print(f" cache_hit: {result.cache_hit}") print(f" model: {result.model_name}") print(f" log_url: {result.log_url}") print(f" data: {result.data}") print("\n测试 2: 缓存命中") result = await analyze(prompt, task_name="test", preset="fast") print(f" cache_hit: {result.cache_hit}") print(f" log_url: {result.log_url}") # 应该是 None print("\n测试 3: 强制重新分析 (force=True)") _, shared_log_url = set_trace() print(f" 共享 trace: {shared_log_url}") result = await analyze( "输出 JSON: {\"test\": 123}", task_name="test", log_url=shared_log_url, force=True, ) print(f" cache_hit: {result.cache_hit}") print(f" log_url: {result.log_url}") asyncio.run(main())