search_demo.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Claude Agent SDK Demo - 简洁版
  5. 演示 claude_agent_sdk 的核心功能:定义本地工具并让 agent 调用。
  6. """
  7. import asyncio
  8. from typing import Any, Optional, List
  9. import httpx
  10. import json
  11. from claude_agent_sdk import (
  12. tool,
  13. create_sdk_mcp_server,
  14. ClaudeSDKClient,
  15. ClaudeAgentOptions,
  16. AssistantMessage,
  17. ResultMessage,
  18. ToolUseBlock,
  19. ToolResultBlock,
  20. )
  21. from claude_agent_sdk.types import (
  22. HookMatcher,
  23. PermissionResultAllow,
  24. PermissionResultDeny,
  25. ToolPermissionContext,
  26. )
  27. BASE_URL = "http://aigc-channel.aiddit.com/aigc/channel"
  28. DEFAULT_TIMEOUT = 60.0
  29. # ============================================================================
  30. # 工具定义
  31. # ============================================================================
  32. @tool(
  33. "search_post",
  34. "搜索帖子",
  35. {
  36. "type": "object",
  37. "properties": {
  38. "keyword": {"type": "string", "description": "搜索关键词"}
  39. },
  40. "required": ["keyword"]
  41. }
  42. )
  43. async def search_post(args: dict[str, Any]) -> dict[str, Any]:
  44. try:
  45. keyword = args.get("keyword")
  46. payload = {
  47. "type": "zhihu",
  48. "keyword": keyword,
  49. "cursor": "0",
  50. "max_count": 2,
  51. "content_type": "图文",
  52. }
  53. print(f"search_post,payload:{payload}")
  54. async with httpx.AsyncClient(timeout=DEFAULT_TIMEOUT) as client:
  55. resp = await client.post(
  56. f"{BASE_URL}/data",
  57. json=payload,
  58. headers={"Content-Type": "application/json"},
  59. )
  60. resp.raise_for_status()
  61. data = resp.json()
  62. posts = data.get("data") or []
  63. posts_json_str = json.dumps(posts, ensure_ascii=False, indent=2)
  64. print(f"search_post,result posts_json_str: {posts_json_str}")
  65. return {"content": [{"type": "text", "text": posts_json_str}]}
  66. except Exception as e:
  67. return {"content": [{"type": "text", "text": f"错误: {str(e)}"}]}
  68. async def _auto_approve_tool(
  69. tool_name: str, input_data: dict, context: ToolPermissionContext
  70. ) -> PermissionResultAllow | PermissionResultDeny:
  71. if tool_name == "AskUserQuestion":
  72. questions = input_data.get("questions", [])
  73. answers = {}
  74. for q in questions:
  75. question_text = q.get("question", "")
  76. options = q.get("options", [])
  77. if options:
  78. answers[question_text] = options[0].get("label", "")
  79. else:
  80. answers[question_text] = ""
  81. print(f"[auto_approve] AskUserQuestion 自动选择: {answers}")
  82. return PermissionResultAllow(updated_input={**input_data, "answers": answers})
  83. return PermissionResultAllow(updated_input=input_data)
  84. # ============================================================================
  85. # Agent 封装
  86. # ============================================================================
  87. class SimpleAgent:
  88. def __init__(self, model: str = "claude-sonnet-4-6"):
  89. self.model = model
  90. self.server = create_sdk_mcp_server(
  91. name="demo-tools",
  92. version="1.0.0",
  93. tools=[search_post],
  94. )
  95. async def run(self, query: str, verbose: bool = True):
  96. """运行 agent"""
  97. if verbose:
  98. print(f"\n{'='*60}")
  99. print(f"查询: {query}")
  100. print(f"{'='*60}\n")
  101. options = ClaudeAgentOptions(
  102. system_prompt="你是一个内容收集专家,善于利用 search_post 工具搜索帖子内容,并对搜索结果进行格式化处理,理解并输出每个帖子的主要内容。\n "
  103. "注意:你需要提取帖子中的每张图片url并分析出图片中的关键信息,综合帖子其他信息返回帖子的主要内容",
  104. model=self.model,
  105. mcp_servers={"demo-tools": self.server},
  106. allowed_tools=["mcp__demo-tools__search_post"],
  107. disallowed_tools=[
  108. "Bash", "Read", "Write", "Edit", "MultiEdit",
  109. "Glob", "Grep", "WebSearch", "WebFetch",
  110. "TodoRead", "TodoWrite",
  111. ],
  112. permission_mode="bypassPermissions",
  113. max_turns=10,
  114. effort="low",
  115. can_use_tool=_auto_approve_tool,
  116. )
  117. tool_calls = []
  118. response = ""
  119. async with ClaudeSDKClient(options=options) as client:
  120. await client.query(query)
  121. async for msg in client.receive_response():
  122. if isinstance(msg, AssistantMessage):
  123. for block in msg.content:
  124. if isinstance(block, ToolUseBlock):
  125. tool_calls.append(f"{block.name}({block.input})")
  126. if verbose:
  127. print(f"🔧 工具调用: {block.name}({block.input})")
  128. elif isinstance(block, ToolResultBlock):
  129. if verbose and hasattr(block, 'content'):
  130. for c in block.content:
  131. if hasattr(c, 'text'):
  132. print(f"📥 工具结果: {c.text}")
  133. elif hasattr(block, 'text') and block.text:
  134. response = block.text
  135. if verbose:
  136. print(f"💬 回复: {block.text}")
  137. elif isinstance(msg, ResultMessage):
  138. if verbose:
  139. print(f"\n📊 统计: 成本=${msg.total_cost_usd:.4f}, 轮次={msg.num_turns}")
  140. return {"response": response, "tool_calls": tool_calls}
  141. # ============================================================================
  142. # 使用示例
  143. # ============================================================================
  144. async def main():
  145. agent = SimpleAgent()
  146. await agent.run("搜索 “北京秋天” 并返回帖子主要内容")
  147. # await search_post(keyword="柴犬")
  148. if __name__ == "__main__":
  149. asyncio.run(main())