websocket.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. """
  2. Trace WebSocket 推送
  3. 实时推送进行中 Trace 的更新,支持断线续传
  4. """
  5. from typing import Dict, Set, Any
  6. from datetime import datetime
  7. from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
  8. from agent.execution.protocols import TraceStore
  9. router = APIRouter(prefix="/api/traces", tags=["websocket"])
  10. # ===== 全局状态 =====
  11. _trace_store: TraceStore = None
  12. _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket]
  13. def set_trace_store(store: TraceStore):
  14. """设置 TraceStore 实例"""
  15. global _trace_store
  16. _trace_store = store
  17. def get_trace_store() -> TraceStore:
  18. """获取 TraceStore 实例"""
  19. if _trace_store is None:
  20. raise RuntimeError("TraceStore not initialized")
  21. return _trace_store
  22. # ===== WebSocket 路由 =====
  23. @router.websocket("/{trace_id}/watch")
  24. async def watch_trace(
  25. websocket: WebSocket,
  26. trace_id: str,
  27. since_event_id: int = Query(0, description="从哪个事件 ID 开始(0=补发所有历史)")
  28. ):
  29. """
  30. 监听 Trace 的更新,支持断线续传
  31. 事件类型:
  32. - connected: 连接成功,返回 goal_tree 和 branches
  33. - goal_added: 新增 Goal
  34. - goal_updated: Goal 状态变化(含级联完成)
  35. - message_added: 新 Message(含 affected_goals)
  36. - branch_started: 分支开始探索
  37. - branch_goal_added: 分支内新增 Goal
  38. - branch_completed: 分支完成
  39. - explore_completed: 所有分支完成
  40. - trace_completed: 执行完成
  41. Args:
  42. trace_id: Trace ID
  43. since_event_id: 从哪个事件 ID 开始
  44. - 0: 补发所有历史事件(初次连接)
  45. - >0: 补发指定 ID 之后的事件(断线重连)
  46. """
  47. await websocket.accept()
  48. # 验证 Trace 存在
  49. store = get_trace_store()
  50. trace = await store.get_trace(trace_id)
  51. if not trace:
  52. await websocket.send_json({
  53. "event": "error",
  54. "message": "Trace not found"
  55. })
  56. await websocket.close()
  57. return
  58. # 注册连接
  59. if trace_id not in _active_connections:
  60. _active_connections[trace_id] = set()
  61. _active_connections[trace_id].add(websocket)
  62. try:
  63. # 获取 GoalTree 和分支元数据
  64. goal_tree = await store.get_goal_tree(trace_id)
  65. branches = await store.list_branches(trace_id)
  66. # 发送连接成功消息 + 完整状态
  67. await websocket.send_json({
  68. "event": "connected",
  69. "trace_id": trace_id,
  70. "current_event_id": trace.last_event_id,
  71. "goal_tree": goal_tree.to_dict() if goal_tree else None,
  72. "branches": {b_id: b.to_dict() for b_id, b in branches.items()}
  73. })
  74. # 补发历史事件(since_event_id=0 表示补发所有历史)
  75. if since_event_id >= 0:
  76. missed_events = await store.get_events(trace_id, since_event_id)
  77. # 限制补发数量(最多 100 条)
  78. if len(missed_events) > 100:
  79. await websocket.send_json({
  80. "event": "error",
  81. "message": f"Too many missed events ({len(missed_events)}), please reload via REST API"
  82. })
  83. else:
  84. for evt in missed_events:
  85. await websocket.send_json(evt)
  86. # 保持连接(等待客户端断开或接收消息)
  87. while True:
  88. try:
  89. # 接收客户端消息(心跳检测)
  90. data = await websocket.receive_text()
  91. if data == "ping":
  92. await websocket.send_json({"event": "pong"})
  93. except WebSocketDisconnect:
  94. break
  95. finally:
  96. # 清理连接
  97. if trace_id in _active_connections:
  98. _active_connections[trace_id].discard(websocket)
  99. if not _active_connections[trace_id]:
  100. del _active_connections[trace_id]
  101. # ===== 广播函数(由 AgentRunner 或 TraceStore 调用)=====
  102. async def broadcast_goal_added(trace_id: str, goal_dict: Dict[str, Any]):
  103. """
  104. 广播 Goal 添加事件
  105. Args:
  106. trace_id: Trace ID
  107. goal_dict: Goal 字典(完整数据,含 stats)
  108. """
  109. if trace_id not in _active_connections:
  110. return
  111. store = get_trace_store()
  112. event_id = await store.append_event(trace_id, "goal_added", {
  113. "goal": goal_dict
  114. })
  115. message = {
  116. "event": "goal_added",
  117. "event_id": event_id,
  118. "ts": datetime.now().isoformat(),
  119. "goal": goal_dict
  120. }
  121. await _broadcast_to_trace(trace_id, message)
  122. async def broadcast_goal_updated(
  123. trace_id: str,
  124. goal_id: str,
  125. updates: Dict[str, Any],
  126. affected_goals: list[Dict[str, Any]] = None
  127. ):
  128. """
  129. 广播 Goal 更新事件(patch 语义)
  130. Args:
  131. trace_id: Trace ID
  132. goal_id: Goal ID
  133. updates: 更新字段(patch 格式)
  134. affected_goals: 受影响的 Goals(含级联完成的父节点)
  135. """
  136. if trace_id not in _active_connections:
  137. return
  138. store = get_trace_store()
  139. event_id = await store.append_event(trace_id, "goal_updated", {
  140. "goal_id": goal_id,
  141. "updates": updates,
  142. "affected_goals": affected_goals or []
  143. })
  144. message = {
  145. "event": "goal_updated",
  146. "event_id": event_id,
  147. "ts": datetime.now().isoformat(),
  148. "goal_id": goal_id,
  149. "patch": updates,
  150. "affected_goals": affected_goals or []
  151. }
  152. await _broadcast_to_trace(trace_id, message)
  153. async def broadcast_branch_started(trace_id: str, branch_dict: Dict[str, Any]):
  154. """
  155. 广播分支开始事件
  156. Args:
  157. trace_id: Trace ID
  158. branch_dict: BranchContext 字典
  159. """
  160. if trace_id not in _active_connections:
  161. return
  162. store = get_trace_store()
  163. event_id = await store.append_event(trace_id, "branch_started", {
  164. "branch": branch_dict
  165. })
  166. message = {
  167. "event": "branch_started",
  168. "event_id": event_id,
  169. "ts": datetime.now().isoformat(),
  170. "branch": branch_dict
  171. }
  172. await _broadcast_to_trace(trace_id, message)
  173. async def broadcast_branch_goal_added(
  174. trace_id: str,
  175. branch_id: str,
  176. goal_dict: Dict[str, Any]
  177. ):
  178. """
  179. 广播分支内新增 Goal
  180. Args:
  181. trace_id: Trace ID
  182. branch_id: 分支 ID
  183. goal_dict: Goal 字典
  184. """
  185. if trace_id not in _active_connections:
  186. return
  187. store = get_trace_store()
  188. event_id = await store.append_event(trace_id, "branch_goal_added", {
  189. "branch_id": branch_id,
  190. "goal": goal_dict
  191. })
  192. message = {
  193. "event": "branch_goal_added",
  194. "event_id": event_id,
  195. "ts": datetime.now().isoformat(),
  196. "branch_id": branch_id,
  197. "goal": goal_dict
  198. }
  199. await _broadcast_to_trace(trace_id, message)
  200. async def broadcast_branch_completed(
  201. trace_id: str,
  202. branch_id: str,
  203. summary: str,
  204. cumulative_stats: Dict[str, Any]
  205. ):
  206. """
  207. 广播分支完成事件
  208. Args:
  209. trace_id: Trace ID
  210. branch_id: 分支 ID
  211. summary: 分支总结
  212. cumulative_stats: 分支累计统计
  213. """
  214. if trace_id not in _active_connections:
  215. return
  216. store = get_trace_store()
  217. event_id = await store.append_event(trace_id, "branch_completed", {
  218. "branch_id": branch_id,
  219. "summary": summary,
  220. "cumulative_stats": cumulative_stats
  221. })
  222. message = {
  223. "event": "branch_completed",
  224. "event_id": event_id,
  225. "ts": datetime.now().isoformat(),
  226. "branch_id": branch_id,
  227. "summary": summary,
  228. "cumulative_stats": cumulative_stats
  229. }
  230. await _broadcast_to_trace(trace_id, message)
  231. async def broadcast_explore_completed(
  232. trace_id: str,
  233. explore_start_id: str,
  234. merge_summary: str
  235. ):
  236. """
  237. 广播探索完成事件
  238. Args:
  239. trace_id: Trace ID
  240. explore_start_id: explore_start Goal ID
  241. merge_summary: 汇总结果
  242. """
  243. if trace_id not in _active_connections:
  244. return
  245. store = get_trace_store()
  246. event_id = await store.append_event(trace_id, "explore_completed", {
  247. "explore_start_id": explore_start_id,
  248. "merge_summary": merge_summary
  249. })
  250. message = {
  251. "event": "explore_completed",
  252. "event_id": event_id,
  253. "ts": datetime.now().isoformat(),
  254. "explore_start_id": explore_start_id,
  255. "merge_summary": merge_summary
  256. }
  257. await _broadcast_to_trace(trace_id, message)
  258. async def broadcast_trace_completed(trace_id: str, total_messages: int):
  259. """
  260. 广播 Trace 完成事件
  261. Args:
  262. trace_id: Trace ID
  263. total_messages: 总 Message 数
  264. """
  265. if trace_id not in _active_connections:
  266. return
  267. store = get_trace_store()
  268. event_id = await store.append_event(trace_id, "trace_completed", {
  269. "total_messages": total_messages
  270. })
  271. message = {
  272. "event": "trace_completed",
  273. "event_id": event_id,
  274. "ts": datetime.now().isoformat(),
  275. "trace_id": trace_id,
  276. "total_messages": total_messages
  277. }
  278. await _broadcast_to_trace(trace_id, message)
  279. # 完成后清理所有连接
  280. if trace_id in _active_connections:
  281. del _active_connections[trace_id]
  282. # ===== 内部辅助函数 =====
  283. async def _broadcast_to_trace(trace_id: str, message: Dict[str, Any]):
  284. """
  285. 向指定 Trace 的所有连接广播消息
  286. Args:
  287. trace_id: Trace ID
  288. message: 消息内容
  289. """
  290. if trace_id not in _active_connections:
  291. return
  292. disconnected = []
  293. for websocket in _active_connections[trace_id]:
  294. try:
  295. await websocket.send_json(message)
  296. except Exception:
  297. disconnected.append(websocket)
  298. # 清理断开的连接
  299. for ws in disconnected:
  300. _active_connections[trace_id].discard(ws)