| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- """
- 缓存机制测试
- 测试 Anthropic Prompt Caching 的命中率
- """
- import argparse
- import os
- import sys
- import asyncio
- from pathlib import Path
- # 添加项目根目录到 Python 路径
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- from dotenv import load_dotenv
- load_dotenv()
- from agent.llm.prompts import SimplePrompt
- from agent.core.runner import AgentRunner, RunConfig
- from agent.trace import FileSystemTraceStore, Trace, Message
- from agent.llm import create_openrouter_llm_call
- async def main():
- parser = argparse.ArgumentParser(description="缓存机制测试")
- parser.add_argument(
- "--trace", type=str, default=None,
- help="已有的 Trace ID,用于恢复继续执行",
- )
- args = parser.parse_args()
- # 路径配置
- base_dir = Path(__file__).parent
- prompt_path = base_dir / "test.prompt"
- output_dir = base_dir / "output"
- output_dir.mkdir(exist_ok=True)
- print("=" * 60)
- print("Prompt Caching 测试")
- print("=" * 60)
- print()
- # 加载 prompt
- print("1. 加载 prompt 配置...")
- prompt = SimplePrompt(prompt_path)
- # 构建消息
- print("2. 构建任务消息...")
- messages = prompt.build_messages()
- # 创建 Agent Runner
- print("3. 创建 Agent Runner...")
- print(f" - 模型: {prompt.config.get('model', 'sonnet-4.6')}")
- store = FileSystemTraceStore(base_path=".trace")
- runner = AgentRunner(
- trace_store=store,
- llm_call=create_openrouter_llm_call(model=f"anthropic/claude-{prompt.config.get('model', 'sonnet-4.6')}"),
- skills_dir=None,
- debug=True
- )
- # 判断是新建还是恢复
- resume_trace_id = args.trace
- if resume_trace_id:
- existing_trace = await store.get_trace(resume_trace_id)
- if not existing_trace:
- print(f"\n错误: Trace 不存在: {resume_trace_id}")
- sys.exit(1)
- print(f"4. 恢复已有 Trace: {resume_trace_id[:8]}...")
- print(f" - 状态: {existing_trace.status}")
- print(f" - 消息数: {existing_trace.total_messages}")
- else:
- print(f"4. 启动新 Agent 模式...")
- print()
- current_trace_id = resume_trace_id
- try:
- if resume_trace_id:
- initial_messages = None
- config = RunConfig(
- model=f"anthropic/claude-{prompt.config.get('model', 'sonnet-4.6')}",
- temperature=float(prompt.config.get('temperature', 0.3)),
- max_iterations=50,
- trace_id=resume_trace_id,
- )
- else:
- initial_messages = messages
- config = RunConfig(
- model=f"anthropic/claude-{prompt.config.get('model', 'sonnet-4.6')}",
- temperature=float(prompt.config.get('temperature', 0.3)),
- max_iterations=50,
- name="缓存测试",
- )
- print("▶️ 开始执行...")
- print()
- async for item in runner.run(messages=initial_messages, config=config):
- # 处理 Trace 对象
- if isinstance(item, Trace):
- current_trace_id = item.trace_id
- if item.status == "running":
- print(f"[Trace] 开始: {item.trace_id[:8]}...")
- elif item.status == "completed":
- print(f"\n[Trace] ✅ 完成")
- print(f" - Total messages: {item.total_messages}")
- print(f" - Total tokens: {item.total_tokens:,}")
- print(f" - Cache creation: {item.total_cache_creation_tokens:,}")
- print(f" - Cache read: {item.total_cache_read_tokens:,}")
- print(f" - Cache hit rate: {item.total_cache_read_tokens / item.total_prompt_tokens * 100:.1f}%")
- print(f" - Total cost: ${item.total_cost:.4f}")
- elif item.status == "failed":
- print(f"\n[Trace] ❌ 失败: {item.error_message}")
- # 处理 Message 对象
- elif isinstance(item, Message):
- if item.role == "assistant":
- content = item.content
- if isinstance(content, dict):
- tool_calls = content.get("tool_calls")
- if tool_calls:
- print(f"[{item.sequence}] Tool calls: {len(tool_calls)}")
- except KeyboardInterrupt:
- print("\n\n用户中断 (Ctrl+C)")
- if current_trace_id:
- await runner.stop(current_trace_id)
- # 分析缓存情况
- if current_trace_id:
- print()
- print("=" * 60)
- print("缓存分析")
- print("=" * 60)
- trace = await store.get_trace(current_trace_id)
- if trace:
- print(f"\nTrace ID: {current_trace_id}")
- print(f"总消息数: {trace.total_messages}")
- print(f"总 tokens: {trace.total_tokens:,}")
- print(f"Prompt tokens: {trace.total_prompt_tokens:,}")
- print(f"Cache creation: {trace.total_cache_creation_tokens:,} ({trace.total_cache_creation_tokens / trace.total_prompt_tokens * 100:.1f}%)")
- print(f"Cache read: {trace.total_cache_read_tokens:,} ({trace.total_cache_read_tokens / trace.total_prompt_tokens * 100:.1f}%)")
- print(f"总成本: ${trace.total_cost:.4f}")
- # 计算节省的成本
- # cache_read 的价格是 input 的 0.1x
- # 如果没有缓存,这些 tokens 会按 input 价格计费
- saved_tokens = trace.total_cache_read_tokens
- # 假设 input 价格是 $3/MTok,cache_read 是 $0.3/MTok
- saved_cost = saved_tokens / 1_000_000 * 3 * 0.9 # 节省了 90%
- print(f"\n估算节省成本: ${saved_cost:.4f}")
- print()
- print(f"Trace 目录: .trace/{current_trace_id}")
- if __name__ == "__main__":
- asyncio.run(main())
|