| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275 |
- """
- Trace Storage Protocol - Trace 存储接口定义
- 使用 Protocol 定义接口,允许不同的存储实现(内存、PostgreSQL、Neo4j 等)
- """
- from typing import Protocol, List, Optional, Dict, Any, runtime_checkable
- from agent.execution.models import Trace, Message
- from agent.goal.models import GoalTree, Goal, BranchContext
- @runtime_checkable
- class TraceStore(Protocol):
- """Trace + Message + GoalTree + Branch 存储接口"""
- # ===== Trace 操作 =====
- async def create_trace(self, trace: Trace) -> str:
- """
- 创建新的 Trace
- Args:
- trace: Trace 对象
- Returns:
- trace_id
- """
- ...
- async def get_trace(self, trace_id: str) -> Optional[Trace]:
- """获取 Trace"""
- ...
- async def update_trace(self, trace_id: str, **updates) -> None:
- """
- 更新 Trace
- Args:
- trace_id: Trace ID
- **updates: 要更新的字段
- """
- ...
- async def list_traces(
- self,
- mode: Optional[str] = None,
- agent_type: Optional[str] = None,
- uid: Optional[str] = None,
- status: Optional[str] = None,
- limit: int = 50
- ) -> List[Trace]:
- """列出 Traces"""
- ...
- # ===== GoalTree 操作 =====
- async def get_goal_tree(self, trace_id: str) -> Optional[GoalTree]:
- """
- 获取 GoalTree
- Args:
- trace_id: Trace ID
- Returns:
- GoalTree 对象,如果不存在返回 None
- """
- ...
- async def update_goal_tree(self, trace_id: str, tree: GoalTree) -> None:
- """
- 更新完整 GoalTree
- Args:
- trace_id: Trace ID
- tree: GoalTree 对象
- """
- ...
- async def add_goal(self, trace_id: str, goal: Goal) -> None:
- """
- 添加 Goal 到 GoalTree
- Args:
- trace_id: Trace ID
- goal: Goal 对象
- """
- ...
- async def update_goal(self, trace_id: str, goal_id: str, **updates) -> None:
- """
- 更新 Goal 字段
- Args:
- trace_id: Trace ID
- goal_id: Goal ID
- **updates: 要更新的字段(如 status, summary, self_stats, cumulative_stats)
- """
- ...
- # ===== Branch 操作 =====
- async def create_branch(self, trace_id: str, branch: BranchContext) -> None:
- """
- 创建分支上下文
- Args:
- trace_id: Trace ID
- branch: BranchContext 对象
- """
- ...
- async def get_branch(self, trace_id: str, branch_id: str) -> Optional[BranchContext]:
- """
- 获取分支元数据
- Args:
- trace_id: Trace ID
- branch_id: 分支 ID
- Returns:
- BranchContext 对象(不含分支内 GoalTree)
- """
- ...
- async def get_branch_goal_tree(self, trace_id: str, branch_id: str) -> Optional[GoalTree]:
- """
- 获取分支的 GoalTree
- Args:
- trace_id: Trace ID
- branch_id: 分支 ID
- Returns:
- 分支的 GoalTree 对象
- """
- ...
- async def update_branch_goal_tree(self, trace_id: str, branch_id: str, tree: GoalTree) -> None:
- """
- 更新分支的 GoalTree
- Args:
- trace_id: Trace ID
- branch_id: 分支 ID
- tree: GoalTree 对象
- """
- ...
- async def update_branch(self, trace_id: str, branch_id: str, **updates) -> None:
- """
- 更新分支元数据
- Args:
- trace_id: Trace ID
- branch_id: 分支 ID
- **updates: 要更新的字段(如 status, summary, cumulative_stats)
- """
- ...
- async def list_branches(self, trace_id: str) -> Dict[str, BranchContext]:
- """
- 列出所有分支元数据
- Args:
- trace_id: Trace ID
- Returns:
- Dict[branch_id, BranchContext]
- """
- ...
- # ===== Message 操作 =====
- async def add_message(self, message: Message) -> str:
- """
- 添加 Message
- 自动更新关联 Goal 的 stats(self_stats 和祖先的 cumulative_stats)
- Args:
- message: Message 对象
- Returns:
- message_id
- """
- ...
- async def get_message(self, message_id: str) -> Optional[Message]:
- """获取 Message"""
- ...
- async def get_trace_messages(
- self,
- trace_id: str,
- branch_id: Optional[str] = None
- ) -> List[Message]:
- """
- 获取 Trace 的所有 Messages(按 sequence 排序)
- Args:
- trace_id: Trace ID
- branch_id: 可选,过滤指定分支的消息
- Returns:
- Message 列表
- """
- ...
- async def get_messages_by_goal(
- self,
- trace_id: str,
- goal_id: str,
- branch_id: Optional[str] = None
- ) -> List[Message]:
- """
- 获取指定 Goal 关联的所有 Messages
- Args:
- trace_id: Trace ID
- goal_id: Goal ID
- branch_id: 可选,指定分支
- Returns:
- Message 列表
- """
- ...
- async def update_message(self, message_id: str, **updates) -> None:
- """
- 更新 Message 字段(用于状态变更、错误记录等)
- Args:
- message_id: Message ID
- **updates: 要更新的字段
- """
- ...
- # ===== 事件流操作(用于 WebSocket 断线续传)=====
- async def get_events(
- self,
- trace_id: str,
- since_event_id: int = 0
- ) -> List[Dict[str, Any]]:
- """
- 获取事件流(用于 WS 断线续传)
- Args:
- trace_id: Trace ID
- since_event_id: 从哪个事件 ID 开始(0 表示全部)
- Returns:
- 事件列表(按 event_id 排序)
- """
- ...
- async def append_event(
- self,
- trace_id: str,
- event_type: str,
- payload: Dict[str, Any]
- ) -> int:
- """
- 追加事件,返回 event_id
- Args:
- trace_id: Trace ID
- event_type: 事件类型
- payload: 事件数据
- Returns:
- event_id: 新事件的 ID
- """
- ...
|