websocket.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. """
  2. Trace WebSocket 推送
  3. 实时推送进行中 Trace 的更新,支持断线续传
  4. """
  5. from typing import Dict, Set, Any
  6. from datetime import datetime
  7. import asyncio
  8. from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
  9. from .protocols import TraceStore
  10. router = APIRouter(prefix="/api/traces", tags=["websocket"])
  11. # ===== 全局状态 =====
  12. _trace_store: TraceStore = None
  13. _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket]
  14. def set_trace_store(store: TraceStore):
  15. """设置 TraceStore 实例"""
  16. global _trace_store
  17. _trace_store = store
  18. def get_trace_store() -> TraceStore:
  19. """获取 TraceStore 实例"""
  20. if _trace_store is None:
  21. raise RuntimeError("TraceStore not initialized")
  22. return _trace_store
  23. # ===== WebSocket 路由 =====
  24. @router.websocket("/{trace_id}/watch")
  25. async def watch_trace(
  26. websocket: WebSocket,
  27. trace_id: str,
  28. since_event_id: int = Query(0, description="从哪个事件 ID 开始(0=补发所有历史)")
  29. ):
  30. """
  31. 监听 Trace 的更新,支持断线续传
  32. 事件类型:
  33. - connected: 连接成功,返回 goal_tree 和 sub_traces
  34. - goal_added: 新增 Goal
  35. - goal_updated: Goal 状态变化(含级联完成)
  36. - message_added: 新 Message(含 affected_goals)
  37. - sub_trace_started: Sub-Trace 开始执行
  38. - sub_trace_completed: Sub-Trace 完成
  39. - trace_completed: 执行完成
  40. Args:
  41. trace_id: Trace ID
  42. since_event_id: 从哪个事件 ID 开始
  43. - 0: 补发所有历史事件(初次连接)
  44. - >0: 补发指定 ID 之后的事件(断线重连)
  45. """
  46. await websocket.accept()
  47. # 验证 Trace 存在
  48. store = get_trace_store()
  49. trace = await store.get_trace(trace_id)
  50. if not trace:
  51. await websocket.send_json({
  52. "event": "error",
  53. "message": "Trace not found"
  54. })
  55. await websocket.close()
  56. return
  57. # 注册连接
  58. if trace_id not in _active_connections:
  59. _active_connections[trace_id] = set()
  60. _active_connections[trace_id].add(websocket)
  61. try:
  62. # 获取 GoalTree 和 Sub-Traces
  63. goal_tree = await store.get_goal_tree(trace_id)
  64. # 获取所有 Sub-Traces(通过 parent_trace_id 查询)
  65. sub_traces = {}
  66. all_traces = await store.list_traces(limit=1000)
  67. for t in all_traces:
  68. if t.parent_trace_id == trace_id:
  69. sub_traces[t.trace_id] = t.to_dict()
  70. # 发送连接成功消息 + 完整状态(含 trace 当前执行状态)
  71. from .run_api import _running_tasks # 避免循环导入,在函数内 import
  72. is_running = (
  73. trace_id in _running_tasks and not _running_tasks[trace_id].done()
  74. )
  75. await websocket.send_json({
  76. "event": "connected",
  77. "trace_id": trace_id,
  78. "current_event_id": trace.last_event_id,
  79. "trace_status": trace.status if not is_running else "running",
  80. "is_running": is_running,
  81. "goal_tree": goal_tree.to_dict() if goal_tree else None,
  82. "sub_traces": sub_traces
  83. })
  84. # 补发历史事件(since_event_id=0 表示补发所有历史)
  85. last_sent_event_id = since_event_id
  86. if since_event_id >= 0:
  87. missed_events = await store.get_events(trace_id, since_event_id)
  88. # 限制补发数量(最多 100 条)
  89. if len(missed_events) > 100:
  90. await websocket.send_json({
  91. "event": "error",
  92. "message": f"Too many missed events ({len(missed_events)}), please reload via REST API"
  93. })
  94. else:
  95. for evt in missed_events:
  96. await websocket.send_json(evt)
  97. if isinstance(evt, dict) and isinstance(evt.get("event_id"), int):
  98. last_sent_event_id = max(last_sent_event_id, evt["event_id"])
  99. # 保持连接:同时支持心跳 + 轮询 events.jsonl(跨进程写入时也能实时推送)
  100. while True:
  101. try:
  102. # 允许在没有客户端消息时继续轮询事件流
  103. data = await asyncio.wait_for(websocket.receive_text(), timeout=0.5)
  104. if data == "ping":
  105. await websocket.send_json({"event": "pong"})
  106. except WebSocketDisconnect:
  107. break
  108. except asyncio.TimeoutError:
  109. pass
  110. new_events = await store.get_events(trace_id, last_sent_event_id)
  111. if len(new_events) > 100:
  112. await websocket.send_json({
  113. "event": "error",
  114. "message": f"Too many missed events ({len(new_events)}), please reload via REST API"
  115. })
  116. continue
  117. for evt in new_events:
  118. await websocket.send_json(evt)
  119. if isinstance(evt, dict) and isinstance(evt.get("event_id"), int):
  120. last_sent_event_id = max(last_sent_event_id, evt["event_id"])
  121. finally:
  122. # 清理连接
  123. if trace_id in _active_connections:
  124. _active_connections[trace_id].discard(websocket)
  125. if not _active_connections[trace_id]:
  126. del _active_connections[trace_id]
  127. # ===== 广播函数(由 AgentRunner 或 TraceStore 调用)=====
  128. async def broadcast_goal_added(trace_id: str, goal_dict: Dict[str, Any]):
  129. """
  130. 广播 Goal 添加事件
  131. Args:
  132. trace_id: Trace ID
  133. goal_dict: Goal 字典(完整数据,含 stats)
  134. """
  135. if trace_id not in _active_connections:
  136. return
  137. store = get_trace_store()
  138. event_id = await store.append_event(trace_id, "goal_added", {
  139. "goal": goal_dict
  140. })
  141. message = {
  142. "event": "goal_added",
  143. "event_id": event_id,
  144. "ts": datetime.now().isoformat(),
  145. "goal": goal_dict
  146. }
  147. await _broadcast_to_trace(trace_id, message)
  148. async def broadcast_goal_updated(
  149. trace_id: str,
  150. goal_id: str,
  151. updates: Dict[str, Any],
  152. affected_goals: list[Dict[str, Any]] = None
  153. ):
  154. """
  155. 广播 Goal 更新事件(patch 语义)
  156. Args:
  157. trace_id: Trace ID
  158. goal_id: Goal ID
  159. updates: 更新字段(patch 格式)
  160. affected_goals: 受影响的 Goals(含级联完成的父节点)
  161. """
  162. if trace_id not in _active_connections:
  163. return
  164. store = get_trace_store()
  165. event_id = await store.append_event(trace_id, "goal_updated", {
  166. "goal_id": goal_id,
  167. "updates": updates,
  168. "affected_goals": affected_goals or []
  169. })
  170. message = {
  171. "event": "goal_updated",
  172. "event_id": event_id,
  173. "ts": datetime.now().isoformat(),
  174. "goal_id": goal_id,
  175. "patch": updates,
  176. "affected_goals": affected_goals or []
  177. }
  178. await _broadcast_to_trace(trace_id, message)
  179. async def broadcast_message_added(
  180. trace_id: str,
  181. event_id: int,
  182. message_dict: Dict[str, Any],
  183. affected_goals: list[Dict[str, Any]] = None,
  184. ):
  185. """
  186. 广播 Message 添加事件(不在此处写入 events.jsonl)
  187. 说明:
  188. - message_added 的 events.jsonl 写入由 TraceStore.append_event 负责
  189. - 这里仅负责把“已经持久化”的事件推送给当前活跃连接
  190. """
  191. if trace_id not in _active_connections:
  192. return
  193. message = {
  194. "event": "message_added",
  195. "event_id": event_id,
  196. "ts": datetime.now().isoformat(),
  197. "message": message_dict,
  198. "affected_goals": affected_goals or [],
  199. }
  200. await _broadcast_to_trace(trace_id, message)
  201. async def broadcast_sub_trace_started(
  202. trace_id: str,
  203. sub_trace_id: str,
  204. parent_goal_id: str,
  205. agent_type: str,
  206. task: str
  207. ):
  208. """
  209. 广播 Sub-Trace 开始事件
  210. Args:
  211. trace_id: 主 Trace ID
  212. sub_trace_id: Sub-Trace ID
  213. parent_goal_id: 父 Goal ID
  214. agent_type: Agent 类型
  215. task: 任务描述
  216. """
  217. if trace_id not in _active_connections:
  218. return
  219. store = get_trace_store()
  220. event_id = await store.append_event(trace_id, "sub_trace_started", {
  221. "trace_id": sub_trace_id,
  222. "parent_trace_id": trace_id,
  223. "parent_goal_id": parent_goal_id,
  224. "agent_type": agent_type,
  225. "task": task
  226. })
  227. message = {
  228. "event": "sub_trace_started",
  229. "event_id": event_id,
  230. "ts": datetime.now().isoformat(),
  231. "trace_id": sub_trace_id,
  232. "parent_goal_id": parent_goal_id,
  233. "agent_type": agent_type,
  234. "task": task
  235. }
  236. await _broadcast_to_trace(trace_id, message)
  237. async def broadcast_sub_trace_completed(
  238. trace_id: str,
  239. sub_trace_id: str,
  240. status: str,
  241. summary: str = "",
  242. stats: Dict[str, Any] = None
  243. ):
  244. """
  245. 广播 Sub-Trace 完成事件
  246. Args:
  247. trace_id: 主 Trace ID
  248. sub_trace_id: Sub-Trace ID
  249. status: 状态(completed/failed)
  250. summary: 总结
  251. stats: 统计信息
  252. """
  253. if trace_id not in _active_connections:
  254. return
  255. store = get_trace_store()
  256. event_id = await store.append_event(trace_id, "sub_trace_completed", {
  257. "trace_id": sub_trace_id,
  258. "status": status,
  259. "summary": summary,
  260. "stats": stats or {}
  261. })
  262. message = {
  263. "event": "sub_trace_completed",
  264. "event_id": event_id,
  265. "ts": datetime.now().isoformat(),
  266. "trace_id": sub_trace_id,
  267. "status": status,
  268. "summary": summary,
  269. "stats": stats or {}
  270. }
  271. await _broadcast_to_trace(trace_id, message)
  272. async def broadcast_trace_completed(trace_id: str, total_messages: int):
  273. """
  274. 广播 Trace 完成事件
  275. Args:
  276. trace_id: Trace ID
  277. total_messages: 总 Message 数
  278. """
  279. if trace_id not in _active_connections:
  280. return
  281. store = get_trace_store()
  282. event_id = await store.append_event(trace_id, "trace_completed", {
  283. "total_messages": total_messages
  284. })
  285. message = {
  286. "event": "trace_completed",
  287. "event_id": event_id,
  288. "ts": datetime.now().isoformat(),
  289. "trace_id": trace_id,
  290. "total_messages": total_messages
  291. }
  292. await _broadcast_to_trace(trace_id, message)
  293. # 完成后清理所有连接
  294. if trace_id in _active_connections:
  295. del _active_connections[trace_id]
  296. async def broadcast_trace_status_changed(trace_id: str, status: str):
  297. """
  298. 广播 Trace 状态变化事件(用于暂停/继续等状态切换)
  299. Args:
  300. trace_id: Trace ID
  301. status: 新状态 (running/stopped/completed/failed)
  302. """
  303. if trace_id not in _active_connections:
  304. return
  305. store = get_trace_store()
  306. trace = await store.get_trace(trace_id)
  307. if not trace:
  308. return
  309. event_id = await store.append_event(trace_id, "trace_status_changed", {
  310. "status": status
  311. })
  312. message = {
  313. "event": "trace_status_changed",
  314. "event_id": event_id,
  315. "ts": datetime.now().isoformat(),
  316. "trace_id": trace_id,
  317. "status": status,
  318. "total_cost": trace.total_cost,
  319. "total_messages": trace.total_messages
  320. }
  321. await _broadcast_to_trace(trace_id, message)
  322. # ===== 内部辅助函数 =====
  323. async def _broadcast_to_trace(trace_id: str, message: Dict[str, Any]):
  324. """
  325. 向指定 Trace 的所有连接广播消息
  326. Args:
  327. trace_id: Trace ID
  328. message: 消息内容
  329. """
  330. if trace_id not in _active_connections:
  331. return
  332. disconnected = []
  333. for websocket in _active_connections[trace_id]:
  334. try:
  335. await websocket.send_json(message)
  336. except Exception:
  337. disconnected.append(websocket)
  338. # 清理断开的连接
  339. for ws in disconnected:
  340. _active_connections[trace_id].discard(ws)