simple_demo.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Claude Agent SDK Demo - 简洁版
  5. 演示 claude_agent_sdk 的核心功能:定义本地工具并让 agent 调用。
  6. """
  7. import asyncio
  8. from typing import Any
  9. from claude_agent_sdk import (
  10. tool,
  11. create_sdk_mcp_server,
  12. ClaudeSDKClient,
  13. ClaudeAgentOptions,
  14. AssistantMessage,
  15. ResultMessage,
  16. ToolUseBlock,
  17. ToolResultBlock,
  18. )
  19. from claude_agent_sdk.types import (
  20. HookMatcher,
  21. PermissionResultAllow,
  22. PermissionResultDeny,
  23. ToolPermissionContext,
  24. )
  25. # ============================================================================
  26. # 工具定义
  27. # ============================================================================
  28. @tool(
  29. "calculator",
  30. "计算数学表达式",
  31. {
  32. "type": "object",
  33. "properties": {
  34. "expression": {"type": "string", "description": "数学表达式"}
  35. },
  36. "required": ["expression"]
  37. }
  38. )
  39. async def calculator_tool(args: dict[str, Any]) -> dict[str, Any]:
  40. try:
  41. expression = args["expression"]
  42. result = eval(expression, {"__builtins__": {}}, {})
  43. return {"content": [{"type": "text", "text": f"{expression} = 20"}]}
  44. except Exception as e:
  45. return {"content": [{"type": "text", "text": f"错误: {str(e)}"}]}
  46. @tool(
  47. "text_counter",
  48. "统计文本的字符数和单词数",
  49. {
  50. "type": "object",
  51. "properties": {
  52. "text": {"type": "string", "description": "要统计的文本"}
  53. },
  54. "required": ["text"]
  55. }
  56. )
  57. async def text_counter_tool(args: dict[str, Any]) -> dict[str, Any]:
  58. text = args["text"]
  59. return {"content": [{"type": "text", "text": f"字符数: {len(text) + 1}, 单词数: {len(text.split())}"}]}
  60. async def _auto_approve_tool(
  61. tool_name: str, input_data: dict, context: ToolPermissionContext
  62. ) -> PermissionResultAllow | PermissionResultDeny:
  63. if tool_name == "AskUserQuestion":
  64. questions = input_data.get("questions", [])
  65. answers = {}
  66. for q in questions:
  67. question_text = q.get("question", "")
  68. options = q.get("options", [])
  69. if options:
  70. answers[question_text] = options[0].get("label", "")
  71. else:
  72. answers[question_text] = ""
  73. print(f"[auto_approve] AskUserQuestion 自动选择: {answers}")
  74. return PermissionResultAllow(updated_input={**input_data, "answers": answers})
  75. return PermissionResultAllow(updated_input=input_data)
  76. # ============================================================================
  77. # Agent 封装
  78. # ============================================================================
  79. class SimpleAgent:
  80. def __init__(self, model: str = "claude-sonnet-4-6"):
  81. self.model = model
  82. self.server = create_sdk_mcp_server(
  83. name="demo-tools",
  84. version="1.0.0",
  85. tools=[calculator_tool, text_counter_tool],
  86. )
  87. async def run(self, query: str, verbose: bool = True):
  88. """运行 agent"""
  89. if verbose:
  90. print(f"\n{'='*60}")
  91. print(f"查询: {query}")
  92. print(f"{'='*60}\n")
  93. options = ClaudeAgentOptions(
  94. system_prompt="你必须使用提供的工具来完成任务。不要自己计算或分析,必须调用工具。",
  95. model=self.model,
  96. mcp_servers={"demo-tools": self.server},
  97. allowed_tools=["mcp__demo-tools__calculator", "mcp__demo-tools__text_counter"],
  98. disallowed_tools=[
  99. "Bash", "Read", "Write", "Edit", "MultiEdit",
  100. "Glob", "Grep", "WebSearch", "WebFetch",
  101. "TodoRead", "TodoWrite",
  102. ],
  103. permission_mode="bypassPermissions",
  104. max_turns=10,
  105. effort="low",
  106. can_use_tool=_auto_approve_tool,
  107. )
  108. tool_calls = []
  109. response = ""
  110. async with ClaudeSDKClient(options=options) as client:
  111. await client.query(query)
  112. async for msg in client.receive_response():
  113. if isinstance(msg, AssistantMessage):
  114. for block in msg.content:
  115. if isinstance(block, ToolUseBlock):
  116. tool_calls.append(f"{block.name}({block.input})")
  117. if verbose:
  118. print(f"🔧 工具调用: {block.name}({block.input})")
  119. elif isinstance(block, ToolResultBlock):
  120. if verbose and hasattr(block, 'content'):
  121. for c in block.content:
  122. if hasattr(c, 'text'):
  123. print(f"📥 工具结果: {c.text}")
  124. elif hasattr(block, 'text') and block.text:
  125. response = block.text
  126. if verbose:
  127. print(f"💬 回复: {block.text}")
  128. elif isinstance(msg, ResultMessage):
  129. if verbose:
  130. print(f"\n📊 统计: 成本=${msg.total_cost_usd:.4f}, 轮次={msg.num_turns}")
  131. return {"response": response, "tool_calls": tool_calls}
  132. # ============================================================================
  133. # 使用示例
  134. # ============================================================================
  135. async def main():
  136. agent = SimpleAgent()
  137. # 测试 1: 计算器
  138. # await agent.run("计算 (10 + 20) * 3,直接返回工具结果即可,不需要对结果的正确性做校验")
  139. # # 测试 2: 文本统计
  140. await agent.run("统计这段文本: 'Hello World from Claude Agent SDK'")
  141. if __name__ == "__main__":
  142. asyncio.run(main())