test_api.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. #!/usr/bin/env python3
  2. """
  3. Agent Execution API 自动化测试
  4. 测试 REST API 和 WebSocket 接口是否正常工作
  5. 运行方式:
  6. python3 frontend/test_api.py
  7. 要求:
  8. - API Server 已启动(python3 api_server.py)
  9. - 有至少一个 Trace 数据(运行 examples/feature_extract/run.py)
  10. """
  11. import asyncio
  12. import json
  13. import sys
  14. from typing import Optional
  15. from datetime import datetime
  16. try:
  17. import httpx
  18. import websockets
  19. except ImportError:
  20. print("❌ 缺少依赖,请安装:")
  21. print(" pip install httpx websockets")
  22. sys.exit(1)
  23. # 配置
  24. API_BASE = "http://localhost:8000"
  25. WS_BASE = "ws://localhost:8000"
  26. class TestResult:
  27. """测试结果"""
  28. def __init__(self, name: str):
  29. self.name = name
  30. self.passed = False
  31. self.message = ""
  32. self.duration_ms = 0
  33. def success(self, message: str = ""):
  34. self.passed = True
  35. self.message = message
  36. def fail(self, message: str):
  37. self.passed = False
  38. self.message = message
  39. def __str__(self):
  40. icon = "✅" if self.passed else "❌"
  41. return f"{icon} {self.name} ({self.duration_ms}ms)" + (f"\n {self.message}" if self.message else "")
  42. class APITester:
  43. """API 测试器"""
  44. def __init__(self):
  45. self.results = []
  46. self.test_trace_id: Optional[str] = None
  47. async def run_all_tests(self):
  48. """运行所有测试"""
  49. print("🧪 Agent Execution API 自动化测试")
  50. print("=" * 60)
  51. print()
  52. # 检查 API Server
  53. if not await self.check_server():
  54. print("❌ API Server 未启动,请先运行:")
  55. print(" python3 api_server.py")
  56. return False
  57. print("✅ API Server 已启动\n")
  58. # REST API 测试
  59. print("📡 测试 REST API")
  60. print("-" * 60)
  61. await self.test_list_traces()
  62. await self.test_get_trace()
  63. await self.test_get_tree()
  64. print()
  65. # WebSocket 测试
  66. print("⚡ 测试 WebSocket")
  67. print("-" * 60)
  68. await self.test_websocket_connect()
  69. await self.test_websocket_events()
  70. await self.test_websocket_reconnect()
  71. print()
  72. # 汇总结果
  73. self.print_summary()
  74. return all(r.passed for r in self.results)
  75. async def check_server(self) -> bool:
  76. """检查 API Server 是否运行"""
  77. try:
  78. async with httpx.AsyncClient() as client:
  79. response = await client.get(f"{API_BASE}/api/traces", timeout=2.0)
  80. return response.status_code in [200, 404]
  81. except Exception:
  82. return False
  83. async def test_list_traces(self):
  84. """测试:列出 Traces"""
  85. result = TestResult("GET /api/traces - 列出 Traces")
  86. start = datetime.now()
  87. try:
  88. async with httpx.AsyncClient() as client:
  89. response = await client.get(f"{API_BASE}/api/traces?limit=10")
  90. result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
  91. if response.status_code != 200:
  92. result.fail(f"HTTP {response.status_code}")
  93. else:
  94. data = response.json()
  95. if "traces" not in data:
  96. result.fail("响应缺少 'traces' 字段")
  97. else:
  98. trace_count = len(data["traces"])
  99. if trace_count == 0:
  100. result.fail("没有找到 Trace(请先运行 example 生成数据)")
  101. else:
  102. # 保存第一个 Trace ID 用于后续测试
  103. self.test_trace_id = data["traces"][0]["trace_id"]
  104. result.success(f"找到 {trace_count} 个 Trace")
  105. except Exception as e:
  106. result.fail(f"请求失败: {e}")
  107. self.results.append(result)
  108. print(result)
  109. async def test_get_trace(self):
  110. """测试:获取单个 Trace"""
  111. result = TestResult("GET /api/traces/{id} - 获取 Trace 元数据")
  112. start = datetime.now()
  113. if not self.test_trace_id:
  114. result.fail("跳过(没有可用的 Trace ID)")
  115. self.results.append(result)
  116. print(result)
  117. return
  118. try:
  119. async with httpx.AsyncClient() as client:
  120. response = await client.get(f"{API_BASE}/api/traces/{self.test_trace_id}")
  121. result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
  122. if response.status_code != 200:
  123. result.fail(f"HTTP {response.status_code}")
  124. else:
  125. data = response.json()
  126. required_fields = ["trace_id", "mode", "status", "total_steps"]
  127. missing = [f for f in required_fields if f not in data]
  128. if missing:
  129. result.fail(f"响应缺少字段: {missing}")
  130. else:
  131. result.success(f"Status: {data['status']}, Steps: {data['total_steps']}")
  132. except Exception as e:
  133. result.fail(f"请求失败: {e}")
  134. self.results.append(result)
  135. print(result)
  136. async def test_get_tree(self):
  137. """测试:获取完整 Step 树"""
  138. result = TestResult("GET /api/traces/{id}/tree - 获取 Step 树")
  139. start = datetime.now()
  140. if not self.test_trace_id:
  141. result.fail("跳过(没有可用的 Trace ID)")
  142. self.results.append(result)
  143. print(result)
  144. return
  145. try:
  146. async with httpx.AsyncClient() as client:
  147. response = await client.get(
  148. f"{API_BASE}/api/traces/{self.test_trace_id}/tree?view=compact"
  149. )
  150. result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
  151. if response.status_code != 200:
  152. result.fail(f"HTTP {response.status_code}")
  153. else:
  154. data = response.json()
  155. if "root_steps" not in data:
  156. result.fail("响应缺少 'root_steps' 字段")
  157. else:
  158. root_count = len(data["root_steps"])
  159. # 统计总 step 数(递归)
  160. def count_steps(nodes):
  161. count = len(nodes)
  162. for node in nodes:
  163. if "children" in node:
  164. count += count_steps(node["children"])
  165. return count
  166. total_steps = count_steps(data["root_steps"])
  167. result.success(f"根节点: {root_count}, 总 Steps: {total_steps}")
  168. except Exception as e:
  169. result.fail(f"请求失败: {e}")
  170. self.results.append(result)
  171. print(result)
  172. async def test_websocket_connect(self):
  173. """测试:WebSocket 连接"""
  174. result = TestResult("WebSocket - 连接和接收 connected 事件")
  175. start = datetime.now()
  176. if not self.test_trace_id:
  177. result.fail("跳过(没有可用的 Trace ID)")
  178. self.results.append(result)
  179. print(result)
  180. return
  181. try:
  182. uri = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id=0"
  183. async with websockets.connect(uri) as ws:
  184. # 接收第一条消息(connected)
  185. message = await asyncio.wait_for(ws.recv(), timeout=3.0)
  186. result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
  187. data = json.loads(message)
  188. if data.get("event") != "connected":
  189. result.fail(f"首条消息不是 'connected',而是: {data.get('event')}")
  190. elif "current_event_id" not in data:
  191. result.fail("connected 消息缺少 'current_event_id' 字段")
  192. else:
  193. event_id = data["current_event_id"]
  194. result.success(f"Event ID: {event_id}")
  195. except asyncio.TimeoutError:
  196. result.fail("连接超时(3秒)")
  197. except Exception as e:
  198. result.fail(f"连接失败: {e}")
  199. self.results.append(result)
  200. print(result)
  201. async def test_websocket_events(self):
  202. """测试:WebSocket 接收历史事件"""
  203. result = TestResult("WebSocket - 接收历史 step_added 事件")
  204. start = datetime.now()
  205. if not self.test_trace_id:
  206. result.fail("跳过(没有可用的 Trace ID)")
  207. self.results.append(result)
  208. print(result)
  209. return
  210. try:
  211. uri = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id=0"
  212. async with websockets.connect(uri) as ws:
  213. # 接收 connected
  214. await ws.recv()
  215. # 接收后续消息(应该是历史 step_added 事件)
  216. events_received = []
  217. try:
  218. while len(events_received) < 10: # 最多接收 10 条
  219. message = await asyncio.wait_for(ws.recv(), timeout=1.0)
  220. data = json.loads(message)
  221. events_received.append(data.get("event"))
  222. except asyncio.TimeoutError:
  223. pass # 没有更多消息了
  224. result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
  225. step_added_count = events_received.count("step_added")
  226. if step_added_count == 0:
  227. result.fail("没有收到 step_added 事件(Trace 可能为空)")
  228. else:
  229. other_events = [e for e in events_received if e != "step_added"]
  230. result.success(
  231. f"收到 {step_added_count} 个 step_added" +
  232. (f", {len(other_events)} 个其他事件" if other_events else "")
  233. )
  234. except Exception as e:
  235. result.fail(f"测试失败: {e}")
  236. self.results.append(result)
  237. print(result)
  238. async def test_websocket_reconnect(self):
  239. """测试:WebSocket 断线续传"""
  240. result = TestResult("WebSocket - 断线续传(since_event_id)")
  241. start = datetime.now()
  242. if not self.test_trace_id:
  243. result.fail("跳过(没有可用的 Trace ID)")
  244. self.results.append(result)
  245. print(result)
  246. return
  247. try:
  248. uri = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id=0"
  249. # 第一次连接,获取 event_id
  250. last_event_id = 0
  251. async with websockets.connect(uri) as ws:
  252. message = await ws.recv()
  253. data = json.loads(message)
  254. last_event_id = data.get("current_event_id", 0)
  255. if last_event_id == 0:
  256. result.fail("无法获取 event_id")
  257. self.results.append(result)
  258. print(result)
  259. return
  260. # 第二次连接,使用 since_event_id
  261. uri2 = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id={last_event_id}"
  262. async with websockets.connect(uri2) as ws:
  263. message = await ws.recv()
  264. data = json.loads(message)
  265. result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
  266. if data.get("event") != "connected":
  267. result.fail(f"重连后首条消息不是 'connected': {data.get('event')}")
  268. else:
  269. # 检查是否不再接收历史事件
  270. try:
  271. message = await asyncio.wait_for(ws.recv(), timeout=0.5)
  272. data2 = json.loads(message)
  273. # 如果收到消息,说明补发了历史(可能是新增的)
  274. result.success(f"重连成功,从 event_id={last_event_id} 继续")
  275. except asyncio.TimeoutError:
  276. # 没有收到消息,说明没有新增的事件(正常)
  277. result.success(f"重连成功,无新增事件(event_id={last_event_id})")
  278. except Exception as e:
  279. result.fail(f"测试失败: {e}")
  280. self.results.append(result)
  281. print(result)
  282. def print_summary(self):
  283. """打印测试摘要"""
  284. print("=" * 60)
  285. print("📊 测试摘要")
  286. print("=" * 60)
  287. passed = sum(1 for r in self.results if r.passed)
  288. total = len(self.results)
  289. failed = total - passed
  290. print(f"\n总计: {total} 个测试")
  291. print(f"✅ 通过: {passed}")
  292. if failed > 0:
  293. print(f"❌ 失败: {failed}")
  294. print()
  295. if failed > 0:
  296. print("失败的测试:")
  297. for r in self.results:
  298. if not r.passed:
  299. print(f" - {r.name}")
  300. if r.message:
  301. print(f" {r.message}")
  302. print()
  303. if passed == total:
  304. print("🎉 所有测试通过!")
  305. else:
  306. print("⚠️ 部分测试失败,请检查上述错误")
  307. print()
  308. async def main():
  309. """主函数"""
  310. tester = APITester()
  311. success = await tester.run_all_tests()
  312. sys.exit(0 if success else 1)
  313. if __name__ == "__main__":
  314. try:
  315. asyncio.run(main())
  316. except KeyboardInterrupt:
  317. print("\n\n⚠️ 测试被中断")
  318. sys.exit(1)