""" Trace WebSocket 推送 实时推送进行中 Trace 的更新,支持断线续传 """ from typing import Dict, Set, Any from datetime import datetime from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query from .protocols import TraceStore router = APIRouter(prefix="/api/traces", tags=["websocket"]) # ===== 全局状态 ===== _trace_store: TraceStore = None _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket] def set_trace_store(store: TraceStore): """设置 TraceStore 实例""" global _trace_store _trace_store = store def get_trace_store() -> TraceStore: """获取 TraceStore 实例""" if _trace_store is None: raise RuntimeError("TraceStore not initialized") return _trace_store # ===== WebSocket 路由 ===== @router.websocket("/{trace_id}/watch") async def watch_trace( websocket: WebSocket, trace_id: str, since_event_id: int = Query(0, description="从哪个事件 ID 开始(0=补发所有历史)") ): """ 监听 Trace 的更新,支持断线续传 事件类型: - connected: 连接成功,返回 goal_tree 和 sub_traces - goal_added: 新增 Goal - goal_updated: Goal 状态变化(含级联完成) - message_added: 新 Message(含 affected_goals) - sub_trace_started: Sub-Trace 开始执行 - sub_trace_completed: Sub-Trace 完成 - trace_completed: 执行完成 Args: trace_id: Trace ID since_event_id: 从哪个事件 ID 开始 - 0: 补发所有历史事件(初次连接) - >0: 补发指定 ID 之后的事件(断线重连) """ await websocket.accept() # 验证 Trace 存在 store = get_trace_store() trace = await store.get_trace(trace_id) if not trace: await websocket.send_json({ "event": "error", "message": "Trace not found" }) await websocket.close() return # 注册连接 if trace_id not in _active_connections: _active_connections[trace_id] = set() _active_connections[trace_id].add(websocket) try: # 获取 GoalTree 和 Sub-Traces goal_tree = await store.get_goal_tree(trace_id) # 获取所有 Sub-Traces(通过 parent_trace_id 查询) sub_traces = {} all_traces = await store.list_traces(limit=1000) for t in all_traces: if t.parent_trace_id == trace_id: sub_traces[t.trace_id] = t.to_dict() # 发送连接成功消息 + 完整状态 await websocket.send_json({ "event": "connected", "trace_id": trace_id, "current_event_id": trace.last_event_id, "goal_tree": goal_tree.to_dict() if goal_tree else None, "sub_traces": sub_traces }) # 补发历史事件(since_event_id=0 表示补发所有历史) if since_event_id >= 0: missed_events = await store.get_events(trace_id, since_event_id) # 限制补发数量(最多 100 条) if len(missed_events) > 100: await websocket.send_json({ "event": "error", "message": f"Too many missed events ({len(missed_events)}), please reload via REST API" }) else: for evt in missed_events: await websocket.send_json(evt) # 保持连接(等待客户端断开或接收消息) while True: try: # 接收客户端消息(心跳检测) data = await websocket.receive_text() if data == "ping": await websocket.send_json({"event": "pong"}) except WebSocketDisconnect: break finally: # 清理连接 if trace_id in _active_connections: _active_connections[trace_id].discard(websocket) if not _active_connections[trace_id]: del _active_connections[trace_id] # ===== 广播函数(由 AgentRunner 或 TraceStore 调用)===== async def broadcast_goal_added(trace_id: str, goal_dict: Dict[str, Any]): """ 广播 Goal 添加事件 Args: trace_id: Trace ID goal_dict: Goal 字典(完整数据,含 stats) """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "goal_added", { "goal": goal_dict }) message = { "event": "goal_added", "event_id": event_id, "ts": datetime.now().isoformat(), "goal": goal_dict } await _broadcast_to_trace(trace_id, message) async def broadcast_goal_updated( trace_id: str, goal_id: str, updates: Dict[str, Any], affected_goals: list[Dict[str, Any]] = None ): """ 广播 Goal 更新事件(patch 语义) Args: trace_id: Trace ID goal_id: Goal ID updates: 更新字段(patch 格式) affected_goals: 受影响的 Goals(含级联完成的父节点) """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "goal_updated", { "goal_id": goal_id, "updates": updates, "affected_goals": affected_goals or [] }) message = { "event": "goal_updated", "event_id": event_id, "ts": datetime.now().isoformat(), "goal_id": goal_id, "patch": updates, "affected_goals": affected_goals or [] } await _broadcast_to_trace(trace_id, message) async def broadcast_sub_trace_started( trace_id: str, sub_trace_id: str, parent_goal_id: str, agent_type: str, task: str ): """ 广播 Sub-Trace 开始事件 Args: trace_id: 主 Trace ID sub_trace_id: Sub-Trace ID parent_goal_id: 父 Goal ID agent_type: Agent 类型 task: 任务描述 """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "sub_trace_started", { "trace_id": sub_trace_id, "parent_trace_id": trace_id, "parent_goal_id": parent_goal_id, "agent_type": agent_type, "task": task }) message = { "event": "sub_trace_started", "event_id": event_id, "ts": datetime.now().isoformat(), "trace_id": sub_trace_id, "parent_goal_id": parent_goal_id, "agent_type": agent_type, "task": task } await _broadcast_to_trace(trace_id, message) async def broadcast_sub_trace_completed( trace_id: str, sub_trace_id: str, status: str, summary: str = "", stats: Dict[str, Any] = None ): """ 广播 Sub-Trace 完成事件 Args: trace_id: 主 Trace ID sub_trace_id: Sub-Trace ID status: 状态(completed/failed) summary: 总结 stats: 统计信息 """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "sub_trace_completed", { "trace_id": sub_trace_id, "status": status, "summary": summary, "stats": stats or {} }) message = { "event": "sub_trace_completed", "event_id": event_id, "ts": datetime.now().isoformat(), "trace_id": sub_trace_id, "status": status, "summary": summary, "stats": stats or {} } await _broadcast_to_trace(trace_id, message) async def broadcast_trace_completed(trace_id: str, total_messages: int): """ 广播 Trace 完成事件 Args: trace_id: Trace ID total_messages: 总 Message 数 """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "trace_completed", { "total_messages": total_messages }) message = { "event": "trace_completed", "event_id": event_id, "ts": datetime.now().isoformat(), "trace_id": trace_id, "total_messages": total_messages } await _broadcast_to_trace(trace_id, message) # 完成后清理所有连接 if trace_id in _active_connections: del _active_connections[trace_id] # ===== 内部辅助函数 ===== async def _broadcast_to_trace(trace_id: str, message: Dict[str, Any]): """ 向指定 Trace 的所有连接广播消息 Args: trace_id: Trace ID message: 消息内容 """ if trace_id not in _active_connections: return disconnected = [] for websocket in _active_connections[trace_id]: try: await websocket.send_json(message) except Exception: disconnected.append(websocket) # 清理断开的连接 for ws in disconnected: _active_connections[trace_id].discard(ws)