| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- """
- Memory Implementation - 内存存储实现
- 用于测试和简单场景,数据不持久化
- """
- from typing import Dict, List, Optional, Any
- from datetime import datetime
- from agent.models.trace import Trace, Step
- from agent.models.memory import Experience, Skill
- class MemoryTraceStore:
- """内存 Trace 存储"""
- def __init__(self):
- self._traces: Dict[str, Trace] = {}
- self._steps: Dict[str, Step] = {}
- self._trace_steps: Dict[str, List[str]] = {} # trace_id -> [step_ids]
- async def create_trace(self, trace: Trace) -> str:
- self._traces[trace.trace_id] = trace
- self._trace_steps[trace.trace_id] = []
- return trace.trace_id
- async def get_trace(self, trace_id: str) -> Optional[Trace]:
- return self._traces.get(trace_id)
- async def update_trace(self, trace_id: str, **updates) -> None:
- trace = self._traces.get(trace_id)
- if trace:
- for key, value in updates.items():
- if hasattr(trace, key):
- setattr(trace, key, value)
- async def list_traces(
- self,
- mode: Optional[str] = None,
- agent_type: Optional[str] = None,
- uid: Optional[str] = None,
- status: Optional[str] = None,
- limit: int = 50
- ) -> List[Trace]:
- traces = list(self._traces.values())
- # 过滤
- if mode:
- traces = [t for t in traces if t.mode == mode]
- if agent_type:
- traces = [t for t in traces if t.agent_type == agent_type]
- if uid:
- traces = [t for t in traces if t.uid == uid]
- if status:
- traces = [t for t in traces if t.status == status]
- # 排序(最新的在前)
- traces.sort(key=lambda t: t.created_at, reverse=True)
- return traces[:limit]
- async def add_step(self, step: Step) -> str:
- self._steps[step.step_id] = step
- # 添加到 trace 的 steps 列表
- if step.trace_id in self._trace_steps:
- self._trace_steps[step.trace_id].append(step.step_id)
- # 更新 trace 的 total_steps
- trace = self._traces.get(step.trace_id)
- if trace:
- trace.total_steps += 1
- return step.step_id
- async def get_step(self, step_id: str) -> Optional[Step]:
- return self._steps.get(step_id)
- async def get_trace_steps(self, trace_id: str) -> List[Step]:
- step_ids = self._trace_steps.get(trace_id, [])
- steps = [self._steps[sid] for sid in step_ids if sid in self._steps]
- steps.sort(key=lambda s: s.sequence)
- return steps
- async def get_step_children(self, step_id: str) -> List[Step]:
- children = []
- for step in self._steps.values():
- if step_id in step.parent_ids:
- children.append(step)
- children.sort(key=lambda s: s.sequence)
- return children
- class MemoryMemoryStore:
- """内存 Memory 存储(Experience + Skill)"""
- def __init__(self):
- self._experiences: Dict[str, Experience] = {}
- self._skills: Dict[str, Skill] = {}
- # ===== Experience =====
- async def add_experience(self, exp: Experience) -> str:
- self._experiences[exp.exp_id] = exp
- return exp.exp_id
- async def get_experience(self, exp_id: str) -> Optional[Experience]:
- return self._experiences.get(exp_id)
- async def search_experiences(
- self,
- scope: str,
- context: str,
- limit: int = 10
- ) -> List[Experience]:
- # 简单实现:按 scope 过滤,按 confidence 排序
- experiences = [
- e for e in self._experiences.values()
- if e.scope == scope
- ]
- experiences.sort(key=lambda e: e.confidence, reverse=True)
- return experiences[:limit]
- async def update_experience_stats(
- self,
- exp_id: str,
- success: bool
- ) -> None:
- exp = self._experiences.get(exp_id)
- if exp:
- exp.usage_count += 1
- if success:
- # 更新成功率
- total_success = exp.success_rate * (exp.usage_count - 1) + (1 if success else 0)
- exp.success_rate = total_success / exp.usage_count
- exp.updated_at = datetime.now()
- # ===== Skill =====
- async def add_skill(self, skill: Skill) -> str:
- self._skills[skill.skill_id] = skill
- return skill.skill_id
- async def get_skill(self, skill_id: str) -> Optional[Skill]:
- return self._skills.get(skill_id)
- async def get_skill_tree(self, scope: str) -> List[Skill]:
- return [s for s in self._skills.values() if s.scope == scope]
- async def search_skills(
- self,
- scope: str,
- context: str,
- limit: int = 5
- ) -> List[Skill]:
- # 简单实现:按 scope 过滤
- skills = [s for s in self._skills.values() if s.scope == scope]
- return skills[:limit]
- class MemoryStateStore:
- """内存状态存储"""
- def __init__(self):
- self._state: Dict[str, Dict[str, Any]] = {}
- async def get(self, key: str) -> Optional[Dict[str, Any]]:
- return self._state.get(key)
- async def set(
- self,
- key: str,
- value: Dict[str, Any],
- ttl: Optional[int] = None
- ) -> None:
- # 内存实现忽略 ttl
- self._state[key] = value
- async def update(self, key: str, **updates) -> None:
- if key in self._state:
- self._state[key].update(updates)
- async def delete(self, key: str) -> None:
- self._state.pop(key, None)
|