websocket.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. """
  2. Step 树 WebSocket 推送
  3. 实时推送进行中 Trace 的 Step 更新
  4. """
  5. from typing import Dict, Set
  6. from fastapi import APIRouter, WebSocket, WebSocketDisconnect
  7. from agent.trace.protocols import TraceStore
  8. router = APIRouter(prefix="/api/traces", tags=["websocket"])
  9. # ===== 全局状态 =====
  10. _trace_store: TraceStore = None
  11. _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket]
  12. def set_trace_store(store: TraceStore):
  13. """设置 TraceStore 实例"""
  14. global _trace_store
  15. _trace_store = store
  16. def get_trace_store() -> TraceStore:
  17. """获取 TraceStore 实例"""
  18. if _trace_store is None:
  19. raise RuntimeError("TraceStore not initialized")
  20. return _trace_store
  21. # ===== WebSocket 路由 =====
  22. @router.websocket("/{trace_id}/watch")
  23. async def watch_trace(websocket: WebSocket, trace_id: str):
  24. """
  25. 监听 Trace 的 Step 更新
  26. Args:
  27. trace_id: Trace ID
  28. """
  29. await websocket.accept()
  30. # 验证 Trace 存在
  31. store = get_trace_store()
  32. trace = await store.get_trace(trace_id)
  33. if not trace:
  34. await websocket.send_json({
  35. "event": "error",
  36. "message": "Trace not found"
  37. })
  38. await websocket.close()
  39. return
  40. # 注册连接
  41. if trace_id not in _active_connections:
  42. _active_connections[trace_id] = set()
  43. _active_connections[trace_id].add(websocket)
  44. try:
  45. # 发送连接成功消息
  46. await websocket.send_json({
  47. "event": "connected",
  48. "trace_id": trace_id
  49. })
  50. # 保持连接(等待客户端断开或接收消息)
  51. while True:
  52. try:
  53. # 接收客户端消息(心跳检测)
  54. data = await websocket.receive_text()
  55. # 可以处理客户端请求(如请求完整状态)
  56. if data == "ping":
  57. await websocket.send_json({"event": "pong"})
  58. except WebSocketDisconnect:
  59. break
  60. finally:
  61. # 清理连接
  62. if trace_id in _active_connections:
  63. _active_connections[trace_id].discard(websocket)
  64. if not _active_connections[trace_id]:
  65. del _active_connections[trace_id]
  66. # ===== 广播函数(由 AgentRunner 调用)=====
  67. async def broadcast_step_added(trace_id: str, step_dict: Dict):
  68. """
  69. 广播 Step 添加事件
  70. Args:
  71. trace_id: Trace ID
  72. step_dict: Step 字典(from step.to_dict())
  73. """
  74. if trace_id not in _active_connections:
  75. return
  76. message = {
  77. "event": "step_added",
  78. "step": step_dict
  79. }
  80. # 发送给所有监听该 Trace 的客户端
  81. disconnected = []
  82. for websocket in _active_connections[trace_id]:
  83. try:
  84. await websocket.send_json(message)
  85. except Exception:
  86. disconnected.append(websocket)
  87. # 清理断开的连接
  88. for ws in disconnected:
  89. _active_connections[trace_id].discard(ws)
  90. async def broadcast_step_updated(trace_id: str, step_id: str, updates: Dict):
  91. """
  92. 广播 Step 更新事件
  93. Args:
  94. trace_id: Trace ID
  95. step_id: Step ID
  96. updates: 更新字段
  97. """
  98. if trace_id not in _active_connections:
  99. return
  100. message = {
  101. "event": "step_updated",
  102. "step_id": step_id,
  103. "updates": updates
  104. }
  105. disconnected = []
  106. for websocket in _active_connections[trace_id]:
  107. try:
  108. await websocket.send_json(message)
  109. except Exception:
  110. disconnected.append(websocket)
  111. for ws in disconnected:
  112. _active_connections[trace_id].discard(ws)
  113. async def broadcast_trace_completed(trace_id: str, total_steps: int):
  114. """
  115. 广播 Trace 完成事件
  116. Args:
  117. trace_id: Trace ID
  118. total_steps: 总 Step 数
  119. """
  120. if trace_id not in _active_connections:
  121. return
  122. message = {
  123. "event": "trace_completed",
  124. "trace_id": trace_id,
  125. "total_steps": total_steps
  126. }
  127. disconnected = []
  128. for websocket in _active_connections[trace_id]:
  129. try:
  130. await websocket.send_json(message)
  131. except Exception:
  132. disconnected.append(websocket)
  133. for ws in disconnected:
  134. _active_connections[trace_id].discard(ws)
  135. # 完成后清理所有连接
  136. if trace_id in _active_connections:
  137. del _active_connections[trace_id]