api.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. """
  2. Trace RESTful API
  3. 提供 Trace、GoalTree、Message、Branch 的查询接口
  4. """
  5. from typing import List, Optional, Dict, Any
  6. from fastapi import APIRouter, HTTPException, Query
  7. from pydantic import BaseModel
  8. from agent.execution.protocols import TraceStore
  9. router = APIRouter(prefix="/api/traces", tags=["traces"])
  10. # ===== Response 模型 =====
  11. class TraceListResponse(BaseModel):
  12. """Trace 列表响应"""
  13. traces: List[Dict[str, Any]]
  14. class TraceDetailResponse(BaseModel):
  15. """Trace 详情响应(包含 GoalTree 和分支元数据)"""
  16. trace: Dict[str, Any]
  17. goal_tree: Optional[Dict[str, Any]] = None
  18. branches: Dict[str, Dict[str, Any]] = {}
  19. class MessagesResponse(BaseModel):
  20. """Messages 响应"""
  21. messages: List[Dict[str, Any]]
  22. class BranchDetailResponse(BaseModel):
  23. """分支详情响应(包含分支的 GoalTree)"""
  24. branch: Dict[str, Any]
  25. goal_tree: Optional[Dict[str, Any]] = None
  26. # ===== 全局 TraceStore(由 api_server.py 注入)=====
  27. _trace_store: Optional[TraceStore] = None
  28. def set_trace_store(store: TraceStore):
  29. """设置 TraceStore 实例"""
  30. global _trace_store
  31. _trace_store = store
  32. def get_trace_store() -> TraceStore:
  33. """获取 TraceStore 实例"""
  34. if _trace_store is None:
  35. raise RuntimeError("TraceStore not initialized")
  36. return _trace_store
  37. # ===== 路由 =====
  38. @router.get("", response_model=TraceListResponse)
  39. async def list_traces(
  40. mode: Optional[str] = None,
  41. agent_type: Optional[str] = None,
  42. uid: Optional[str] = None,
  43. status: Optional[str] = None,
  44. limit: int = Query(20, le=100)
  45. ):
  46. """
  47. 列出 Traces
  48. Args:
  49. mode: 模式过滤(call/agent)
  50. agent_type: Agent 类型过滤
  51. uid: 用户 ID 过滤
  52. status: 状态过滤(running/completed/failed)
  53. limit: 最大返回数量
  54. """
  55. store = get_trace_store()
  56. traces = await store.list_traces(
  57. mode=mode,
  58. agent_type=agent_type,
  59. uid=uid,
  60. status=status,
  61. limit=limit
  62. )
  63. return TraceListResponse(
  64. traces=[t.to_dict() for t in traces]
  65. )
  66. @router.get("/{trace_id}", response_model=TraceDetailResponse)
  67. async def get_trace(trace_id: str):
  68. """
  69. 获取 Trace 详情
  70. 返回 Trace 元数据、主线 GoalTree、分支元数据(不含分支内 GoalTree)
  71. Args:
  72. trace_id: Trace ID
  73. """
  74. store = get_trace_store()
  75. # 获取 Trace
  76. trace = await store.get_trace(trace_id)
  77. if not trace:
  78. raise HTTPException(status_code=404, detail="Trace not found")
  79. # 获取 GoalTree
  80. goal_tree = await store.get_goal_tree(trace_id)
  81. # 获取所有分支元数据
  82. branches = await store.list_branches(trace_id)
  83. return TraceDetailResponse(
  84. trace=trace.to_dict(),
  85. goal_tree=goal_tree.to_dict() if goal_tree else None,
  86. branches={b_id: b.to_dict() for b_id, b in branches.items()}
  87. )
  88. @router.get("/{trace_id}/messages", response_model=MessagesResponse)
  89. async def get_messages(
  90. trace_id: str,
  91. goal_id: Optional[str] = Query(None, description="过滤指定 Goal 的消息"),
  92. branch_id: Optional[str] = Query(None, description="过滤指定分支的消息")
  93. ):
  94. """
  95. 获取 Messages
  96. Args:
  97. trace_id: Trace ID
  98. goal_id: 可选,过滤指定 Goal 的消息
  99. branch_id: 可选,过滤指定分支的消息
  100. """
  101. store = get_trace_store()
  102. # 验证 Trace 存在
  103. trace = await store.get_trace(trace_id)
  104. if not trace:
  105. raise HTTPException(status_code=404, detail="Trace not found")
  106. # 获取 Messages
  107. if goal_id:
  108. messages = await store.get_messages_by_goal(trace_id, goal_id, branch_id)
  109. else:
  110. messages = await store.get_trace_messages(trace_id, branch_id)
  111. return MessagesResponse(
  112. messages=[m.to_dict() for m in messages]
  113. )
  114. @router.get("/{trace_id}/branches/{branch_id}", response_model=BranchDetailResponse)
  115. async def get_branch_detail(
  116. trace_id: str,
  117. branch_id: str
  118. ):
  119. """
  120. 获取分支详情
  121. 返回分支元数据和分支的 GoalTree(按需加载)
  122. Args:
  123. trace_id: Trace ID
  124. branch_id: 分支 ID
  125. """
  126. store = get_trace_store()
  127. # 验证 Trace 存在
  128. trace = await store.get_trace(trace_id)
  129. if not trace:
  130. raise HTTPException(status_code=404, detail="Trace not found")
  131. # 获取分支元数据
  132. branch = await store.get_branch(trace_id, branch_id)
  133. if not branch:
  134. raise HTTPException(status_code=404, detail="Branch not found")
  135. # 获取分支的 GoalTree
  136. goal_tree = await store.get_branch_goal_tree(trace_id, branch_id)
  137. return BranchDetailResponse(
  138. branch=branch.to_dict(),
  139. goal_tree=goal_tree.to_dict() if goal_tree else None
  140. )