| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- """
- 特征提取示例
- 使用 Agent 框架 + Prompt loader + 多模态支持
- """
- import os
- import sys
- import asyncio
- from pathlib import Path
- # 添加项目根目录到 Python 路径
- sys.path.insert(0, str(Path(__file__).parent.parent.parent))
- from dotenv import load_dotenv
- load_dotenv()
- from agent.prompts import SimplePrompt
- from agent.runner import AgentRunner
- from agent.llm.providers.gemini import create_gemini_llm_call
- async def main():
- # 路径配置
- base_dir = Path(__file__).parent
- prompt_path = base_dir / "test.prompt"
- feature_md_path = base_dir / "input_1" / "feature.md"
- image_path = base_dir / "input_1" / "image.png"
- output_dir = base_dir / "output_1"
- output_dir.mkdir(exist_ok=True)
- print("=" * 60)
- print("特征提取任务")
- print("=" * 60)
- print()
- # 1. 加载 prompt
- print("1. 加载 prompt...")
- prompt = SimplePrompt(prompt_path)
- # 2. 读取特征描述
- print("2. 读取特征描述...")
- with open(feature_md_path, 'r', encoding='utf-8') as f:
- feature_text = f.read()
- # 3. 构建多模态消息
- print("3. 构建多模态消息(文本 + 图片)...")
- messages = prompt.build_messages(
- text=feature_text,
- images=image_path # 框架自动处理图片
- )
- print(f" - 消息数量: {len(messages)}")
- print(f" - 图片: {image_path.name}")
- # 4. 创建 Agent Runner
- print("4. 创建 Agent Runner...")
- runner = AgentRunner(
- llm_call=create_gemini_llm_call()
- )
- # 5. 调用 Agent
- print(f"5. 调用模型: {prompt.config.get('model', 'gemini-2.5-flash')}...")
- print()
- result = await runner.call(
- messages=messages,
- model=prompt.config.get('model', 'gemini-2.5-flash'),
- temperature=float(prompt.config.get('temperature', 0.3)),
- trace=False # 暂不记录 trace
- )
- # 6. 输出结果
- print("=" * 60)
- print("模型响应:")
- print("=" * 60)
- print(result.reply)
- print("=" * 60)
- print()
- # 7. 保存结果
- output_file = output_dir / "result.txt"
- with open(output_file, 'w', encoding='utf-8') as f:
- f.write(result.reply)
- print(f"✓ 结果已保存到: {output_file}")
- print()
- # 8. 打印统计信息
- print("统计信息:")
- if result.tokens:
- print(f" 输入 tokens: {result.tokens.get('prompt', 0)}")
- print(f" 输出 tokens: {result.tokens.get('completion', 0)}")
- print(f" 费用: ${result.cost:.4f}")
- if __name__ == "__main__":
- asyncio.run(main())
|