memory_impl.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """
  2. Memory Implementation - 内存存储实现
  3. 用于测试和简单场景,数据不持久化
  4. """
  5. from typing import Dict, List, Optional, Any
  6. from datetime import datetime
  7. from agent.models.trace import Trace, Step
  8. from agent.models.memory import Experience, Skill
  9. class MemoryTraceStore:
  10. """内存 Trace 存储"""
  11. def __init__(self):
  12. self._traces: Dict[str, Trace] = {}
  13. self._steps: Dict[str, Step] = {}
  14. self._trace_steps: Dict[str, List[str]] = {} # trace_id -> [step_ids]
  15. async def create_trace(self, trace: Trace) -> str:
  16. self._traces[trace.trace_id] = trace
  17. self._trace_steps[trace.trace_id] = []
  18. return trace.trace_id
  19. async def get_trace(self, trace_id: str) -> Optional[Trace]:
  20. return self._traces.get(trace_id)
  21. async def update_trace(self, trace_id: str, **updates) -> None:
  22. trace = self._traces.get(trace_id)
  23. if trace:
  24. for key, value in updates.items():
  25. if hasattr(trace, key):
  26. setattr(trace, key, value)
  27. async def list_traces(
  28. self,
  29. mode: Optional[str] = None,
  30. agent_type: Optional[str] = None,
  31. uid: Optional[str] = None,
  32. status: Optional[str] = None,
  33. limit: int = 50
  34. ) -> List[Trace]:
  35. traces = list(self._traces.values())
  36. # 过滤
  37. if mode:
  38. traces = [t for t in traces if t.mode == mode]
  39. if agent_type:
  40. traces = [t for t in traces if t.agent_type == agent_type]
  41. if uid:
  42. traces = [t for t in traces if t.uid == uid]
  43. if status:
  44. traces = [t for t in traces if t.status == status]
  45. # 排序(最新的在前)
  46. traces.sort(key=lambda t: t.created_at, reverse=True)
  47. return traces[:limit]
  48. async def add_step(self, step: Step) -> str:
  49. self._steps[step.step_id] = step
  50. # 添加到 trace 的 steps 列表
  51. if step.trace_id in self._trace_steps:
  52. self._trace_steps[step.trace_id].append(step.step_id)
  53. # 更新 trace 的 total_steps
  54. trace = self._traces.get(step.trace_id)
  55. if trace:
  56. trace.total_steps += 1
  57. return step.step_id
  58. async def get_step(self, step_id: str) -> Optional[Step]:
  59. return self._steps.get(step_id)
  60. async def get_trace_steps(self, trace_id: str) -> List[Step]:
  61. step_ids = self._trace_steps.get(trace_id, [])
  62. steps = [self._steps[sid] for sid in step_ids if sid in self._steps]
  63. steps.sort(key=lambda s: s.sequence)
  64. return steps
  65. async def get_step_children(self, step_id: str) -> List[Step]:
  66. children = []
  67. for step in self._steps.values():
  68. if step.parent_id == step_id:
  69. children.append(step)
  70. children.sort(key=lambda s: s.sequence)
  71. return children
  72. class MemoryMemoryStore:
  73. """内存 Memory 存储(Experience + Skill)"""
  74. def __init__(self):
  75. self._experiences: Dict[str, Experience] = {}
  76. self._skills: Dict[str, Skill] = {}
  77. # ===== Experience =====
  78. async def add_experience(self, exp: Experience) -> str:
  79. self._experiences[exp.exp_id] = exp
  80. return exp.exp_id
  81. async def get_experience(self, exp_id: str) -> Optional[Experience]:
  82. return self._experiences.get(exp_id)
  83. async def search_experiences(
  84. self,
  85. scope: str,
  86. context: str,
  87. limit: int = 10
  88. ) -> List[Experience]:
  89. # 简单实现:按 scope 过滤,按 confidence 排序
  90. experiences = [
  91. e for e in self._experiences.values()
  92. if e.scope == scope
  93. ]
  94. experiences.sort(key=lambda e: e.confidence, reverse=True)
  95. return experiences[:limit]
  96. async def update_experience_stats(
  97. self,
  98. exp_id: str,
  99. success: bool
  100. ) -> None:
  101. exp = self._experiences.get(exp_id)
  102. if exp:
  103. exp.usage_count += 1
  104. if success:
  105. # 更新成功率
  106. total_success = exp.success_rate * (exp.usage_count - 1) + (1 if success else 0)
  107. exp.success_rate = total_success / exp.usage_count
  108. exp.updated_at = datetime.now()
  109. # ===== Skill =====
  110. async def add_skill(self, skill: Skill) -> str:
  111. self._skills[skill.skill_id] = skill
  112. return skill.skill_id
  113. async def get_skill(self, skill_id: str) -> Optional[Skill]:
  114. return self._skills.get(skill_id)
  115. async def get_skill_tree(self, scope: str) -> List[Skill]:
  116. return [s for s in self._skills.values() if s.scope == scope]
  117. async def search_skills(
  118. self,
  119. scope: str,
  120. context: str,
  121. limit: int = 5
  122. ) -> List[Skill]:
  123. # 简单实现:按 scope 过滤
  124. skills = [s for s in self._skills.values() if s.scope == scope]
  125. return skills[:limit]
  126. class MemoryStateStore:
  127. """内存状态存储"""
  128. def __init__(self):
  129. self._state: Dict[str, Dict[str, Any]] = {}
  130. async def get(self, key: str) -> Optional[Dict[str, Any]]:
  131. return self._state.get(key)
  132. async def set(
  133. self,
  134. key: str,
  135. value: Dict[str, Any],
  136. ttl: Optional[int] = None
  137. ) -> None:
  138. # 内存实现忽略 ttl
  139. self._state[key] = value
  140. async def update(self, key: str, **updates) -> None:
  141. if key in self._state:
  142. self._state[key].update(updates)
  143. async def delete(self, key: str) -> None:
  144. self._state.pop(key, None)