stores.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. """
  2. Memory Implementation - 内存存储实现
  3. 用于测试和简单场景,数据不持久化
  4. MemoryTraceStore 已移动到 agent.execution.store
  5. """
  6. from typing import Dict, List, Optional, Any
  7. from datetime import datetime
  8. from agent.memory.models import Experience, Skill
  9. class MemoryMemoryStore:
  10. """内存 Memory 存储(Experience + Skill)"""
  11. def __init__(self):
  12. self._experiences: Dict[str, Experience] = {}
  13. self._skills: Dict[str, Skill] = {}
  14. # ===== Experience =====
  15. async def add_experience(self, exp: Experience) -> str:
  16. self._experiences[exp.exp_id] = exp
  17. return exp.exp_id
  18. async def get_experience(self, exp_id: str) -> Optional[Experience]:
  19. return self._experiences.get(exp_id)
  20. async def search_experiences(
  21. self,
  22. scope: str,
  23. context: str,
  24. limit: int = 10
  25. ) -> List[Experience]:
  26. # 简单实现:按 scope 过滤,按 confidence 排序
  27. experiences = [
  28. e for e in self._experiences.values()
  29. if e.scope == scope
  30. ]
  31. experiences.sort(key=lambda e: e.confidence, reverse=True)
  32. return experiences[:limit]
  33. async def update_experience_stats(
  34. self,
  35. exp_id: str,
  36. success: bool
  37. ) -> None:
  38. exp = self._experiences.get(exp_id)
  39. if exp:
  40. exp.usage_count += 1
  41. if success:
  42. # 更新成功率
  43. total_success = exp.success_rate * (exp.usage_count - 1) + (1 if success else 0)
  44. exp.success_rate = total_success / exp.usage_count
  45. exp.updated_at = datetime.now()
  46. # ===== Skill =====
  47. async def add_skill(self, skill: Skill) -> str:
  48. self._skills[skill.skill_id] = skill
  49. return skill.skill_id
  50. async def get_skill(self, skill_id: str) -> Optional[Skill]:
  51. return self._skills.get(skill_id)
  52. async def get_skill_tree(self, scope: str) -> List[Skill]:
  53. return [s for s in self._skills.values() if s.scope == scope]
  54. async def search_skills(
  55. self,
  56. scope: str,
  57. context: str,
  58. limit: int = 5
  59. ) -> List[Skill]:
  60. # 简单实现:按 scope 过滤
  61. skills = [s for s in self._skills.values() if s.scope == scope]
  62. return skills[:limit]
  63. class MemoryStateStore:
  64. """内存状态存储"""
  65. def __init__(self):
  66. self._state: Dict[str, Dict[str, Any]] = {}
  67. async def get(self, key: str) -> Optional[Dict[str, Any]]:
  68. return self._state.get(key)
  69. async def set(
  70. self,
  71. key: str,
  72. value: Dict[str, Any],
  73. ttl: Optional[int] = None
  74. ) -> None:
  75. # 内存实现忽略 ttl
  76. self._state[key] = value
  77. async def update(self, key: str, **updates) -> None:
  78. if key in self._state:
  79. self._state[key].update(updates)
  80. async def delete(self, key: str) -> None:
  81. self._state.pop(key, None)