run.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. """
  2. 特征提取示例
  3. 使用 Agent 框架 + Prompt loader + 多模态支持
  4. """
  5. import os
  6. import sys
  7. import asyncio
  8. from pathlib import Path
  9. # 添加项目根目录到 Python 路径
  10. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  11. from dotenv import load_dotenv
  12. load_dotenv()
  13. from agent.prompts import SimplePrompt
  14. from agent.runner import AgentRunner
  15. from agent.storage import MemoryTraceStore
  16. from agent.llm.providers.gemini import create_gemini_llm_call
  17. async def main():
  18. # 路径配置
  19. base_dir = Path(__file__).parent
  20. prompt_path = base_dir / "test.prompt"
  21. feature_md_path = base_dir / "input_1" / "feature.md"
  22. image_path = base_dir / "input_1" / "image.png"
  23. output_dir = base_dir / "output_1"
  24. output_dir.mkdir(exist_ok=True)
  25. print("=" * 60)
  26. print("特征提取任务")
  27. print("=" * 60)
  28. print()
  29. # 1. 加载 prompt
  30. print("1. 加载 prompt...")
  31. prompt = SimplePrompt(prompt_path)
  32. # 2. 读取特征描述
  33. print("2. 读取特征描述...")
  34. with open(feature_md_path, 'r', encoding='utf-8') as f:
  35. feature_text = f.read()
  36. # 3. 构建多模态消息
  37. print("3. 构建多模态消息(文本 + 图片)...")
  38. messages = prompt.build_messages(
  39. text=feature_text,
  40. images=image_path # 框架自动处理图片
  41. )
  42. print(f" - 消息数量: {len(messages)}")
  43. print(f" - 图片: {image_path.name}")
  44. # 4. 创建 Agent Runner
  45. print("4. 创建 Agent Runner...")
  46. runner = AgentRunner(
  47. trace_store=MemoryTraceStore(),
  48. llm_call=create_gemini_llm_call(),
  49. debug=True # 启用 debug,输出到 .trace/tree.txt
  50. )
  51. # 5. 调用 Agent
  52. print(f"5. 调用模型: {prompt.config.get('model', 'gemini-2.5-flash')}...")
  53. print()
  54. result = await runner.call(
  55. messages=messages,
  56. model=prompt.config.get('model', 'gemini-2.5-flash'),
  57. temperature=float(prompt.config.get('temperature', 0.3)),
  58. trace=True # 启用 trace,配合 debug 输出 step tree
  59. )
  60. # 6. 输出结果
  61. print("=" * 60)
  62. print("模型响应:")
  63. print("=" * 60)
  64. print(result.reply)
  65. print("=" * 60)
  66. print()
  67. # 7. 保存结果
  68. output_file = output_dir / "result.txt"
  69. with open(output_file, 'w', encoding='utf-8') as f:
  70. f.write(result.reply)
  71. print(f"✓ 结果已保存到: {output_file}")
  72. print()
  73. # 8. 打印统计信息
  74. print("统计信息:")
  75. if result.tokens:
  76. print(f" 输入 tokens: {result.tokens.get('prompt', 0)}")
  77. print(f" 输出 tokens: {result.tokens.get('completion', 0)}")
  78. print(f" 费用: ${result.cost:.4f}")
  79. if __name__ == "__main__":
  80. asyncio.run(main())