""" Trace RESTful API 提供 Trace、GoalTree、Message、Branch 的查询接口 """ from typing import List, Optional, Dict, Any from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel from agent.execution.protocols import TraceStore router = APIRouter(prefix="/api/traces", tags=["traces"]) # ===== Response 模型 ===== class TraceListResponse(BaseModel): """Trace 列表响应""" traces: List[Dict[str, Any]] class TraceDetailResponse(BaseModel): """Trace 详情响应(包含 GoalTree 和分支元数据)""" trace: Dict[str, Any] goal_tree: Optional[Dict[str, Any]] = None branches: Dict[str, Dict[str, Any]] = {} class MessagesResponse(BaseModel): """Messages 响应""" messages: List[Dict[str, Any]] class BranchDetailResponse(BaseModel): """分支详情响应(包含分支的 GoalTree)""" branch: Dict[str, Any] goal_tree: Optional[Dict[str, Any]] = None # ===== 全局 TraceStore(由 api_server.py 注入)===== _trace_store: Optional[TraceStore] = None def set_trace_store(store: TraceStore): """设置 TraceStore 实例""" global _trace_store _trace_store = store def get_trace_store() -> TraceStore: """获取 TraceStore 实例""" if _trace_store is None: raise RuntimeError("TraceStore not initialized") return _trace_store # ===== 路由 ===== @router.get("", response_model=TraceListResponse) async def list_traces( mode: Optional[str] = None, agent_type: Optional[str] = None, uid: Optional[str] = None, status: Optional[str] = None, limit: int = Query(20, le=100) ): """ 列出 Traces Args: mode: 模式过滤(call/agent) agent_type: Agent 类型过滤 uid: 用户 ID 过滤 status: 状态过滤(running/completed/failed) limit: 最大返回数量 """ store = get_trace_store() traces = await store.list_traces( mode=mode, agent_type=agent_type, uid=uid, status=status, limit=limit ) return TraceListResponse( traces=[t.to_dict() for t in traces] ) @router.get("/{trace_id}", response_model=TraceDetailResponse) async def get_trace(trace_id: str): """ 获取 Trace 详情 返回 Trace 元数据、主线 GoalTree、分支元数据(不含分支内 GoalTree) Args: trace_id: Trace ID """ store = get_trace_store() # 获取 Trace trace = await store.get_trace(trace_id) if not trace: raise HTTPException(status_code=404, detail="Trace not found") # 获取 GoalTree goal_tree = await store.get_goal_tree(trace_id) # 获取所有分支元数据 branches = await store.list_branches(trace_id) return TraceDetailResponse( trace=trace.to_dict(), goal_tree=goal_tree.to_dict() if goal_tree else None, branches={b_id: b.to_dict() for b_id, b in branches.items()} ) @router.get("/{trace_id}/messages", response_model=MessagesResponse) async def get_messages( trace_id: str, goal_id: Optional[str] = Query(None, description="过滤指定 Goal 的消息"), branch_id: Optional[str] = Query(None, description="过滤指定分支的消息") ): """ 获取 Messages Args: trace_id: Trace ID goal_id: 可选,过滤指定 Goal 的消息 branch_id: 可选,过滤指定分支的消息 """ store = get_trace_store() # 验证 Trace 存在 trace = await store.get_trace(trace_id) if not trace: raise HTTPException(status_code=404, detail="Trace not found") # 获取 Messages if goal_id: messages = await store.get_messages_by_goal(trace_id, goal_id, branch_id) else: messages = await store.get_trace_messages(trace_id, branch_id) return MessagesResponse( messages=[m.to_dict() for m in messages] ) @router.get("/{trace_id}/branches/{branch_id}", response_model=BranchDetailResponse) async def get_branch_detail( trace_id: str, branch_id: str ): """ 获取分支详情 返回分支元数据和分支的 GoalTree(按需加载) Args: trace_id: Trace ID branch_id: 分支 ID """ store = get_trace_store() # 验证 Trace 存在 trace = await store.get_trace(trace_id) if not trace: raise HTTPException(status_code=404, detail="Trace not found") # 获取分支元数据 branch = await store.get_branch(trace_id, branch_id) if not branch: raise HTTPException(status_code=404, detail="Branch not found") # 获取分支的 GoalTree goal_tree = await store.get_branch_goal_tree(trace_id, branch_id) return BranchDetailResponse( branch=branch.to_dict(), goal_tree=goal_tree.to_dict() if goal_tree else None )