""" Step 树 WebSocket 推送 实时推送进行中 Trace 的 Step 更新 """ from typing import Dict, Set from fastapi import APIRouter, WebSocket, WebSocketDisconnect from agent.execution.protocols import TraceStore router = APIRouter(prefix="/api/traces", tags=["websocket"]) # ===== 全局状态 ===== _trace_store: TraceStore = None _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket] def set_trace_store(store: TraceStore): """设置 TraceStore 实例""" global _trace_store _trace_store = store def get_trace_store() -> TraceStore: """获取 TraceStore 实例""" if _trace_store is None: raise RuntimeError("TraceStore not initialized") return _trace_store # ===== WebSocket 路由 ===== @router.websocket("/{trace_id}/watch") async def watch_trace(websocket: WebSocket, trace_id: str): """ 监听 Trace 的 Step 更新 Args: trace_id: Trace ID """ await websocket.accept() # 验证 Trace 存在 store = get_trace_store() trace = await store.get_trace(trace_id) if not trace: await websocket.send_json({ "event": "error", "message": "Trace not found" }) await websocket.close() return # 注册连接 if trace_id not in _active_connections: _active_connections[trace_id] = set() _active_connections[trace_id].add(websocket) try: # 发送连接成功消息 await websocket.send_json({ "event": "connected", "trace_id": trace_id }) # 保持连接(等待客户端断开或接收消息) while True: try: # 接收客户端消息(心跳检测) data = await websocket.receive_text() # 可以处理客户端请求(如请求完整状态) if data == "ping": await websocket.send_json({"event": "pong"}) except WebSocketDisconnect: break finally: # 清理连接 if trace_id in _active_connections: _active_connections[trace_id].discard(websocket) if not _active_connections[trace_id]: del _active_connections[trace_id] # ===== 广播函数(由 AgentRunner 调用)===== async def broadcast_step_added(trace_id: str, step_dict: Dict): """ 广播 Step 添加事件 Args: trace_id: Trace ID step_dict: Step 字典(from step.to_dict()) """ if trace_id not in _active_connections: return message = { "event": "step_added", "step": step_dict } # 发送给所有监听该 Trace 的客户端 disconnected = [] for websocket in _active_connections[trace_id]: try: await websocket.send_json(message) except Exception: disconnected.append(websocket) # 清理断开的连接 for ws in disconnected: _active_connections[trace_id].discard(ws) async def broadcast_step_updated(trace_id: str, step_id: str, updates: Dict): """ 广播 Step 更新事件 Args: trace_id: Trace ID step_id: Step ID updates: 更新字段 """ if trace_id not in _active_connections: return message = { "event": "step_updated", "step_id": step_id, "updates": updates } disconnected = [] for websocket in _active_connections[trace_id]: try: await websocket.send_json(message) except Exception: disconnected.append(websocket) for ws in disconnected: _active_connections[trace_id].discard(ws) async def broadcast_trace_completed(trace_id: str, total_steps: int): """ 广播 Trace 完成事件 Args: trace_id: Trace ID total_steps: 总 Step 数 """ if trace_id not in _active_connections: return message = { "event": "trace_completed", "trace_id": trace_id, "total_steps": total_steps } disconnected = [] for websocket in _active_connections[trace_id]: try: await websocket.send_json(message) except Exception: disconnected.append(websocket) for ws in disconnected: _active_connections[trace_id].discard(ws) # 完成后清理所有连接 if trace_id in _active_connections: del _active_connections[trace_id]