memory_impl.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """
  2. Memory Implementation - 内存存储实现
  3. 用于测试和简单场景,数据不持久化
  4. """
  5. from typing import Dict, List, Optional, Any
  6. from datetime import datetime
  7. from reson_agent.models.trace import Trace, Step
  8. from reson_agent.models.memory import Experience, Skill
  9. class MemoryTraceStore:
  10. """内存 Trace 存储"""
  11. def __init__(self):
  12. self._traces: Dict[str, Trace] = {}
  13. self._steps: Dict[str, Step] = {}
  14. self._trace_steps: Dict[str, List[str]] = {} # trace_id -> [step_ids]
  15. async def create_trace(self, trace: Trace) -> str:
  16. self._traces[trace.trace_id] = trace
  17. self._trace_steps[trace.trace_id] = []
  18. return trace.trace_id
  19. async def get_trace(self, trace_id: str) -> Optional[Trace]:
  20. return self._traces.get(trace_id)
  21. async def update_trace(self, trace_id: str, **updates) -> None:
  22. trace = self._traces.get(trace_id)
  23. if trace:
  24. for key, value in updates.items():
  25. if hasattr(trace, key):
  26. setattr(trace, key, value)
  27. async def list_traces(
  28. self,
  29. mode: Optional[str] = None,
  30. agent_type: Optional[str] = None,
  31. uid: Optional[str] = None,
  32. status: Optional[str] = None,
  33. limit: int = 50
  34. ) -> List[Trace]:
  35. traces = list(self._traces.values())
  36. # 过滤
  37. if mode:
  38. traces = [t for t in traces if t.mode == mode]
  39. if agent_type:
  40. traces = [t for t in traces if t.agent_type == agent_type]
  41. if uid:
  42. traces = [t for t in traces if t.uid == uid]
  43. if status:
  44. traces = [t for t in traces if t.status == status]
  45. # 排序(最新的在前)
  46. traces.sort(key=lambda t: t.created_at, reverse=True)
  47. return traces[:limit]
  48. async def add_step(self, step: Step) -> str:
  49. self._steps[step.step_id] = step
  50. # 添加到 trace 的 steps 列表
  51. if step.trace_id in self._trace_steps:
  52. self._trace_steps[step.trace_id].append(step.step_id)
  53. # 更新 trace 的 total_steps
  54. trace = self._traces.get(step.trace_id)
  55. if trace:
  56. trace.total_steps += 1
  57. return step.step_id
  58. async def get_step(self, step_id: str) -> Optional[Step]:
  59. return self._steps.get(step_id)
  60. async def get_trace_steps(self, trace_id: str) -> List[Step]:
  61. step_ids = self._trace_steps.get(trace_id, [])
  62. steps = [self._steps[sid] for sid in step_ids if sid in self._steps]
  63. steps.sort(key=lambda s: s.sequence)
  64. return steps
  65. async def get_step_children(self, step_id: str) -> List[Step]:
  66. children = []
  67. for step in self._steps.values():
  68. if step_id in step.parent_ids:
  69. children.append(step)
  70. children.sort(key=lambda s: s.sequence)
  71. return children
  72. class MemoryMemoryStore:
  73. """内存 Memory 存储(Experience + Skill)"""
  74. def __init__(self):
  75. self._experiences: Dict[str, Experience] = {}
  76. self._skills: Dict[str, Skill] = {}
  77. # ===== Experience =====
  78. async def add_experience(self, exp: Experience) -> str:
  79. self._experiences[exp.exp_id] = exp
  80. return exp.exp_id
  81. async def get_experience(self, exp_id: str) -> Optional[Experience]:
  82. return self._experiences.get(exp_id)
  83. async def search_experiences(
  84. self,
  85. scope: str,
  86. context: str,
  87. limit: int = 10
  88. ) -> List[Experience]:
  89. # 简单实现:按 scope 过滤,按 confidence 排序
  90. experiences = [
  91. e for e in self._experiences.values()
  92. if e.scope == scope
  93. ]
  94. experiences.sort(key=lambda e: e.confidence, reverse=True)
  95. return experiences[:limit]
  96. async def update_experience_stats(
  97. self,
  98. exp_id: str,
  99. success: bool
  100. ) -> None:
  101. exp = self._experiences.get(exp_id)
  102. if exp:
  103. exp.usage_count += 1
  104. if success:
  105. # 更新成功率
  106. total_success = exp.success_rate * (exp.usage_count - 1) + (1 if success else 0)
  107. exp.success_rate = total_success / exp.usage_count
  108. exp.updated_at = datetime.now()
  109. # ===== Skill =====
  110. async def add_skill(self, skill: Skill) -> str:
  111. self._skills[skill.skill_id] = skill
  112. return skill.skill_id
  113. async def get_skill(self, skill_id: str) -> Optional[Skill]:
  114. return self._skills.get(skill_id)
  115. async def get_skill_tree(self, scope: str) -> List[Skill]:
  116. return [s for s in self._skills.values() if s.scope == scope]
  117. async def search_skills(
  118. self,
  119. scope: str,
  120. context: str,
  121. limit: int = 5
  122. ) -> List[Skill]:
  123. # 简单实现:按 scope 过滤
  124. skills = [s for s in self._skills.values() if s.scope == scope]
  125. return skills[:limit]
  126. class MemoryStateStore:
  127. """内存状态存储"""
  128. def __init__(self):
  129. self._state: Dict[str, Dict[str, Any]] = {}
  130. async def get(self, key: str) -> Optional[Dict[str, Any]]:
  131. return self._state.get(key)
  132. async def set(
  133. self,
  134. key: str,
  135. value: Dict[str, Any],
  136. ttl: Optional[int] = None
  137. ) -> None:
  138. # 内存实现忽略 ttl
  139. self._state[key] = value
  140. async def update(self, key: str, **updates) -> None:
  141. if key in self._state:
  142. self._state[key].update(updates)
  143. async def delete(self, key: str) -> None:
  144. self._state.pop(key, None)