""" Trace WebSocket 推送 实时推送进行中 Trace 的更新,支持断线续传 """ from typing import Dict, Set, Any from datetime import datetime import asyncio from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query from .protocols import TraceStore router = APIRouter(prefix="/api/traces", tags=["websocket"]) # ===== 全局状态 ===== _trace_store: TraceStore = None _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket] def set_trace_store(store: TraceStore): """设置 TraceStore 实例""" global _trace_store _trace_store = store def get_trace_store() -> TraceStore: """获取 TraceStore 实例""" if _trace_store is None: raise RuntimeError("TraceStore not initialized") return _trace_store # ===== WebSocket 路由 ===== @router.websocket("/{trace_id}/watch") async def watch_trace( websocket: WebSocket, trace_id: str, since_event_id: int = Query(0, description="从哪个事件 ID 开始(0=补发所有历史)") ): """ 监听 Trace 的更新,支持断线续传 事件类型: - connected: 连接成功,返回 goal_tree 和 sub_traces - goal_added: 新增 Goal - goal_updated: Goal 状态变化(含级联完成) - message_added: 新 Message(含 affected_goals) - sub_trace_started: Sub-Trace 开始执行 - sub_trace_completed: Sub-Trace 完成 - trace_completed: 执行完成 Args: trace_id: Trace ID since_event_id: 从哪个事件 ID 开始 - 0: 补发所有历史事件(初次连接) - >0: 补发指定 ID 之后的事件(断线重连) """ await websocket.accept() # 验证 Trace 存在 store = get_trace_store() trace = await store.get_trace(trace_id) if not trace: await websocket.send_json({ "event": "error", "message": "Trace not found" }) await websocket.close() return # 注册连接 if trace_id not in _active_connections: _active_connections[trace_id] = set() _active_connections[trace_id].add(websocket) try: # 获取 GoalTree 和 Sub-Traces goal_tree = await store.get_goal_tree(trace_id) # 获取所有 Sub-Traces(通过 parent_trace_id 查询) sub_traces = {} all_traces = await store.list_traces(limit=1000) for t in all_traces: if t.parent_trace_id == trace_id: sub_traces[t.trace_id] = t.to_dict() # 发送连接成功消息 + 完整状态(含 trace 当前执行状态) from .run_api import _running_tasks # 避免循环导入,在函数内 import is_running = ( trace_id in _running_tasks and not _running_tasks[trace_id].done() ) await websocket.send_json({ "event": "connected", "trace_id": trace_id, "current_event_id": trace.last_event_id, "trace_status": trace.status if not is_running else "running", "is_running": is_running, "goal_tree": goal_tree.to_dict() if goal_tree else None, "sub_traces": sub_traces }) # 补发历史事件(since_event_id=0 表示补发所有历史) last_sent_event_id = since_event_id if since_event_id >= 0: missed_events = await store.get_events(trace_id, since_event_id) # 限制补发数量(最多 100 条) if len(missed_events) > 100: await websocket.send_json({ "event": "error", "message": f"Too many missed events ({len(missed_events)}), please reload via REST API" }) else: for evt in missed_events: await websocket.send_json(evt) if isinstance(evt, dict) and isinstance(evt.get("event_id"), int): last_sent_event_id = max(last_sent_event_id, evt["event_id"]) # 保持连接:同时支持心跳 + 轮询 events.jsonl(跨进程写入时也能实时推送) while True: try: # 允许在没有客户端消息时继续轮询事件流 data = await asyncio.wait_for(websocket.receive_text(), timeout=0.5) if data == "ping": await websocket.send_json({"event": "pong"}) except WebSocketDisconnect: break except asyncio.TimeoutError: pass try: new_events = await store.get_events(trace_id, last_sent_event_id) if len(new_events) > 100: await websocket.send_json({ "event": "error", "message": f"Too many missed events ({len(new_events)}), please reload via REST API" }) continue for evt in new_events: await websocket.send_json(evt) if isinstance(evt, dict) and isinstance(evt.get("event_id"), int): last_sent_event_id = max(last_sent_event_id, evt["event_id"]) except WebSocketDisconnect: break finally: # 清理连接 if trace_id in _active_connections: _active_connections[trace_id].discard(websocket) if not _active_connections[trace_id]: del _active_connections[trace_id] # ===== 广播函数(由 AgentRunner 或 TraceStore 调用)===== async def broadcast_goal_added(trace_id: str, goal_dict: Dict[str, Any]): """ 广播 Goal 添加事件 Args: trace_id: Trace ID goal_dict: Goal 字典(完整数据,含 stats) """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "goal_added", { "goal": goal_dict }) message = { "event": "goal_added", "event_id": event_id, "ts": datetime.now().isoformat(), "goal": goal_dict } await _broadcast_to_trace(trace_id, message) async def broadcast_goal_updated( trace_id: str, goal_id: str, updates: Dict[str, Any], affected_goals: list[Dict[str, Any]] = None ): """ 广播 Goal 更新事件(patch 语义) Args: trace_id: Trace ID goal_id: Goal ID updates: 更新字段(patch 格式) affected_goals: 受影响的 Goals(含级联完成的父节点) """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "goal_updated", { "goal_id": goal_id, "updates": updates, "affected_goals": affected_goals or [] }) message = { "event": "goal_updated", "event_id": event_id, "ts": datetime.now().isoformat(), "goal_id": goal_id, "patch": updates, "affected_goals": affected_goals or [] } await _broadcast_to_trace(trace_id, message) async def broadcast_message_added( trace_id: str, event_id: int, message_dict: Dict[str, Any], affected_goals: list[Dict[str, Any]] = None, ): """ 广播 Message 添加事件(不在此处写入 events.jsonl) 说明: - message_added 的 events.jsonl 写入由 TraceStore.append_event 负责 - 这里仅负责把“已经持久化”的事件推送给当前活跃连接 """ if trace_id not in _active_connections: return message = { "event": "message_added", "event_id": event_id, "ts": datetime.now().isoformat(), "message": message_dict, "affected_goals": affected_goals or [], } await _broadcast_to_trace(trace_id, message) async def broadcast_sub_trace_started( trace_id: str, sub_trace_id: str, parent_goal_id: str, agent_type: str, task: str ): """ 广播 Sub-Trace 开始事件 Args: trace_id: 主 Trace ID sub_trace_id: Sub-Trace ID parent_goal_id: 父 Goal ID agent_type: Agent 类型 task: 任务描述 """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "sub_trace_started", { "trace_id": sub_trace_id, "parent_trace_id": trace_id, "parent_goal_id": parent_goal_id, "agent_type": agent_type, "task": task }) message = { "event": "sub_trace_started", "event_id": event_id, "ts": datetime.now().isoformat(), "trace_id": sub_trace_id, "parent_goal_id": parent_goal_id, "agent_type": agent_type, "task": task } await _broadcast_to_trace(trace_id, message) async def broadcast_sub_trace_completed( trace_id: str, sub_trace_id: str, status: str, summary: str = "", stats: Dict[str, Any] = None ): """ 广播 Sub-Trace 完成事件 Args: trace_id: 主 Trace ID sub_trace_id: Sub-Trace ID status: 状态(completed/failed) summary: 总结 stats: 统计信息 """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "sub_trace_completed", { "trace_id": sub_trace_id, "status": status, "summary": summary, "stats": stats or {} }) message = { "event": "sub_trace_completed", "event_id": event_id, "ts": datetime.now().isoformat(), "trace_id": sub_trace_id, "status": status, "summary": summary, "stats": stats or {} } await _broadcast_to_trace(trace_id, message) async def broadcast_trace_completed(trace_id: str, total_messages: int): """ 广播 Trace 完成事件 Args: trace_id: Trace ID total_messages: 总 Message 数 """ if trace_id not in _active_connections: return store = get_trace_store() event_id = await store.append_event(trace_id, "trace_completed", { "total_messages": total_messages }) message = { "event": "trace_completed", "event_id": event_id, "ts": datetime.now().isoformat(), "trace_id": trace_id, "total_messages": total_messages } await _broadcast_to_trace(trace_id, message) # 完成后清理所有连接 if trace_id in _active_connections: del _active_connections[trace_id] async def broadcast_trace_status_changed(trace_id: str, status: str): """ 广播 Trace 状态变化事件(用于暂停/继续等状态切换) Args: trace_id: Trace ID status: 新状态 (running/stopped/completed/failed) """ if trace_id not in _active_connections: return store = get_trace_store() trace = await store.get_trace(trace_id) if not trace: return event_id = await store.append_event(trace_id, "trace_status_changed", { "status": status }) message = { "event": "trace_status_changed", "event_id": event_id, "ts": datetime.now().isoformat(), "trace_id": trace_id, "status": status, "total_cost": trace.total_cost, "total_messages": trace.total_messages } await _broadcast_to_trace(trace_id, message) # ===== 内部辅助函数 ===== async def _broadcast_to_trace(trace_id: str, message: Dict[str, Any]): """ 向指定 Trace 的所有连接广播消息 Args: trace_id: Trace ID message: 消息内容 """ if trace_id not in _active_connections: return disconnected = [] for websocket in _active_connections[trace_id]: try: await websocket.send_json(message) except Exception: disconnected.append(websocket) # 清理断开的连接 for ws in disconnected: _active_connections[trace_id].discard(ws)