| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- """
- Step 树 WebSocket 推送
- 实时推送进行中 Trace 的 Step 更新
- """
- from typing import Dict, Set
- from fastapi import APIRouter, WebSocket, WebSocketDisconnect
- from agent.trace.protocols import TraceStore
- router = APIRouter(prefix="/api/traces", tags=["websocket"])
- # ===== 全局状态 =====
- _trace_store: TraceStore = None
- _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket]
- def set_trace_store(store: TraceStore):
- """设置 TraceStore 实例"""
- global _trace_store
- _trace_store = store
- def get_trace_store() -> TraceStore:
- """获取 TraceStore 实例"""
- if _trace_store is None:
- raise RuntimeError("TraceStore not initialized")
- return _trace_store
- # ===== WebSocket 路由 =====
- @router.websocket("/{trace_id}/watch")
- async def watch_trace(websocket: WebSocket, trace_id: str):
- """
- 监听 Trace 的 Step 更新
- Args:
- trace_id: Trace ID
- """
- await websocket.accept()
- # 验证 Trace 存在
- store = get_trace_store()
- trace = await store.get_trace(trace_id)
- if not trace:
- await websocket.send_json({
- "event": "error",
- "message": "Trace not found"
- })
- await websocket.close()
- return
- # 注册连接
- if trace_id not in _active_connections:
- _active_connections[trace_id] = set()
- _active_connections[trace_id].add(websocket)
- try:
- # 发送连接成功消息
- await websocket.send_json({
- "event": "connected",
- "trace_id": trace_id
- })
- # 保持连接(等待客户端断开或接收消息)
- while True:
- try:
- # 接收客户端消息(心跳检测)
- data = await websocket.receive_text()
- # 可以处理客户端请求(如请求完整状态)
- if data == "ping":
- await websocket.send_json({"event": "pong"})
- except WebSocketDisconnect:
- break
- finally:
- # 清理连接
- if trace_id in _active_connections:
- _active_connections[trace_id].discard(websocket)
- if not _active_connections[trace_id]:
- del _active_connections[trace_id]
- # ===== 广播函数(由 AgentRunner 调用)=====
- async def broadcast_step_added(trace_id: str, step_dict: Dict):
- """
- 广播 Step 添加事件
- Args:
- trace_id: Trace ID
- step_dict: Step 字典(from step.to_dict())
- """
- if trace_id not in _active_connections:
- return
- message = {
- "event": "step_added",
- "step": step_dict
- }
- # 发送给所有监听该 Trace 的客户端
- disconnected = []
- for websocket in _active_connections[trace_id]:
- try:
- await websocket.send_json(message)
- except Exception:
- disconnected.append(websocket)
- # 清理断开的连接
- for ws in disconnected:
- _active_connections[trace_id].discard(ws)
- async def broadcast_step_updated(trace_id: str, step_id: str, updates: Dict):
- """
- 广播 Step 更新事件
- Args:
- trace_id: Trace ID
- step_id: Step ID
- updates: 更新字段
- """
- if trace_id not in _active_connections:
- return
- message = {
- "event": "step_updated",
- "step_id": step_id,
- "updates": updates
- }
- disconnected = []
- for websocket in _active_connections[trace_id]:
- try:
- await websocket.send_json(message)
- except Exception:
- disconnected.append(websocket)
- for ws in disconnected:
- _active_connections[trace_id].discard(ws)
- async def broadcast_trace_completed(trace_id: str, total_steps: int):
- """
- 广播 Trace 完成事件
- Args:
- trace_id: Trace ID
- total_steps: 总 Step 数
- """
- if trace_id not in _active_connections:
- return
- message = {
- "event": "trace_completed",
- "trace_id": trace_id,
- "total_steps": total_steps
- }
- disconnected = []
- for websocket in _active_connections[trace_id]:
- try:
- await websocket.send_json(message)
- except Exception:
- disconnected.append(websocket)
- for ws in disconnected:
- _active_connections[trace_id].discard(ws)
- # 完成后清理所有连接
- if trace_id in _active_connections:
- del _active_connections[trace_id]
|