gemini_basic_agent.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. """
  2. Gemini Agent 基础示例
  3. 使用 Gemini 2.5 Pro 模型,演示带工具调用的 Agent
  4. 依赖:
  5. pip install httpx python-dotenv
  6. 使用方法:
  7. python examples/gemini_basic_agent.py
  8. """
  9. import os
  10. import sys
  11. import json
  12. import asyncio
  13. from typing import Dict, Any, List, Optional
  14. from dotenv import load_dotenv
  15. # 添加项目根目录到 Python 路径
  16. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  17. # 加载环境变量
  18. load_dotenv()
  19. # 导入框架
  20. from agent.tools import tool, ToolResult, get_tool_registry
  21. from agent.runner import AgentRunner
  22. from agent.llm.providers.gemini import create_gemini_llm_call
  23. # ============================================================
  24. # 定义工具
  25. # ============================================================
  26. @tool()
  27. async def get_current_weather(location: str, unit: str = "celsius", uid: str = "") -> Dict[str, Any]:
  28. """
  29. 获取指定地点的当前天气
  30. Args:
  31. location: 城市名称,如 "北京"、"San Francisco"
  32. unit: 温度单位,"celsius" 或 "fahrenheit"
  33. Returns:
  34. 天气信息字典
  35. """
  36. # 模拟天气数据
  37. weather_data = {
  38. "北京": {"temp": 15, "condition": "晴朗", "humidity": 45},
  39. "上海": {"temp": 20, "condition": "多云", "humidity": 60},
  40. "San Francisco": {"temp": 18, "condition": "Foggy", "humidity": 70},
  41. "New York": {"temp": 10, "condition": "Rainy", "humidity": 80}
  42. }
  43. data = weather_data.get(location, {"temp": 22, "condition": "Unknown", "humidity": 50})
  44. if unit == "fahrenheit":
  45. data["temp"] = data["temp"] * 9/5 + 32
  46. return {
  47. "location": location,
  48. "temperature": data["temp"],
  49. "unit": unit,
  50. "condition": data["condition"],
  51. "humidity": data["humidity"]
  52. }
  53. @tool()
  54. async def calculate(expression: str, uid: str = "") -> ToolResult:
  55. """
  56. 执行数学计算
  57. Args:
  58. expression: 数学表达式,如 "2 + 2"、"10 * 5"
  59. Returns:
  60. 计算结果
  61. """
  62. try:
  63. # 安全地计算简单表达式
  64. # 注意:实际生产环境应使用更安全的方法
  65. result = eval(expression, {"__builtins__": {}}, {})
  66. return ToolResult(
  67. title="计算结果",
  68. output=f"{expression} = {result}",
  69. long_term_memory=f"计算了 {expression}"
  70. )
  71. except Exception as e:
  72. return ToolResult(
  73. title="计算错误",
  74. output=f"无法计算 '{expression}': {str(e)}",
  75. long_term_memory=f"计算失败: {expression}"
  76. )
  77. @tool()
  78. async def search_knowledge(query: str, max_results: int = 3, uid: str = "") -> ToolResult:
  79. """
  80. 搜索知识库
  81. Args:
  82. query: 搜索关键词
  83. max_results: 返回结果数量
  84. Returns:
  85. 搜索结果
  86. """
  87. # 模拟知识库搜索
  88. knowledge_base = {
  89. "Python": "Python 是一种高级编程语言,以简洁易读的语法著称。",
  90. "Agent": "Agent 是能够感知环境并采取行动以实现目标的智能体。",
  91. "Gemini": "Gemini 是 Google 开发的多模态大语言模型系列。",
  92. "AI": "人工智能(AI)是计算机科学的一个分支,致力于创建智能机器。"
  93. }
  94. results = []
  95. for key, value in knowledge_base.items():
  96. if query.lower() in key.lower() or query.lower() in value.lower():
  97. results.append({"title": key, "content": value})
  98. if len(results) >= max_results:
  99. break
  100. if not results:
  101. output = f"未找到关于 '{query}' 的信息"
  102. else:
  103. output = "\n\n".join([f"**{r['title']}**\n{r['content']}" for r in results])
  104. return ToolResult(
  105. title=f"搜索结果: {query}",
  106. output=output,
  107. long_term_memory=f"搜索了 '{query}',找到 {len(results)} 条结果"
  108. )
  109. # ============================================================
  110. # 主函数
  111. # ============================================================
  112. async def main():
  113. print("=" * 60)
  114. print("Gemini Agent 基础示例")
  115. print("=" * 60)
  116. print()
  117. # 获取工具注册表
  118. registry = get_tool_registry()
  119. # 打印可用工具
  120. print("可用工具:")
  121. for tool_name in registry.get_tool_names():
  122. print(f" - {tool_name}")
  123. print()
  124. # 创建 Gemini LLM 调用函数
  125. gemini_llm_call = create_gemini_llm_call()
  126. # 创建 Agent Runner
  127. runner = AgentRunner(
  128. tool_registry=registry,
  129. llm_call=gemini_llm_call,
  130. )
  131. # 测试任务
  132. task = "北京今天的天气怎么样?顺便帮我计算一下 15 * 8 等于多少。"
  133. print(f"任务: {task}")
  134. print("-" * 60)
  135. print()
  136. # 运行 Agent
  137. async for event in runner.run(
  138. task=task,
  139. model="gemini-2.5-pro",
  140. tools=["get_current_weather", "calculate", "search_knowledge"],
  141. max_iterations=5,
  142. enable_memory=False, # 暂不启用记忆
  143. auto_execute_tools=True,
  144. system_prompt="你是一个有用的AI助手,可以使用工具来帮助用户。请简洁明了地回答问题。"
  145. ):
  146. event_type = event.type
  147. data = event.data
  148. if event_type == "trace_started":
  149. print(f"✓ Trace 开始: {data['trace_id']}")
  150. print()
  151. elif event_type == "llm_call_completed":
  152. print(f"🤖 LLM 响应:")
  153. if data.get("content"):
  154. print(f" {data['content']}")
  155. if data.get("tool_calls"):
  156. print(f" 工具调用: {len(data['tool_calls'])} 个")
  157. print(f" Tokens: {data.get('tokens', 0)}")
  158. print()
  159. elif event_type == "tool_executing":
  160. print(f"🔧 执行工具: {data['tool_name']}")
  161. print(f" 参数: {json.dumps(data['arguments'], ensure_ascii=False)}")
  162. elif event_type == "tool_result":
  163. print(f" 结果: {data['result'][:100]}...")
  164. print()
  165. elif event_type == "conclusion":
  166. print(f"✅ 最终回答:")
  167. print(f" {data['content']}")
  168. print()
  169. elif event_type == "trace_completed":
  170. print(f"✓ Trace 完成")
  171. print(f" 总 Tokens: {data.get('total_tokens', 0)}")
  172. print(f" 总成本: ${data.get('total_cost', 0):.4f}")
  173. print()
  174. elif event_type == "trace_failed":
  175. print(f"❌ Trace 失败: {data.get('error')}")
  176. print()
  177. if __name__ == "__main__":
  178. asyncio.run(main())