run.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. """
  2. 统一 subagent 工具集成测试(mock LLM)。
  3. """
  4. import asyncio
  5. import os
  6. import sys
  7. from pathlib import Path
  8. from tempfile import TemporaryDirectory
  9. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
  10. from agent.core.runner import AgentRunner
  11. from agent.trace.store import FileSystemTraceStore
  12. from agent.trace.models import Trace
  13. from agent.trace.goal_models import GoalTree
  14. from agent.tools.builtin.subagent import subagent
  15. async def mock_llm_call(messages, model="gpt-4o", tools=None, **kwargs):
  16. last_user = ""
  17. for msg in reversed(messages):
  18. if msg.get("role") == "user":
  19. last_user = str(msg.get("content", ""))
  20. break
  21. if "# 评估任务" in last_user:
  22. content = "## 评估结论\n通过\n\n## 评估理由\n满足需求。"
  23. elif "# 探索任务" in last_user:
  24. content = "探索完成:建议优先采用方案 1。"
  25. else:
  26. content = "委托任务已完成。"
  27. return {
  28. "content": content,
  29. "tool_calls": None,
  30. "finish_reason": "stop",
  31. "prompt_tokens": 10,
  32. "completion_tokens": 10,
  33. "cost": 0.0,
  34. }
  35. async def run_case():
  36. with TemporaryDirectory(prefix="subagent-unified-") as tmp_dir:
  37. store = FileSystemTraceStore(base_path=tmp_dir)
  38. runner = AgentRunner(trace_store=store, llm_call=mock_llm_call)
  39. # 创建主 Trace 与 GoalTree(供 subagent 作为父上下文)
  40. main_trace = Trace(
  41. trace_id="main-trace",
  42. mode="agent",
  43. task="主任务",
  44. agent_type="default",
  45. status="running",
  46. )
  47. await store.create_trace(main_trace)
  48. goal_tree = GoalTree(mission="主任务")
  49. new_goals = goal_tree.add_goals(["实现主流程"])
  50. goal_tree.focus(new_goals[0].id)
  51. await store.update_goal_tree(main_trace.trace_id, goal_tree)
  52. context = {
  53. "store": store,
  54. "trace_id": main_trace.trace_id,
  55. "goal_id": new_goals[0].id,
  56. "runner": runner,
  57. }
  58. # 1) delegate
  59. delegate_result = await subagent(
  60. mode="delegate",
  61. task="实现用户登录功能",
  62. context=context,
  63. )
  64. assert delegate_result["status"] == "completed", delegate_result
  65. assert delegate_result["summary"], delegate_result
  66. delegate_trace = await store.get_trace(delegate_result["sub_trace_id"])
  67. assert delegate_trace is not None
  68. assert delegate_trace.parent_trace_id == main_trace.trace_id
  69. assert delegate_trace.parent_goal_id == new_goals[0].id
  70. # 2) explore
  71. explore_result = await subagent(
  72. mode="explore",
  73. branches=["JWT 方案", "Session 方案"],
  74. background="请比较维护成本和安全性。",
  75. context=context,
  76. )
  77. assert explore_result["status"] == "completed", explore_result
  78. assert "探索" in explore_result["summary"], explore_result
  79. # 3) evaluate
  80. evaluate_result = await subagent(
  81. mode="evaluate",
  82. target_goal_id=new_goals[0].id,
  83. evaluation_input={"actual_result": "已实现登录接口并通过单元测试"},
  84. requirements="请评估是否满足安全和可维护性要求。",
  85. context=context,
  86. )
  87. assert evaluate_result["status"] == "completed", evaluate_result
  88. assert "评估结论" in evaluate_result["summary"], evaluate_result
  89. # 4) continue_from
  90. continue_result = await subagent(
  91. mode="delegate",
  92. task="继续补充边界条件处理",
  93. continue_from=delegate_result["sub_trace_id"],
  94. context=context,
  95. )
  96. assert continue_result["status"] == "completed", continue_result
  97. assert continue_result["sub_trace_id"] == delegate_result["sub_trace_id"]
  98. assert continue_result["continue_from"] is True
  99. print("✅ unified subagent tests passed")
  100. print(f"delegate: {delegate_result['sub_trace_id']}")
  101. print(f"explore : {explore_result['sub_trace_id']}")
  102. print(f"evaluate: {evaluate_result['sub_trace_id']}")
  103. def main():
  104. asyncio.run(run_case())
  105. if __name__ == "__main__":
  106. main()