#!/usr/bin/env python3 """ Agent Execution API 自动化测试 测试 REST API 和 WebSocket 接口是否正常工作 运行方式: python3 frontend/test_api.py 要求: - API Server 已启动(python3 api_server.py) - 有至少一个 Trace 数据(运行 examples/feature_extract/run.py) """ import asyncio import json import sys from typing import Optional from datetime import datetime try: import httpx import websockets except ImportError: print("❌ 缺少依赖,请安装:") print(" pip install httpx websockets") sys.exit(1) # 配置 API_BASE = "http://localhost:8000" WS_BASE = "ws://localhost:8000" class TestResult: """测试结果""" def __init__(self, name: str): self.name = name self.passed = False self.message = "" self.duration_ms = 0 def success(self, message: str = ""): self.passed = True self.message = message def fail(self, message: str): self.passed = False self.message = message def __str__(self): icon = "✅" if self.passed else "❌" return f"{icon} {self.name} ({self.duration_ms}ms)" + (f"\n {self.message}" if self.message else "") class APITester: """API 测试器""" def __init__(self): self.results = [] self.test_trace_id: Optional[str] = None async def run_all_tests(self): """运行所有测试""" print("🧪 Agent Execution API 自动化测试") print("=" * 60) print() # 检查 API Server if not await self.check_server(): print("❌ API Server 未启动,请先运行:") print(" python3 api_server.py") return False print("✅ API Server 已启动\n") # REST API 测试 print("📡 测试 REST API") print("-" * 60) await self.test_list_traces() await self.test_get_trace() await self.test_get_messages() print() # WebSocket 测试 print("⚡ 测试 WebSocket") print("-" * 60) await self.test_websocket_connect() await self.test_websocket_events() await self.test_websocket_reconnect() print() # 汇总结果 self.print_summary() return all(r.passed for r in self.results) async def check_server(self) -> bool: """检查 API Server 是否运行""" try: async with httpx.AsyncClient() as client: response = await client.get(f"{API_BASE}/api/traces", timeout=2.0) return response.status_code in [200, 404] except Exception: return False async def test_list_traces(self): """测试:列出 Traces""" result = TestResult("GET /api/traces - 列出 Traces") start = datetime.now() try: async with httpx.AsyncClient() as client: response = await client.get(f"{API_BASE}/api/traces?limit=10") result.duration_ms = int((datetime.now() - start).total_seconds() * 1000) if response.status_code != 200: result.fail(f"HTTP {response.status_code}") else: data = response.json() if "traces" not in data: result.fail("响应缺少 'traces' 字段") else: trace_count = len(data["traces"]) if trace_count == 0: result.fail("没有找到 Trace(请先运行 example 生成数据)") else: # 保存第一个 Trace ID 用于后续测试 self.test_trace_id = data["traces"][0]["trace_id"] result.success(f"找到 {trace_count} 个 Trace") except Exception as e: result.fail(f"请求失败: {e}") self.results.append(result) print(result) async def test_get_trace(self): """测试:获取单个 Trace""" result = TestResult("GET /api/traces/{id} - 获取 Trace 元数据") start = datetime.now() if not self.test_trace_id: result.fail("跳过(没有可用的 Trace ID)") self.results.append(result) print(result) return try: async with httpx.AsyncClient() as client: response = await client.get(f"{API_BASE}/api/traces/{self.test_trace_id}") result.duration_ms = int((datetime.now() - start).total_seconds() * 1000) if response.status_code != 200: result.fail(f"HTTP {response.status_code}") else: data = response.json() # API 返回格式:{trace: {...}, goal_tree: {...}, sub_traces: {...}} if "trace" not in data: result.fail("响应缺少 'trace' 字段") else: trace = data["trace"] required_fields = ["trace_id", "mode", "status", "total_messages"] missing = [f for f in required_fields if f not in trace] if missing: result.fail(f"trace 缺少字段: {missing}") else: goal_count = len(data.get("goal_tree", {}).get("goals", [])) result.success(f"Status: {trace['status']}, Messages: {trace['total_messages']}, Goals: {goal_count}") except Exception as e: result.fail(f"请求失败: {e}") self.results.append(result) print(result) async def test_get_messages(self): """测试:获取 Messages""" result = TestResult("GET /api/traces/{id}/messages - 获取消息列表") start = datetime.now() if not self.test_trace_id: result.fail("跳过(没有可用的 Trace ID)") self.results.append(result) print(result) return try: async with httpx.AsyncClient() as client: response = await client.get( f"{API_BASE}/api/traces/{self.test_trace_id}/messages" ) result.duration_ms = int((datetime.now() - start).total_seconds() * 1000) if response.status_code != 200: result.fail(f"HTTP {response.status_code}") else: data = response.json() if "messages" not in data: result.fail("响应缺少 'messages' 字段") else: message_count = len(data["messages"]) if message_count == 0: result.fail("没有找到 Message") else: # 统计角色分布 roles = {} for msg in data["messages"]: role = msg.get("role", "unknown") roles[role] = roles.get(role, 0) + 1 role_str = ", ".join([f"{role}: {count}" for role, count in roles.items()]) result.success(f"总计 {message_count} 条 ({role_str})") except Exception as e: result.fail(f"请求失败: {e}") self.results.append(result) print(result) async def test_websocket_connect(self): """测试:WebSocket 连接""" result = TestResult("WebSocket - 连接和接收 connected 事件") start = datetime.now() if not self.test_trace_id: result.fail("跳过(没有可用的 Trace ID)") self.results.append(result) print(result) return try: uri = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id=0" async with websockets.connect(uri) as ws: # 接收第一条消息(connected) message = await asyncio.wait_for(ws.recv(), timeout=3.0) result.duration_ms = int((datetime.now() - start).total_seconds() * 1000) data = json.loads(message) if data.get("event") != "connected": result.fail(f"首条消息不是 'connected',而是: {data.get('event')}") elif "current_event_id" not in data: result.fail("connected 消息缺少 'current_event_id' 字段") else: event_id = data["current_event_id"] result.success(f"Event ID: {event_id}") except asyncio.TimeoutError: result.fail("连接超时(3秒)") except Exception as e: result.fail(f"连接失败: {e}") self.results.append(result) print(result) async def test_websocket_events(self): """测试:WebSocket 接收历史事件""" result = TestResult("WebSocket - 接收历史事件") start = datetime.now() if not self.test_trace_id: result.fail("跳过(没有可用的 Trace ID)") self.results.append(result) print(result) return try: uri = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id=0" async with websockets.connect(uri) as ws: # 接收 connected await ws.recv() # 接收后续消息(应该是历史事件:message_added, goal_added 等) events_received = [] try: while len(events_received) < 10: # 最多接收 10 条 message = await asyncio.wait_for(ws.recv(), timeout=1.0) data = json.loads(message) events_received.append(data.get("event")) except asyncio.TimeoutError: pass # 没有更多消息了 result.duration_ms = int((datetime.now() - start).total_seconds() * 1000) # 统计事件类型 event_counts = {} for event in events_received: event_counts[event] = event_counts.get(event, 0) + 1 if len(events_received) == 0: result.fail("没有收到任何事件(Trace 可能为空)") else: event_summary = ", ".join([f"{evt}: {cnt}" for evt, cnt in event_counts.items()]) result.success(f"收到 {len(events_received)} 个事件 ({event_summary})") except Exception as e: result.fail(f"测试失败: {e}") self.results.append(result) print(result) async def test_websocket_reconnect(self): """测试:WebSocket 断线续传""" result = TestResult("WebSocket - 断线续传(since_event_id)") start = datetime.now() if not self.test_trace_id: result.fail("跳过(没有可用的 Trace ID)") self.results.append(result) print(result) return try: uri = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id=0" # 第一次连接,获取 event_id last_event_id = 0 async with websockets.connect(uri) as ws: message = await ws.recv() data = json.loads(message) last_event_id = data.get("current_event_id", 0) if last_event_id == 0: result.fail("无法获取 event_id") self.results.append(result) print(result) return # 第二次连接,使用 since_event_id uri2 = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id={last_event_id}" async with websockets.connect(uri2) as ws: message = await ws.recv() data = json.loads(message) result.duration_ms = int((datetime.now() - start).total_seconds() * 1000) if data.get("event") != "connected": result.fail(f"重连后首条消息不是 'connected': {data.get('event')}") else: # 检查是否不再接收历史事件 try: message = await asyncio.wait_for(ws.recv(), timeout=0.5) data2 = json.loads(message) # 如果收到消息,说明补发了历史(可能是新增的) result.success(f"重连成功,从 event_id={last_event_id} 继续") except asyncio.TimeoutError: # 没有收到消息,说明没有新增的事件(正常) result.success(f"重连成功,无新增事件(event_id={last_event_id})") except Exception as e: result.fail(f"测试失败: {e}") self.results.append(result) print(result) def print_summary(self): """打印测试摘要""" print("=" * 60) print("📊 测试摘要") print("=" * 60) passed = sum(1 for r in self.results if r.passed) total = len(self.results) failed = total - passed print(f"\n总计: {total} 个测试") print(f"✅ 通过: {passed}") if failed > 0: print(f"❌ 失败: {failed}") print() if failed > 0: print("失败的测试:") for r in self.results: if not r.passed: print(f" - {r.name}") if r.message: print(f" {r.message}") print() if passed == total: print("🎉 所有测试通过!") else: print("⚠️ 部分测试失败,请检查上述错误") print() async def main(): """主函数""" tester = APITester() success = await tester.run_all_tests() sys.exit(0 if success else 1) if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: print("\n\n⚠️ 测试被中断") sys.exit(1)