| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424 |
- """
- Trace WebSocket 推送
- 实时推送进行中 Trace 的更新,支持断线续传
- """
- from typing import Dict, Set, Any
- from datetime import datetime
- import asyncio
- from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
- from .protocols import TraceStore
- router = APIRouter(prefix="/api/traces", tags=["websocket"])
- # ===== 全局状态 =====
- _trace_store: TraceStore = None
- _active_connections: Dict[str, Set[WebSocket]] = {} # trace_id -> Set[WebSocket]
- def set_trace_store(store: TraceStore):
- """设置 TraceStore 实例"""
- global _trace_store
- _trace_store = store
- def get_trace_store() -> TraceStore:
- """获取 TraceStore 实例"""
- if _trace_store is None:
- raise RuntimeError("TraceStore not initialized")
- return _trace_store
- # ===== WebSocket 路由 =====
- @router.websocket("/{trace_id}/watch")
- async def watch_trace(
- websocket: WebSocket,
- trace_id: str,
- since_event_id: int = Query(0, description="从哪个事件 ID 开始(0=补发所有历史)")
- ):
- """
- 监听 Trace 的更新,支持断线续传
- 事件类型:
- - connected: 连接成功,返回 goal_tree 和 sub_traces
- - goal_added: 新增 Goal
- - goal_updated: Goal 状态变化(含级联完成)
- - message_added: 新 Message(含 affected_goals)
- - sub_trace_started: Sub-Trace 开始执行
- - sub_trace_completed: Sub-Trace 完成
- - trace_completed: 执行完成
- Args:
- trace_id: Trace ID
- since_event_id: 从哪个事件 ID 开始
- - 0: 补发所有历史事件(初次连接)
- - >0: 补发指定 ID 之后的事件(断线重连)
- """
- await websocket.accept()
- # 验证 Trace 存在
- store = get_trace_store()
- trace = await store.get_trace(trace_id)
- if not trace:
- await websocket.send_json({
- "event": "error",
- "message": "Trace not found"
- })
- await websocket.close()
- return
- # 注册连接
- if trace_id not in _active_connections:
- _active_connections[trace_id] = set()
- _active_connections[trace_id].add(websocket)
- try:
- # 获取 GoalTree 和 Sub-Traces
- goal_tree = await store.get_goal_tree(trace_id)
- # 获取所有 Sub-Traces(通过 parent_trace_id 查询)
- sub_traces = {}
- all_traces = await store.list_traces(limit=1000)
- for t in all_traces:
- if t.parent_trace_id == trace_id:
- sub_traces[t.trace_id] = t.to_dict()
- # 发送连接成功消息 + 完整状态(含 trace 当前执行状态)
- from .run_api import _running_tasks # 避免循环导入,在函数内 import
- is_running = (
- trace_id in _running_tasks and not _running_tasks[trace_id].done()
- )
- await websocket.send_json({
- "event": "connected",
- "trace_id": trace_id,
- "current_event_id": trace.last_event_id,
- "trace_status": trace.status if not is_running else "running",
- "is_running": is_running,
- "goal_tree": goal_tree.to_dict() if goal_tree else None,
- "sub_traces": sub_traces
- })
- # 补发历史事件(since_event_id=0 表示补发所有历史)
- last_sent_event_id = since_event_id
- if since_event_id >= 0:
- missed_events = await store.get_events(trace_id, since_event_id)
- # 限制补发数量(最多 100 条)
- if len(missed_events) > 100:
- await websocket.send_json({
- "event": "error",
- "message": f"Too many missed events ({len(missed_events)}), please reload via REST API"
- })
- else:
- for evt in missed_events:
- await websocket.send_json(evt)
- if isinstance(evt, dict) and isinstance(evt.get("event_id"), int):
- last_sent_event_id = max(last_sent_event_id, evt["event_id"])
- # 保持连接:同时支持心跳 + 轮询 events.jsonl(跨进程写入时也能实时推送)
- while True:
- try:
- # 允许在没有客户端消息时继续轮询事件流
- data = await asyncio.wait_for(websocket.receive_text(), timeout=0.5)
- if data == "ping":
- await websocket.send_json({"event": "pong"})
- except WebSocketDisconnect:
- break
- except asyncio.TimeoutError:
- pass
- try:
- new_events = await store.get_events(trace_id, last_sent_event_id)
- if len(new_events) > 100:
- await websocket.send_json({
- "event": "error",
- "message": f"Too many missed events ({len(new_events)}), please reload via REST API"
- })
- continue
- for evt in new_events:
- await websocket.send_json(evt)
- if isinstance(evt, dict) and isinstance(evt.get("event_id"), int):
- last_sent_event_id = max(last_sent_event_id, evt["event_id"])
- except WebSocketDisconnect:
- break
- finally:
- # 清理连接
- if trace_id in _active_connections:
- _active_connections[trace_id].discard(websocket)
- if not _active_connections[trace_id]:
- del _active_connections[trace_id]
- # ===== 广播函数(由 AgentRunner 或 TraceStore 调用)=====
- async def broadcast_goal_added(trace_id: str, goal_dict: Dict[str, Any]):
- """
- 广播 Goal 添加事件
- Args:
- trace_id: Trace ID
- goal_dict: Goal 字典(完整数据,含 stats)
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "goal_added", {
- "goal": goal_dict
- })
- message = {
- "event": "goal_added",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "goal": goal_dict
- }
- await _broadcast_to_trace(trace_id, message)
- async def broadcast_goal_updated(
- trace_id: str,
- goal_id: str,
- updates: Dict[str, Any],
- affected_goals: list[Dict[str, Any]] = None
- ):
- """
- 广播 Goal 更新事件(patch 语义)
- Args:
- trace_id: Trace ID
- goal_id: Goal ID
- updates: 更新字段(patch 格式)
- affected_goals: 受影响的 Goals(含级联完成的父节点)
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "goal_updated", {
- "goal_id": goal_id,
- "updates": updates,
- "affected_goals": affected_goals or []
- })
- message = {
- "event": "goal_updated",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "goal_id": goal_id,
- "patch": updates,
- "affected_goals": affected_goals or []
- }
- await _broadcast_to_trace(trace_id, message)
- async def broadcast_message_added(
- trace_id: str,
- event_id: int,
- message_dict: Dict[str, Any],
- affected_goals: list[Dict[str, Any]] = None,
- ):
- """
- 广播 Message 添加事件(不在此处写入 events.jsonl)
- 说明:
- - message_added 的 events.jsonl 写入由 TraceStore.append_event 负责
- - 这里仅负责把“已经持久化”的事件推送给当前活跃连接
- """
- if trace_id not in _active_connections:
- return
- message = {
- "event": "message_added",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "message": message_dict,
- "affected_goals": affected_goals or [],
- }
- await _broadcast_to_trace(trace_id, message)
- async def broadcast_sub_trace_started(
- trace_id: str,
- sub_trace_id: str,
- parent_goal_id: str,
- agent_type: str,
- task: str
- ):
- """
- 广播 Sub-Trace 开始事件
- Args:
- trace_id: 主 Trace ID
- sub_trace_id: Sub-Trace ID
- parent_goal_id: 父 Goal ID
- agent_type: Agent 类型
- task: 任务描述
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "sub_trace_started", {
- "trace_id": sub_trace_id,
- "parent_trace_id": trace_id,
- "parent_goal_id": parent_goal_id,
- "agent_type": agent_type,
- "task": task
- })
- message = {
- "event": "sub_trace_started",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "trace_id": sub_trace_id,
- "parent_goal_id": parent_goal_id,
- "agent_type": agent_type,
- "task": task
- }
- await _broadcast_to_trace(trace_id, message)
- async def broadcast_sub_trace_completed(
- trace_id: str,
- sub_trace_id: str,
- status: str,
- summary: str = "",
- stats: Dict[str, Any] = None
- ):
- """
- 广播 Sub-Trace 完成事件
- Args:
- trace_id: 主 Trace ID
- sub_trace_id: Sub-Trace ID
- status: 状态(completed/failed)
- summary: 总结
- stats: 统计信息
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "sub_trace_completed", {
- "trace_id": sub_trace_id,
- "status": status,
- "summary": summary,
- "stats": stats or {}
- })
- message = {
- "event": "sub_trace_completed",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "trace_id": sub_trace_id,
- "status": status,
- "summary": summary,
- "stats": stats or {}
- }
- await _broadcast_to_trace(trace_id, message)
- async def broadcast_trace_completed(trace_id: str, total_messages: int):
- """
- 广播 Trace 完成事件
- Args:
- trace_id: Trace ID
- total_messages: 总 Message 数
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- event_id = await store.append_event(trace_id, "trace_completed", {
- "total_messages": total_messages
- })
- message = {
- "event": "trace_completed",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "trace_id": trace_id,
- "total_messages": total_messages
- }
- await _broadcast_to_trace(trace_id, message)
- # 完成后清理所有连接
- if trace_id in _active_connections:
- del _active_connections[trace_id]
- async def broadcast_trace_status_changed(trace_id: str, status: str):
- """
- 广播 Trace 状态变化事件(用于暂停/继续等状态切换)
- Args:
- trace_id: Trace ID
- status: 新状态 (running/stopped/completed/failed)
- """
- if trace_id not in _active_connections:
- return
- store = get_trace_store()
- trace = await store.get_trace(trace_id)
- if not trace:
- return
- event_id = await store.append_event(trace_id, "trace_status_changed", {
- "status": status
- })
- message = {
- "event": "trace_status_changed",
- "event_id": event_id,
- "ts": datetime.now().isoformat(),
- "trace_id": trace_id,
- "status": status,
- "total_cost": trace.total_cost,
- "total_messages": trace.total_messages
- }
- await _broadcast_to_trace(trace_id, message)
- # ===== 内部辅助函数 =====
- async def _broadcast_to_trace(trace_id: str, message: Dict[str, Any]):
- """
- 向指定 Trace 的所有连接广播消息
- Args:
- trace_id: Trace ID
- message: 消息内容
- """
- if trace_id not in _active_connections:
- return
- disconnected = []
- for websocket in _active_connections[trace_id]:
- try:
- await websocket.send_json(message)
- except Exception:
- disconnected.append(websocket)
- # 清理断开的连接
- for ws in disconnected:
- _active_connections[trace_id].discard(ws)
|