run.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. """
  2. 特征提取示例
  3. 使用 Agent 框架 + Prompt loader + 多模态支持
  4. """
  5. import os
  6. import sys
  7. import asyncio
  8. from pathlib import Path
  9. # 添加项目根目录到 Python 路径
  10. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  11. from dotenv import load_dotenv
  12. load_dotenv()
  13. from agent.prompts import SimplePrompt
  14. from agent.runner import AgentRunner
  15. from agent.llm.providers.gemini import create_gemini_llm_call
  16. async def main():
  17. # 路径配置
  18. base_dir = Path(__file__).parent
  19. prompt_path = base_dir / "test.prompt"
  20. feature_md_path = base_dir / "input_1" / "feature.md"
  21. image_path = base_dir / "input_1" / "image.png"
  22. output_dir = base_dir / "output_1"
  23. output_dir.mkdir(exist_ok=True)
  24. print("=" * 60)
  25. print("特征提取任务")
  26. print("=" * 60)
  27. print()
  28. # 1. 加载 prompt
  29. print("1. 加载 prompt...")
  30. prompt = SimplePrompt(prompt_path)
  31. # 2. 读取特征描述
  32. print("2. 读取特征描述...")
  33. with open(feature_md_path, 'r', encoding='utf-8') as f:
  34. feature_text = f.read()
  35. # 3. 构建多模态消息
  36. print("3. 构建多模态消息(文本 + 图片)...")
  37. messages = prompt.build_messages(
  38. text=feature_text,
  39. images=image_path # 框架自动处理图片
  40. )
  41. print(f" - 消息数量: {len(messages)}")
  42. print(f" - 图片: {image_path.name}")
  43. # 4. 创建 Agent Runner
  44. print("4. 创建 Agent Runner...")
  45. runner = AgentRunner(
  46. llm_call=create_gemini_llm_call()
  47. )
  48. # 5. 调用 Agent
  49. print(f"5. 调用模型: {prompt.config.get('model', 'gemini-2.5-flash')}...")
  50. print()
  51. result = await runner.call(
  52. messages=messages,
  53. model=prompt.config.get('model', 'gemini-2.5-flash'),
  54. temperature=float(prompt.config.get('temperature', 0.3)),
  55. trace=False # 暂不记录 trace
  56. )
  57. # 6. 输出结果
  58. print("=" * 60)
  59. print("模型响应:")
  60. print("=" * 60)
  61. print(result.reply)
  62. print("=" * 60)
  63. print()
  64. # 7. 保存结果
  65. output_file = output_dir / "result.txt"
  66. with open(output_file, 'w', encoding='utf-8') as f:
  67. f.write(result.reply)
  68. print(f"✓ 结果已保存到: {output_file}")
  69. print()
  70. # 8. 打印统计信息
  71. print("统计信息:")
  72. if result.tokens:
  73. print(f" 输入 tokens: {result.tokens.get('prompt', 0)}")
  74. print(f" 输出 tokens: {result.tokens.get('completion', 0)}")
  75. print(f" 费用: ${result.cost:.4f}")
  76. if __name__ == "__main__":
  77. asyncio.run(main())