api.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. """
  2. Trace RESTful API
  3. 提供 Trace、GoalTree、Message 的查询接口
  4. """
  5. from typing import List, Optional, Dict, Any
  6. from fastapi import APIRouter, HTTPException, Query
  7. from pydantic import BaseModel
  8. from .protocols import TraceStore
  9. router = APIRouter(prefix="/api/traces", tags=["traces"])
  10. # ===== Response 模型 =====
  11. class TraceListResponse(BaseModel):
  12. """Trace 列表响应"""
  13. traces: List[Dict[str, Any]]
  14. class TraceDetailResponse(BaseModel):
  15. """Trace 详情响应(包含 GoalTree 和 Sub-Traces 元数据)"""
  16. trace: Dict[str, Any]
  17. goal_tree: Optional[Dict[str, Any]] = None
  18. sub_traces: Dict[str, Dict[str, Any]] = {}
  19. class MessagesResponse(BaseModel):
  20. """Messages 响应"""
  21. messages: List[Dict[str, Any]]
  22. # ===== 全局 TraceStore(由 api_server.py 注入)=====
  23. _trace_store: Optional[TraceStore] = None
  24. def set_trace_store(store: TraceStore):
  25. """设置 TraceStore 实例"""
  26. global _trace_store
  27. _trace_store = store
  28. def get_trace_store() -> TraceStore:
  29. """获取 TraceStore 实例"""
  30. if _trace_store is None:
  31. raise RuntimeError("TraceStore not initialized")
  32. return _trace_store
  33. # ===== 路由 =====
  34. @router.get("", response_model=TraceListResponse)
  35. async def list_traces(
  36. mode: Optional[str] = None,
  37. agent_type: Optional[str] = None,
  38. uid: Optional[str] = None,
  39. status: Optional[str] = None,
  40. limit: int = Query(20, le=100)
  41. ):
  42. """
  43. 列出 Traces
  44. Args:
  45. mode: 模式过滤(call/agent)
  46. agent_type: Agent 类型过滤
  47. uid: 用户 ID 过滤
  48. status: 状态过滤(running/completed/failed)
  49. limit: 最大返回数量
  50. """
  51. store = get_trace_store()
  52. traces = await store.list_traces(
  53. mode=mode,
  54. agent_type=agent_type,
  55. uid=uid,
  56. status=status,
  57. limit=limit
  58. )
  59. return TraceListResponse(
  60. traces=[t.to_dict() for t in traces]
  61. )
  62. @router.get("/{trace_id}", response_model=TraceDetailResponse)
  63. async def get_trace(trace_id: str):
  64. """
  65. 获取 Trace 详情
  66. 返回 Trace 元数据、GoalTree、Sub-Traces 元数据(不含 Sub-Trace 内 GoalTree)
  67. Args:
  68. trace_id: Trace ID
  69. """
  70. store = get_trace_store()
  71. # 获取 Trace
  72. trace = await store.get_trace(trace_id)
  73. if not trace:
  74. raise HTTPException(status_code=404, detail="Trace not found")
  75. # 获取 GoalTree
  76. goal_tree = await store.get_goal_tree(trace_id)
  77. # 获取所有 Sub-Traces(通过 parent_trace_id 查询)
  78. sub_traces = {}
  79. all_traces = await store.list_traces(limit=1000) # 获取所有 traces
  80. for t in all_traces:
  81. if t.parent_trace_id == trace_id:
  82. sub_traces[t.trace_id] = t.to_dict()
  83. return TraceDetailResponse(
  84. trace=trace.to_dict(),
  85. goal_tree=goal_tree.to_dict() if goal_tree else None,
  86. sub_traces=sub_traces
  87. )
  88. @router.get("/{trace_id}/messages", response_model=MessagesResponse)
  89. async def get_messages(
  90. trace_id: str,
  91. goal_id: Optional[str] = Query(None, description="过滤指定 Goal 的消息。使用 '_init' 查询初始阶段(goal_id=None)的消息")
  92. ):
  93. """
  94. 获取 Messages
  95. Args:
  96. trace_id: Trace ID
  97. goal_id: 可选,过滤指定 Goal 的消息
  98. - 不指定: 返回所有消息
  99. - "_init" 或 "null": 返回初始阶段(goal_id=None)的消息
  100. - 其他值: 返回指定 Goal 的消息
  101. """
  102. store = get_trace_store()
  103. # 验证 Trace 存在
  104. trace = await store.get_trace(trace_id)
  105. if not trace:
  106. raise HTTPException(status_code=404, detail="Trace not found")
  107. # 获取 Messages
  108. if goal_id is None:
  109. # 没有指定 goal_id,返回所有消息
  110. messages = await store.get_trace_messages(trace_id)
  111. elif goal_id in ("_init", "null"):
  112. # 特殊值:查询初始阶段的消息(goal_id=None)
  113. all_messages = await store.get_trace_messages(trace_id)
  114. messages = [m for m in all_messages if m.goal_id is None]
  115. else:
  116. # 查询指定 Goal 的消息
  117. messages = await store.get_messages_by_goal(trace_id, goal_id)
  118. return MessagesResponse(
  119. messages=[m.to_dict() for m in messages]
  120. )