| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- """
- Trace RESTful API
- 提供 Trace、GoalTree、Message、Branch 的查询接口
- """
- from typing import List, Optional, Dict, Any
- from fastapi import APIRouter, HTTPException, Query
- from pydantic import BaseModel
- from agent.execution.protocols import TraceStore
- router = APIRouter(prefix="/api/traces", tags=["traces"])
- # ===== Response 模型 =====
- class TraceListResponse(BaseModel):
- """Trace 列表响应"""
- traces: List[Dict[str, Any]]
- class TraceDetailResponse(BaseModel):
- """Trace 详情响应(包含 GoalTree 和分支元数据)"""
- trace: Dict[str, Any]
- goal_tree: Optional[Dict[str, Any]] = None
- branches: Dict[str, Dict[str, Any]] = {}
- class MessagesResponse(BaseModel):
- """Messages 响应"""
- messages: List[Dict[str, Any]]
- class BranchDetailResponse(BaseModel):
- """分支详情响应(包含分支的 GoalTree)"""
- branch: Dict[str, Any]
- goal_tree: Optional[Dict[str, Any]] = None
- # ===== 全局 TraceStore(由 api_server.py 注入)=====
- _trace_store: Optional[TraceStore] = None
- def set_trace_store(store: TraceStore):
- """设置 TraceStore 实例"""
- global _trace_store
- _trace_store = store
- def get_trace_store() -> TraceStore:
- """获取 TraceStore 实例"""
- if _trace_store is None:
- raise RuntimeError("TraceStore not initialized")
- return _trace_store
- # ===== 路由 =====
- @router.get("", response_model=TraceListResponse)
- async def list_traces(
- mode: Optional[str] = None,
- agent_type: Optional[str] = None,
- uid: Optional[str] = None,
- status: Optional[str] = None,
- limit: int = Query(20, le=100)
- ):
- """
- 列出 Traces
- Args:
- mode: 模式过滤(call/agent)
- agent_type: Agent 类型过滤
- uid: 用户 ID 过滤
- status: 状态过滤(running/completed/failed)
- limit: 最大返回数量
- """
- store = get_trace_store()
- traces = await store.list_traces(
- mode=mode,
- agent_type=agent_type,
- uid=uid,
- status=status,
- limit=limit
- )
- return TraceListResponse(
- traces=[t.to_dict() for t in traces]
- )
- @router.get("/{trace_id}", response_model=TraceDetailResponse)
- async def get_trace(trace_id: str):
- """
- 获取 Trace 详情
- 返回 Trace 元数据、主线 GoalTree、分支元数据(不含分支内 GoalTree)
- Args:
- trace_id: Trace ID
- """
- store = get_trace_store()
- # 获取 Trace
- trace = await store.get_trace(trace_id)
- if not trace:
- raise HTTPException(status_code=404, detail="Trace not found")
- # 获取 GoalTree
- goal_tree = await store.get_goal_tree(trace_id)
- # 获取所有分支元数据
- branches = await store.list_branches(trace_id)
- return TraceDetailResponse(
- trace=trace.to_dict(),
- goal_tree=goal_tree.to_dict() if goal_tree else None,
- branches={b_id: b.to_dict() for b_id, b in branches.items()}
- )
- @router.get("/{trace_id}/messages", response_model=MessagesResponse)
- async def get_messages(
- trace_id: str,
- goal_id: Optional[str] = Query(None, description="过滤指定 Goal 的消息"),
- branch_id: Optional[str] = Query(None, description="过滤指定分支的消息")
- ):
- """
- 获取 Messages
- Args:
- trace_id: Trace ID
- goal_id: 可选,过滤指定 Goal 的消息
- branch_id: 可选,过滤指定分支的消息
- """
- store = get_trace_store()
- # 验证 Trace 存在
- trace = await store.get_trace(trace_id)
- if not trace:
- raise HTTPException(status_code=404, detail="Trace not found")
- # 获取 Messages
- if goal_id:
- messages = await store.get_messages_by_goal(trace_id, goal_id, branch_id)
- else:
- messages = await store.get_trace_messages(trace_id, branch_id)
- return MessagesResponse(
- messages=[m.to_dict() for m in messages]
- )
- @router.get("/{trace_id}/branches/{branch_id}", response_model=BranchDetailResponse)
- async def get_branch_detail(
- trace_id: str,
- branch_id: str
- ):
- """
- 获取分支详情
- 返回分支元数据和分支的 GoalTree(按需加载)
- Args:
- trace_id: Trace ID
- branch_id: 分支 ID
- """
- store = get_trace_store()
- # 验证 Trace 存在
- trace = await store.get_trace(trace_id)
- if not trace:
- raise HTTPException(status_code=404, detail="Trace not found")
- # 获取分支元数据
- branch = await store.get_branch(trace_id, branch_id)
- if not branch:
- raise HTTPException(status_code=404, detail="Branch not found")
- # 获取分支的 GoalTree
- goal_tree = await store.get_branch_goal_tree(trace_id, branch_id)
- return BranchDetailResponse(
- branch=branch.to_dict(),
- goal_tree=goal_tree.to_dict() if goal_tree else None
- )
|