Talegorithm 1 месяц назад
Родитель
Сommit
d79be2adf9
1 измененных файлов с 41 добавлено и 33 удалено
  1. 41 33
      frontend/test_api.py

+ 41 - 33
frontend/test_api.py

@@ -79,7 +79,7 @@ class APITester:
         print("-" * 60)
         await self.test_list_traces()
         await self.test_get_trace()
-        await self.test_get_tree()
+        await self.test_get_messages()
         print()
 
         # WebSocket 测试
@@ -154,21 +154,27 @@ class APITester:
                     result.fail(f"HTTP {response.status_code}")
                 else:
                     data = response.json()
-                    required_fields = ["trace_id", "mode", "status", "total_steps"]
-                    missing = [f for f in required_fields if f not in data]
-                    if missing:
-                        result.fail(f"响应缺少字段: {missing}")
+                    # API 返回格式:{trace: {...}, goal_tree: {...}, sub_traces: {...}}
+                    if "trace" not in data:
+                        result.fail("响应缺少 'trace' 字段")
                     else:
-                        result.success(f"Status: {data['status']}, Steps: {data['total_steps']}")
+                        trace = data["trace"]
+                        required_fields = ["trace_id", "mode", "status", "total_messages"]
+                        missing = [f for f in required_fields if f not in trace]
+                        if missing:
+                            result.fail(f"trace 缺少字段: {missing}")
+                        else:
+                            goal_count = len(data.get("goal_tree", {}).get("goals", []))
+                            result.success(f"Status: {trace['status']}, Messages: {trace['total_messages']}, Goals: {goal_count}")
         except Exception as e:
             result.fail(f"请求失败: {e}")
 
         self.results.append(result)
         print(result)
 
-    async def test_get_tree(self):
-        """测试:获取完整 Step 树"""
-        result = TestResult("GET /api/traces/{id}/tree - 获取 Step 树")
+    async def test_get_messages(self):
+        """测试:获取 Messages"""
+        result = TestResult("GET /api/traces/{id}/messages - 获取消息列表")
         start = datetime.now()
 
         if not self.test_trace_id:
@@ -180,7 +186,7 @@ class APITester:
         try:
             async with httpx.AsyncClient() as client:
                 response = await client.get(
-                    f"{API_BASE}/api/traces/{self.test_trace_id}/tree?view=compact"
+                    f"{API_BASE}/api/traces/{self.test_trace_id}/messages"
                 )
                 result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
 
@@ -188,20 +194,21 @@ class APITester:
                     result.fail(f"HTTP {response.status_code}")
                 else:
                     data = response.json()
-                    if "root_steps" not in data:
-                        result.fail("响应缺少 'root_steps' 字段")
+                    if "messages" not in data:
+                        result.fail("响应缺少 'messages' 字段")
                     else:
-                        root_count = len(data["root_steps"])
-                        # 统计总 step 数(递归)
-                        def count_steps(nodes):
-                            count = len(nodes)
-                            for node in nodes:
-                                if "children" in node:
-                                    count += count_steps(node["children"])
-                            return count
-
-                        total_steps = count_steps(data["root_steps"])
-                        result.success(f"根节点: {root_count}, 总 Steps: {total_steps}")
+                        message_count = len(data["messages"])
+                        if message_count == 0:
+                            result.fail("没有找到 Message")
+                        else:
+                            # 统计角色分布
+                            roles = {}
+                            for msg in data["messages"]:
+                                role = msg.get("role", "unknown")
+                                roles[role] = roles.get(role, 0) + 1
+
+                            role_str = ", ".join([f"{role}: {count}" for role, count in roles.items()])
+                            result.success(f"总计 {message_count} 条 ({role_str})")
         except Exception as e:
             result.fail(f"请求失败: {e}")
 
@@ -244,7 +251,7 @@ class APITester:
 
     async def test_websocket_events(self):
         """测试:WebSocket 接收历史事件"""
-        result = TestResult("WebSocket - 接收历史 step_added 事件")
+        result = TestResult("WebSocket - 接收历史事件")
         start = datetime.now()
 
         if not self.test_trace_id:
@@ -259,7 +266,7 @@ class APITester:
                 # 接收 connected
                 await ws.recv()
 
-                # 接收后续消息(应该是历史 step_added 事件
+                # 接收后续消息(应该是历史事件:message_added, goal_added 等
                 events_received = []
                 try:
                     while len(events_received) < 10:  # 最多接收 10 条
@@ -271,15 +278,16 @@ class APITester:
 
                 result.duration_ms = int((datetime.now() - start).total_seconds() * 1000)
 
-                step_added_count = events_received.count("step_added")
-                if step_added_count == 0:
-                    result.fail("没有收到 step_added 事件(Trace 可能为空)")
+                # 统计事件类型
+                event_counts = {}
+                for event in events_received:
+                    event_counts[event] = event_counts.get(event, 0) + 1
+
+                if len(events_received) == 0:
+                    result.fail("没有收到任何事件(Trace 可能为空)")
                 else:
-                    other_events = [e for e in events_received if e != "step_added"]
-                    result.success(
-                        f"收到 {step_added_count} 个 step_added" +
-                        (f", {len(other_events)} 个其他事件" if other_events else "")
-                    )
+                    event_summary = ", ".join([f"{evt}: {cnt}" for evt, cnt in event_counts.items()])
+                    result.success(f"收到 {len(events_received)} 个事件 ({event_summary})")
         except Exception as e:
             result.fail(f"测试失败: {e}")