""" Trace WebSocket 推送 实时推送进行中 Trace 的更新,支持断线续传 """ from typing import Dict, Set, Any from datetime import datetime from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query from agent.execution.protocols import TraceStore router = APIRouter(prefix="/api/traces", tags=["websocket"]) # ===== 全局状态 ===== _trace_store: TraceStore = None _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket] def set_trace_store(store: TraceStore): """设置 TraceStore 实例""" global _trace_store _trace_store = store def get_trace_store() -> TraceStore: """获取 TraceStore 实例""" if _trace_store is None: raise RuntimeError("TraceStore not initialized") return _trace_store # ===== WebSocket 路由 ===== @router.websocket("/{trace_id}/watch") async def watch_trace( websocket: WebSocket, trace_id: str, since_event_id: int = Query(0, description="从哪个事件 ID 开始(0=补发所有历史)") ): """ 监听 Trace 的更新,支持断线续传 事件类型: - connected: 连接成功,返回 goal_tree 和 branches - goal_added: 新增 Goal - goal_updated: Goal 状态变化(含级联完成) - message_added: 新 Message(含 affected_goals) - branch_started: 分支开始探索 - branch_goal_added: 分支内新增 Goal - branch_completed: 分支完成 - explore_completed: 所有分支完成 - trace_completed: 执行完成 Args: trace_id: Trace ID since_event_id: 从哪个事件 ID 开始 - 0: 补发所有历史事件(初次连接) - >0: 补发指定 ID 之后的事件(断线重连) """ await websocket.accept() # 验证 Trace 存在 store = get_trace_store() trace = await store.get_trace(trace_id) if not trace: await websocket.send_json({ "event": "error", "message": "Trace not found" }) await websocket.close() return # 注册连接 if trace_id not in _active_connections: _active_connections[trace_id] = set() _active_connections[trace_id].add(websocket) try: # 获取 GoalTree 和分支元数据 goal_tree = await store.get_goal_tree(trace_id) branches = await store.list_branches(trace_id) # 发送连接成功消息 + 完整状态 await websocket.send_json({ "event": "connected", "trace_id": trace_id, "current_event_id": trace.last_event_id, "goal_tree": goal_tree.to_dict() if goal_tree else None, "branches": {b_id: b.to_dict() for b_id, b in branches.items()} }) # 补发历史事件(since_event_id=0 表示补发所有历史) if since_event_id >= 0: missed_events = await store.get_events(trace_id, since_event_id) # 限制补发数量(最多 100 条) if len(missed_events) > 100: await websocket.send_json({ "event": "error", "message": f"Too many missed events ({len(missed_events)}), please reload via REST API" }) else: for evt in missed_events: await websocket.send_json(evt) # 保持连接(等待客户端断开或接收消息) while True: try: # 接收客户端消息(心跳检测) data = await websocket.receive_text() if data == "ping": await websocket.send_json({"event": "pong"}) except WebSocketDisconnect: break finally: # 清理连接 if trace_id in _active_connections: _active_connections[trace_id].discard(websocket) if not _active_connections[trace_id]: del _active_connections[trace_id] # ===== 广播函数(由 AgentRunner 或 TraceStore 调用)===== async def broadcast_goal_added(trace_id: str, goal_dict: Dict[str, Any]): """ 广播 Goal 添加事件 Args: trace_id: Trace ID goal_dict: Goal 字典(完整数据,含 stats) """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "goal_added", { "goal": goal_dict }) message = { "event": "goal_added", "event_id": event_id, "ts": datetime.now().isoformat(), "goal": goal_dict } await _broadcast_to_trace(trace_id, message) async def broadcast_goal_updated( trace_id: str, goal_id: str, updates: Dict[str, Any], affected_goals: list[Dict[str, Any]] = None ): """ 广播 Goal 更新事件(patch 语义) Args: trace_id: Trace ID goal_id: Goal ID updates: 更新字段(patch 格式) affected_goals: 受影响的 Goals(含级联完成的父节点) """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "goal_updated", { "goal_id": goal_id, "updates": updates, "affected_goals": affected_goals or [] }) message = { "event": "goal_updated", "event_id": event_id, "ts": datetime.now().isoformat(), "goal_id": goal_id, "patch": updates, "affected_goals": affected_goals or [] } await _broadcast_to_trace(trace_id, message) async def broadcast_branch_started(trace_id: str, branch_dict: Dict[str, Any]): """ 广播分支开始事件 Args: trace_id: Trace ID branch_dict: BranchContext 字典 """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "branch_started", { "branch": branch_dict }) message = { "event": "branch_started", "event_id": event_id, "ts": datetime.now().isoformat(), "branch": branch_dict } await _broadcast_to_trace(trace_id, message) async def broadcast_branch_goal_added( trace_id: str, branch_id: str, goal_dict: Dict[str, Any] ): """ 广播分支内新增 Goal Args: trace_id: Trace ID branch_id: 分支 ID goal_dict: Goal 字典 """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "branch_goal_added", { "branch_id": branch_id, "goal": goal_dict }) message = { "event": "branch_goal_added", "event_id": event_id, "ts": datetime.now().isoformat(), "branch_id": branch_id, "goal": goal_dict } await _broadcast_to_trace(trace_id, message) async def broadcast_branch_completed( trace_id: str, branch_id: str, summary: str, cumulative_stats: Dict[str, Any] ): """ 广播分支完成事件 Args: trace_id: Trace ID branch_id: 分支 ID summary: 分支总结 cumulative_stats: 分支累计统计 """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "branch_completed", { "branch_id": branch_id, "summary": summary, "cumulative_stats": cumulative_stats }) message = { "event": "branch_completed", "event_id": event_id, "ts": datetime.now().isoformat(), "branch_id": branch_id, "summary": summary, "cumulative_stats": cumulative_stats } await _broadcast_to_trace(trace_id, message) async def broadcast_explore_completed( trace_id: str, explore_start_id: str, merge_summary: str ): """ 广播探索完成事件 Args: trace_id: Trace ID explore_start_id: explore_start Goal ID merge_summary: 汇总结果 """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "explore_completed", { "explore_start_id": explore_start_id, "merge_summary": merge_summary }) message = { "event": "explore_completed", "event_id": event_id, "ts": datetime.now().isoformat(), "explore_start_id": explore_start_id, "merge_summary": merge_summary } await _broadcast_to_trace(trace_id, message) async def broadcast_trace_completed(trace_id: str, total_messages: int): """ 广播 Trace 完成事件 Args: trace_id: Trace ID total_messages: 总 Message 数 """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "trace_completed", { "total_messages": total_messages }) message = { "event": "trace_completed", "event_id": event_id, "ts": datetime.now().isoformat(), "trace_id": trace_id, "total_messages": total_messages } await _broadcast_to_trace(trace_id, message) # 完成后清理所有连接 if trace_id in _active_connections: del _active_connections[trace_id] # ===== 内部辅助函数 ===== async def _broadcast_to_trace(trace_id: str, message: Dict[str, Any]): """ 向指定 Trace 的所有连接广播消息 Args: trace_id: Trace ID message: 消息内容 """ if trace_id not in _active_connections: return disconnected = [] for websocket in _active_connections[trace_id]: try: await websocket.send_json(message) except Exception: disconnected.append(websocket) # 清理断开的连接 for ws in disconnected: _active_connections[trace_id].discard(ws)