fs_store.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. """
  2. FileSystem Trace Store - 文件系统存储实现
  3. 用于跨进程数据共享,数据持久化到 .trace/ 目录
  4. 目录结构:
  5. .trace/{trace_id}/
  6. ├── meta.json # Trace 元数据
  7. ├── goal.json # GoalTree(扁平 JSON,通过 parent_id 构建层级)
  8. ├── messages/ # Messages(每条独立文件)
  9. │ ├── {message_id}.json
  10. │ └── ...
  11. └── events.jsonl # 事件流(WebSocket 续传)
  12. Sub-Trace 是完全独立的 Trace,有自己的目录:
  13. .trace/{parent_id}@{mode}-{timestamp}-{seq}/
  14. ├── meta.json # parent_trace_id 指向父 Trace
  15. ├── goal.json
  16. ├── messages/
  17. └── events.jsonl
  18. """
  19. import json
  20. import os
  21. from pathlib import Path
  22. from typing import Dict, List, Optional, Any
  23. from datetime import datetime
  24. from agent.execution.models import Trace, Message
  25. from agent.goal.models import GoalTree, Goal, GoalStats
  26. class FileSystemTraceStore:
  27. """文件系统 Trace 存储"""
  28. def __init__(self, base_path: str = ".trace"):
  29. self.base_path = Path(base_path)
  30. self.base_path.mkdir(exist_ok=True)
  31. def _get_trace_dir(self, trace_id: str) -> Path:
  32. """获取 trace 目录"""
  33. return self.base_path / trace_id
  34. def _get_meta_file(self, trace_id: str) -> Path:
  35. """获取 meta.json 文件路径"""
  36. return self._get_trace_dir(trace_id) / "meta.json"
  37. def _get_goal_file(self, trace_id: str) -> Path:
  38. """获取 goal.json 文件路径"""
  39. return self._get_trace_dir(trace_id) / "goal.json"
  40. def _get_messages_dir(self, trace_id: str) -> Path:
  41. """获取 messages 目录"""
  42. return self._get_trace_dir(trace_id) / "messages"
  43. def _get_message_file(self, trace_id: str, message_id: str) -> Path:
  44. """获取 message 文件路径"""
  45. return self._get_messages_dir(trace_id) / f"{message_id}.json"
  46. def _get_events_file(self, trace_id: str) -> Path:
  47. """获取 events.jsonl 文件路径"""
  48. return self._get_trace_dir(trace_id) / "events.jsonl"
  49. # ===== Trace 操作 =====
  50. async def create_trace(self, trace: Trace) -> str:
  51. """创建新的 Trace"""
  52. trace_dir = self._get_trace_dir(trace.trace_id)
  53. trace_dir.mkdir(exist_ok=True)
  54. # 创建 messages 目录
  55. messages_dir = self._get_messages_dir(trace.trace_id)
  56. messages_dir.mkdir(exist_ok=True)
  57. # 写入 meta.json
  58. meta_file = self._get_meta_file(trace.trace_id)
  59. meta_file.write_text(json.dumps(trace.to_dict(), indent=2, ensure_ascii=False))
  60. # 创建空的 events.jsonl
  61. events_file = self._get_events_file(trace.trace_id)
  62. events_file.touch()
  63. return trace.trace_id
  64. async def get_trace(self, trace_id: str) -> Optional[Trace]:
  65. """获取 Trace"""
  66. meta_file = self._get_meta_file(trace_id)
  67. if not meta_file.exists():
  68. return None
  69. data = json.loads(meta_file.read_text())
  70. # 解析 datetime 字段
  71. if data.get("created_at"):
  72. data["created_at"] = datetime.fromisoformat(data["created_at"])
  73. if data.get("completed_at"):
  74. data["completed_at"] = datetime.fromisoformat(data["completed_at"])
  75. return Trace(**data)
  76. async def update_trace(self, trace_id: str, **updates) -> None:
  77. """更新 Trace"""
  78. trace = await self.get_trace(trace_id)
  79. if not trace:
  80. return
  81. # 更新字段
  82. for key, value in updates.items():
  83. if hasattr(trace, key):
  84. setattr(trace, key, value)
  85. # 写回文件
  86. meta_file = self._get_meta_file(trace_id)
  87. meta_file.write_text(json.dumps(trace.to_dict(), indent=2, ensure_ascii=False))
  88. async def list_traces(
  89. self,
  90. mode: Optional[str] = None,
  91. agent_type: Optional[str] = None,
  92. uid: Optional[str] = None,
  93. status: Optional[str] = None,
  94. limit: int = 50
  95. ) -> List[Trace]:
  96. """列出 Traces"""
  97. traces = []
  98. if not self.base_path.exists():
  99. return []
  100. for trace_dir in self.base_path.iterdir():
  101. if not trace_dir.is_dir():
  102. continue
  103. meta_file = trace_dir / "meta.json"
  104. if not meta_file.exists():
  105. continue
  106. try:
  107. data = json.loads(meta_file.read_text())
  108. # 过滤
  109. if mode and data.get("mode") != mode:
  110. continue
  111. if agent_type and data.get("agent_type") != agent_type:
  112. continue
  113. if uid and data.get("uid") != uid:
  114. continue
  115. if status and data.get("status") != status:
  116. continue
  117. # 解析 datetime
  118. if data.get("created_at"):
  119. data["created_at"] = datetime.fromisoformat(data["created_at"])
  120. if data.get("completed_at"):
  121. data["completed_at"] = datetime.fromisoformat(data["completed_at"])
  122. traces.append(Trace(**data))
  123. except Exception:
  124. continue
  125. # 排序(最新的在前)
  126. traces.sort(key=lambda t: t.created_at, reverse=True)
  127. return traces[:limit]
  128. # ===== GoalTree 操作 =====
  129. async def get_goal_tree(self, trace_id: str) -> Optional[GoalTree]:
  130. """获取 GoalTree"""
  131. goal_file = self._get_goal_file(trace_id)
  132. if not goal_file.exists():
  133. return None
  134. try:
  135. data = json.loads(goal_file.read_text())
  136. return GoalTree.from_dict(data)
  137. except Exception:
  138. return None
  139. async def update_goal_tree(self, trace_id: str, tree: GoalTree) -> None:
  140. """更新完整 GoalTree"""
  141. goal_file = self._get_goal_file(trace_id)
  142. goal_file.write_text(json.dumps(tree.to_dict(), indent=2, ensure_ascii=False))
  143. async def add_goal(self, trace_id: str, goal: Goal) -> None:
  144. """添加 Goal 到 GoalTree"""
  145. tree = await self.get_goal_tree(trace_id)
  146. if not tree:
  147. return
  148. tree.goals.append(goal)
  149. await self.update_goal_tree(trace_id, tree)
  150. async def update_goal(self, trace_id: str, goal_id: str, **updates) -> None:
  151. """更新 Goal 字段"""
  152. tree = await self.get_goal_tree(trace_id)
  153. if not tree:
  154. return
  155. goal = tree.find(goal_id)
  156. if not goal:
  157. return
  158. # 更新字段
  159. for key, value in updates.items():
  160. if hasattr(goal, key):
  161. # 特殊处理 stats 字段(可能是 dict)
  162. if key in ["self_stats", "cumulative_stats"] and isinstance(value, dict):
  163. value = GoalStats.from_dict(value)
  164. setattr(goal, key, value)
  165. await self.update_goal_tree(trace_id, tree)
  166. # ===== Branch 操作 =====
  167. # ===== Message 操作 =====
  168. async def add_message(self, message: Message) -> str:
  169. """
  170. 添加 Message
  171. 自动更新关联 Goal 的 stats(self_stats 和祖先的 cumulative_stats)
  172. """
  173. trace_id = message.trace_id
  174. branch_id = message.branch_id
  175. # 1. 写入 message 文件
  176. if branch_id:
  177. # 分支消息
  178. messages_dir = self._get_branch_messages_dir(trace_id, branch_id)
  179. else:
  180. # 主线消息
  181. messages_dir = self._get_messages_dir(trace_id)
  182. message_file = messages_dir / f"{message.message_id}.json"
  183. message_file.write_text(json.dumps(message.to_dict(), indent=2, ensure_ascii=False))
  184. # 2. 更新 trace 统计
  185. trace = await self.get_trace(trace_id)
  186. if trace:
  187. trace.total_messages += 1
  188. trace.last_sequence = max(trace.last_sequence, message.sequence)
  189. if message.tokens:
  190. trace.total_tokens += message.tokens
  191. if message.cost:
  192. trace.total_cost += message.cost
  193. if message.duration_ms:
  194. trace.total_duration_ms += message.duration_ms
  195. # 更新 Trace(不要传递 trace_id,它已经在方法参数中)
  196. await self.update_trace(
  197. trace_id,
  198. total_messages=trace.total_messages,
  199. last_sequence=trace.last_sequence,
  200. total_tokens=trace.total_tokens,
  201. total_cost=trace.total_cost,
  202. total_duration_ms=trace.total_duration_ms
  203. )
  204. # 3. 更新 Goal stats
  205. await self._update_goal_stats(trace_id, message)
  206. # 4. 追加 message_added 事件
  207. affected_goals = await self._get_affected_goals(trace_id, message)
  208. await self.append_event(trace_id, "message_added", {
  209. "message": message.to_dict(),
  210. "affected_goals": affected_goals
  211. })
  212. return message.message_id
  213. async def _update_goal_stats(self, trace_id: str, message: Message) -> None:
  214. """更新 Goal 的 self_stats 和祖先的 cumulative_stats"""
  215. # 确定使用主线还是分支的 GoalTree
  216. if message.branch_id:
  217. tree = await self.get_branch_goal_tree(trace_id, message.branch_id)
  218. else:
  219. tree = await self.get_goal_tree(trace_id)
  220. if not tree:
  221. return
  222. # 找到关联的 Goal
  223. goal = tree.find(message.goal_id)
  224. if not goal:
  225. return
  226. # 更新自身 self_stats
  227. goal.self_stats.message_count += 1
  228. if message.tokens:
  229. goal.self_stats.total_tokens += message.tokens
  230. if message.cost:
  231. goal.self_stats.total_cost += message.cost
  232. # TODO: 更新 preview(工具调用摘要)
  233. # 更新自身 cumulative_stats
  234. goal.cumulative_stats.message_count += 1
  235. if message.tokens:
  236. goal.cumulative_stats.total_tokens += message.tokens
  237. if message.cost:
  238. goal.cumulative_stats.total_cost += message.cost
  239. # 沿祖先链向上更新 cumulative_stats
  240. current_goal = goal
  241. while current_goal.parent_id:
  242. parent = tree.find(current_goal.parent_id)
  243. if not parent:
  244. break
  245. parent.cumulative_stats.message_count += 1
  246. if message.tokens:
  247. parent.cumulative_stats.total_tokens += message.tokens
  248. if message.cost:
  249. parent.cumulative_stats.total_cost += message.cost
  250. current_goal = parent
  251. # 保存更新后的 tree
  252. if message.branch_id:
  253. await self.update_branch_goal_tree(trace_id, message.branch_id, tree)
  254. else:
  255. await self.update_goal_tree(trace_id, tree)
  256. async def _get_affected_goals(self, trace_id: str, message: Message) -> List[Dict[str, Any]]:
  257. """获取受影响的 Goals(自身 + 所有祖先)"""
  258. if message.branch_id:
  259. tree = await self.get_branch_goal_tree(trace_id, message.branch_id)
  260. else:
  261. tree = await self.get_goal_tree(trace_id)
  262. if not tree:
  263. return []
  264. goal = tree.find(message.goal_id)
  265. if not goal:
  266. return []
  267. affected = []
  268. # 添加自身(包含 self_stats 和 cumulative_stats)
  269. affected.append({
  270. "goal_id": goal.id,
  271. "self_stats": goal.self_stats.to_dict(),
  272. "cumulative_stats": goal.cumulative_stats.to_dict()
  273. })
  274. # 添加所有祖先(仅 cumulative_stats)
  275. current_goal = goal
  276. while current_goal.parent_id:
  277. parent = tree.find(current_goal.parent_id)
  278. if not parent:
  279. break
  280. affected.append({
  281. "goal_id": parent.id,
  282. "cumulative_stats": parent.cumulative_stats.to_dict()
  283. })
  284. current_goal = parent
  285. return affected
  286. async def get_message(self, message_id: str) -> Optional[Message]:
  287. """获取 Message(扫描所有 trace)"""
  288. for trace_dir in self.base_path.iterdir():
  289. if not trace_dir.is_dir():
  290. continue
  291. # 检查主线 messages
  292. message_file = trace_dir / "messages" / f"{message_id}.json"
  293. if message_file.exists():
  294. try:
  295. data = json.loads(message_file.read_text())
  296. if data.get("created_at"):
  297. data["created_at"] = datetime.fromisoformat(data["created_at"])
  298. return Message(**data)
  299. except Exception:
  300. pass
  301. # 检查分支 messages
  302. branches_dir = trace_dir / "branches"
  303. if branches_dir.exists():
  304. for branch_dir in branches_dir.iterdir():
  305. if not branch_dir.is_dir():
  306. continue
  307. message_file = branch_dir / "messages" / f"{message_id}.json"
  308. if message_file.exists():
  309. try:
  310. data = json.loads(message_file.read_text())
  311. if data.get("created_at"):
  312. data["created_at"] = datetime.fromisoformat(data["created_at"])
  313. return Message(**data)
  314. except Exception:
  315. pass
  316. return None
  317. async def get_trace_messages(
  318. self,
  319. trace_id: str,
  320. branch_id: Optional[str] = None
  321. ) -> List[Message]:
  322. """获取 Trace 的所有 Messages"""
  323. if branch_id:
  324. messages_dir = self._get_branch_messages_dir(trace_id, branch_id)
  325. else:
  326. messages_dir = self._get_messages_dir(trace_id)
  327. if not messages_dir.exists():
  328. return []
  329. messages = []
  330. for message_file in messages_dir.glob("*.json"):
  331. try:
  332. data = json.loads(message_file.read_text())
  333. if data.get("created_at"):
  334. data["created_at"] = datetime.fromisoformat(data["created_at"])
  335. messages.append(Message(**data))
  336. except Exception:
  337. continue
  338. # 按 sequence 排序
  339. messages.sort(key=lambda m: m.sequence)
  340. return messages
  341. async def get_messages_by_goal(
  342. self,
  343. trace_id: str,
  344. goal_id: str,
  345. branch_id: Optional[str] = None
  346. ) -> List[Message]:
  347. """获取指定 Goal 关联的所有 Messages"""
  348. all_messages = await self.get_trace_messages(trace_id, branch_id)
  349. return [m for m in all_messages if m.goal_id == goal_id]
  350. async def update_message(self, message_id: str, **updates) -> None:
  351. """更新 Message 字段"""
  352. message = await self.get_message(message_id)
  353. if not message:
  354. return
  355. # 更新字段
  356. for key, value in updates.items():
  357. if hasattr(message, key):
  358. setattr(message, key, value)
  359. # 确定文件路径
  360. if message.branch_id:
  361. messages_dir = self._get_branch_messages_dir(message.trace_id, message.branch_id)
  362. else:
  363. messages_dir = self._get_messages_dir(message.trace_id)
  364. message_file = messages_dir / f"{message_id}.json"
  365. message_file.write_text(json.dumps(message.to_dict(), indent=2, ensure_ascii=False))
  366. # ===== 事件流操作(用于 WebSocket 断线续传)=====
  367. async def get_events(
  368. self,
  369. trace_id: str,
  370. since_event_id: int = 0
  371. ) -> List[Dict[str, Any]]:
  372. """获取事件流"""
  373. events_file = self._get_events_file(trace_id)
  374. if not events_file.exists():
  375. return []
  376. events = []
  377. with events_file.open('r') as f:
  378. for line in f:
  379. try:
  380. event = json.loads(line.strip())
  381. if event.get("event_id", 0) > since_event_id:
  382. events.append(event)
  383. except Exception:
  384. continue
  385. return events
  386. async def append_event(
  387. self,
  388. trace_id: str,
  389. event_type: str,
  390. payload: Dict[str, Any]
  391. ) -> int:
  392. """追加事件,返回 event_id"""
  393. # 获取 trace 并递增 event_id
  394. trace = await self.get_trace(trace_id)
  395. if not trace:
  396. return 0
  397. trace.last_event_id += 1
  398. event_id = trace.last_event_id
  399. # 更新 trace 的 last_event_id
  400. await self.update_trace(trace_id, last_event_id=event_id)
  401. # 创建事件
  402. event = {
  403. "event_id": event_id,
  404. "event": event_type,
  405. "ts": datetime.now().isoformat(),
  406. **payload
  407. }
  408. # 追加到 events.jsonl
  409. events_file = self._get_events_file(trace_id)
  410. with events_file.open('a') as f:
  411. f.write(json.dumps(event, ensure_ascii=False) + '\n')
  412. return event_id