""" Trace RESTful API 提供 Trace、GoalTree、Message 的查询接口 """ from typing import List, Optional, Dict, Any from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel from .protocols import TraceStore router = APIRouter(prefix="/api/traces", tags=["traces"]) # ===== Response 模型 ===== class TraceListResponse(BaseModel): """Trace 列表响应""" traces: List[Dict[str, Any]] class TraceDetailResponse(BaseModel): """Trace 详情响应(包含 GoalTree 和 Sub-Traces 元数据)""" trace: Dict[str, Any] goal_tree: Optional[Dict[str, Any]] = None sub_traces: Dict[str, Dict[str, Any]] = {} class MessagesResponse(BaseModel): """Messages 响应""" messages: List[Dict[str, Any]] # ===== 全局 TraceStore(由 api_server.py 注入)===== _trace_store: Optional[TraceStore] = None def set_trace_store(store: TraceStore): """设置 TraceStore 实例""" global _trace_store _trace_store = store def get_trace_store() -> TraceStore: """获取 TraceStore 实例""" if _trace_store is None: raise RuntimeError("TraceStore not initialized") return _trace_store # ===== 路由 ===== @router.get("", response_model=TraceListResponse) async def list_traces( mode: Optional[str] = None, agent_type: Optional[str] = None, uid: Optional[str] = None, status: Optional[str] = None, limit: int = Query(20, le=100) ): """ 列出 Traces Args: mode: 模式过滤(call/agent) agent_type: Agent 类型过滤 uid: 用户 ID 过滤 status: 状态过滤(running/completed/failed) limit: 最大返回数量 """ store = get_trace_store() traces = await store.list_traces( mode=mode, agent_type=agent_type, uid=uid, status=status, limit=limit ) return TraceListResponse( traces=[t.to_dict() for t in traces] ) @router.get("/{trace_id}", response_model=TraceDetailResponse) async def get_trace(trace_id: str): """ 获取 Trace 详情 返回 Trace 元数据、GoalTree、Sub-Traces 元数据(不含 Sub-Trace 内 GoalTree) Args: trace_id: Trace ID """ store = get_trace_store() # 获取 Trace trace = await store.get_trace(trace_id) if not trace: raise HTTPException(status_code=404, detail="Trace not found") # 获取 GoalTree goal_tree = await store.get_goal_tree(trace_id) # 获取所有 Sub-Traces(通过 parent_trace_id 查询) sub_traces = {} all_traces = await store.list_traces(limit=1000) # 获取所有 traces for t in all_traces: if t.parent_trace_id == trace_id: sub_traces[t.trace_id] = t.to_dict() return TraceDetailResponse( trace=trace.to_dict(), goal_tree=goal_tree.to_dict() if goal_tree else None, sub_traces=sub_traces ) @router.get("/{trace_id}/messages", response_model=MessagesResponse) async def get_messages( trace_id: str, mode: str = Query("main_path", description="查询模式:main_path(当前主路径消息)或 all(全部消息含所有分支)"), head: Optional[int] = Query(None, description="主路径的 head sequence(仅 mode=main_path 有效,默认用 trace.head_sequence)"), goal_id: Optional[str] = Query(None, description="过滤指定 Goal 的消息。使用 '_init' 查询初始阶段(goal_id=None)的消息"), ): """ 获取 Messages Args: trace_id: Trace ID mode: 查询模式 - "main_path"(默认): 从 head 沿 parent_sequence 链回溯的主路径消息 - "all": 返回所有消息(包含所有分支) head: 可选,指定主路径的 head sequence(仅 mode=main_path 有效) goal_id: 可选,过滤指定 Goal 的消息 - 不指定: 不按 goal 过滤 - "_init" 或 "null": 返回初始阶段(goal_id=None)的消息 - 其他值: 返回指定 Goal 的消息 """ store = get_trace_store() # 验证 Trace 存在 trace = await store.get_trace(trace_id) if not trace: raise HTTPException(status_code=404, detail="Trace not found") # 获取 Messages if goal_id and goal_id not in ("_init", "null"): # 按 Goal 过滤(独立查询) messages = await store.get_messages_by_goal(trace_id, goal_id) elif mode == "main_path": # 主路径模式 head_seq = head if head is not None else trace.head_sequence if head_seq > 0: messages = await store.get_main_path_messages(trace_id, head_seq) else: messages = [] else: # all 模式:返回所有消息 messages = await store.get_trace_messages(trace_id) # goal_id 过滤(_init 表示 goal_id=None 的消息) if goal_id in ("_init", "null"): messages = [m for m in messages if m.goal_id is None] return MessagesResponse( messages=[m.to_dict() for m in messages] )