run.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """
  2. 缓存机制测试
  3. 测试 Anthropic Prompt Caching 的命中率
  4. """
  5. import argparse
  6. import os
  7. import sys
  8. import asyncio
  9. from pathlib import Path
  10. # 添加项目根目录到 Python 路径
  11. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  12. from dotenv import load_dotenv
  13. load_dotenv()
  14. from agent.llm.prompts import SimplePrompt
  15. from agent.core.runner import AgentRunner, RunConfig
  16. from agent.trace import FileSystemTraceStore, Trace, Message
  17. from agent.llm import create_openrouter_llm_call
  18. async def main():
  19. parser = argparse.ArgumentParser(description="缓存机制测试")
  20. parser.add_argument(
  21. "--trace", type=str, default=None,
  22. help="已有的 Trace ID,用于恢复继续执行",
  23. )
  24. args = parser.parse_args()
  25. # 路径配置
  26. base_dir = Path(__file__).parent
  27. prompt_path = base_dir / "test.prompt"
  28. output_dir = base_dir / "output"
  29. output_dir.mkdir(exist_ok=True)
  30. print("=" * 60)
  31. print("Prompt Caching 测试")
  32. print("=" * 60)
  33. print()
  34. # 加载 prompt
  35. print("1. 加载 prompt 配置...")
  36. prompt = SimplePrompt(prompt_path)
  37. # 构建消息
  38. print("2. 构建任务消息...")
  39. messages = prompt.build_messages()
  40. # 创建 Agent Runner
  41. print("3. 创建 Agent Runner...")
  42. print(f" - 模型: {prompt.config.get('model', 'sonnet-4.6')}")
  43. store = FileSystemTraceStore(base_path=".trace")
  44. runner = AgentRunner(
  45. trace_store=store,
  46. llm_call=create_openrouter_llm_call(model=f"anthropic/claude-{prompt.config.get('model', 'sonnet-4.6')}"),
  47. skills_dir=None,
  48. debug=True
  49. )
  50. # 判断是新建还是恢复
  51. resume_trace_id = args.trace
  52. if resume_trace_id:
  53. existing_trace = await store.get_trace(resume_trace_id)
  54. if not existing_trace:
  55. print(f"\n错误: Trace 不存在: {resume_trace_id}")
  56. sys.exit(1)
  57. print(f"4. 恢复已有 Trace: {resume_trace_id[:8]}...")
  58. print(f" - 状态: {existing_trace.status}")
  59. print(f" - 消息数: {existing_trace.total_messages}")
  60. else:
  61. print(f"4. 启动新 Agent 模式...")
  62. print()
  63. current_trace_id = resume_trace_id
  64. try:
  65. if resume_trace_id:
  66. initial_messages = None
  67. config = RunConfig(
  68. model=f"anthropic/claude-{prompt.config.get('model', 'sonnet-4.6')}",
  69. temperature=float(prompt.config.get('temperature', 0.3)),
  70. max_iterations=50,
  71. trace_id=resume_trace_id,
  72. )
  73. else:
  74. initial_messages = messages
  75. config = RunConfig(
  76. model=f"anthropic/claude-{prompt.config.get('model', 'sonnet-4.6')}",
  77. temperature=float(prompt.config.get('temperature', 0.3)),
  78. max_iterations=50,
  79. name="缓存测试",
  80. )
  81. print("▶️ 开始执行...")
  82. print()
  83. async for item in runner.run(messages=initial_messages, config=config):
  84. # 处理 Trace 对象
  85. if isinstance(item, Trace):
  86. current_trace_id = item.trace_id
  87. if item.status == "running":
  88. print(f"[Trace] 开始: {item.trace_id[:8]}...")
  89. elif item.status == "completed":
  90. print(f"\n[Trace] ✅ 完成")
  91. print(f" - Total messages: {item.total_messages}")
  92. print(f" - Total tokens: {item.total_tokens:,}")
  93. print(f" - Cache creation: {item.total_cache_creation_tokens:,}")
  94. print(f" - Cache read: {item.total_cache_read_tokens:,}")
  95. print(f" - Cache hit rate: {item.total_cache_read_tokens / item.total_prompt_tokens * 100:.1f}%")
  96. print(f" - Total cost: ${item.total_cost:.4f}")
  97. elif item.status == "failed":
  98. print(f"\n[Trace] ❌ 失败: {item.error_message}")
  99. # 处理 Message 对象
  100. elif isinstance(item, Message):
  101. if item.role == "assistant":
  102. content = item.content
  103. if isinstance(content, dict):
  104. tool_calls = content.get("tool_calls")
  105. if tool_calls:
  106. print(f"[{item.sequence}] Tool calls: {len(tool_calls)}")
  107. except KeyboardInterrupt:
  108. print("\n\n用户中断 (Ctrl+C)")
  109. if current_trace_id:
  110. await runner.stop(current_trace_id)
  111. # 分析缓存情况
  112. if current_trace_id:
  113. print()
  114. print("=" * 60)
  115. print("缓存分析")
  116. print("=" * 60)
  117. trace = await store.get_trace(current_trace_id)
  118. if trace:
  119. print(f"\nTrace ID: {current_trace_id}")
  120. print(f"总消息数: {trace.total_messages}")
  121. print(f"总 tokens: {trace.total_tokens:,}")
  122. print(f"Prompt tokens: {trace.total_prompt_tokens:,}")
  123. print(f"Cache creation: {trace.total_cache_creation_tokens:,} ({trace.total_cache_creation_tokens / trace.total_prompt_tokens * 100:.1f}%)")
  124. print(f"Cache read: {trace.total_cache_read_tokens:,} ({trace.total_cache_read_tokens / trace.total_prompt_tokens * 100:.1f}%)")
  125. print(f"总成本: ${trace.total_cost:.4f}")
  126. # 计算节省的成本
  127. # cache_read 的价格是 input 的 0.1x
  128. # 如果没有缓存,这些 tokens 会按 input 价格计费
  129. saved_tokens = trace.total_cache_read_tokens
  130. # 假设 input 价格是 $3/MTok,cache_read 是 $0.3/MTok
  131. saved_cost = saved_tokens / 1_000_000 * 3 * 0.9 # 节省了 90%
  132. print(f"\n估算节省成本: ${saved_cost:.4f}")
  133. print()
  134. print(f"Trace 目录: .trace/{current_trace_id}")
  135. if __name__ == "__main__":
  136. asyncio.run(main())