run_multi.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. """
  2. 测试多轮对话的 Prompt Caching
  3. """
  4. import asyncio
  5. import os
  6. import sys
  7. from pathlib import Path
  8. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  9. from dotenv import load_dotenv
  10. load_dotenv()
  11. from agent.core.runner import AgentRunner, RunConfig
  12. from agent.trace import FileSystemTraceStore, Trace, Message
  13. from agent.llm import create_openrouter_llm_call
  14. async def main():
  15. print("=" * 60)
  16. print("测试多轮对话 Prompt Caching")
  17. print("=" * 60)
  18. print()
  19. base_dir = Path(__file__).parent
  20. project_root = base_dir.parent.parent
  21. trace_dir = project_root / ".trace"
  22. runner = AgentRunner(
  23. trace_store=FileSystemTraceStore(base_path=str(trace_dir)),
  24. llm_call=create_openrouter_llm_call(model="anthropic/claude-sonnet-4.5"),
  25. debug=True
  26. )
  27. # 超长 system prompt 确保 >1024 tokens
  28. system_prompt = """你是一个专业的 AI 助手,专注于帮助用户解决技术问题。
  29. ## 核心能力
  30. - 代码分析和生成
  31. - 问题解决和调试
  32. - 技术文档编写
  33. - 架构设计建议
  34. - 性能优化建议
  35. - 安全审计
  36. ## 工作原则
  37. 1. 准确性优先:确保提供的信息和代码是正确的
  38. 2. 清晰表达:用简洁明了的语言解释复杂概念
  39. 3. 实用导向:提供可直接使用的解决方案
  40. 4. 持续学习:根据反馈不断改进
  41. 5. 安全意识:始终考虑安全性和最佳实践
  42. 6. 性能考虑:提供高效的解决方案
  43. ## 技术栈
  44. - 编程语言:Python, JavaScript, TypeScript, Go, Rust, Java
  45. - 前端框架:React, Vue, Angular, Svelte
  46. - 后端框架:Node.js, Django, Flask, FastAPI, Spring Boot
  47. - 数据库:PostgreSQL, MongoDB, Redis, MySQL, Elasticsearch
  48. - 云平台:AWS, GCP, Azure
  49. - DevOps:Docker, Kubernetes, CI/CD, Terraform
  50. - 机器学习:TensorFlow, PyTorch, scikit-learn
  51. ## 响应格式
  52. - 提供清晰的步骤说明
  53. - 包含代码示例
  54. - 解释关键概念
  55. - 指出潜在问题
  56. - 给出最佳实践建议
  57. 这是一个足够长的 system prompt,用于测试 Anthropic Prompt Caching 功能。
  58. 缓存需要至少 1024 tokens 才能生效,所以我们需要让这个 prompt 足够长。
  59. """ * 5 # 重复 5 次确保足够长
  60. messages = [
  61. {"role": "user", "content": "请用一句话介绍 Python"}
  62. ]
  63. print("开始多轮对话测试...")
  64. print("-" * 60)
  65. trace_id = None
  66. iteration = 0
  67. async for item in runner.run(
  68. messages=messages,
  69. config=RunConfig(
  70. system_prompt=system_prompt,
  71. model="anthropic/claude-sonnet-4.5",
  72. temperature=0.3,
  73. max_iterations=5, # 多轮对话
  74. enable_prompt_caching=True,
  75. name="多轮缓存测试"
  76. )
  77. ):
  78. if isinstance(item, Trace):
  79. trace_id = item.trace_id
  80. if item.status == "completed":
  81. print(f"\n✓ Trace 完成")
  82. print(f" Total messages: {item.total_messages}")
  83. print(f" Total tokens: {item.total_tokens}")
  84. print(f" Total cache creation: {item.total_cache_creation_tokens}")
  85. print(f" Total cache read: {item.total_cache_read_tokens}")
  86. print(f" Total cost: ${item.total_cost:.6f}")
  87. elif isinstance(item, Message):
  88. if item.role == "assistant":
  89. iteration += 1
  90. print(f"\n[Iteration {iteration}]")
  91. print(f" Prompt tokens: {item.prompt_tokens}")
  92. print(f" Completion tokens: {item.completion_tokens}")
  93. print(f" Cache creation: {item.cache_creation_tokens}")
  94. print(f" Cache read: {item.cache_read_tokens}")
  95. print(f" Cost: ${item.cost:.6f}")
  96. content = item.content
  97. if isinstance(content, dict):
  98. text = content.get("text", "")
  99. tool_calls = content.get("tool_calls")
  100. if text and not tool_calls:
  101. preview = text[:80] + "..." if len(text) > 80 else text
  102. print(f" Response: {preview}")
  103. if tool_calls:
  104. print(f" Tool calls: {len(tool_calls)}")
  105. print()
  106. print("=" * 60)
  107. print("测试完成")
  108. print("=" * 60)
  109. print()
  110. if trace_id:
  111. print("分析:")
  112. print("- 第 1 次调用:应该有 cache_creation_tokens > 0(创建缓存)")
  113. print("- 第 2+ 次调用:应该有 cache_read_tokens > 0(命中缓存)")
  114. print(f"\nTrace ID: {trace_id}")
  115. if __name__ == "__main__":
  116. asyncio.run(main())