test_integration.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. """
  2. 测试新的 Goal + Message 系统(不需要 LLM)
  3. 验证 GoalTree 和工具注册是否正常工作
  4. """
  5. import asyncio
  6. from agent.goal.models import GoalTree
  7. from agent.execution.models import Trace, Message
  8. from agent.execution.fs_store import FileSystemTraceStore
  9. from agent.tools import get_tool_registry
  10. async def test_goal_tree_and_messages():
  11. """测试 GoalTree 和 Message 集成"""
  12. print("=" * 60)
  13. print("测试 Goal + Message 系统集成")
  14. print("=" * 60)
  15. print()
  16. # 1. 验证 goal 工具已注册
  17. print("1. 检查工具注册")
  18. registry = get_tool_registry()
  19. tools = registry.get_tool_names()
  20. print(f" 已注册 {len(tools)} 个工具")
  21. if "goal" in tools:
  22. print(" ✓ goal 工具已注册")
  23. else:
  24. print(" ✗ goal 工具未注册")
  25. return
  26. # 获取 goal 工具的 schema
  27. goal_schema = registry.get_schemas(["goal"])[0]
  28. print(f" - 工具名称: {goal_schema.get('name')}")
  29. print(f" - 描述: {goal_schema.get('description', '')[:60]}...")
  30. print()
  31. # 2. 创建 Trace 和 GoalTree
  32. print("2. 创建 Trace 和 GoalTree")
  33. store = FileSystemTraceStore(base_path=".trace_test")
  34. trace = Trace.create(mode="agent", task="测试任务:实现用户认证")
  35. await store.create_trace(trace)
  36. print(f" Trace ID: {trace.trace_id[:8]}...")
  37. tree = GoalTree(mission="测试任务:实现用户认证")
  38. await store.update_goal_tree(trace.trace_id, tree)
  39. print(f" GoalTree 创建成功")
  40. print()
  41. # 3. 测试 goal 工具调用(模拟 LLM 调用)
  42. print("3. 测试 goal 工具调用")
  43. # 设置 goal_tree
  44. from agent.tools.builtin.goal import set_goal_tree
  45. set_goal_tree(tree)
  46. # 模拟调用 goal 工具添加目标
  47. result1 = await registry.execute("goal", {"add": "分析代码, 设计方案, 实现功能"}, uid="test")
  48. print(" 调用: goal(add='分析代码, 设计方案, 实现功能')")
  49. print(f" 结果: {result1[:200]}...")
  50. print()
  51. # 保存更新后的 tree
  52. await store.update_goal_tree(trace.trace_id, tree)
  53. # 模拟调用 goal 工具 focus
  54. result2 = await registry.execute("goal", {"focus": "1"}, uid="test")
  55. print(" 调用: goal(focus='1')")
  56. print(f" 结果: {result2[:150]}...")
  57. print()
  58. # 4. 添加 Messages
  59. print("4. 添加 Messages 并自动更新 stats")
  60. # 添加 assistant message(与 goal 1 关联)
  61. msg1 = Message.create(
  62. trace_id=trace.trace_id,
  63. role="assistant",
  64. sequence=1,
  65. goal_id="1",
  66. content={"text": "开始分析代码结构", "tool_calls": []},
  67. tokens=50,
  68. cost=0.001
  69. )
  70. await store.add_message(msg1)
  71. print(f" Message 1: {msg1.description}")
  72. # 添加更多 messages
  73. msg2 = Message.create(
  74. trace_id=trace.trace_id,
  75. role="assistant",
  76. sequence=2,
  77. goal_id="1",
  78. content={"text": "", "tool_calls": [
  79. {"id": "call1", "function": {"name": "read_file", "arguments": "{}"}}
  80. ]},
  81. tokens=30,
  82. cost=0.0006
  83. )
  84. await store.add_message(msg2)
  85. print(f" Message 2: {msg2.description}")
  86. msg3 = Message.create(
  87. trace_id=trace.trace_id,
  88. role="tool",
  89. sequence=3,
  90. goal_id="1",
  91. tool_call_id="call1",
  92. content={"tool_name": "read_file", "result": "文件内容..."},
  93. tokens=100,
  94. cost=0.002
  95. )
  96. await store.add_message(msg3)
  97. print(f" Message 3: {msg3.description}")
  98. print()
  99. # 5. 查看自动更新的 stats
  100. print("5. 查看自动更新的 Goal stats")
  101. updated_tree = await store.get_goal_tree(trace.trace_id)
  102. goal1 = updated_tree.find("1")
  103. print(f" Goal 1: {goal1.description}")
  104. print(f" - status: {goal1.status}")
  105. print(f" - self_stats.message_count: {goal1.self_stats.message_count}")
  106. print(f" - self_stats.total_tokens: {goal1.self_stats.total_tokens}")
  107. print(f" - self_stats.total_cost: ${goal1.self_stats.total_cost:.4f}")
  108. print()
  109. # 6. 完成 goal
  110. print("6. 完成 goal 并查看更新")
  111. result3 = await registry.execute("goal", {"done": "代码分析完成,找到认证模块"}, uid="test")
  112. await store.update_goal_tree(trace.trace_id, tree)
  113. print(f" 调用: goal(done='...')")
  114. print(f" 结果: {result3[:200]}...")
  115. print()
  116. # 7. 显示最终的 GoalTree
  117. print("7. 最终 GoalTree 状态")
  118. final_tree = await store.get_goal_tree(trace.trace_id)
  119. print(final_tree.to_prompt())
  120. print()
  121. # 8. 查看 Trace 统计
  122. print("8. Trace 统计")
  123. final_trace = await store.get_trace(trace.trace_id)
  124. print(f" total_messages: {final_trace.total_messages}")
  125. print(f" total_tokens: {final_trace.total_tokens}")
  126. print(f" total_cost: ${final_trace.total_cost:.4f}")
  127. print()
  128. print("=" * 60)
  129. print("✅ 测试完成!所有功能正常工作")
  130. print("=" * 60)
  131. print()
  132. print(f"数据已保存到: .trace_test/{trace.trace_id[:8]}...")
  133. if __name__ == "__main__":
  134. asyncio.run(test_goal_tree_and_messages())