websocket.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. """
  2. Trace WebSocket 推送
  3. 实时推送进行中 Trace 的更新,支持断线续传
  4. """
  5. from typing import Dict, Set, Any
  6. from datetime import datetime
  7. import asyncio
  8. from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
  9. from .protocols import TraceStore
  10. router = APIRouter(prefix="/api/traces", tags=["websocket"])
  11. # ===== 全局状态 =====
  12. _trace_store: TraceStore = None
  13. _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket]
  14. def set_trace_store(store: TraceStore):
  15. """设置 TraceStore 实例"""
  16. global _trace_store
  17. _trace_store = store
  18. def get_trace_store() -> TraceStore:
  19. """获取 TraceStore 实例"""
  20. if _trace_store is None:
  21. raise RuntimeError("TraceStore not initialized")
  22. return _trace_store
  23. # ===== WebSocket 路由 =====
  24. @router.websocket("/{trace_id}/watch")
  25. async def watch_trace(
  26. websocket: WebSocket,
  27. trace_id: str,
  28. since_event_id: int = Query(0, description="从哪个事件 ID 开始(0=补发所有历史)")
  29. ):
  30. """
  31. 监听 Trace 的更新,支持断线续传
  32. 事件类型:
  33. - connected: 连接成功,返回 goal_tree 和 sub_traces
  34. - goal_added: 新增 Goal
  35. - goal_updated: Goal 状态变化(含级联完成)
  36. - message_added: 新 Message(含 affected_goals)
  37. - sub_trace_started: Sub-Trace 开始执行
  38. - sub_trace_completed: Sub-Trace 完成
  39. - trace_completed: 执行完成
  40. Args:
  41. trace_id: Trace ID
  42. since_event_id: 从哪个事件 ID 开始
  43. - 0: 补发所有历史事件(初次连接)
  44. - >0: 补发指定 ID 之后的事件(断线重连)
  45. """
  46. await websocket.accept()
  47. # 验证 Trace 存在
  48. store = get_trace_store()
  49. trace = await store.get_trace(trace_id)
  50. if not trace:
  51. await websocket.send_json({
  52. "event": "error",
  53. "message": "Trace not found"
  54. })
  55. await websocket.close()
  56. return
  57. # 注册连接
  58. if trace_id not in _active_connections:
  59. _active_connections[trace_id] = set()
  60. _active_connections[trace_id].add(websocket)
  61. try:
  62. # 获取 GoalTree 和 Sub-Traces
  63. goal_tree = await store.get_goal_tree(trace_id)
  64. # 获取所有 Sub-Traces(通过 parent_trace_id 查询)
  65. sub_traces = {}
  66. all_traces = await store.list_traces(limit=1000)
  67. for t in all_traces:
  68. if t.parent_trace_id == trace_id:
  69. sub_traces[t.trace_id] = t.to_dict()
  70. # 发送连接成功消息 + 完整状态(含 trace 当前执行状态)
  71. from .run_api import _running_tasks # 避免循环导入,在函数内 import
  72. is_running = (
  73. trace_id in _running_tasks and not _running_tasks[trace_id].done()
  74. )
  75. await websocket.send_json({
  76. "event": "connected",
  77. "trace_id": trace_id,
  78. "current_event_id": trace.last_event_id,
  79. "trace_status": trace.status if not is_running else "running",
  80. "is_running": is_running,
  81. "goal_tree": goal_tree.to_dict() if goal_tree else None,
  82. "sub_traces": sub_traces
  83. })
  84. # 补发历史事件(since_event_id=0 表示补发所有历史)
  85. last_sent_event_id = since_event_id
  86. if since_event_id >= 0:
  87. missed_events = await store.get_events(trace_id, since_event_id)
  88. # 限制补发数量(最多 100 条)
  89. if len(missed_events) > 100:
  90. await websocket.send_json({
  91. "event": "error",
  92. "message": f"Too many missed events ({len(missed_events)}), please reload via REST API"
  93. })
  94. else:
  95. for evt in missed_events:
  96. await websocket.send_json(evt)
  97. if isinstance(evt, dict) and isinstance(evt.get("event_id"), int):
  98. last_sent_event_id = max(last_sent_event_id, evt["event_id"])
  99. # 保持连接:同时支持心跳 + 轮询 events.jsonl(跨进程写入时也能实时推送)
  100. while True:
  101. try:
  102. # 允许在没有客户端消息时继续轮询事件流
  103. data = await asyncio.wait_for(websocket.receive_text(), timeout=0.5)
  104. if data == "ping":
  105. await websocket.send_json({"event": "pong"})
  106. except WebSocketDisconnect:
  107. break
  108. except asyncio.TimeoutError:
  109. pass
  110. try:
  111. new_events = await store.get_events(trace_id, last_sent_event_id)
  112. if len(new_events) > 100:
  113. await websocket.send_json({
  114. "event": "error",
  115. "message": f"Too many missed events ({len(new_events)}), please reload via REST API"
  116. })
  117. continue
  118. for evt in new_events:
  119. await websocket.send_json(evt)
  120. if isinstance(evt, dict) and isinstance(evt.get("event_id"), int):
  121. last_sent_event_id = max(last_sent_event_id, evt["event_id"])
  122. except WebSocketDisconnect:
  123. break
  124. finally:
  125. # 清理连接
  126. if trace_id in _active_connections:
  127. _active_connections[trace_id].discard(websocket)
  128. if not _active_connections[trace_id]:
  129. del _active_connections[trace_id]
  130. # ===== 广播函数(由 AgentRunner 或 TraceStore 调用)=====
  131. async def broadcast_goal_added(trace_id: str, goal_dict: Dict[str, Any]):
  132. """
  133. 广播 Goal 添加事件
  134. Args:
  135. trace_id: Trace ID
  136. goal_dict: Goal 字典(完整数据,含 stats)
  137. """
  138. if trace_id not in _active_connections:
  139. return
  140. store = get_trace_store()
  141. event_id = await store.append_event(trace_id, "goal_added", {
  142. "goal": goal_dict
  143. })
  144. message = {
  145. "event": "goal_added",
  146. "event_id": event_id,
  147. "ts": datetime.now().isoformat(),
  148. "goal": goal_dict
  149. }
  150. await _broadcast_to_trace(trace_id, message)
  151. async def broadcast_goal_updated(
  152. trace_id: str,
  153. goal_id: str,
  154. updates: Dict[str, Any],
  155. affected_goals: list[Dict[str, Any]] = None
  156. ):
  157. """
  158. 广播 Goal 更新事件(patch 语义)
  159. Args:
  160. trace_id: Trace ID
  161. goal_id: Goal ID
  162. updates: 更新字段(patch 格式)
  163. affected_goals: 受影响的 Goals(含级联完成的父节点)
  164. """
  165. if trace_id not in _active_connections:
  166. return
  167. store = get_trace_store()
  168. event_id = await store.append_event(trace_id, "goal_updated", {
  169. "goal_id": goal_id,
  170. "updates": updates,
  171. "affected_goals": affected_goals or []
  172. })
  173. message = {
  174. "event": "goal_updated",
  175. "event_id": event_id,
  176. "ts": datetime.now().isoformat(),
  177. "goal_id": goal_id,
  178. "patch": updates,
  179. "affected_goals": affected_goals or []
  180. }
  181. await _broadcast_to_trace(trace_id, message)
  182. async def broadcast_message_added(
  183. trace_id: str,
  184. event_id: int,
  185. message_dict: Dict[str, Any],
  186. affected_goals: list[Dict[str, Any]] = None,
  187. ):
  188. """
  189. 广播 Message 添加事件(不在此处写入 events.jsonl)
  190. 说明:
  191. - message_added 的 events.jsonl 写入由 TraceStore.append_event 负责
  192. - 这里仅负责把“已经持久化”的事件推送给当前活跃连接
  193. """
  194. if trace_id not in _active_connections:
  195. return
  196. message = {
  197. "event": "message_added",
  198. "event_id": event_id,
  199. "ts": datetime.now().isoformat(),
  200. "message": message_dict,
  201. "affected_goals": affected_goals or [],
  202. }
  203. await _broadcast_to_trace(trace_id, message)
  204. async def broadcast_sub_trace_started(
  205. trace_id: str,
  206. sub_trace_id: str,
  207. parent_goal_id: str,
  208. agent_type: str,
  209. task: str
  210. ):
  211. """
  212. 广播 Sub-Trace 开始事件
  213. Args:
  214. trace_id: 主 Trace ID
  215. sub_trace_id: Sub-Trace ID
  216. parent_goal_id: 父 Goal ID
  217. agent_type: Agent 类型
  218. task: 任务描述
  219. """
  220. if trace_id not in _active_connections:
  221. return
  222. store = get_trace_store()
  223. event_id = await store.append_event(trace_id, "sub_trace_started", {
  224. "trace_id": sub_trace_id,
  225. "parent_trace_id": trace_id,
  226. "parent_goal_id": parent_goal_id,
  227. "agent_type": agent_type,
  228. "task": task
  229. })
  230. message = {
  231. "event": "sub_trace_started",
  232. "event_id": event_id,
  233. "ts": datetime.now().isoformat(),
  234. "trace_id": sub_trace_id,
  235. "parent_goal_id": parent_goal_id,
  236. "agent_type": agent_type,
  237. "task": task
  238. }
  239. await _broadcast_to_trace(trace_id, message)
  240. async def broadcast_sub_trace_completed(
  241. trace_id: str,
  242. sub_trace_id: str,
  243. status: str,
  244. summary: str = "",
  245. stats: Dict[str, Any] = None
  246. ):
  247. """
  248. 广播 Sub-Trace 完成事件
  249. Args:
  250. trace_id: 主 Trace ID
  251. sub_trace_id: Sub-Trace ID
  252. status: 状态(completed/failed)
  253. summary: 总结
  254. stats: 统计信息
  255. """
  256. if trace_id not in _active_connections:
  257. return
  258. store = get_trace_store()
  259. event_id = await store.append_event(trace_id, "sub_trace_completed", {
  260. "trace_id": sub_trace_id,
  261. "status": status,
  262. "summary": summary,
  263. "stats": stats or {}
  264. })
  265. message = {
  266. "event": "sub_trace_completed",
  267. "event_id": event_id,
  268. "ts": datetime.now().isoformat(),
  269. "trace_id": sub_trace_id,
  270. "status": status,
  271. "summary": summary,
  272. "stats": stats or {}
  273. }
  274. await _broadcast_to_trace(trace_id, message)
  275. async def broadcast_trace_completed(trace_id: str, total_messages: int):
  276. """
  277. 广播 Trace 完成事件
  278. Args:
  279. trace_id: Trace ID
  280. total_messages: 总 Message 数
  281. """
  282. if trace_id not in _active_connections:
  283. return
  284. store = get_trace_store()
  285. event_id = await store.append_event(trace_id, "trace_completed", {
  286. "total_messages": total_messages
  287. })
  288. message = {
  289. "event": "trace_completed",
  290. "event_id": event_id,
  291. "ts": datetime.now().isoformat(),
  292. "trace_id": trace_id,
  293. "total_messages": total_messages
  294. }
  295. await _broadcast_to_trace(trace_id, message)
  296. # 完成后清理所有连接
  297. if trace_id in _active_connections:
  298. del _active_connections[trace_id]
  299. async def broadcast_trace_status_changed(trace_id: str, status: str):
  300. """
  301. 广播 Trace 状态变化事件(用于暂停/继续等状态切换)
  302. Args:
  303. trace_id: Trace ID
  304. status: 新状态 (running/stopped/completed/failed)
  305. """
  306. if trace_id not in _active_connections:
  307. return
  308. store = get_trace_store()
  309. trace = await store.get_trace(trace_id)
  310. if not trace:
  311. return
  312. event_id = await store.append_event(trace_id, "trace_status_changed", {
  313. "status": status
  314. })
  315. message = {
  316. "event": "trace_status_changed",
  317. "event_id": event_id,
  318. "ts": datetime.now().isoformat(),
  319. "trace_id": trace_id,
  320. "status": status,
  321. "total_cost": trace.total_cost,
  322. "total_messages": trace.total_messages
  323. }
  324. await _broadcast_to_trace(trace_id, message)
  325. # ===== 内部辅助函数 =====
  326. async def _broadcast_to_trace(trace_id: str, message: Dict[str, Any]):
  327. """
  328. 向指定 Trace 的所有连接广播消息
  329. Args:
  330. trace_id: Trace ID
  331. message: 消息内容
  332. """
  333. if trace_id not in _active_connections:
  334. return
  335. disconnected = []
  336. for websocket in _active_connections[trace_id]:
  337. try:
  338. await websocket.send_json(message)
  339. except Exception:
  340. disconnected.append(websocket)
  341. # 清理断开的连接
  342. for ws in disconnected:
  343. _active_connections[trace_id].discard(ws)