| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- """
- 特征提取示例
- 使用 Agent 框架 + Prompt loader + 多模态支持
- """
- import os
- import sys
- import asyncio
- from pathlib import Path
- # 添加项目根目录到 Python 路径
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- from dotenv import load_dotenv
- load_dotenv()
- from agent.prompts import SimplePrompt
- from agent.runner import AgentRunner
- from agent.storage import MemoryTraceStore
- from agent.llm.providers.gemini import create_gemini_llm_call
- async def main():
- # 路径配置
- base_dir = Path(__file__).parent
- prompt_path = base_dir / "test.prompt"
- feature_md_path = base_dir / "input_1" / "feature.md"
- image_path = base_dir / "input_1" / "image.png"
- output_dir = base_dir / "output_1"
- output_dir.mkdir(exist_ok=True)
- print("=" * 60)
- print("特征提取任务")
- print("=" * 60)
- print()
- # 1. 加载 prompt
- print("1. 加载 prompt...")
- prompt = SimplePrompt(prompt_path)
- # 2. 读取特征描述
- print("2. 读取特征描述...")
- with open(feature_md_path, 'r', encoding='utf-8') as f:
- feature_text = f.read()
- # 3. 构建多模态消息
- print("3. 构建多模态消息(文本 + 图片)...")
- messages = prompt.build_messages(
- text=feature_text,
- images=image_path # 框架自动处理图片
- )
- print(f" - 消息数量: {len(messages)}")
- print(f" - 图片: {image_path.name}")
- # 4. 创建 Agent Runner
- print("4. 创建 Agent Runner...")
- runner = AgentRunner(
- trace_store=MemoryTraceStore(),
- llm_call=create_gemini_llm_call(),
- debug=True # 启用 debug,输出到 .trace/tree.txt
- )
- # 5. 调用 Agent
- print(f"5. 调用模型: {prompt.config.get('model', 'gemini-2.5-flash')}...")
- print()
- result = await runner.call(
- messages=messages,
- model=prompt.config.get('model', 'gemini-2.5-flash'),
- temperature=float(prompt.config.get('temperature', 0.3)),
- trace=True # 启用 trace,配合 debug 输出 step tree
- )
- # 6. 输出结果
- print("=" * 60)
- print("模型响应:")
- print("=" * 60)
- print(result.reply)
- print("=" * 60)
- print()
- # 7. 保存结果
- output_file = output_dir / "result.txt"
- with open(output_file, 'w', encoding='utf-8') as f:
- f.write(result.reply)
- print(f"✓ 结果已保存到: {output_file}")
- print()
- # 8. 打印统计信息
- print("统计信息:")
- if result.tokens:
- print(f" 输入 tokens: {result.tokens.get('prompt', 0)}")
- print(f" 输出 tokens: {result.tokens.get('completion', 0)}")
- print(f" 费用: ${result.cost:.4f}")
- if __name__ == "__main__":
- asyncio.run(main())
|