| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- """
- AgentRunner 测试
- """
- import pytest
- from agent import (
- AgentRunner,
- AgentEvent,
- Trace,
- Step,
- Experience,
- Skill,
- tool,
- get_tool_registry,
- )
- from agent.storage import MemoryTraceStore, MemoryMemoryStore
- # 测试工具
- @tool(
- editable_params=["query"],
- display={"zh": {"name": "测试搜索", "params": {"query": "关键词"}}}
- )
- async def search_tool(query: str, limit: int = 10, uid: str = "") -> dict:
- """测试搜索工具"""
- return {"results": [f"结果: {query}"], "count": 1}
- # Mock LLM 调用
- async def mock_llm_call(
- messages: list,
- model: str = "gpt-4o",
- tools: list = None,
- **kwargs
- ) -> dict:
- """模拟 LLM 调用"""
- # 简单模拟:如果有工具,第一次调用返回 tool_call,第二次返回结果
- user_msg = messages[-1]["content"] if messages else ""
- if "搜索" in user_msg and tools:
- return {
- "content": "",
- "tool_calls": [{
- "id": "call_123",
- "function": {
- "name": "search_tool",
- "arguments": '{"query": "测试查询"}'
- }
- }],
- "prompt_tokens": 100,
- "completion_tokens": 50,
- "cost": 0.01
- }
- return {
- "content": f"回复: {user_msg}",
- "tool_calls": None,
- "prompt_tokens": 100,
- "completion_tokens": 50,
- "cost": 0.01
- }
- class TestTraceAndStep:
- """测试 Trace 和 Step"""
- def test_trace_create(self):
- trace = Trace.create(mode="call", uid="user123")
- assert trace.trace_id is not None
- assert trace.mode == "call"
- assert trace.uid == "user123"
- assert trace.status == "running"
- def test_step_create(self):
- step = Step.create(
- trace_id="trace_123",
- step_type="thought",
- sequence=0,
- status="completed",
- description="测试步骤",
- data={"content": "hello"}
- )
- assert step.step_id is not None
- assert step.trace_id == "trace_123"
- assert step.step_type == "thought"
- assert step.status == "completed"
- assert step.data["content"] == "hello"
- class TestMemoryStore:
- """测试内存存储"""
- @pytest.mark.asyncio
- async def test_trace_store(self):
- store = MemoryTraceStore()
- # 创建 Trace
- trace = Trace.create(mode="agent", task="测试任务")
- trace_id = await store.create_trace(trace)
- assert trace_id == trace.trace_id
- # 获取 Trace
- retrieved = await store.get_trace(trace_id)
- assert retrieved is not None
- assert retrieved.task == "测试任务"
- # 添加 Step
- step = Step.create(
- trace_id=trace_id,
- step_type="llm_call",
- sequence=0
- )
- await store.add_step(step)
- # 获取 Steps
- steps = await store.get_trace_steps(trace_id)
- assert len(steps) == 1
- @pytest.mark.asyncio
- async def test_memory_store(self):
- store = MemoryMemoryStore()
- # 添加 Experience
- exp = Experience.create(
- scope="agent:test",
- condition="测试条件",
- rule="测试规则"
- )
- exp_id = await store.add_experience(exp)
- assert exp_id == exp.exp_id
- # 搜索
- results = await store.search_experiences("agent:test", "")
- assert len(results) == 1
- class TestToolRegistry:
- """测试工具注册"""
- def test_tool_registered(self):
- registry = get_tool_registry()
- assert registry.is_registered("search_tool")
- def test_get_schemas(self):
- registry = get_tool_registry()
- schemas = registry.get_schemas(["search_tool"])
- assert len(schemas) == 1
- assert schemas[0]["function"]["name"] == "search_tool"
- @pytest.mark.asyncio
- async def test_execute_tool(self):
- registry = get_tool_registry()
- result = await registry.execute("search_tool", {"query": "hello"}, uid="test")
- assert "结果" in result
- class TestAgentRunner:
- """测试 AgentRunner"""
- @pytest.mark.asyncio
- async def test_call_simple(self):
- """测试简单调用"""
- runner = AgentRunner(
- trace_store=MemoryTraceStore(),
- llm_call=mock_llm_call
- )
- result = await runner.call(
- messages=[{"role": "user", "content": "你好"}],
- model="gpt-4o"
- )
- assert "你好" in result.reply
- assert result.trace_id is not None
- @pytest.mark.asyncio
- async def test_run_simple(self):
- """测试 Agent 运行"""
- runner = AgentRunner(
- trace_store=MemoryTraceStore(),
- memory_store=MemoryMemoryStore(),
- llm_call=mock_llm_call
- )
- events = []
- async for event in runner.run(
- task="简单任务",
- agent_type="test"
- ):
- events.append(event)
- # 检查事件序列
- event_types = [e.type for e in events]
- assert "trace_started" in event_types
- assert "trace_completed" in event_types
- @pytest.mark.asyncio
- async def test_run_with_tools(self):
- """测试带工具的 Agent 运行"""
- runner = AgentRunner(
- trace_store=MemoryTraceStore(),
- llm_call=mock_llm_call
- )
- events = []
- async for event in runner.run(
- task="请搜索相关内容",
- tools=["search_tool"],
- agent_type="test"
- ):
- events.append(event)
- event_types = [e.type for e in events]
- assert "tool_executing" in event_types
- assert "tool_result" in event_types
- @pytest.mark.asyncio
- async def test_add_feedback(self):
- """测试添加反馈"""
- trace_store = MemoryTraceStore()
- memory_store = MemoryMemoryStore()
- runner = AgentRunner(
- trace_store=trace_store,
- memory_store=memory_store,
- llm_call=mock_llm_call
- )
- # 先运行一个任务
- trace_id = None
- step_id = None
- async for event in runner.run(task="测试任务", agent_type="test"):
- if event.type == "trace_started":
- trace_id = event.data["trace_id"]
- if event.type == "llm_call_completed":
- step_id = event.data["step_id"]
- # 添加反馈
- exp_id = await runner.add_feedback(
- trace_id=trace_id,
- target_step_id=step_id,
- feedback_type="correction",
- content="应该这样做"
- )
- assert exp_id is not None
- # 验证经验被存储
- exp = await memory_store.get_experience(exp_id)
- assert exp is not None
- assert exp.rule == "应该这样做"
- if __name__ == "__main__":
- pytest.main([__file__, "-v"])
|