websocket.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. """
  2. Trace WebSocket 推送
  3. 实时推送进行中 Trace 的更新,支持断线续传
  4. """
  5. from typing import Dict, Set, Any
  6. from datetime import datetime
  7. from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
  8. from agent.execution.protocols import TraceStore
  9. router = APIRouter(prefix="/api/traces", tags=["websocket"])
  10. # ===== 全局状态 =====
  11. _trace_store: TraceStore = None
  12. _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket]
  13. def set_trace_store(store: TraceStore):
  14. """设置 TraceStore 实例"""
  15. global _trace_store
  16. _trace_store = store
  17. def get_trace_store() -> TraceStore:
  18. """获取 TraceStore 实例"""
  19. if _trace_store is None:
  20. raise RuntimeError("TraceStore not initialized")
  21. return _trace_store
  22. # ===== WebSocket 路由 =====
  23. @router.websocket("/{trace_id}/watch")
  24. async def watch_trace(
  25. websocket: WebSocket,
  26. trace_id: str,
  27. since_event_id: int = Query(0, description="从哪个事件 ID 开始(0=补发所有历史)")
  28. ):
  29. """
  30. 监听 Trace 的更新,支持断线续传
  31. 事件类型:
  32. - connected: 连接成功,返回 goal_tree 和 sub_traces
  33. - goal_added: 新增 Goal
  34. - goal_updated: Goal 状态变化(含级联完成)
  35. - message_added: 新 Message(含 affected_goals)
  36. - sub_trace_started: Sub-Trace 开始执行
  37. - sub_trace_completed: Sub-Trace 完成
  38. - trace_completed: 执行完成
  39. Args:
  40. trace_id: Trace ID
  41. since_event_id: 从哪个事件 ID 开始
  42. - 0: 补发所有历史事件(初次连接)
  43. - >0: 补发指定 ID 之后的事件(断线重连)
  44. """
  45. await websocket.accept()
  46. # 验证 Trace 存在
  47. store = get_trace_store()
  48. trace = await store.get_trace(trace_id)
  49. if not trace:
  50. await websocket.send_json({
  51. "event": "error",
  52. "message": "Trace not found"
  53. })
  54. await websocket.close()
  55. return
  56. # 注册连接
  57. if trace_id not in _active_connections:
  58. _active_connections[trace_id] = set()
  59. _active_connections[trace_id].add(websocket)
  60. try:
  61. # 获取 GoalTree 和 Sub-Traces
  62. goal_tree = await store.get_goal_tree(trace_id)
  63. # 获取所有 Sub-Traces(通过 parent_trace_id 查询)
  64. sub_traces = {}
  65. all_traces = await store.list_traces(limit=1000)
  66. for t in all_traces:
  67. if t.parent_trace_id == trace_id:
  68. sub_traces[t.trace_id] = t.to_dict()
  69. # 发送连接成功消息 + 完整状态
  70. await websocket.send_json({
  71. "event": "connected",
  72. "trace_id": trace_id,
  73. "current_event_id": trace.last_event_id,
  74. "goal_tree": goal_tree.to_dict() if goal_tree else None,
  75. "sub_traces": sub_traces
  76. })
  77. # 补发历史事件(since_event_id=0 表示补发所有历史)
  78. if since_event_id >= 0:
  79. missed_events = await store.get_events(trace_id, since_event_id)
  80. # 限制补发数量(最多 100 条)
  81. if len(missed_events) > 100:
  82. await websocket.send_json({
  83. "event": "error",
  84. "message": f"Too many missed events ({len(missed_events)}), please reload via REST API"
  85. })
  86. else:
  87. for evt in missed_events:
  88. await websocket.send_json(evt)
  89. # 保持连接(等待客户端断开或接收消息)
  90. while True:
  91. try:
  92. # 接收客户端消息(心跳检测)
  93. data = await websocket.receive_text()
  94. if data == "ping":
  95. await websocket.send_json({"event": "pong"})
  96. except WebSocketDisconnect:
  97. break
  98. finally:
  99. # 清理连接
  100. if trace_id in _active_connections:
  101. _active_connections[trace_id].discard(websocket)
  102. if not _active_connections[trace_id]:
  103. del _active_connections[trace_id]
  104. # ===== 广播函数(由 AgentRunner 或 TraceStore 调用)=====
  105. async def broadcast_goal_added(trace_id: str, goal_dict: Dict[str, Any]):
  106. """
  107. 广播 Goal 添加事件
  108. Args:
  109. trace_id: Trace ID
  110. goal_dict: Goal 字典(完整数据,含 stats)
  111. """
  112. if trace_id not in _active_connections:
  113. return
  114. store = get_trace_store()
  115. event_id = await store.append_event(trace_id, "goal_added", {
  116. "goal": goal_dict
  117. })
  118. message = {
  119. "event": "goal_added",
  120. "event_id": event_id,
  121. "ts": datetime.now().isoformat(),
  122. "goal": goal_dict
  123. }
  124. await _broadcast_to_trace(trace_id, message)
  125. async def broadcast_goal_updated(
  126. trace_id: str,
  127. goal_id: str,
  128. updates: Dict[str, Any],
  129. affected_goals: list[Dict[str, Any]] = None
  130. ):
  131. """
  132. 广播 Goal 更新事件(patch 语义)
  133. Args:
  134. trace_id: Trace ID
  135. goal_id: Goal ID
  136. updates: 更新字段(patch 格式)
  137. affected_goals: 受影响的 Goals(含级联完成的父节点)
  138. """
  139. if trace_id not in _active_connections:
  140. return
  141. store = get_trace_store()
  142. event_id = await store.append_event(trace_id, "goal_updated", {
  143. "goal_id": goal_id,
  144. "updates": updates,
  145. "affected_goals": affected_goals or []
  146. })
  147. message = {
  148. "event": "goal_updated",
  149. "event_id": event_id,
  150. "ts": datetime.now().isoformat(),
  151. "goal_id": goal_id,
  152. "patch": updates,
  153. "affected_goals": affected_goals or []
  154. }
  155. await _broadcast_to_trace(trace_id, message)
  156. async def broadcast_sub_trace_started(
  157. trace_id: str,
  158. sub_trace_id: str,
  159. parent_goal_id: str,
  160. agent_type: str,
  161. task: str
  162. ):
  163. """
  164. 广播 Sub-Trace 开始事件
  165. Args:
  166. trace_id: 主 Trace ID
  167. sub_trace_id: Sub-Trace ID
  168. parent_goal_id: 父 Goal ID
  169. agent_type: Agent 类型
  170. task: 任务描述
  171. """
  172. if trace_id not in _active_connections:
  173. return
  174. store = get_trace_store()
  175. event_id = await store.append_event(trace_id, "sub_trace_started", {
  176. "trace_id": sub_trace_id,
  177. "parent_trace_id": trace_id,
  178. "parent_goal_id": parent_goal_id,
  179. "agent_type": agent_type,
  180. "task": task
  181. })
  182. message = {
  183. "event": "sub_trace_started",
  184. "event_id": event_id,
  185. "ts": datetime.now().isoformat(),
  186. "trace_id": sub_trace_id,
  187. "parent_goal_id": parent_goal_id,
  188. "agent_type": agent_type,
  189. "task": task
  190. }
  191. await _broadcast_to_trace(trace_id, message)
  192. async def broadcast_sub_trace_completed(
  193. trace_id: str,
  194. sub_trace_id: str,
  195. status: str,
  196. summary: str = "",
  197. stats: Dict[str, Any] = None
  198. ):
  199. """
  200. 广播 Sub-Trace 完成事件
  201. Args:
  202. trace_id: 主 Trace ID
  203. sub_trace_id: Sub-Trace ID
  204. status: 状态(completed/failed)
  205. summary: 总结
  206. stats: 统计信息
  207. """
  208. if trace_id not in _active_connections:
  209. return
  210. store = get_trace_store()
  211. event_id = await store.append_event(trace_id, "sub_trace_completed", {
  212. "trace_id": sub_trace_id,
  213. "status": status,
  214. "summary": summary,
  215. "stats": stats or {}
  216. })
  217. message = {
  218. "event": "sub_trace_completed",
  219. "event_id": event_id,
  220. "ts": datetime.now().isoformat(),
  221. "trace_id": sub_trace_id,
  222. "status": status,
  223. "summary": summary,
  224. "stats": stats or {}
  225. }
  226. await _broadcast_to_trace(trace_id, message)
  227. async def broadcast_trace_completed(trace_id: str, total_messages: int):
  228. """
  229. 广播 Trace 完成事件
  230. Args:
  231. trace_id: Trace ID
  232. total_messages: 总 Message 数
  233. """
  234. if trace_id not in _active_connections:
  235. return
  236. store = get_trace_store()
  237. event_id = await store.append_event(trace_id, "trace_completed", {
  238. "total_messages": total_messages
  239. })
  240. message = {
  241. "event": "trace_completed",
  242. "event_id": event_id,
  243. "ts": datetime.now().isoformat(),
  244. "trace_id": trace_id,
  245. "total_messages": total_messages
  246. }
  247. await _broadcast_to_trace(trace_id, message)
  248. # 完成后清理所有连接
  249. if trace_id in _active_connections:
  250. del _active_connections[trace_id]
  251. # ===== 内部辅助函数 =====
  252. async def _broadcast_to_trace(trace_id: str, message: Dict[str, Any]):
  253. """
  254. 向指定 Trace 的所有连接广播消息
  255. Args:
  256. trace_id: Trace ID
  257. message: 消息内容
  258. """
  259. if trace_id not in _active_connections:
  260. return
  261. disconnected = []
  262. for websocket in _active_connections[trace_id]:
  263. try:
  264. await websocket.send_json(message)
  265. except Exception:
  266. disconnected.append(websocket)
  267. # 清理断开的连接
  268. for ws in disconnected:
  269. _active_connections[trace_id].discard(ws)