memory_store.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """
  2. Memory Trace Store - 内存存储实现
  3. 用于测试和简单场景,数据不持久化
  4. """
  5. from typing import Dict, List, Optional
  6. from agent.trace.models import Trace, Step
  7. class MemoryTraceStore:
  8. """内存 Trace 存储"""
  9. def __init__(self):
  10. self._traces: Dict[str, Trace] = {}
  11. self._steps: Dict[str, Step] = {}
  12. self._trace_steps: Dict[str, List[str]] = {} # trace_id -> [step_ids]
  13. async def create_trace(self, trace: Trace) -> str:
  14. self._traces[trace.trace_id] = trace
  15. self._trace_steps[trace.trace_id] = []
  16. return trace.trace_id
  17. async def get_trace(self, trace_id: str) -> Optional[Trace]:
  18. return self._traces.get(trace_id)
  19. async def update_trace(self, trace_id: str, **updates) -> None:
  20. trace = self._traces.get(trace_id)
  21. if trace:
  22. for key, value in updates.items():
  23. if hasattr(trace, key):
  24. setattr(trace, key, value)
  25. async def list_traces(
  26. self,
  27. mode: Optional[str] = None,
  28. agent_type: Optional[str] = None,
  29. uid: Optional[str] = None,
  30. status: Optional[str] = None,
  31. limit: int = 50
  32. ) -> List[Trace]:
  33. traces = list(self._traces.values())
  34. # 过滤
  35. if mode:
  36. traces = [t for t in traces if t.mode == mode]
  37. if agent_type:
  38. traces = [t for t in traces if t.agent_type == agent_type]
  39. if uid:
  40. traces = [t for t in traces if t.uid == uid]
  41. if status:
  42. traces = [t for t in traces if t.status == status]
  43. # 排序(最新的在前)
  44. traces.sort(key=lambda t: t.created_at, reverse=True)
  45. return traces[:limit]
  46. async def add_step(self, step: Step) -> str:
  47. self._steps[step.step_id] = step
  48. # 添加到 trace 的 steps 列表
  49. if step.trace_id in self._trace_steps:
  50. self._trace_steps[step.trace_id].append(step.step_id)
  51. # 更新 trace 的 total_steps
  52. trace = self._traces.get(step.trace_id)
  53. if trace:
  54. trace.total_steps += 1
  55. return step.step_id
  56. async def get_step(self, step_id: str) -> Optional[Step]:
  57. return self._steps.get(step_id)
  58. async def get_trace_steps(self, trace_id: str) -> List[Step]:
  59. step_ids = self._trace_steps.get(trace_id, [])
  60. steps = [self._steps[sid] for sid in step_ids if sid in self._steps]
  61. steps.sort(key=lambda s: s.sequence)
  62. return steps
  63. async def get_step_children(self, step_id: str) -> List[Step]:
  64. children = []
  65. for step in self._steps.values():
  66. if step.parent_id == step_id:
  67. children.append(step)
  68. children.sort(key=lambda s: s.sequence)
  69. return children