test_api.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. #!/usr/bin/env python3
  2. """
  3. Agent Execution API 自动化测试
  4. 测试 REST API 和 WebSocket 接口是否正常工作
  5. 运行方式:
  6. python3 frontend/test_api.py
  7. 要求:
  8. - API Server 已启动(python3 api_server.py)
  9. - 有至少一个 Trace 数据(运行 examples/feature_extract/run.py)
  10. """
  11. import asyncio
  12. import json
  13. import sys
  14. from typing import Optional
  15. from datetime import datetime
  16. try:
  17. import httpx
  18. import websockets
  19. except ImportError:
  20. print("❌ 缺少依赖,请安装:")
  21. print(" pip install httpx websockets")
  22. sys.exit(1)
  23. # 配置
  24. API_BASE = "http://43.106.118.91:8000"
  25. WS_BASE = "ws://43.106.118.91:8000"
  26. class TestResult:
  27. """测试结果"""
  28. def __init__(self, name: str):
  29. self.name = name
  30. self.passed = False
  31. self.message = ""
  32. self.duration_ms = 0
  33. def success(self, message: str = ""):
  34. self.passed = True
  35. self.message = message
  36. def fail(self, message: str):
  37. self.passed = False
  38. self.message = message
  39. def __str__(self):
  40. icon = "✅" if self.passed else "❌"
  41. return f"{icon} {self.name} ({self.duration_ms}ms)" + (f"\n {self.message}" if self.message else "")
  42. class APITester:
  43. """API 测试器"""
  44. def __init__(self):
  45. self.results = []
  46. self.test_trace_id: Optional[str] = None
  47. async def run_all_tests(self):
  48. """运行所有测试"""
  49. print("🧪 Agent Execution API 自动化测试")
  50. print("=" * 60)
  51. print()
  52. # 检查 API Server
  53. if not await self.check_server():
  54. print("❌ API Server 未启动,请先运行:")
  55. print(" python3 api_server.py")
  56. return False
  57. print("✅ API Server 已启动\n")
  58. # REST API 测试
  59. print("📡 测试 REST API")
  60. print("-" * 60)
  61. await self.test_list_traces()
  62. await self.test_get_trace()
  63. await self.test_get_messages()
  64. print()
  65. # WebSocket 测试
  66. print("⚡ 测试 WebSocket")
  67. print("-" * 60)
  68. await self.test_websocket_connect()
  69. await self.test_websocket_events()
  70. await self.test_websocket_reconnect()
  71. print()
  72. # 汇总结果
  73. self.print_summary()
  74. return all(r.passed for r in self.results)
  75. async def check_server(self) -> bool:
  76. """检查 API Server 是否运行"""
  77. try:
  78. async with httpx.AsyncClient() as client:
  79. response = await client.get(f"{API_BASE}/api/traces", timeout=2.0)
  80. return response.status_code in [200, 404]
  81. except Exception:
  82. return False
  83. async def test_list_traces(self):
  84. """测试:列出 Traces"""
  85. result = TestResult("GET /api/traces - 列出 Traces")
  86. start = datetime.now()
  87. try:
  88. async with httpx.AsyncClient() as client:
  89. response = await client.get(f"{API_BASE}/api/traces?limit=10")
  90. result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
  91. if response.status_code != 200:
  92. result.fail(f"HTTP {response.status_code}")
  93. else:
  94. data = response.json()
  95. if "traces" not in data:
  96. result.fail("响应缺少 'traces' 字段")
  97. else:
  98. trace_count = len(data["traces"])
  99. if trace_count == 0:
  100. result.fail("没有找到 Trace(请先运行 example 生成数据)")
  101. else:
  102. # 保存第一个 Trace ID 用于后续测试
  103. self.test_trace_id = data["traces"][0]["trace_id"]
  104. result.success(f"找到 {trace_count} 个 Trace")
  105. except Exception as e:
  106. result.fail(f"请求失败: {e}")
  107. self.results.append(result)
  108. print(result)
  109. async def test_get_trace(self):
  110. """测试:获取单个 Trace"""
  111. result = TestResult("GET /api/traces/{id} - 获取 Trace 元数据")
  112. start = datetime.now()
  113. if not self.test_trace_id:
  114. result.fail("跳过(没有可用的 Trace ID)")
  115. self.results.append(result)
  116. print(result)
  117. return
  118. try:
  119. async with httpx.AsyncClient() as client:
  120. response = await client.get(f"{API_BASE}/api/traces/{self.test_trace_id}")
  121. result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
  122. if response.status_code != 200:
  123. result.fail(f"HTTP {response.status_code}")
  124. else:
  125. data = response.json()
  126. # API 返回格式:{trace: {...}, goal_tree: {...}, sub_traces: {...}}
  127. if "trace" not in data:
  128. result.fail("响应缺少 'trace' 字段")
  129. else:
  130. trace = data["trace"]
  131. required_fields = ["trace_id", "mode", "status", "total_messages"]
  132. missing = [f for f in required_fields if f not in trace]
  133. if missing:
  134. result.fail(f"trace 缺少字段: {missing}")
  135. else:
  136. goal_count = len(data.get("goal_tree", {}).get("goals", []))
  137. result.success(f"Status: {trace['status']}, Messages: {trace['total_messages']}, Goals: {goal_count}")
  138. except Exception as e:
  139. result.fail(f"请求失败: {e}")
  140. self.results.append(result)
  141. print(result)
  142. async def test_get_messages(self):
  143. """测试:获取 Messages"""
  144. result = TestResult("GET /api/traces/{id}/messages - 获取消息列表")
  145. start = datetime.now()
  146. if not self.test_trace_id:
  147. result.fail("跳过(没有可用的 Trace ID)")
  148. self.results.append(result)
  149. print(result)
  150. return
  151. try:
  152. async with httpx.AsyncClient() as client:
  153. response = await client.get(
  154. f"{API_BASE}/api/traces/{self.test_trace_id}/messages"
  155. )
  156. result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
  157. if response.status_code != 200:
  158. result.fail(f"HTTP {response.status_code}")
  159. else:
  160. data = response.json()
  161. if "messages" not in data:
  162. result.fail("响应缺少 'messages' 字段")
  163. else:
  164. message_count = len(data["messages"])
  165. if message_count == 0:
  166. result.fail("没有找到 Message")
  167. else:
  168. # 统计角色分布
  169. roles = {}
  170. for msg in data["messages"]:
  171. role = msg.get("role", "unknown")
  172. roles[role] = roles.get(role, 0) + 1
  173. role_str = ", ".join([f"{role}: {count}" for role, count in roles.items()])
  174. result.success(f"总计 {message_count} 条 ({role_str})")
  175. except Exception as e:
  176. result.fail(f"请求失败: {e}")
  177. self.results.append(result)
  178. print(result)
  179. async def test_websocket_connect(self):
  180. """测试:WebSocket 连接"""
  181. result = TestResult("WebSocket - 连接和接收 connected 事件")
  182. start = datetime.now()
  183. if not self.test_trace_id:
  184. result.fail("跳过(没有可用的 Trace ID)")
  185. self.results.append(result)
  186. print(result)
  187. return
  188. try:
  189. uri = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id=0"
  190. async with websockets.connect(uri) as ws:
  191. # 接收第一条消息(connected)
  192. message = await asyncio.wait_for(ws.recv(), timeout=3.0)
  193. result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
  194. data = json.loads(message)
  195. if data.get("event") != "connected":
  196. result.fail(f"首条消息不是 'connected',而是: {data.get('event')}")
  197. elif "current_event_id" not in data:
  198. result.fail("connected 消息缺少 'current_event_id' 字段")
  199. else:
  200. event_id = data["current_event_id"]
  201. result.success(f"Event ID: {event_id}")
  202. except asyncio.TimeoutError:
  203. result.fail("连接超时(3秒)")
  204. except Exception as e:
  205. result.fail(f"连接失败: {e}")
  206. self.results.append(result)
  207. print(result)
  208. async def test_websocket_events(self):
  209. """测试:WebSocket 接收历史事件"""
  210. result = TestResult("WebSocket - 接收历史事件")
  211. start = datetime.now()
  212. if not self.test_trace_id:
  213. result.fail("跳过(没有可用的 Trace ID)")
  214. self.results.append(result)
  215. print(result)
  216. return
  217. try:
  218. uri = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id=0"
  219. async with websockets.connect(uri) as ws:
  220. # 接收 connected
  221. await ws.recv()
  222. # 接收后续消息(应该是历史事件:message_added, goal_added 等)
  223. events_received = []
  224. try:
  225. while len(events_received) < 10: # 最多接收 10 条
  226. message = await asyncio.wait_for(ws.recv(), timeout=1.0)
  227. data = json.loads(message)
  228. events_received.append(data.get("event"))
  229. except asyncio.TimeoutError:
  230. pass # 没有更多消息了
  231. result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
  232. # 统计事件类型
  233. event_counts = {}
  234. for event in events_received:
  235. event_counts[event] = event_counts.get(event, 0) + 1
  236. if len(events_received) == 0:
  237. result.fail("没有收到任何事件(Trace 可能为空)")
  238. else:
  239. event_summary = ", ".join([f"{evt}: {cnt}" for evt, cnt in event_counts.items()])
  240. result.success(f"收到 {len(events_received)} 个事件 ({event_summary})")
  241. except Exception as e:
  242. result.fail(f"测试失败: {e}")
  243. self.results.append(result)
  244. print(result)
  245. async def test_websocket_reconnect(self):
  246. """测试:WebSocket 断线续传"""
  247. result = TestResult("WebSocket - 断线续传(since_event_id)")
  248. start = datetime.now()
  249. if not self.test_trace_id:
  250. result.fail("跳过(没有可用的 Trace ID)")
  251. self.results.append(result)
  252. print(result)
  253. return
  254. try:
  255. uri = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id=0"
  256. # 第一次连接,获取 event_id
  257. last_event_id = 0
  258. async with websockets.connect(uri) as ws:
  259. message = await ws.recv()
  260. data = json.loads(message)
  261. last_event_id = data.get("current_event_id", 0)
  262. if last_event_id == 0:
  263. result.fail("无法获取 event_id")
  264. self.results.append(result)
  265. print(result)
  266. return
  267. # 第二次连接,使用 since_event_id
  268. uri2 = f"{WS_BASE}/api/traces/{self.test_trace_id}/watch?since_event_id={last_event_id}"
  269. async with websockets.connect(uri2) as ws:
  270. message = await ws.recv()
  271. data = json.loads(message)
  272. result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
  273. if data.get("event") != "connected":
  274. result.fail(f"重连后首条消息不是 'connected': {data.get('event')}")
  275. else:
  276. # 检查是否不再接收历史事件
  277. try:
  278. message = await asyncio.wait_for(ws.recv(), timeout=0.5)
  279. data2 = json.loads(message)
  280. # 如果收到消息,说明补发了历史(可能是新增的)
  281. result.success(f"重连成功,从 event_id={last_event_id} 继续")
  282. except asyncio.TimeoutError:
  283. # 没有收到消息,说明没有新增的事件(正常)
  284. result.success(f"重连成功,无新增事件(event_id={last_event_id})")
  285. except Exception as e:
  286. result.fail(f"测试失败: {e}")
  287. self.results.append(result)
  288. print(result)
  289. def print_summary(self):
  290. """打印测试摘要"""
  291. print("=" * 60)
  292. print("📊 测试摘要")
  293. print("=" * 60)
  294. passed = sum(1 for r in self.results if r.passed)
  295. total = len(self.results)
  296. failed = total - passed
  297. print(f"\n总计: {total} 个测试")
  298. print(f"✅ 通过: {passed}")
  299. if failed > 0:
  300. print(f"❌ 失败: {failed}")
  301. print()
  302. if failed > 0:
  303. print("失败的测试:")
  304. for r in self.results:
  305. if not r.passed:
  306. print(f" - {r.name}")
  307. if r.message:
  308. print(f" {r.message}")
  309. print()
  310. if passed == total:
  311. print("🎉 所有测试通过!")
  312. else:
  313. print("⚠️ 部分测试失败,请检查上述错误")
  314. print()
  315. async def main():
  316. """主函数"""
  317. tester = APITester()
  318. success = await tester.run_all_tests()
  319. sys.exit(0 if success else 1)
  320. if __name__ == "__main__":
  321. try:
  322. asyncio.run(main())
  323. except KeyboardInterrupt:
  324. print("\n\n⚠️ 测试被中断")
  325. sys.exit(1)