protocols.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. """
  2. Trace Storage Protocol - Trace 存储接口定义
  3. 使用 Protocol 定义接口,允许不同的存储实现(内存、PostgreSQL、Neo4j 等)
  4. """
  5. from typing import Protocol, List, Optional, Dict, Any, runtime_checkable
  6. from agent.execution.models import Trace, Step
  7. @runtime_checkable
  8. class TraceStore(Protocol):
  9. """Trace + Step 存储接口"""
  10. # ===== Trace 操作 =====
  11. async def create_trace(self, trace: Trace) -> str:
  12. """
  13. 创建新的 Trace
  14. Args:
  15. trace: Trace 对象
  16. Returns:
  17. trace_id
  18. """
  19. ...
  20. async def get_trace(self, trace_id: str) -> Optional[Trace]:
  21. """获取 Trace"""
  22. ...
  23. async def update_trace(self, trace_id: str, **updates) -> None:
  24. """
  25. 更新 Trace
  26. Args:
  27. trace_id: Trace ID
  28. **updates: 要更新的字段
  29. """
  30. ...
  31. async def list_traces(
  32. self,
  33. mode: Optional[str] = None,
  34. agent_type: Optional[str] = None,
  35. uid: Optional[str] = None,
  36. status: Optional[str] = None,
  37. limit: int = 50
  38. ) -> List[Trace]:
  39. """列出 Traces"""
  40. ...
  41. # ===== Step 操作 =====
  42. async def add_step(self, step: Step) -> str:
  43. """
  44. 添加 Step
  45. Args:
  46. step: Step 对象
  47. Returns:
  48. step_id
  49. """
  50. ...
  51. async def get_step(self, step_id: str) -> Optional[Step]:
  52. """获取 Step"""
  53. ...
  54. async def get_trace_steps(self, trace_id: str) -> List[Step]:
  55. """获取 Trace 的所有 Steps(按 sequence 排序)"""
  56. ...
  57. async def get_step_children(self, step_id: str) -> List[Step]:
  58. """获取 Step 的子节点"""
  59. ...
  60. async def update_step(self, step_id: str, **updates) -> None:
  61. """
  62. 更新 Step 字段(用于状态变更、错误记录等)
  63. Args:
  64. step_id: Step ID
  65. **updates: 要更新的字段
  66. """
  67. ...
  68. # ===== 事件流操作(用于 WebSocket 断线续传)=====
  69. async def get_events(
  70. self,
  71. trace_id: str,
  72. since_event_id: int = 0
  73. ) -> List[Dict[str, Any]]:
  74. """
  75. 获取事件流(用于 WS 断线续传)
  76. Args:
  77. trace_id: Trace ID
  78. since_event_id: 从哪个事件 ID 开始(0 表示全部)
  79. Returns:
  80. 事件列表(按 event_id 排序)
  81. """
  82. ...
  83. async def append_event(
  84. self,
  85. trace_id: str,
  86. event_type: str,
  87. payload: Dict[str, Any]
  88. ) -> int:
  89. """
  90. 追加事件,返回 event_id
  91. Args:
  92. trace_id: Trace ID
  93. event_type: 事件类型(step_added/step_updated/trace_completed)
  94. payload: 事件数据
  95. Returns:
  96. event_id: 新事件的 ID
  97. """
  98. ...