websocket.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. """
  2. Step 树 WebSocket 推送
  3. 实时推送进行中 Trace 的 Step 更新,支持断线续传
  4. """
  5. from typing import Dict, Set, Any
  6. from datetime import datetime
  7. from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
  8. from agent.execution.protocols import TraceStore
  9. router = APIRouter(prefix="/api/traces", tags=["websocket"])
  10. # ===== 全局状态 =====
  11. _trace_store: TraceStore = None
  12. _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket]
  13. def set_trace_store(store: TraceStore):
  14. """设置 TraceStore 实例"""
  15. global _trace_store
  16. _trace_store = store
  17. def get_trace_store() -> TraceStore:
  18. """获取 TraceStore 实例"""
  19. if _trace_store is None:
  20. raise RuntimeError("TraceStore not initialized")
  21. return _trace_store
  22. # ===== WebSocket 路由 =====
  23. @router.websocket("/{trace_id}/watch")
  24. async def watch_trace(
  25. websocket: WebSocket,
  26. trace_id: str,
  27. since_event_id: int = Query(0, description="从哪个事件 ID 开始(0=补发所有历史)")
  28. ):
  29. """
  30. 监听 Trace 的 Step 更新,支持断线续传
  31. Args:
  32. trace_id: Trace ID
  33. since_event_id: 从哪个事件 ID 开始
  34. - 0: 补发所有历史事件(初次连接)
  35. - >0: 补发指定 ID 之后的事件(断线重连)
  36. """
  37. await websocket.accept()
  38. # 验证 Trace 存在
  39. store = get_trace_store()
  40. trace = await store.get_trace(trace_id)
  41. if not trace:
  42. await websocket.send_json({
  43. "event": "error",
  44. "message": "Trace not found"
  45. })
  46. await websocket.close()
  47. return
  48. # 注册连接
  49. if trace_id not in _active_connections:
  50. _active_connections[trace_id] = set()
  51. _active_connections[trace_id].add(websocket)
  52. try:
  53. # 发送连接成功消息 + 当前 event_id
  54. await websocket.send_json({
  55. "event": "connected",
  56. "trace_id": trace_id,
  57. "current_event_id": trace.last_event_id
  58. })
  59. # 补发历史事件(since_event_id=0 表示补发所有历史)
  60. if since_event_id >= 0:
  61. missed_events = await store.get_events(trace_id, since_event_id)
  62. # 限制补发数量(最多 100 条)
  63. if len(missed_events) > 100:
  64. await websocket.send_json({
  65. "event": "error",
  66. "message": f"Too many missed events ({len(missed_events)}), please reload full tree via REST API"
  67. })
  68. else:
  69. for evt in missed_events:
  70. await websocket.send_json(evt)
  71. # 保持连接(等待客户端断开或接收消息)
  72. while True:
  73. try:
  74. # 接收客户端消息(心跳检测)
  75. data = await websocket.receive_text()
  76. # 可以处理客户端请求(如请求完整状态)
  77. if data == "ping":
  78. await websocket.send_json({"event": "pong"})
  79. except WebSocketDisconnect:
  80. break
  81. finally:
  82. # 清理连接
  83. if trace_id in _active_connections:
  84. _active_connections[trace_id].discard(websocket)
  85. if not _active_connections[trace_id]:
  86. del _active_connections[trace_id]
  87. # ===== 广播函数(由 AgentRunner 调用)=====
  88. async def broadcast_step_added(trace_id: str, step_dict: Dict):
  89. """
  90. 广播 Step 添加事件(自动分配 event_id)
  91. Args:
  92. trace_id: Trace ID
  93. step_dict: Step 字典(from step.to_dict(view="compact"))
  94. """
  95. if trace_id not in _active_connections:
  96. return
  97. # 从 store 获取最新 event_id(已由 add_step 自动追加)
  98. store = get_trace_store()
  99. trace = await store.get_trace(trace_id)
  100. if not trace:
  101. return
  102. message = {
  103. "event": "step_added",
  104. "event_id": trace.last_event_id,
  105. "ts": datetime.now().isoformat(),
  106. "step": step_dict # compact 视图
  107. }
  108. # 发送给所有监听该 Trace 的客户端
  109. disconnected = []
  110. for websocket in _active_connections[trace_id]:
  111. try:
  112. await websocket.send_json(message)
  113. except Exception:
  114. disconnected.append(websocket)
  115. # 清理断开的连接
  116. for ws in disconnected:
  117. _active_connections[trace_id].discard(ws)
  118. async def broadcast_step_updated(trace_id: str, step_id: str, updates: Dict):
  119. """
  120. 广播 Step 更新事件(patch 语义)
  121. Args:
  122. trace_id: Trace ID
  123. step_id: Step ID
  124. updates: 更新字段(patch 格式)
  125. """
  126. if trace_id not in _active_connections:
  127. return
  128. store = get_trace_store()
  129. # 追加事件到 store
  130. event_id = await store.append_event(trace_id, "step_updated", {
  131. "step_id": step_id,
  132. "updates": updates
  133. })
  134. message = {
  135. "event": "step_updated",
  136. "event_id": event_id,
  137. "ts": datetime.now().isoformat(),
  138. "step_id": step_id,
  139. "patch": updates # JSON Patch 风格:{"status": "completed", "duration_ms": 123}
  140. }
  141. disconnected = []
  142. for websocket in _active_connections[trace_id]:
  143. try:
  144. await websocket.send_json(message)
  145. except Exception:
  146. disconnected.append(websocket)
  147. for ws in disconnected:
  148. _active_connections[trace_id].discard(ws)
  149. async def broadcast_trace_completed(trace_id: str, total_steps: int):
  150. """
  151. 广播 Trace 完成事件
  152. Args:
  153. trace_id: Trace ID
  154. total_steps: 总 Step 数
  155. """
  156. if trace_id not in _active_connections:
  157. return
  158. store = get_trace_store()
  159. event_id = await store.append_event(trace_id, "trace_completed", {
  160. "total_steps": total_steps
  161. })
  162. message = {
  163. "event": "trace_completed",
  164. "event_id": event_id,
  165. "ts": datetime.now().isoformat(),
  166. "trace_id": trace_id,
  167. "total_steps": total_steps
  168. }
  169. disconnected = []
  170. for websocket in _active_connections[trace_id]:
  171. try:
  172. await websocket.send_json(message)
  173. except Exception:
  174. disconnected.append(websocket)
  175. for ws in disconnected:
  176. _active_connections[trace_id].discard(ws)
  177. # 完成后清理所有连接
  178. if trace_id in _active_connections:
  179. del _active_connections[trace_id]