| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378 |
- """
- Trace WebSocket 推送
- 实时推送进行中 Trace 的更新,支持断线续传
- """
- from typing import Dict, Set, Any
- from datetime import datetime
- from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
- from agent.execution.protocols import TraceStore
- router = APIRouter(prefix="/api/traces", tags=["websocket"])
- # ===== 全局状态 =====
- _trace_store: TraceStore = None
- _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket]
- def set_trace_store(store: TraceStore):
- """设置 TraceStore 实例"""
- global _trace_store
- _trace_store = store
- def get_trace_store() -> TraceStore:
- """获取 TraceStore 实例"""
- if _trace_store is None:
- raise RuntimeError("TraceStore not initialized")
- return _trace_store
- # ===== WebSocket 路由 =====
- @router.websocket("/{trace_id}/watch")
- async def watch_trace(
- websocket: WebSocket,
- trace_id: str,
- since_event_id: int = Query(0, description="从哪个事件 ID 开始(0=补发所有历史)")
- ):
- """
- 监听 Trace 的更新,支持断线续传
- 事件类型:
- - connected: 连接成功,返回 goal_tree 和 branches
- - goal_added: 新增 Goal
- - goal_updated: Goal 状态变化(含级联完成)
- - message_added: 新 Message(含 affected_goals)
- - branch_started: 分支开始探索
- - branch_goal_added: 分支内新增 Goal
- - branch_completed: 分支完成
- - explore_completed: 所有分支完成
- - trace_completed: 执行完成
- Args:
- trace_id: Trace ID
- since_event_id: 从哪个事件 ID 开始
- - 0: 补发所有历史事件(初次连接)
- - >0: 补发指定 ID 之后的事件(断线重连)
- """
- await websocket.accept()
- # 验证 Trace 存在
- store = get_trace_store()
- trace = await store.get_trace(trace_id)
- if not trace:
- await websocket.send_json({
- "event": "error",
- "message": "Trace not found"
- })
- await websocket.close()
- return
- # 注册连接
- if trace_id not in _active_connections:
- _active_connections[trace_id] = set()
- _active_connections[trace_id].add(websocket)
- try:
- # 获取 GoalTree 和分支元数据
- goal_tree = await store.get_goal_tree(trace_id)
- branches = await store.list_branches(trace_id)
- # 发送连接成功消息 + 完整状态
- await websocket.send_json({
- "event": "connected",
- "trace_id": trace_id,
- "current_event_id": trace.last_event_id,
- "goal_tree": goal_tree.to_dict() if goal_tree else None,
- "branches": {b_id: b.to_dict() for b_id, b in branches.items()}
- })
- # 补发历史事件(since_event_id=0 表示补发所有历史)
- if since_event_id >= 0:
- missed_events = await store.get_events(trace_id, since_event_id)
- # 限制补发数量(最多 100 条)
- if len(missed_events) > 100:
- await websocket.send_json({
- "event": "error",
- "message": f"Too many missed events ({len(missed_events)}), please reload via REST API"
- })
- else:
- for evt in missed_events:
- await websocket.send_json(evt)
- # 保持连接(等待客户端断开或接收消息)
- while True:
- try:
- # 接收客户端消息(心跳检测)
- data = await websocket.receive_text()
- if data == "ping":
- await websocket.send_json({"event": "pong"})
- except WebSocketDisconnect:
- break
- finally:
- # 清理连接
- if trace_id in _active_connections:
- _active_connections[trace_id].discard(websocket)
- if not _active_connections[trace_id]:
- del _active_connections[trace_id]
- # ===== 广播函数(由 AgentRunner 或 TraceStore 调用)=====
- async def broadcast_goal_added(trace_id: str, goal_dict: Dict[str, Any]):
- """
- 广播 Goal 添加事件
- Args:
- trace_id: Trace ID
- goal_dict: Goal 字典(完整数据,含 stats)
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "goal_added", {
- "goal": goal_dict
- })
- message = {
- "event": "goal_added",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "goal": goal_dict
- }
- await _broadcast_to_trace(trace_id, message)
- async def broadcast_goal_updated(
- trace_id: str,
- goal_id: str,
- updates: Dict[str, Any],
- affected_goals: list[Dict[str, Any]] = None
- ):
- """
- 广播 Goal 更新事件(patch 语义)
- Args:
- trace_id: Trace ID
- goal_id: Goal ID
- updates: 更新字段(patch 格式)
- affected_goals: 受影响的 Goals(含级联完成的父节点)
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "goal_updated", {
- "goal_id": goal_id,
- "updates": updates,
- "affected_goals": affected_goals or []
- })
- message = {
- "event": "goal_updated",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "goal_id": goal_id,
- "patch": updates,
- "affected_goals": affected_goals or []
- }
- await _broadcast_to_trace(trace_id, message)
- async def broadcast_branch_started(trace_id: str, branch_dict: Dict[str, Any]):
- """
- 广播分支开始事件
- Args:
- trace_id: Trace ID
- branch_dict: BranchContext 字典
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "branch_started", {
- "branch": branch_dict
- })
- message = {
- "event": "branch_started",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "branch": branch_dict
- }
- await _broadcast_to_trace(trace_id, message)
- async def broadcast_branch_goal_added(
- trace_id: str,
- branch_id: str,
- goal_dict: Dict[str, Any]
- ):
- """
- 广播分支内新增 Goal
- Args:
- trace_id: Trace ID
- branch_id: 分支 ID
- goal_dict: Goal 字典
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "branch_goal_added", {
- "branch_id": branch_id,
- "goal": goal_dict
- })
- message = {
- "event": "branch_goal_added",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "branch_id": branch_id,
- "goal": goal_dict
- }
- await _broadcast_to_trace(trace_id, message)
- async def broadcast_branch_completed(
- trace_id: str,
- branch_id: str,
- summary: str,
- cumulative_stats: Dict[str, Any]
- ):
- """
- 广播分支完成事件
- Args:
- trace_id: Trace ID
- branch_id: 分支 ID
- summary: 分支总结
- cumulative_stats: 分支累计统计
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "branch_completed", {
- "branch_id": branch_id,
- "summary": summary,
- "cumulative_stats": cumulative_stats
- })
- message = {
- "event": "branch_completed",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "branch_id": branch_id,
- "summary": summary,
- "cumulative_stats": cumulative_stats
- }
- await _broadcast_to_trace(trace_id, message)
- async def broadcast_explore_completed(
- trace_id: str,
- explore_start_id: str,
- merge_summary: str
- ):
- """
- 广播探索完成事件
- Args:
- trace_id: Trace ID
- explore_start_id: explore_start Goal ID
- merge_summary: 汇总结果
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "explore_completed", {
- "explore_start_id": explore_start_id,
- "merge_summary": merge_summary
- })
- message = {
- "event": "explore_completed",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "explore_start_id": explore_start_id,
- "merge_summary": merge_summary
- }
- await _broadcast_to_trace(trace_id, message)
- async def broadcast_trace_completed(trace_id: str, total_messages: int):
- """
- 广播 Trace 完成事件
- Args:
- trace_id: Trace ID
- total_messages: 总 Message 数
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "trace_completed", {
- "total_messages": total_messages
- })
- message = {
- "event": "trace_completed",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "trace_id": trace_id,
- "total_messages": total_messages
- }
- await _broadcast_to_trace(trace_id, message)
- # 完成后清理所有连接
- if trace_id in _active_connections:
- del _active_connections[trace_id]
- # ===== 内部辅助函数 =====
- async def _broadcast_to_trace(trace_id: str, message: Dict[str, Any]):
- """
- 向指定 Trace 的所有连接广播消息
- Args:
- trace_id: Trace ID
- message: 消息内容
- """
- if trace_id not in _active_connections:
- return
- disconnected = []
- for websocket in _active_connections[trace_id]:
- try:
- await websocket.send_json(message)
- except Exception:
- disconnected.append(websocket)
- # 清理断开的连接
- for ws in disconnected:
- _active_connections[trace_id].discard(ws)
|