| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474 |
- #!/usr/bin/env python3
- """
- 通用 LLM 缓存分析模块
- 根据 prompt + 模型 + 参数 进行缓存,相同输入直接返回缓存结果。
- """
- import asyncio
- import hashlib
- import json
- import re
- from dataclasses import dataclass, field
- from datetime import datetime
- from pathlib import Path
- from typing import Any, Callable, Dict, List, Optional
- from agents import Agent, Runner, ModelSettings
- from lib.client import get_model
- from lib.config import get_cache_dir
- from lib.utils import parse_json_from_text
- from lib.my_trace import set_trace_smith as set_trace
- # ===== 配置 =====
- @dataclass
- class LLMConfig:
- """LLM 配置"""
- model_name: str = "google/gemini-3-pro-preview"
- temperature: float = 0.0
- max_tokens: int = 65536
- def to_dict(self) -> Dict:
- return {
- "model_name": self.model_name,
- "temperature": self.temperature,
- "max_tokens": self.max_tokens,
- }
- @dataclass
- class AnalyzeResult:
- """分析结果(包含元数据)"""
- data: Any # 解析后的数据或原始字符串
- cache_hit: bool # 是否命中缓存
- model_name: str # 使用的模型
- cache_key: str # 缓存键
- log_url: Optional[str] = None # trace URL(仅当实际调用 LLM 时)
- retries: int = 0 # 重试次数
- def to_dict(self) -> Dict:
- return {
- "data": self.data,
- "cache_hit": self.cache_hit,
- "model_name": self.model_name,
- "cache_key": self.cache_key,
- "log_url": self.log_url,
- "retries": self.retries,
- }
- # 默认模型
- DEFAULT_MODEL = "google/gemini-3-pro-preview"
- # 预设配置
- PRESETS = {
- "default": LLMConfig(
- model_name=DEFAULT_MODEL,
- temperature=0.0,
- max_tokens=65536,
- ),
- "fast": LLMConfig(
- model_name="openai/gpt-4.1-mini",
- temperature=0.0,
- max_tokens=65536,
- ),
- "balanced": LLMConfig(
- model_name="google/gemini-2.5-flash-preview-05-20",
- temperature=0.0,
- max_tokens=65536,
- ),
- "quality": LLMConfig(
- model_name="anthropic/claude-sonnet-4",
- temperature=0.0,
- max_tokens=65536,
- ),
- "best": LLMConfig(
- model_name="google/gemini-2.5-pro-preview-05-06",
- temperature=0.0,
- max_tokens=65536,
- ),
- }
- # ===== 缓存工具函数 =====
- def _get_cache_dir(task_name: str) -> Path:
- """获取缓存目录"""
- return Path(get_cache_dir(f"llm_cached/{task_name}"))
- def _generate_cache_key(
- prompt: str,
- config: LLMConfig,
- ) -> str:
- """生成缓存键(MD5 哈希)"""
- cache_string = f"{prompt}||{config.model_name}||{config.temperature}||{config.max_tokens}"
- return hashlib.md5(cache_string.encode('utf-8')).hexdigest()
- def _sanitize_filename(text: str, max_length: int = 30) -> str:
- """将文本转换为安全的文件名"""
- sanitized = re.sub(r'[^\w\u4e00-\u9fff]', '_', text)
- sanitized = re.sub(r'_+', '_', sanitized)
- if len(sanitized) > max_length:
- sanitized = sanitized[:max_length]
- return sanitized.strip('_')
- def _get_cache_filepath(
- task_name: str,
- cache_key: str,
- prompt_preview: str,
- config: LLMConfig,
- ) -> Path:
- """
- 获取缓存文件路径
- 文件名格式: {prompt_preview}_{model}_{hash[:8]}.json
- """
- cache_dir = _get_cache_dir(task_name)
- # 清理 prompt 预览
- clean_preview = _sanitize_filename(prompt_preview, max_length=40)
- # 简化模型名
- model_short = config.model_name.split('/')[-1]
- model_short = _sanitize_filename(model_short, max_length=20)
- # 哈希前8位
- hash_short = cache_key[:8]
- filename = f"{clean_preview}_{model_short}_{hash_short}.json"
- return cache_dir / filename
- def _load_from_cache(
- task_name: str,
- cache_key: str,
- prompt_preview: str,
- config: LLMConfig,
- ) -> Optional[Dict]:
- """从缓存加载,返回 {raw: str, log_url: str}"""
- cache_file = _get_cache_filepath(task_name, cache_key, prompt_preview, config)
- # 如果文件不存在,尝试通过哈希匹配
- if not cache_file.exists():
- cache_dir = _get_cache_dir(task_name)
- if cache_dir.exists():
- hash_short = cache_key[:8]
- matching_files = list(cache_dir.glob(f"*_{hash_short}.json"))
- if matching_files:
- cache_file = matching_files[0]
- else:
- return None
- else:
- return None
- try:
- with open(cache_file, 'r', encoding='utf-8') as f:
- cached_data = json.load(f)
- return {
- "raw": cached_data['output']['raw'],
- "log_url": cached_data.get('metadata', {}).get('log_url'),
- }
- except (json.JSONDecodeError, IOError, KeyError):
- return None
- def _save_to_cache(
- task_name: str,
- cache_key: str,
- prompt_preview: str,
- prompt: str,
- config: LLMConfig,
- result: str,
- log_url: Optional[str] = None,
- ) -> None:
- """保存到缓存(包含 log_url)"""
- cache_file = _get_cache_filepath(task_name, cache_key, prompt_preview, config)
- cache_file.parent.mkdir(parents=True, exist_ok=True)
- # 尝试解析 JSON
- parsed_result = parse_json_from_text(result)
- cache_data = {
- "input": {
- "prompt": prompt,
- "prompt_preview": prompt_preview,
- **config.to_dict(),
- },
- "output": {
- "raw": result,
- "parsed": parsed_result,
- },
- "metadata": {
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
- "cache_key": cache_key,
- "cache_file": str(cache_file.name),
- "log_url": log_url,
- }
- }
- try:
- with open(cache_file, 'w', encoding='utf-8') as f:
- json.dump(cache_data, f, ensure_ascii=False, indent=2)
- except IOError:
- pass
- # ===== 核心 API =====
- async def analyze(
- prompt: str,
- task_name: str = "default",
- config: Optional[LLMConfig] = None,
- preset: Optional[str] = None,
- force: bool = False,
- parse_json: bool = True,
- max_retries: int = 3,
- log_url: Optional[str] = None,
- ) -> AnalyzeResult:
- """
- 通用 LLM 分析(带缓存)
- Args:
- prompt: 完整的 prompt
- task_name: 任务名称(用于缓存目录分类)
- config: LLM 配置,如果为 None 则使用 preset 或默认配置
- preset: 预设配置名称 ("default", "fast", "balanced", "quality", "best")
- force: 强制重新分析(跳过缓存),默认 False
- parse_json: 是否解析为 JSON
- max_retries: 最大重试次数(默认3次)
- log_url: 外部传入的 trace URL(如果为 None 且缓存未命中,则自动生成)
- Returns:
- AnalyzeResult 对象,包含:
- - data: 解析后的数据或原始字符串
- - cache_hit: 是否命中缓存
- - model_name: 使用的模型
- - cache_key: 缓存键
- - log_url: trace URL(仅当实际调用 LLM 时)
- - retries: 实际重试次数
- Examples:
- >>> # 使用缓存(默认)
- >>> result = await analyze("分析...", task_name="origin")
- >>> # 强制重新分析
- >>> result = await analyze("分析...", task_name="origin", force=True)
- >>> # 外部控制 trace(多个分析共享同一个 trace)
- >>> _, log_url = set_trace()
- >>> result1 = await analyze("分析1...", log_url=log_url)
- >>> result2 = await analyze("分析2...", log_url=log_url)
- """
- # 确定配置
- if config is None:
- if preset and preset in PRESETS:
- config = PRESETS[preset]
- else:
- config = PRESETS["default"]
- # 生成缓存键
- cache_key = _generate_cache_key(prompt, config)
- # prompt 预览(用于文件名)
- prompt_preview = prompt[:50].replace('\n', ' ')
- # 尝试从缓存加载(除非 force=True)
- if not force:
- cached_data = _load_from_cache(task_name, cache_key, prompt_preview, config)
- if cached_data is not None:
- cached_raw = cached_data["raw"]
- cached_log_url = cached_data.get("log_url")
- if parse_json:
- parsed = parse_json_from_text(cached_raw)
- if parsed:
- return AnalyzeResult(
- data=parsed,
- cache_hit=True,
- model_name=config.model_name,
- cache_key=cache_key,
- log_url=cached_log_url, # 返回缓存时的 log_url
- retries=0,
- )
- else:
- return AnalyzeResult(
- data=cached_raw,
- cache_hit=True,
- model_name=config.model_name,
- cache_key=cache_key,
- log_url=cached_log_url, # 返回缓存时的 log_url
- retries=0,
- )
- # 设置 trace(仅当实际调用 LLM 且未传入 log_url 时)
- if log_url is None:
- _, log_url = set_trace()
- # 创建 Agent
- agent = Agent(
- name=f"LLM-{task_name}",
- model=get_model(config.model_name),
- model_settings=ModelSettings(
- temperature=config.temperature,
- max_tokens=config.max_tokens,
- ),
- tools=[],
- )
- last_error = None
- retries = 0
- for attempt in range(max_retries):
- try:
- result = await Runner.run(agent, input=prompt)
- raw_output = result.final_output
- if parse_json:
- parsed = parse_json_from_text(raw_output)
- if parsed:
- # 解析成功,缓存并返回
- _save_to_cache(task_name, cache_key, prompt_preview, prompt, config, raw_output, log_url)
- return AnalyzeResult(
- data=parsed,
- cache_hit=False,
- model_name=config.model_name,
- cache_key=cache_key,
- log_url=log_url,
- retries=retries,
- )
- else:
- # 解析失败,重试
- retries += 1
- last_error = f"JSON 解析失败 (尝试 {attempt + 1}/{max_retries})\n响应: {raw_output[:500]}..."
- print(f" ⚠️ {last_error}")
- if attempt < max_retries - 1:
- await asyncio.sleep(1)
- else:
- # 不需要解析 JSON
- _save_to_cache(task_name, cache_key, prompt_preview, prompt, config, raw_output, log_url)
- return AnalyzeResult(
- data=raw_output,
- cache_hit=False,
- model_name=config.model_name,
- cache_key=cache_key,
- log_url=log_url,
- retries=retries,
- )
- except Exception as e:
- retries += 1
- last_error = f"API 调用失败 (尝试 {attempt + 1}/{max_retries}): {str(e)}"
- print(f" ⚠️ {last_error}")
- if attempt < max_retries - 1:
- await asyncio.sleep(1)
- raise ValueError(f"所有 {max_retries} 次重试均失败: {last_error}")
- async def analyze_batch(
- prompts: list[str],
- task_name: str = "default",
- config: Optional[LLMConfig] = None,
- preset: Optional[str] = None,
- force: bool = False,
- parse_json: bool = True,
- max_concurrent: int = 10,
- log_url: Optional[str] = None,
- progress_callback: Optional[Callable] = None,
- ) -> list[AnalyzeResult]:
- """
- 批量 LLM 分析(带并发控制)
- Args:
- prompts: prompt 列表
- task_name: 任务名称
- config: LLM 配置
- preset: 预设配置名称
- force: 强制重新分析(跳过缓存),默认 False
- parse_json: 是否解析为 JSON
- max_concurrent: 最大并发数
- log_url: 外部传入的 trace URL(所有分析共享同一个 trace)
- progress_callback: 进度回调
- Returns:
- AnalyzeResult 列表
- """
- # 如果没有传入 log_url,生成一个共享的
- if log_url is None:
- _, log_url = set_trace()
- semaphore = asyncio.Semaphore(max_concurrent)
- async def limited_analyze(prompt: str):
- async with semaphore:
- result = await analyze(
- prompt=prompt,
- task_name=task_name,
- config=config,
- preset=preset,
- force=force,
- parse_json=parse_json,
- log_url=log_url,
- )
- if progress_callback:
- progress_callback(1)
- return result
- tasks = [limited_analyze(p) for p in prompts]
- return await asyncio.gather(*tasks)
- # ===== 便捷函数 =====
- async def analyze_fast(prompt: str, task_name: str = "default", **kwargs) -> AnalyzeResult:
- """快速分析(使用 fast 预设)"""
- return await analyze(prompt, task_name=task_name, preset="fast", **kwargs)
- async def analyze_quality(prompt: str, task_name: str = "default", **kwargs) -> AnalyzeResult:
- """高质量分析(使用 quality 预设)"""
- return await analyze(prompt, task_name=task_name, preset="quality", **kwargs)
- # ===== 测试 =====
- if __name__ == "__main__":
- async def main():
- prompt = """
- 分析"猫咪"和"宠物"的关系,输出 JSON:
- ```json
- {
- "关系": "...",
- "说明": "..."
- }
- ```
- """
- print("测试 1: 基本用法(自动生成 trace)")
- result = await analyze(prompt, task_name="test", preset="fast")
- print(f" cache_hit: {result.cache_hit}")
- print(f" model: {result.model_name}")
- print(f" log_url: {result.log_url}")
- print(f" data: {result.data}")
- print("\n测试 2: 缓存命中")
- result = await analyze(prompt, task_name="test", preset="fast")
- print(f" cache_hit: {result.cache_hit}")
- print(f" log_url: {result.log_url}") # 应该是 None
- print("\n测试 3: 强制重新分析 (force=True)")
- _, shared_log_url = set_trace()
- print(f" 共享 trace: {shared_log_url}")
- result = await analyze(
- "输出 JSON: {\"test\": 123}",
- task_name="test",
- log_url=shared_log_url,
- force=True,
- )
- print(f" cache_hit: {result.cache_hit}")
- print(f" log_url: {result.log_url}")
- asyncio.run(main())
|