| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- """
- Trace RESTful API
- 提供 Trace、GoalTree、Message 的查询接口
- """
- from typing import List, Optional, Dict, Any
- from fastapi import APIRouter, HTTPException, Query
- from pydantic import BaseModel
- from .protocols import TraceStore
- router = APIRouter(prefix="/api/traces", tags=["traces"])
- # ===== Response 模型 =====
- class TraceListResponse(BaseModel):
- """Trace 列表响应"""
- traces: List[Dict[str, Any]]
- class TraceDetailResponse(BaseModel):
- """Trace 详情响应(包含 GoalTree 和 Sub-Traces 元数据)"""
- trace: Dict[str, Any]
- goal_tree: Optional[Dict[str, Any]] = None
- sub_traces: Dict[str, Dict[str, Any]] = {}
- class MessagesResponse(BaseModel):
- """Messages 响应"""
- messages: List[Dict[str, Any]]
- # ===== 全局 TraceStore(由 api_server.py 注入)=====
- _trace_store: Optional[TraceStore] = None
- def set_trace_store(store: TraceStore):
- """设置 TraceStore 实例"""
- global _trace_store
- _trace_store = store
- def get_trace_store() -> TraceStore:
- """获取 TraceStore 实例"""
- if _trace_store is None:
- raise RuntimeError("TraceStore not initialized")
- return _trace_store
- # ===== 路由 =====
- @router.get("", response_model=TraceListResponse)
- async def list_traces(
- mode: Optional[str] = None,
- agent_type: Optional[str] = None,
- uid: Optional[str] = None,
- status: Optional[str] = None,
- limit: int = Query(20, le=100)
- ):
- """
- 列出 Traces
- Args:
- mode: 模式过滤(call/agent)
- agent_type: Agent 类型过滤
- uid: 用户 ID 过滤
- status: 状态过滤(running/completed/failed)
- limit: 最大返回数量
- """
- store = get_trace_store()
- traces = await store.list_traces(
- mode=mode,
- agent_type=agent_type,
- uid=uid,
- status=status,
- limit=limit
- )
- return TraceListResponse(
- traces=[t.to_dict() for t in traces]
- )
- @router.get("/{trace_id}", response_model=TraceDetailResponse)
- async def get_trace(trace_id: str):
- """
- 获取 Trace 详情
- 返回 Trace 元数据、GoalTree、Sub-Traces 元数据(不含 Sub-Trace 内 GoalTree)
- Args:
- trace_id: Trace ID
- """
- store = get_trace_store()
- # 获取 Trace
- trace = await store.get_trace(trace_id)
- if not trace:
- raise HTTPException(status_code=404, detail="Trace not found")
- # 获取 GoalTree
- goal_tree = await store.get_goal_tree(trace_id)
- # 获取所有 Sub-Traces(通过 parent_trace_id 查询)
- sub_traces = {}
- all_traces = await store.list_traces(limit=1000) # 获取所有 traces
- for t in all_traces:
- if t.parent_trace_id == trace_id:
- sub_traces[t.trace_id] = t.to_dict()
- return TraceDetailResponse(
- trace=trace.to_dict(),
- goal_tree=goal_tree.to_dict() if goal_tree else None,
- sub_traces=sub_traces
- )
- @router.get("/{trace_id}/messages", response_model=MessagesResponse)
- async def get_messages(
- trace_id: str,
- mode: str = Query("main_path", description="查询模式:main_path(当前主路径消息)或 all(全部消息含所有分支)"),
- head: Optional[int] = Query(None, description="主路径的 head sequence(仅 mode=main_path 有效,默认用 trace.head_sequence)"),
- goal_id: Optional[str] = Query(None, description="过滤指定 Goal 的消息。使用 '_init' 查询初始阶段(goal_id=None)的消息"),
- ):
- """
- 获取 Messages
- Args:
- trace_id: Trace ID
- mode: 查询模式
- - "main_path"(默认): 从 head 沿 parent_sequence 链回溯的主路径消息
- - "all": 返回所有消息(包含所有分支)
- head: 可选,指定主路径的 head sequence(仅 mode=main_path 有效)
- goal_id: 可选,过滤指定 Goal 的消息
- - 不指定: 不按 goal 过滤
- - "_init" 或 "null": 返回初始阶段(goal_id=None)的消息
- - 其他值: 返回指定 Goal 的消息
- """
- store = get_trace_store()
- # 验证 Trace 存在
- trace = await store.get_trace(trace_id)
- if not trace:
- raise HTTPException(status_code=404, detail="Trace not found")
- # 获取 Messages
- if goal_id and goal_id not in ("_init", "null"):
- # 按 Goal 过滤(独立查询)
- messages = await store.get_messages_by_goal(trace_id, goal_id)
- elif mode == "main_path":
- # 主路径模式
- head_seq = head if head is not None else trace.head_sequence
- if head_seq > 0:
- messages = await store.get_main_path_messages(trace_id, head_seq)
- else:
- messages = []
- else:
- # all 模式:返回所有消息
- messages = await store.get_trace_messages(trace_id)
- # goal_id 过滤(_init 表示 goal_id=None 的消息)
- if goal_id in ("_init", "null"):
- messages = [m for m in messages if m.goal_id is None]
- return MessagesResponse(
- messages=[m.to_dict() for m in messages]
- )
|