test_plan.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. """
  2. 测试新的 Plan 系统
  3. 测试 GoalTree、Message、TraceStore 的基本功能
  4. """
  5. import asyncio
  6. import sys
  7. from pathlib import Path
  8. # 添加项目根目录到 Python 路径
  9. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  10. from agent.goal.models import GoalTree, Goal, GoalStats
  11. from agent.execution.models import Trace, Message
  12. from agent.execution.fs_store import FileSystemTraceStore
  13. from agent.goal.tool import goal_tool
  14. async def test_basic_plan():
  15. """测试基本的计划功能"""
  16. print("=" * 60)
  17. print("测试 1: 基本计划功能")
  18. print("=" * 60)
  19. print()
  20. # 1. 创建 GoalTree
  21. tree = GoalTree(mission="实现用户认证功能")
  22. print("1. 创建 GoalTree")
  23. print(f" Mission: {tree.mission}")
  24. print()
  25. # 2. 添加顶层目标
  26. print("2. 添加顶层目标")
  27. result = goal_tool(tree, add="分析代码, 实现功能, 测试")
  28. print(result)
  29. print()
  30. # 3. Focus 到目标 2
  31. print("3. Focus 到目标 2")
  32. result = goal_tool(tree, focus="2")
  33. print(result)
  34. print()
  35. # 4. 添加子目标
  36. print("4. 在目标 2 下添加子目标")
  37. result = goal_tool(tree, add="设计接口, 实现代码, 单元测试")
  38. print(result)
  39. print()
  40. # 5. Focus 到子目标并完成
  41. print("5. Focus 到 2.1 并完成")
  42. result = goal_tool(tree, focus="3") # 内部 ID 是 3("2.1" 的内部 ID)
  43. print(result)
  44. print()
  45. # 通过显示 ID 完成
  46. print("6. 完成目标(使用内部 ID)")
  47. result = goal_tool(tree, done="API 接口设计完成,定义了登录和注册端点")
  48. print(result)
  49. print()
  50. # 7. 测试 abandon
  51. print("7. Focus 到 4(2.2)并放弃")
  52. result = goal_tool(tree, focus="4")
  53. print(result)
  54. print()
  55. result = goal_tool(tree, abandon="发现需求变更,需要改用 OAuth")
  56. print(result)
  57. print()
  58. # 8. 添加新方案
  59. print("8. 添加新的实现方案")
  60. result = goal_tool(tree, add="实现 OAuth 认证")
  61. print(result)
  62. print()
  63. # 9. 查看最终状态
  64. print("9. 查看完整计划(包含废弃目标)")
  65. print(tree.to_prompt(include_abandoned=True))
  66. print()
  67. # 10. 查看过滤后的计划(默认不显示废弃目标)
  68. print("10. 查看过滤后的计划(默认)")
  69. print(tree.to_prompt())
  70. print()
  71. async def test_trace_store():
  72. """测试 TraceStore 的 GoalTree 和 Message 存储"""
  73. print("=" * 60)
  74. print("测试 2: TraceStore 存储功能")
  75. print("=" * 60)
  76. print()
  77. # 创建 TraceStore
  78. store = FileSystemTraceStore(base_path=".trace_test")
  79. # 1. 创建 Trace
  80. print("1. 创建 Trace")
  81. trace = Trace.create(
  82. mode="agent",
  83. task="测试任务"
  84. )
  85. await store.create_trace(trace)
  86. print(f" Trace ID: {trace.trace_id[:8]}...")
  87. print()
  88. # 2. 创建并保存 GoalTree
  89. print("2. 创建并保存 GoalTree")
  90. tree = GoalTree(mission="测试任务")
  91. tree.add_goals(["分析", "实现", "测试"])
  92. await store.update_goal_tree(trace.trace_id, tree)
  93. print(f" 添加了 {len(tree.goals)} 个目标")
  94. print()
  95. # 3. 添加 Messages
  96. print("3. 添加 Messages")
  97. # Focus 到第一个目标
  98. tree.focus("1")
  99. await store.update_goal_tree(trace.trace_id, tree)
  100. # 添加 assistant message
  101. msg1 = Message.create(
  102. trace_id=trace.trace_id,
  103. role="assistant",
  104. sequence=1,
  105. goal_id="1",
  106. content={"text": "开始分析代码", "tool_calls": [
  107. {
  108. "id": "call_1",
  109. "function": {
  110. "name": "read_file",
  111. "arguments": '{"path": "src/main.py"}'
  112. }
  113. }
  114. ]},
  115. tokens=100,
  116. cost=0.002
  117. )
  118. await store.add_message(msg1)
  119. print(f" Message 1: {msg1.description}")
  120. # 添加 tool message
  121. msg2 = Message.create(
  122. trace_id=trace.trace_id,
  123. role="tool",
  124. sequence=2,
  125. goal_id="1",
  126. tool_call_id="call_1",
  127. content={"tool_name": "read_file", "result": "文件内容..."},
  128. tokens=50,
  129. cost=0.001
  130. )
  131. await store.add_message(msg2)
  132. print(f" Message 2: {msg2.description}")
  133. print()
  134. # 4. 查看更新后的 GoalTree(stats 应该自动更新)
  135. print("4. 查看更新后的 GoalTree(含 stats)")
  136. updated_tree = await store.get_goal_tree(trace.trace_id)
  137. goal1 = updated_tree.find("1")
  138. print(f" Goal 1 stats:")
  139. print(f" - message_count: {goal1.self_stats.message_count}")
  140. print(f" - total_tokens: {goal1.self_stats.total_tokens}")
  141. print(f" - total_cost: ${goal1.self_stats.total_cost:.4f}")
  142. print()
  143. # 5. 添加子目标和 Message
  144. print("5. 添加子目标和 Message")
  145. updated_tree.add_goals(["读取配置", "解析代码"], parent_id="1")
  146. updated_tree.focus("4") # Focus 到第一个子目标
  147. await store.update_goal_tree(trace.trace_id, updated_tree)
  148. msg3 = Message.create(
  149. trace_id=trace.trace_id,
  150. role="assistant",
  151. sequence=3,
  152. goal_id="4",
  153. content={"text": "读取配置文件"},
  154. tokens=80,
  155. cost=0.0015
  156. )
  157. await store.add_message(msg3)
  158. print(f" 添加子目标 Message: {msg3.description}")
  159. print()
  160. # 6. 查看累计 stats(父节点应该包含子节点的统计)
  161. print("6. 查看累计 stats")
  162. updated_tree = await store.get_goal_tree(trace.trace_id)
  163. goal1 = updated_tree.find("1")
  164. print(f" Goal 1 cumulative stats:")
  165. print(f" - message_count: {goal1.cumulative_stats.message_count}")
  166. print(f" - total_tokens: {goal1.cumulative_stats.total_tokens}")
  167. print(f" - total_cost: ${goal1.cumulative_stats.total_cost:.4f}")
  168. print()
  169. # 7. 查看 Messages
  170. print("7. 查询 Messages")
  171. all_messages = await store.get_trace_messages(trace.trace_id)
  172. print(f" 总共 {len(all_messages)} 条 Messages")
  173. goal1_messages = await store.get_messages_by_goal(trace.trace_id, "1")
  174. print(f" Goal 1 的 Messages: {len(goal1_messages)} 条")
  175. print()
  176. # 8. 显示完整 GoalTree
  177. print("8. 完整 GoalTree")
  178. print(updated_tree.to_prompt())
  179. print()
  180. # 9. 测试级联完成
  181. print("9. 测试级联完成")
  182. updated_tree.focus("4")
  183. updated_tree.complete("4", "配置读取完成")
  184. updated_tree.focus("5")
  185. updated_tree.complete("5", "代码解析完成")
  186. # 检查父节点是否自动完成
  187. goal1 = updated_tree.find("1")
  188. print(f" Goal 1 status: {goal1.status}")
  189. print(f" Goal 1 summary: {goal1.summary}")
  190. print()
  191. print("✅ TraceStore 测试完成!")
  192. print(f" 数据保存在: .trace_test/{trace.trace_id[:8]}...")
  193. print()
  194. async def test_display_ids():
  195. """测试显示 ID 的生成"""
  196. print("=" * 60)
  197. print("测试 3: 显示 ID 生成")
  198. print("=" * 60)
  199. print()
  200. tree = GoalTree(mission="测试显示 ID")
  201. # 添加多层嵌套目标
  202. tree.add_goals(["A", "B", "C"])
  203. tree.focus("2")
  204. tree.add_goals(["B1", "B2"])
  205. tree.focus("4")
  206. tree.add_goals(["B1-1", "B1-2"])
  207. print("完整结构:")
  208. print(tree.to_prompt())
  209. print()
  210. # 测试 abandon 后的重新编号
  211. print("放弃 B1-1 后:")
  212. tree.focus("6")
  213. tree.abandon("6", "测试废弃")
  214. print(tree.to_prompt())
  215. print()
  216. print("包含废弃目标的完整视图:")
  217. print(tree.to_prompt(include_abandoned=True))
  218. print()
  219. async def main():
  220. """运行所有测试"""
  221. await test_basic_plan()
  222. await test_trace_store()
  223. await test_display_ids()
  224. print("=" * 60)
  225. print("所有测试完成!")
  226. print("=" * 60)
  227. if __name__ == "__main__":
  228. asyncio.run(main())