test_runner.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. """
  2. AgentRunner 测试
  3. """
  4. import pytest
  5. from agent import (
  6. AgentRunner,
  7. AgentEvent,
  8. Trace,
  9. Step,
  10. Experience,
  11. Skill,
  12. tool,
  13. get_tool_registry,
  14. )
  15. from agent.storage import MemoryTraceStore, MemoryMemoryStore
  16. # 测试工具
  17. @tool(
  18. editable_params=["query"],
  19. display={"zh": {"name": "测试搜索", "params": {"query": "关键词"}}}
  20. )
  21. async def search_tool(query: str, limit: int = 10, uid: str = "") -> dict:
  22. """测试搜索工具"""
  23. return {"results": [f"结果: {query}"], "count": 1}
  24. # Mock LLM 调用
  25. async def mock_llm_call(
  26. messages: list,
  27. model: str = "gpt-4o",
  28. tools: list = None,
  29. **kwargs
  30. ) -> dict:
  31. """模拟 LLM 调用"""
  32. # 简单模拟:如果有工具,第一次调用返回 tool_call,第二次返回结果
  33. user_msg = messages[-1]["content"] if messages else ""
  34. if "搜索" in user_msg and tools:
  35. return {
  36. "content": "",
  37. "tool_calls": [{
  38. "id": "call_123",
  39. "function": {
  40. "name": "search_tool",
  41. "arguments": '{"query": "测试查询"}'
  42. }
  43. }],
  44. "prompt_tokens": 100,
  45. "completion_tokens": 50,
  46. "cost": 0.01
  47. }
  48. return {
  49. "content": f"回复: {user_msg}",
  50. "tool_calls": None,
  51. "prompt_tokens": 100,
  52. "completion_tokens": 50,
  53. "cost": 0.01
  54. }
  55. class TestTraceAndStep:
  56. """测试 Trace 和 Step"""
  57. def test_trace_create(self):
  58. trace = Trace.create(mode="call", uid="user123")
  59. assert trace.trace_id is not None
  60. assert trace.mode == "call"
  61. assert trace.uid == "user123"
  62. assert trace.status == "running"
  63. def test_step_create(self):
  64. step = Step.create(
  65. trace_id="trace_123",
  66. step_type="llm_call",
  67. sequence=0,
  68. data={"response": "hello"}
  69. )
  70. assert step.step_id is not None
  71. assert step.trace_id == "trace_123"
  72. assert step.step_type == "llm_call"
  73. assert step.data["response"] == "hello"
  74. class TestMemoryStore:
  75. """测试内存存储"""
  76. @pytest.mark.asyncio
  77. async def test_trace_store(self):
  78. store = MemoryTraceStore()
  79. # 创建 Trace
  80. trace = Trace.create(mode="agent", task="测试任务")
  81. trace_id = await store.create_trace(trace)
  82. assert trace_id == trace.trace_id
  83. # 获取 Trace
  84. retrieved = await store.get_trace(trace_id)
  85. assert retrieved is not None
  86. assert retrieved.task == "测试任务"
  87. # 添加 Step
  88. step = Step.create(
  89. trace_id=trace_id,
  90. step_type="llm_call",
  91. sequence=0
  92. )
  93. await store.add_step(step)
  94. # 获取 Steps
  95. steps = await store.get_trace_steps(trace_id)
  96. assert len(steps) == 1
  97. @pytest.mark.asyncio
  98. async def test_memory_store(self):
  99. store = MemoryMemoryStore()
  100. # 添加 Experience
  101. exp = Experience.create(
  102. scope="agent:test",
  103. condition="测试条件",
  104. rule="测试规则"
  105. )
  106. exp_id = await store.add_experience(exp)
  107. assert exp_id == exp.exp_id
  108. # 搜索
  109. results = await store.search_experiences("agent:test", "")
  110. assert len(results) == 1
  111. class TestToolRegistry:
  112. """测试工具注册"""
  113. def test_tool_registered(self):
  114. registry = get_tool_registry()
  115. assert registry.is_registered("search_tool")
  116. def test_get_schemas(self):
  117. registry = get_tool_registry()
  118. schemas = registry.get_schemas(["search_tool"])
  119. assert len(schemas) == 1
  120. assert schemas[0]["function"]["name"] == "search_tool"
  121. @pytest.mark.asyncio
  122. async def test_execute_tool(self):
  123. registry = get_tool_registry()
  124. result = await registry.execute("search_tool", {"query": "hello"}, uid="test")
  125. assert "结果" in result
  126. class TestAgentRunner:
  127. """测试 AgentRunner"""
  128. @pytest.mark.asyncio
  129. async def test_call_simple(self):
  130. """测试简单调用"""
  131. runner = AgentRunner(
  132. trace_store=MemoryTraceStore(),
  133. llm_call=mock_llm_call
  134. )
  135. result = await runner.call(
  136. messages=[{"role": "user", "content": "你好"}],
  137. model="gpt-4o"
  138. )
  139. assert "你好" in result.reply
  140. assert result.trace_id is not None
  141. @pytest.mark.asyncio
  142. async def test_run_simple(self):
  143. """测试 Agent 运行"""
  144. runner = AgentRunner(
  145. trace_store=MemoryTraceStore(),
  146. memory_store=MemoryMemoryStore(),
  147. llm_call=mock_llm_call
  148. )
  149. events = []
  150. async for event in runner.run(
  151. task="简单任务",
  152. agent_type="test"
  153. ):
  154. events.append(event)
  155. # 检查事件序列
  156. event_types = [e.type for e in events]
  157. assert "trace_started" in event_types
  158. assert "trace_completed" in event_types
  159. @pytest.mark.asyncio
  160. async def test_run_with_tools(self):
  161. """测试带工具的 Agent 运行"""
  162. runner = AgentRunner(
  163. trace_store=MemoryTraceStore(),
  164. llm_call=mock_llm_call
  165. )
  166. events = []
  167. async for event in runner.run(
  168. task="请搜索相关内容",
  169. tools=["search_tool"],
  170. agent_type="test"
  171. ):
  172. events.append(event)
  173. event_types = [e.type for e in events]
  174. assert "tool_executing" in event_types
  175. assert "tool_result" in event_types
  176. @pytest.mark.asyncio
  177. async def test_add_feedback(self):
  178. """测试添加反馈"""
  179. trace_store = MemoryTraceStore()
  180. memory_store = MemoryMemoryStore()
  181. runner = AgentRunner(
  182. trace_store=trace_store,
  183. memory_store=memory_store,
  184. llm_call=mock_llm_call
  185. )
  186. # 先运行一个任务
  187. trace_id = None
  188. step_id = None
  189. async for event in runner.run(task="测试任务", agent_type="test"):
  190. if event.type == "trace_started":
  191. trace_id = event.data["trace_id"]
  192. if event.type == "llm_call_completed":
  193. step_id = event.data["step_id"]
  194. # 添加反馈
  195. exp_id = await runner.add_feedback(
  196. trace_id=trace_id,
  197. target_step_id=step_id,
  198. feedback_type="correction",
  199. content="应该这样做"
  200. )
  201. assert exp_id is not None
  202. # 验证经验被存储
  203. exp = await memory_store.get_experience(exp_id)
  204. assert exp is not None
  205. assert exp.rule == "应该这样做"
  206. if __name__ == "__main__":
  207. pytest.main([__file__, "-v"])