| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387 |
- #!/usr/bin/env python3
- """
- Agent Execution API 自动化测试
- 测试 REST API 和 WebSocket 接口是否正常工作
- 运行方式:
- python3 frontend/test_api.py
- 要求:
- - API Server 已启动(python3 api_server.py)
- - 有至少一个 Trace 数据(运行 examples/feature_extract/run.py)
- """
- import asyncio
- import json
- import sys
- from typing import Optional
- from datetime import datetime
- try:
- import httpx
- import websockets
- except ImportError:
- print("❌ 缺少依赖,请安装:")
- print(" pip install httpx websockets")
- sys.exit(1)
- # 配置
- API_BASE = "http://localhost:8000"
- WS_BASE = "ws://localhost:8000"
- class TestResult:
- """测试结果"""
- def __init__(self, name: str):
- self.name = name
- self.passed = False
- self.message = ""
- self.duration_ms = 0
- def success(self, message: str = ""):
- self.passed = True
- self.message = message
- def fail(self, message: str):
- self.passed = False
- self.message = message
- def __str__(self):
- icon = "✅" if self.passed else "❌"
- return f"{icon} {self.name} ({self.duration_ms}ms)" + (f"\n {self.message}" if self.message else "")
- class APITester:
- """API 测试器"""
- def __init__(self):
- self.results = []
- self.test_trace_id: Optional[str] = None
- async def run_all_tests(self):
- """运行所有测试"""
- print("🧪 Agent Execution API 自动化测试")
- print("=" * 60)
- print()
- # 检查 API Server
- if not await self.check_server():
- print("❌ API Server 未启动,请先运行:")
- print(" python3 api_server.py")
- return False
- print("✅ API Server 已启动\n")
- # REST API 测试
- print("📡 测试 REST API")
- print("-" * 60)
- await self.test_list_traces()
- await self.test_get_trace()
- await self.test_get_tree()
- print()
- # WebSocket 测试
- print("⚡ 测试 WebSocket")
- print("-" * 60)
- await self.test_websocket_connect()
- await self.test_websocket_events()
- await self.test_websocket_reconnect()
- print()
- # 汇总结果
- self.print_summary()
- return all(r.passed for r in self.results)
- async def check_server(self) -> bool:
- """检查 API Server 是否运行"""
- try:
- async with httpx.AsyncClient() as client:
- response = await client.get(f"{API_BASE}/api/traces", timeout=2.0)
- return response.status_code in [200, 404]
- except Exception:
- return False
- async def test_list_traces(self):
- """测试:列出 Traces"""
- result = TestResult("GET /api/traces - 列出 Traces")
- start = datetime.now()
- try:
- async with httpx.AsyncClient() as client:
- response = await client.get(f"{API_BASE}/api/traces?limit=10")
- result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
- if response.status_code != 200:
- result.fail(f"HTTP {response.status_code}")
- else:
- data = response.json()
- if "traces" not in data:
- result.fail("响应缺少 'traces' 字段")
- else:
- trace_count = len(data["traces"])
- if trace_count == 0:
- result.fail("没有找到 Trace(请先运行 example 生成数据)")
- else:
- # 保存第一个 Trace ID 用于后续测试
- self.test_trace_id = data["traces"][0]["trace_id"]
- result.success(f"找到 {trace_count} 个 Trace")
- except Exception as e:
- result.fail(f"请求失败: {e}")
- self.results.append(result)
- print(result)
- async def test_get_trace(self):
- """测试:获取单个 Trace"""
- result = TestResult("GET /api/traces/{id} - 获取 Trace 元数据")
- start = datetime.now()
- if not self.test_trace_id:
- result.fail("跳过(没有可用的 Trace ID)")
- self.results.append(result)
- print(result)
- return
- try:
- async with httpx.AsyncClient() as client:
- response = await client.get(f"{API_BASE}/api/traces/{self.test_trace_id}")
- result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
- if response.status_code != 200:
- result.fail(f"HTTP {response.status_code}")
- else:
- data = response.json()
- required_fields = ["trace_id", "mode", "status", "total_steps"]
- missing = [f for f in required_fields if f not in data]
- if missing:
- result.fail(f"响应缺少字段: {missing}")
- else:
- result.success(f"Status: {data['status']}, Steps: {data['total_steps']}")
- except Exception as e:
- result.fail(f"请求失败: {e}")
- self.results.append(result)
- print(result)
- async def test_get_tree(self):
- """测试:获取完整 Step 树"""
- result = TestResult("GET /api/traces/{id}/tree - 获取 Step 树")
- start = datetime.now()
- if not self.test_trace_id:
- result.fail("跳过(没有可用的 Trace ID)")
- self.results.append(result)
- print(result)
- return
- try:
- async with httpx.AsyncClient() as client:
- response = await client.get(
- f"{API_BASE}/api/traces/{self.test_trace_id}/tree?view=compact"
- )
- result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
- if response.status_code != 200:
- result.fail(f"HTTP {response.status_code}")
- else:
- data = response.json()
- if "root_steps" not in data:
- result.fail("响应缺少 'root_steps' 字段")
- else:
- root_count = len(data["root_steps"])
- # 统计总 step 数(递归)
- def count_steps(nodes):
- count = len(nodes)
- for node in nodes:
- if "children" in node:
- count += count_steps(node["children"])
- return count
- total_steps = count_steps(data["root_steps"])
- result.success(f"根节点: {root_count}, 总 Steps: {total_steps}")
- except Exception as e:
- result.fail(f"请求失败: {e}")
- self.results.append(result)
- print(result)
- async def test_websocket_connect(self):
- """测试:WebSocket 连接"""
- result = TestResult("WebSocket - 连接和接收 connected 事件")
- start = datetime.now()
- if not self.test_trace_id:
- result.fail("跳过(没有可用的 Trace ID)")
- self.results.append(result)
- print(result)
- return
- try:
- uri = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id=0"
- async with websockets.connect(uri) as ws:
- # 接收第一条消息(connected)
- message = await asyncio.wait_for(ws.recv(), timeout=3.0)
- result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
- data = json.loads(message)
- if data.get("event") != "connected":
- result.fail(f"首条消息不是 'connected',而是: {data.get('event')}")
- elif "current_event_id" not in data:
- result.fail("connected 消息缺少 'current_event_id' 字段")
- else:
- event_id = data["current_event_id"]
- result.success(f"Event ID: {event_id}")
- except asyncio.TimeoutError:
- result.fail("连接超时(3秒)")
- except Exception as e:
- result.fail(f"连接失败: {e}")
- self.results.append(result)
- print(result)
- async def test_websocket_events(self):
- """测试:WebSocket 接收历史事件"""
- result = TestResult("WebSocket - 接收历史 step_added 事件")
- start = datetime.now()
- if not self.test_trace_id:
- result.fail("跳过(没有可用的 Trace ID)")
- self.results.append(result)
- print(result)
- return
- try:
- uri = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id=0"
- async with websockets.connect(uri) as ws:
- # 接收 connected
- await ws.recv()
- # 接收后续消息(应该是历史 step_added 事件)
- events_received = []
- try:
- while len(events_received) < 10: # 最多接收 10 条
- message = await asyncio.wait_for(ws.recv(), timeout=1.0)
- data = json.loads(message)
- events_received.append(data.get("event"))
- except asyncio.TimeoutError:
- pass # 没有更多消息了
- result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
- step_added_count = events_received.count("step_added")
- if step_added_count == 0:
- result.fail("没有收到 step_added 事件(Trace 可能为空)")
- else:
- other_events = [e for e in events_received if e != "step_added"]
- result.success(
- f"收到 {step_added_count} 个 step_added" +
- (f", {len(other_events)} 个其他事件" if other_events else "")
- )
- except Exception as e:
- result.fail(f"测试失败: {e}")
- self.results.append(result)
- print(result)
- async def test_websocket_reconnect(self):
- """测试:WebSocket 断线续传"""
- result = TestResult("WebSocket - 断线续传(since_event_id)")
- start = datetime.now()
- if not self.test_trace_id:
- result.fail("跳过(没有可用的 Trace ID)")
- self.results.append(result)
- print(result)
- return
- try:
- uri = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id=0"
- # 第一次连接,获取 event_id
- last_event_id = 0
- async with websockets.connect(uri) as ws:
- message = await ws.recv()
- data = json.loads(message)
- last_event_id = data.get("current_event_id", 0)
- if last_event_id == 0:
- result.fail("无法获取 event_id")
- self.results.append(result)
- print(result)
- return
- # 第二次连接,使用 since_event_id
- uri2 = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id={last_event_id}"
- async with websockets.connect(uri2) as ws:
- message = await ws.recv()
- data = json.loads(message)
- result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
- if data.get("event") != "connected":
- result.fail(f"重连后首条消息不是 'connected': {data.get('event')}")
- else:
- # 检查是否不再接收历史事件
- try:
- message = await asyncio.wait_for(ws.recv(), timeout=0.5)
- data2 = json.loads(message)
- # 如果收到消息,说明补发了历史(可能是新增的)
- result.success(f"重连成功,从 event_id={last_event_id} 继续")
- except asyncio.TimeoutError:
- # 没有收到消息,说明没有新增的事件(正常)
- result.success(f"重连成功,无新增事件(event_id={last_event_id})")
- except Exception as e:
- result.fail(f"测试失败: {e}")
- self.results.append(result)
- print(result)
- def print_summary(self):
- """打印测试摘要"""
- print("=" * 60)
- print("📊 测试摘要")
- print("=" * 60)
- passed = sum(1 for r in self.results if r.passed)
- total = len(self.results)
- failed = total - passed
- print(f"\n总计: {total} 个测试")
- print(f"✅ 通过: {passed}")
- if failed > 0:
- print(f"❌ 失败: {failed}")
- print()
- if failed > 0:
- print("失败的测试:")
- for r in self.results:
- if not r.passed:
- print(f" - {r.name}")
- if r.message:
- print(f" {r.message}")
- print()
- if passed == total:
- print("🎉 所有测试通过!")
- else:
- print("⚠️ 部分测试失败,请检查上述错误")
- print()
- async def main():
- """主函数"""
- tester = APITester()
- success = await tester.run_all_tests()
- sys.exit(0 if success else 1)
- if __name__ == "__main__":
- try:
- asyncio.run(main())
- except KeyboardInterrupt:
- print("\n\n⚠️ 测试被中断")
- sys.exit(1)
|