| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- """
- Step 树 WebSocket 推送
- 实时推送进行中 Trace 的 Step 更新,支持断线续传
- """
- from typing import Dict, Set, Any
- from datetime import datetime
- from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
- from agent.execution.protocols import TraceStore
- router = APIRouter(prefix="/api/traces", tags=["websocket"])
- # ===== 全局状态 =====
- _trace_store: TraceStore = None
- _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket]
- def set_trace_store(store: TraceStore):
- """设置 TraceStore 实例"""
- global _trace_store
- _trace_store = store
- def get_trace_store() -> TraceStore:
- """获取 TraceStore 实例"""
- if _trace_store is None:
- raise RuntimeError("TraceStore not initialized")
- return _trace_store
- # ===== WebSocket 路由 =====
- @router.websocket("/{trace_id}/watch")
- async def watch_trace(
- websocket: WebSocket,
- trace_id: str,
- since_event_id: int = Query(0, description="从哪个事件 ID 开始(0=补发所有历史)")
- ):
- """
- 监听 Trace 的 Step 更新,支持断线续传
- Args:
- trace_id: Trace ID
- since_event_id: 从哪个事件 ID 开始
- - 0: 补发所有历史事件(初次连接)
- - >0: 补发指定 ID 之后的事件(断线重连)
- """
- await websocket.accept()
- # 验证 Trace 存在
- store = get_trace_store()
- trace = await store.get_trace(trace_id)
- if not trace:
- await websocket.send_json({
- "event": "error",
- "message": "Trace not found"
- })
- await websocket.close()
- return
- # 注册连接
- if trace_id not in _active_connections:
- _active_connections[trace_id] = set()
- _active_connections[trace_id].add(websocket)
- try:
- # 发送连接成功消息 + 当前 event_id
- await websocket.send_json({
- "event": "connected",
- "trace_id": trace_id,
- "current_event_id": trace.last_event_id
- })
- # 补发历史事件(since_event_id=0 表示补发所有历史)
- if since_event_id >= 0:
- missed_events = await store.get_events(trace_id, since_event_id)
- # 限制补发数量(最多 100 条)
- if len(missed_events) > 100:
- await websocket.send_json({
- "event": "error",
- "message": f"Too many missed events ({len(missed_events)}), please reload full tree via REST API"
- })
- else:
- for evt in missed_events:
- await websocket.send_json(evt)
- # 保持连接(等待客户端断开或接收消息)
- while True:
- try:
- # 接收客户端消息(心跳检测)
- data = await websocket.receive_text()
- # 可以处理客户端请求(如请求完整状态)
- if data == "ping":
- await websocket.send_json({"event": "pong"})
- except WebSocketDisconnect:
- break
- finally:
- # 清理连接
- if trace_id in _active_connections:
- _active_connections[trace_id].discard(websocket)
- if not _active_connections[trace_id]:
- del _active_connections[trace_id]
- # ===== 广播函数(由 AgentRunner 调用)=====
- async def broadcast_step_added(trace_id: str, step_dict: Dict):
- """
- 广播 Step 添加事件(自动分配 event_id)
- Args:
- trace_id: Trace ID
- step_dict: Step 字典(from step.to_dict(view="compact"))
- """
- if trace_id not in _active_connections:
- return
- # 从 store 获取最新 event_id(已由 add_step 自动追加)
- store = get_trace_store()
- trace = await store.get_trace(trace_id)
- if not trace:
- return
- message = {
- "event": "step_added",
- "event_id": trace.last_event_id,
- "ts": datetime.now().isoformat(),
- "step": step_dict # compact 视图
- }
- # 发送给所有监听该 Trace 的客户端
- disconnected = []
- for websocket in _active_connections[trace_id]:
- try:
- await websocket.send_json(message)
- except Exception:
- disconnected.append(websocket)
- # 清理断开的连接
- for ws in disconnected:
- _active_connections[trace_id].discard(ws)
- async def broadcast_step_updated(trace_id: str, step_id: str, updates: Dict):
- """
- 广播 Step 更新事件(patch 语义)
- Args:
- trace_id: Trace ID
- step_id: Step ID
- updates: 更新字段(patch 格式)
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- # 追加事件到 store
- event_id = await store.append_event(trace_id, "step_updated", {
- "step_id": step_id,
- "updates": updates
- })
- message = {
- "event": "step_updated",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "step_id": step_id,
- "patch": updates # JSON Patch 风格:{"status": "completed", "duration_ms": 123}
- }
- disconnected = []
- for websocket in _active_connections[trace_id]:
- try:
- await websocket.send_json(message)
- except Exception:
- disconnected.append(websocket)
- for ws in disconnected:
- _active_connections[trace_id].discard(ws)
- async def broadcast_trace_completed(trace_id: str, total_steps: int):
- """
- 广播 Trace 完成事件
- Args:
- trace_id: Trace ID
- total_steps: 总 Step 数
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "trace_completed", {
- "total_steps": total_steps
- })
- message = {
- "event": "trace_completed",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "trace_id": trace_id,
- "total_steps": total_steps
- }
- disconnected = []
- for websocket in _active_connections[trace_id]:
- try:
- await websocket.send_json(message)
- except Exception:
- disconnected.append(websocket)
- for ws in disconnected:
- _active_connections[trace_id].discard(ws)
- # 完成后清理所有连接
- if trace_id in _active_connections:
- del _active_connections[trace_id]
|