test_runner.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. """
  2. AgentRunner 测试
  3. """
  4. import pytest
  5. from agent import (
  6. AgentRunner,
  7. AgentEvent,
  8. Trace,
  9. Step,
  10. Experience,
  11. Skill,
  12. tool,
  13. get_tool_registry,
  14. )
  15. from agent.storage import MemoryTraceStore, MemoryMemoryStore
  16. # 测试工具
  17. @tool(
  18. editable_params=["query"],
  19. display={"zh": {"name": "测试搜索", "params": {"query": "关键词"}}}
  20. )
  21. async def search_tool(query: str, limit: int = 10, uid: str = "") -> dict:
  22. """测试搜索工具"""
  23. return {"results": [f"结果: {query}"], "count": 1}
  24. # Mock LLM 调用
  25. async def mock_llm_call(
  26. messages: list,
  27. model: str = "gpt-4o",
  28. tools: list = None,
  29. **kwargs
  30. ) -> dict:
  31. """模拟 LLM 调用"""
  32. # 简单模拟:如果有工具,第一次调用返回 tool_call,第二次返回结果
  33. user_msg = messages[-1]["content"] if messages else ""
  34. if "搜索" in user_msg and tools:
  35. return {
  36. "content": "",
  37. "tool_calls": [{
  38. "id": "call_123",
  39. "function": {
  40. "name": "search_tool",
  41. "arguments": '{"query": "测试查询"}'
  42. }
  43. }],
  44. "prompt_tokens": 100,
  45. "completion_tokens": 50,
  46. "cost": 0.01
  47. }
  48. return {
  49. "content": f"回复: {user_msg}",
  50. "tool_calls": None,
  51. "prompt_tokens": 100,
  52. "completion_tokens": 50,
  53. "cost": 0.01
  54. }
  55. class TestTraceAndStep:
  56. """测试 Trace 和 Step"""
  57. def test_trace_create(self):
  58. trace = Trace.create(mode="call", uid="user123")
  59. assert trace.trace_id is not None
  60. assert trace.mode == "call"
  61. assert trace.uid == "user123"
  62. assert trace.status == "running"
  63. def test_step_create(self):
  64. step = Step.create(
  65. trace_id="trace_123",
  66. step_type="thought",
  67. sequence=0,
  68. status="completed",
  69. description="测试步骤",
  70. data={"content": "hello"}
  71. )
  72. assert step.step_id is not None
  73. assert step.trace_id == "trace_123"
  74. assert step.step_type == "thought"
  75. assert step.status == "completed"
  76. assert step.data["content"] == "hello"
  77. class TestMemoryStore:
  78. """测试内存存储"""
  79. @pytest.mark.asyncio
  80. async def test_trace_store(self):
  81. store = MemoryTraceStore()
  82. # 创建 Trace
  83. trace = Trace.create(mode="agent", task="测试任务")
  84. trace_id = await store.create_trace(trace)
  85. assert trace_id == trace.trace_id
  86. # 获取 Trace
  87. retrieved = await store.get_trace(trace_id)
  88. assert retrieved is not None
  89. assert retrieved.task == "测试任务"
  90. # 添加 Step
  91. step = Step.create(
  92. trace_id=trace_id,
  93. step_type="llm_call",
  94. sequence=0
  95. )
  96. await store.add_step(step)
  97. # 获取 Steps
  98. steps = await store.get_trace_steps(trace_id)
  99. assert len(steps) == 1
  100. @pytest.mark.asyncio
  101. async def test_memory_store(self):
  102. store = MemoryMemoryStore()
  103. # 添加 Experience
  104. exp = Experience.create(
  105. scope="agent:test",
  106. condition="测试条件",
  107. rule="测试规则"
  108. )
  109. exp_id = await store.add_experience(exp)
  110. assert exp_id == exp.exp_id
  111. # 搜索
  112. results = await store.search_experiences("agent:test", "")
  113. assert len(results) == 1
  114. class TestToolRegistry:
  115. """测试工具注册"""
  116. def test_tool_registered(self):
  117. registry = get_tool_registry()
  118. assert registry.is_registered("search_tool")
  119. def test_get_schemas(self):
  120. registry = get_tool_registry()
  121. schemas = registry.get_schemas(["search_tool"])
  122. assert len(schemas) == 1
  123. assert schemas[0]["function"]["name"] == "search_tool"
  124. @pytest.mark.asyncio
  125. async def test_execute_tool(self):
  126. registry = get_tool_registry()
  127. result = await registry.execute("search_tool", {"query": "hello"}, uid="test")
  128. assert "结果" in result
  129. class TestAgentRunner:
  130. """测试 AgentRunner"""
  131. @pytest.mark.asyncio
  132. async def test_call_simple(self):
  133. """测试简单调用"""
  134. runner = AgentRunner(
  135. trace_store=MemoryTraceStore(),
  136. llm_call=mock_llm_call
  137. )
  138. result = await runner.call(
  139. messages=[{"role": "user", "content": "你好"}],
  140. model="gpt-4o"
  141. )
  142. assert "你好" in result.reply
  143. assert result.trace_id is not None
  144. @pytest.mark.asyncio
  145. async def test_run_simple(self):
  146. """测试 Agent 运行"""
  147. runner = AgentRunner(
  148. trace_store=MemoryTraceStore(),
  149. memory_store=MemoryMemoryStore(),
  150. llm_call=mock_llm_call
  151. )
  152. events = []
  153. async for event in runner.run(
  154. task="简单任务",
  155. agent_type="test"
  156. ):
  157. events.append(event)
  158. # 检查事件序列
  159. event_types = [e.type for e in events]
  160. assert "trace_started" in event_types
  161. assert "trace_completed" in event_types
  162. @pytest.mark.asyncio
  163. async def test_run_with_tools(self):
  164. """测试带工具的 Agent 运行"""
  165. runner = AgentRunner(
  166. trace_store=MemoryTraceStore(),
  167. llm_call=mock_llm_call
  168. )
  169. events = []
  170. async for event in runner.run(
  171. task="请搜索相关内容",
  172. tools=["search_tool"],
  173. agent_type="test"
  174. ):
  175. events.append(event)
  176. event_types = [e.type for e in events]
  177. assert "tool_executing" in event_types
  178. assert "tool_result" in event_types
  179. @pytest.mark.asyncio
  180. async def test_add_feedback(self):
  181. """测试添加反馈"""
  182. trace_store = MemoryTraceStore()
  183. memory_store = MemoryMemoryStore()
  184. runner = AgentRunner(
  185. trace_store=trace_store,
  186. memory_store=memory_store,
  187. llm_call=mock_llm_call
  188. )
  189. # 先运行一个任务
  190. trace_id = None
  191. step_id = None
  192. async for event in runner.run(task="测试任务", agent_type="test"):
  193. if event.type == "trace_started":
  194. trace_id = event.data["trace_id"]
  195. if event.type == "llm_call_completed":
  196. step_id = event.data["step_id"]
  197. # 添加反馈
  198. exp_id = await runner.add_feedback(
  199. trace_id=trace_id,
  200. target_step_id=step_id,
  201. feedback_type="correction",
  202. content="应该这样做"
  203. )
  204. assert exp_id is not None
  205. # 验证经验被存储
  206. exp = await memory_store.get_experience(exp_id)
  207. assert exp is not None
  208. assert exp.rule == "应该这样做"
  209. if __name__ == "__main__":
  210. pytest.main([__file__, "-v"])