api.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. """
  2. Trace RESTful API
  3. 提供 Trace、GoalTree、Message 的查询接口
  4. """
  5. from typing import List, Optional, Dict, Any
  6. from fastapi import APIRouter, HTTPException, Query
  7. from pydantic import BaseModel
  8. from .protocols import TraceStore
  9. router = APIRouter(prefix="/api/traces", tags=["traces"])
  10. # ===== Response 模型 =====
  11. class TraceListResponse(BaseModel):
  12. """Trace 列表响应"""
  13. traces: List[Dict[str, Any]]
  14. class TraceDetailResponse(BaseModel):
  15. """Trace 详情响应(包含 GoalTree 和 Sub-Traces 元数据)"""
  16. trace: Dict[str, Any]
  17. goal_tree: Optional[Dict[str, Any]] = None
  18. sub_traces: Dict[str, Dict[str, Any]] = {}
  19. class MessagesResponse(BaseModel):
  20. """Messages 响应"""
  21. messages: List[Dict[str, Any]]
  22. # ===== 全局 TraceStore(由 api_server.py 注入)=====
  23. _trace_store: Optional[TraceStore] = None
  24. def set_trace_store(store: TraceStore):
  25. """设置 TraceStore 实例"""
  26. global _trace_store
  27. _trace_store = store
  28. def get_trace_store() -> TraceStore:
  29. """获取 TraceStore 实例"""
  30. if _trace_store is None:
  31. raise RuntimeError("TraceStore not initialized")
  32. return _trace_store
  33. # ===== 路由 =====
  34. @router.get("", response_model=TraceListResponse)
  35. async def list_traces(
  36. mode: Optional[str] = None,
  37. agent_type: Optional[str] = None,
  38. uid: Optional[str] = None,
  39. status: Optional[str] = None,
  40. limit: int = Query(20, le=100)
  41. ):
  42. """
  43. 列出 Traces
  44. Args:
  45. mode: 模式过滤(call/agent)
  46. agent_type: Agent 类型过滤
  47. uid: 用户 ID 过滤
  48. status: 状态过滤(running/completed/failed)
  49. limit: 最大返回数量
  50. """
  51. store = get_trace_store()
  52. traces = await store.list_traces(
  53. mode=mode,
  54. agent_type=agent_type,
  55. uid=uid,
  56. status=status,
  57. limit=limit
  58. )
  59. return TraceListResponse(
  60. traces=[t.to_dict() for t in traces]
  61. )
  62. @router.get("/{trace_id}", response_model=TraceDetailResponse)
  63. async def get_trace(trace_id: str):
  64. """
  65. 获取 Trace 详情
  66. 返回 Trace 元数据、GoalTree、Sub-Traces 元数据(不含 Sub-Trace 内 GoalTree)
  67. Args:
  68. trace_id: Trace ID
  69. """
  70. store = get_trace_store()
  71. # 获取 Trace
  72. trace = await store.get_trace(trace_id)
  73. if not trace:
  74. raise HTTPException(status_code=404, detail="Trace not found")
  75. # 获取 GoalTree
  76. goal_tree = await store.get_goal_tree(trace_id)
  77. # 获取所有 Sub-Traces(通过 parent_trace_id 查询)
  78. sub_traces = {}
  79. all_traces = await store.list_traces(limit=1000) # 获取所有 traces
  80. for t in all_traces:
  81. if t.parent_trace_id == trace_id:
  82. sub_traces[t.trace_id] = t.to_dict()
  83. return TraceDetailResponse(
  84. trace=trace.to_dict(),
  85. goal_tree=goal_tree.to_dict() if goal_tree else None,
  86. sub_traces=sub_traces
  87. )
  88. @router.get("/{trace_id}/messages", response_model=MessagesResponse)
  89. async def get_messages(
  90. trace_id: str,
  91. mode: str = Query("main_path", description="查询模式:main_path(当前主路径消息)或 all(全部消息含所有分支)"),
  92. head: Optional[int] = Query(None, description="主路径的 head sequence(仅 mode=main_path 有效,默认用 trace.head_sequence)"),
  93. goal_id: Optional[str] = Query(None, description="过滤指定 Goal 的消息。使用 '_init' 查询初始阶段(goal_id=None)的消息"),
  94. ):
  95. """
  96. 获取 Messages
  97. Args:
  98. trace_id: Trace ID
  99. mode: 查询模式
  100. - "main_path"(默认): 从 head 沿 parent_sequence 链回溯的主路径消息
  101. - "all": 返回所有消息(包含所有分支)
  102. head: 可选,指定主路径的 head sequence(仅 mode=main_path 有效)
  103. goal_id: 可选,过滤指定 Goal 的消息
  104. - 不指定: 不按 goal 过滤
  105. - "_init" 或 "null": 返回初始阶段(goal_id=None)的消息
  106. - 其他值: 返回指定 Goal 的消息
  107. """
  108. store = get_trace_store()
  109. # 验证 Trace 存在
  110. trace = await store.get_trace(trace_id)
  111. if not trace:
  112. raise HTTPException(status_code=404, detail="Trace not found")
  113. # 获取 Messages
  114. if goal_id and goal_id not in ("_init", "null"):
  115. # 按 Goal 过滤(独立查询)
  116. messages = await store.get_messages_by_goal(trace_id, goal_id)
  117. elif mode == "main_path":
  118. # 主路径模式
  119. head_seq = head if head is not None else trace.head_sequence
  120. if head_seq > 0:
  121. messages = await store.get_main_path_messages(trace_id, head_seq)
  122. else:
  123. messages = []
  124. else:
  125. # all 模式:返回所有消息
  126. messages = await store.get_trace_messages(trace_id)
  127. # goal_id 过滤(_init 表示 goal_id=None 的消息)
  128. if goal_id in ("_init", "null"):
  129. messages = [m for m in messages if m.goal_id is None]
  130. return MessagesResponse(
  131. messages=[m.to_dict() for m in messages]
  132. )