| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162 |
- """
- 测试新的 Goal + Message 系统(不需要 LLM)
- 验证 GoalTree 和工具注册是否正常工作
- """
- import asyncio
- from agent.goal.models import GoalTree
- from agent.execution.models import Trace, Message
- from agent.execution.fs_store import FileSystemTraceStore
- from agent.tools import get_tool_registry
- async def test_goal_tree_and_messages():
- """测试 GoalTree 和 Message 集成"""
- print("=" * 60)
- print("测试 Goal + Message 系统集成")
- print("=" * 60)
- print()
- # 1. 验证 goal 工具已注册
- print("1. 检查工具注册")
- registry = get_tool_registry()
- tools = registry.get_tool_names()
- print(f" 已注册 {len(tools)} 个工具")
- if "goal" in tools:
- print(" ✓ goal 工具已注册")
- else:
- print(" ✗ goal 工具未注册")
- return
- # 获取 goal 工具的 schema
- goal_schema = registry.get_schemas(["goal"])[0]
- print(f" - 工具名称: {goal_schema.get('name')}")
- print(f" - 描述: {goal_schema.get('description', '')[:60]}...")
- print()
- # 2. 创建 Trace 和 GoalTree
- print("2. 创建 Trace 和 GoalTree")
- store = FileSystemTraceStore(base_path=".trace_test")
- trace = Trace.create(mode="agent", task="测试任务:实现用户认证")
- await store.create_trace(trace)
- print(f" Trace ID: {trace.trace_id[:8]}...")
- tree = GoalTree(mission="测试任务:实现用户认证")
- await store.update_goal_tree(trace.trace_id, tree)
- print(f" GoalTree 创建成功")
- print()
- # 3. 测试 goal 工具调用(模拟 LLM 调用)
- print("3. 测试 goal 工具调用")
- # 设置 goal_tree
- from agent.tools.builtin.goal import set_goal_tree
- set_goal_tree(tree)
- # 模拟调用 goal 工具添加目标
- result1 = await registry.execute("goal", {"add": "分析代码, 设计方案, 实现功能"}, uid="test")
- print(" 调用: goal(add='分析代码, 设计方案, 实现功能')")
- print(f" 结果: {result1[:200]}...")
- print()
- # 保存更新后的 tree
- await store.update_goal_tree(trace.trace_id, tree)
- # 模拟调用 goal 工具 focus
- result2 = await registry.execute("goal", {"focus": "1"}, uid="test")
- print(" 调用: goal(focus='1')")
- print(f" 结果: {result2[:150]}...")
- print()
- # 4. 添加 Messages
- print("4. 添加 Messages 并自动更新 stats")
- # 添加 assistant message(与 goal 1 关联)
- msg1 = Message.create(
- trace_id=trace.trace_id,
- role="assistant",
- sequence=1,
- goal_id="1",
- content={"text": "开始分析代码结构", "tool_calls": []},
- tokens=50,
- cost=0.001
- )
- await store.add_message(msg1)
- print(f" Message 1: {msg1.description}")
- # 添加更多 messages
- msg2 = Message.create(
- trace_id=trace.trace_id,
- role="assistant",
- sequence=2,
- goal_id="1",
- content={"text": "", "tool_calls": [
- {"id": "call1", "function": {"name": "read_file", "arguments": "{}"}}
- ]},
- tokens=30,
- cost=0.0006
- )
- await store.add_message(msg2)
- print(f" Message 2: {msg2.description}")
- msg3 = Message.create(
- trace_id=trace.trace_id,
- role="tool",
- sequence=3,
- goal_id="1",
- tool_call_id="call1",
- content={"tool_name": "read_file", "result": "文件内容..."},
- tokens=100,
- cost=0.002
- )
- await store.add_message(msg3)
- print(f" Message 3: {msg3.description}")
- print()
- # 5. 查看自动更新的 stats
- print("5. 查看自动更新的 Goal stats")
- updated_tree = await store.get_goal_tree(trace.trace_id)
- goal1 = updated_tree.find("1")
- print(f" Goal 1: {goal1.description}")
- print(f" - status: {goal1.status}")
- print(f" - self_stats.message_count: {goal1.self_stats.message_count}")
- print(f" - self_stats.total_tokens: {goal1.self_stats.total_tokens}")
- print(f" - self_stats.total_cost: ${goal1.self_stats.total_cost:.4f}")
- print()
- # 6. 完成 goal
- print("6. 完成 goal 并查看更新")
- result3 = await registry.execute("goal", {"done": "代码分析完成,找到认证模块"}, uid="test")
- await store.update_goal_tree(trace.trace_id, tree)
- print(f" 调用: goal(done='...')")
- print(f" 结果: {result3[:200]}...")
- print()
- # 7. 显示最终的 GoalTree
- print("7. 最终 GoalTree 状态")
- final_tree = await store.get_goal_tree(trace.trace_id)
- print(final_tree.to_prompt())
- print()
- # 8. 查看 Trace 统计
- print("8. Trace 统计")
- final_trace = await store.get_trace(trace.trace_id)
- print(f" total_messages: {final_trace.total_messages}")
- print(f" total_tokens: {final_trace.total_tokens}")
- print(f" total_cost: ${final_trace.total_cost:.4f}")
- print()
- print("=" * 60)
- print("✅ 测试完成!所有功能正常工作")
- print("=" * 60)
- print()
- print(f"数据已保存到: .trace_test/{trace.trace_id[:8]}...")
- if __name__ == "__main__":
- asyncio.run(test_goal_tree_and_messages())
|