protocols.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. """
  2. Storage Protocols - 存储接口定义
  3. 使用 Protocol 定义接口,允许不同的存储实现(内存、PostgreSQL、Neo4j 等)
  4. """
  5. from typing import Protocol, List, Optional, Dict, Any, runtime_checkable
  6. from agent.models.trace import Trace, Step
  7. from agent.models.memory import Experience, Skill
  8. @runtime_checkable
  9. class TraceStore(Protocol):
  10. """Trace + Step 存储接口"""
  11. # ===== Trace 操作 =====
  12. async def create_trace(self, trace: Trace) -> str:
  13. """
  14. 创建新的 Trace
  15. Args:
  16. trace: Trace 对象
  17. Returns:
  18. trace_id
  19. """
  20. ...
  21. async def get_trace(self, trace_id: str) -> Optional[Trace]:
  22. """获取 Trace"""
  23. ...
  24. async def update_trace(self, trace_id: str, **updates) -> None:
  25. """
  26. 更新 Trace
  27. Args:
  28. trace_id: Trace ID
  29. **updates: 要更新的字段
  30. """
  31. ...
  32. async def list_traces(
  33. self,
  34. mode: Optional[str] = None,
  35. agent_type: Optional[str] = None,
  36. uid: Optional[str] = None,
  37. status: Optional[str] = None,
  38. limit: int = 50
  39. ) -> List[Trace]:
  40. """列出 Traces"""
  41. ...
  42. # ===== Step 操作 =====
  43. async def add_step(self, step: Step) -> str:
  44. """
  45. 添加 Step
  46. Args:
  47. step: Step 对象
  48. Returns:
  49. step_id
  50. """
  51. ...
  52. async def get_step(self, step_id: str) -> Optional[Step]:
  53. """获取 Step"""
  54. ...
  55. async def get_trace_steps(self, trace_id: str) -> List[Step]:
  56. """获取 Trace 的所有 Steps(按 sequence 排序)"""
  57. ...
  58. async def get_step_children(self, step_id: str) -> List[Step]:
  59. """获取 Step 的子节点"""
  60. ...
  61. @runtime_checkable
  62. class MemoryStore(Protocol):
  63. """Experience + Skill 存储接口"""
  64. # ===== Experience 操作 =====
  65. async def add_experience(self, exp: Experience) -> str:
  66. """添加 Experience"""
  67. ...
  68. async def get_experience(self, exp_id: str) -> Optional[Experience]:
  69. """获取 Experience"""
  70. ...
  71. async def search_experiences(
  72. self,
  73. scope: str,
  74. context: str,
  75. limit: int = 10
  76. ) -> List[Experience]:
  77. """
  78. 搜索相关 Experience
  79. Args:
  80. scope: 范围(如 "agent:researcher")
  81. context: 当前上下文,用于语义匹配
  82. limit: 最大返回数量
  83. """
  84. ...
  85. async def update_experience_stats(
  86. self,
  87. exp_id: str,
  88. success: bool
  89. ) -> None:
  90. """更新 Experience 使用统计"""
  91. ...
  92. # ===== Skill 操作 =====
  93. async def add_skill(self, skill: Skill) -> str:
  94. """添加 Skill"""
  95. ...
  96. async def get_skill(self, skill_id: str) -> Optional[Skill]:
  97. """获取 Skill"""
  98. ...
  99. async def get_skill_tree(self, scope: str) -> List[Skill]:
  100. """获取技能树"""
  101. ...
  102. async def search_skills(
  103. self,
  104. scope: str,
  105. context: str,
  106. limit: int = 5
  107. ) -> List[Skill]:
  108. """搜索相关 Skills"""
  109. ...
  110. @runtime_checkable
  111. class StateStore(Protocol):
  112. """短期状态存储接口(用于 Task State,通常用 Redis)"""
  113. async def get(self, key: str) -> Optional[Dict[str, Any]]:
  114. """获取状态"""
  115. ...
  116. async def set(
  117. self,
  118. key: str,
  119. value: Dict[str, Any],
  120. ttl: Optional[int] = None
  121. ) -> None:
  122. """
  123. 设置状态
  124. Args:
  125. key: 键
  126. value: 值
  127. ttl: 过期时间(秒)
  128. """
  129. ...
  130. async def update(self, key: str, **updates) -> None:
  131. """部分更新"""
  132. ...
  133. async def delete(self, key: str) -> None:
  134. """删除"""
  135. ...