fs_store.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. """
  2. FileSystem Trace Store - 文件系统存储实现
  3. 用于跨进程数据共享,数据持久化到 .trace/ 目录
  4. 目录结构:
  5. .trace/{trace_id}/
  6. ├── meta.json # Trace 元数据
  7. ├── goal.json # 主线 GoalTree(扁平 JSON,通过 parent_id 构建层级)
  8. ├── messages/ # 主线 Messages(每条独立文件)
  9. │ ├── {message_id}.json
  10. │ └── ...
  11. ├── branches/ # 分支数据(独立存储)
  12. │ ├── A/
  13. │ │ ├── meta.json # BranchContext 元数据
  14. │ │ ├── goal.json # 分支 A 的 GoalTree
  15. │ │ └── messages/ # 分支 A 的 Messages
  16. │ └── B/
  17. │ └── ...
  18. └── events.jsonl # 事件流(WebSocket 续传)
  19. """
  20. import json
  21. import os
  22. from pathlib import Path
  23. from typing import Dict, List, Optional, Any
  24. from datetime import datetime
  25. from agent.execution.models import Trace, Message
  26. from agent.goal.models import GoalTree, Goal, BranchContext, GoalStats
  27. class FileSystemTraceStore:
  28. """文件系统 Trace 存储"""
  29. def __init__(self, base_path: str = ".trace"):
  30. self.base_path = Path(base_path)
  31. self.base_path.mkdir(exist_ok=True)
  32. def _get_trace_dir(self, trace_id: str) -> Path:
  33. """获取 trace 目录"""
  34. return self.base_path / trace_id
  35. def _get_meta_file(self, trace_id: str) -> Path:
  36. """获取 meta.json 文件路径"""
  37. return self._get_trace_dir(trace_id) / "meta.json"
  38. def _get_goal_file(self, trace_id: str) -> Path:
  39. """获取 goal.json 文件路径"""
  40. return self._get_trace_dir(trace_id) / "goal.json"
  41. def _get_messages_dir(self, trace_id: str) -> Path:
  42. """获取 messages 目录"""
  43. return self._get_trace_dir(trace_id) / "messages"
  44. def _get_message_file(self, trace_id: str, message_id: str) -> Path:
  45. """获取 message 文件路径"""
  46. return self._get_messages_dir(trace_id) / f"{message_id}.json"
  47. def _get_branches_dir(self, trace_id: str) -> Path:
  48. """获取 branches 目录"""
  49. return self._get_trace_dir(trace_id) / "branches"
  50. def _get_branch_dir(self, trace_id: str, branch_id: str) -> Path:
  51. """获取分支目录"""
  52. return self._get_branches_dir(trace_id) / branch_id
  53. def _get_branch_meta_file(self, trace_id: str, branch_id: str) -> Path:
  54. """获取分支 meta.json 文件路径"""
  55. return self._get_branch_dir(trace_id, branch_id) / "meta.json"
  56. def _get_branch_goal_file(self, trace_id: str, branch_id: str) -> Path:
  57. """获取分支 goal.json 文件路径"""
  58. return self._get_branch_dir(trace_id, branch_id) / "goal.json"
  59. def _get_branch_messages_dir(self, trace_id: str, branch_id: str) -> Path:
  60. """获取分支 messages 目录"""
  61. return self._get_branch_dir(trace_id, branch_id) / "messages"
  62. def _get_events_file(self, trace_id: str) -> Path:
  63. """获取 events.jsonl 文件路径"""
  64. return self._get_trace_dir(trace_id) / "events.jsonl"
  65. # ===== Trace 操作 =====
  66. async def create_trace(self, trace: Trace) -> str:
  67. """创建新的 Trace"""
  68. trace_dir = self._get_trace_dir(trace.trace_id)
  69. trace_dir.mkdir(exist_ok=True)
  70. # 创建 messages 目录
  71. messages_dir = self._get_messages_dir(trace.trace_id)
  72. messages_dir.mkdir(exist_ok=True)
  73. # 创建 branches 目录
  74. branches_dir = self._get_branches_dir(trace.trace_id)
  75. branches_dir.mkdir(exist_ok=True)
  76. # 写入 meta.json
  77. meta_file = self._get_meta_file(trace.trace_id)
  78. meta_file.write_text(json.dumps(trace.to_dict(), indent=2, ensure_ascii=False))
  79. # 创建空的 events.jsonl
  80. events_file = self._get_events_file(trace.trace_id)
  81. events_file.touch()
  82. return trace.trace_id
  83. async def get_trace(self, trace_id: str) -> Optional[Trace]:
  84. """获取 Trace"""
  85. meta_file = self._get_meta_file(trace_id)
  86. if not meta_file.exists():
  87. return None
  88. data = json.loads(meta_file.read_text())
  89. # 解析 datetime 字段
  90. if data.get("created_at"):
  91. data["created_at"] = datetime.fromisoformat(data["created_at"])
  92. if data.get("completed_at"):
  93. data["completed_at"] = datetime.fromisoformat(data["completed_at"])
  94. return Trace(**data)
  95. async def update_trace(self, trace_id: str, **updates) -> None:
  96. """更新 Trace"""
  97. trace = await self.get_trace(trace_id)
  98. if not trace:
  99. return
  100. # 更新字段
  101. for key, value in updates.items():
  102. if hasattr(trace, key):
  103. setattr(trace, key, value)
  104. # 写回文件
  105. meta_file = self._get_meta_file(trace_id)
  106. meta_file.write_text(json.dumps(trace.to_dict(), indent=2, ensure_ascii=False))
  107. async def list_traces(
  108. self,
  109. mode: Optional[str] = None,
  110. agent_type: Optional[str] = None,
  111. uid: Optional[str] = None,
  112. status: Optional[str] = None,
  113. limit: int = 50
  114. ) -> List[Trace]:
  115. """列出 Traces"""
  116. traces = []
  117. if not self.base_path.exists():
  118. return []
  119. for trace_dir in self.base_path.iterdir():
  120. if not trace_dir.is_dir():
  121. continue
  122. meta_file = trace_dir / "meta.json"
  123. if not meta_file.exists():
  124. continue
  125. try:
  126. data = json.loads(meta_file.read_text())
  127. # 过滤
  128. if mode and data.get("mode") != mode:
  129. continue
  130. if agent_type and data.get("agent_type") != agent_type:
  131. continue
  132. if uid and data.get("uid") != uid:
  133. continue
  134. if status and data.get("status") != status:
  135. continue
  136. # 解析 datetime
  137. if data.get("created_at"):
  138. data["created_at"] = datetime.fromisoformat(data["created_at"])
  139. if data.get("completed_at"):
  140. data["completed_at"] = datetime.fromisoformat(data["completed_at"])
  141. traces.append(Trace(**data))
  142. except Exception:
  143. continue
  144. # 排序(最新的在前)
  145. traces.sort(key=lambda t: t.created_at, reverse=True)
  146. return traces[:limit]
  147. # ===== GoalTree 操作 =====
  148. async def get_goal_tree(self, trace_id: str) -> Optional[GoalTree]:
  149. """获取 GoalTree"""
  150. goal_file = self._get_goal_file(trace_id)
  151. if not goal_file.exists():
  152. return None
  153. try:
  154. data = json.loads(goal_file.read_text())
  155. return GoalTree.from_dict(data)
  156. except Exception:
  157. return None
  158. async def update_goal_tree(self, trace_id: str, tree: GoalTree) -> None:
  159. """更新完整 GoalTree"""
  160. goal_file = self._get_goal_file(trace_id)
  161. goal_file.write_text(json.dumps(tree.to_dict(), indent=2, ensure_ascii=False))
  162. async def add_goal(self, trace_id: str, goal: Goal) -> None:
  163. """添加 Goal 到 GoalTree"""
  164. tree = await self.get_goal_tree(trace_id)
  165. if not tree:
  166. return
  167. tree.goals.append(goal)
  168. await self.update_goal_tree(trace_id, tree)
  169. async def update_goal(self, trace_id: str, goal_id: str, **updates) -> None:
  170. """更新 Goal 字段"""
  171. tree = await self.get_goal_tree(trace_id)
  172. if not tree:
  173. return
  174. goal = tree.find(goal_id)
  175. if not goal:
  176. return
  177. # 更新字段
  178. for key, value in updates.items():
  179. if hasattr(goal, key):
  180. # 特殊处理 stats 字段(可能是 dict)
  181. if key in ["self_stats", "cumulative_stats"] and isinstance(value, dict):
  182. value = GoalStats.from_dict(value)
  183. setattr(goal, key, value)
  184. await self.update_goal_tree(trace_id, tree)
  185. # ===== Branch 操作 =====
  186. async def create_branch(self, trace_id: str, branch: BranchContext) -> None:
  187. """创建分支上下文"""
  188. branch_dir = self._get_branch_dir(trace_id, branch.id)
  189. branch_dir.mkdir(parents=True, exist_ok=True)
  190. # 创建分支 messages 目录
  191. messages_dir = self._get_branch_messages_dir(trace_id, branch.id)
  192. messages_dir.mkdir(exist_ok=True)
  193. # 写入 meta.json
  194. meta_file = self._get_branch_meta_file(trace_id, branch.id)
  195. meta_file.write_text(json.dumps(branch.to_dict(), indent=2, ensure_ascii=False))
  196. async def get_branch(self, trace_id: str, branch_id: str) -> Optional[BranchContext]:
  197. """获取分支元数据"""
  198. meta_file = self._get_branch_meta_file(trace_id, branch_id)
  199. if not meta_file.exists():
  200. return None
  201. try:
  202. data = json.loads(meta_file.read_text())
  203. return BranchContext.from_dict(data)
  204. except Exception:
  205. return None
  206. async def get_branch_goal_tree(self, trace_id: str, branch_id: str) -> Optional[GoalTree]:
  207. """获取分支的 GoalTree"""
  208. goal_file = self._get_branch_goal_file(trace_id, branch_id)
  209. if not goal_file.exists():
  210. return None
  211. try:
  212. data = json.loads(goal_file.read_text())
  213. return GoalTree.from_dict(data)
  214. except Exception:
  215. return None
  216. async def update_branch_goal_tree(self, trace_id: str, branch_id: str, tree: GoalTree) -> None:
  217. """更新分支的 GoalTree"""
  218. goal_file = self._get_branch_goal_file(trace_id, branch_id)
  219. goal_file.write_text(json.dumps(tree.to_dict(), indent=2, ensure_ascii=False))
  220. async def update_branch(self, trace_id: str, branch_id: str, **updates) -> None:
  221. """更新分支元数据"""
  222. branch = await self.get_branch(trace_id, branch_id)
  223. if not branch:
  224. return
  225. # 更新字段
  226. for key, value in updates.items():
  227. if hasattr(branch, key):
  228. # 特殊处理 stats 字段
  229. if key == "cumulative_stats" and isinstance(value, dict):
  230. value = GoalStats.from_dict(value)
  231. setattr(branch, key, value)
  232. # 写回文件
  233. meta_file = self._get_branch_meta_file(trace_id, branch_id)
  234. meta_file.write_text(json.dumps(branch.to_dict(), indent=2, ensure_ascii=False))
  235. async def list_branches(self, trace_id: str) -> Dict[str, BranchContext]:
  236. """列出所有分支元数据"""
  237. branches_dir = self._get_branches_dir(trace_id)
  238. if not branches_dir.exists():
  239. return {}
  240. branches = {}
  241. for branch_dir in branches_dir.iterdir():
  242. if not branch_dir.is_dir():
  243. continue
  244. meta_file = branch_dir / "meta.json"
  245. if not meta_file.exists():
  246. continue
  247. try:
  248. data = json.loads(meta_file.read_text())
  249. branch = BranchContext.from_dict(data)
  250. branches[branch.id] = branch
  251. except Exception:
  252. continue
  253. return branches
  254. # ===== Message 操作 =====
  255. async def add_message(self, message: Message) -> str:
  256. """
  257. 添加 Message
  258. 自动更新关联 Goal 的 stats(self_stats 和祖先的 cumulative_stats)
  259. """
  260. trace_id = message.trace_id
  261. branch_id = message.branch_id
  262. # 1. 写入 message 文件
  263. if branch_id:
  264. # 分支消息
  265. messages_dir = self._get_branch_messages_dir(trace_id, branch_id)
  266. else:
  267. # 主线消息
  268. messages_dir = self._get_messages_dir(trace_id)
  269. message_file = messages_dir / f"{message.message_id}.json"
  270. message_file.write_text(json.dumps(message.to_dict(), indent=2, ensure_ascii=False))
  271. # 2. 更新 trace 统计
  272. trace = await self.get_trace(trace_id)
  273. if trace:
  274. trace.total_messages += 1
  275. trace.last_sequence = max(trace.last_sequence, message.sequence)
  276. if message.tokens:
  277. trace.total_tokens += message.tokens
  278. if message.cost:
  279. trace.total_cost += message.cost
  280. if message.duration_ms:
  281. trace.total_duration_ms += message.duration_ms
  282. # 更新 Trace(不要传递 trace_id,它已经在方法参数中)
  283. await self.update_trace(
  284. trace_id,
  285. total_messages=trace.total_messages,
  286. last_sequence=trace.last_sequence,
  287. total_tokens=trace.total_tokens,
  288. total_cost=trace.total_cost,
  289. total_duration_ms=trace.total_duration_ms
  290. )
  291. # 3. 更新 Goal stats
  292. await self._update_goal_stats(trace_id, message)
  293. # 4. 追加 message_added 事件
  294. affected_goals = await self._get_affected_goals(trace_id, message)
  295. await self.append_event(trace_id, "message_added", {
  296. "message": message.to_dict(),
  297. "affected_goals": affected_goals
  298. })
  299. return message.message_id
  300. async def _update_goal_stats(self, trace_id: str, message: Message) -> None:
  301. """更新 Goal 的 self_stats 和祖先的 cumulative_stats"""
  302. # 确定使用主线还是分支的 GoalTree
  303. if message.branch_id:
  304. tree = await self.get_branch_goal_tree(trace_id, message.branch_id)
  305. else:
  306. tree = await self.get_goal_tree(trace_id)
  307. if not tree:
  308. return
  309. # 找到关联的 Goal
  310. goal = tree.find(message.goal_id)
  311. if not goal:
  312. return
  313. # 更新自身 self_stats
  314. goal.self_stats.message_count += 1
  315. if message.tokens:
  316. goal.self_stats.total_tokens += message.tokens
  317. if message.cost:
  318. goal.self_stats.total_cost += message.cost
  319. # TODO: 更新 preview(工具调用摘要)
  320. # 更新自身 cumulative_stats
  321. goal.cumulative_stats.message_count += 1
  322. if message.tokens:
  323. goal.cumulative_stats.total_tokens += message.tokens
  324. if message.cost:
  325. goal.cumulative_stats.total_cost += message.cost
  326. # 沿祖先链向上更新 cumulative_stats
  327. current_goal = goal
  328. while current_goal.parent_id:
  329. parent = tree.find(current_goal.parent_id)
  330. if not parent:
  331. break
  332. parent.cumulative_stats.message_count += 1
  333. if message.tokens:
  334. parent.cumulative_stats.total_tokens += message.tokens
  335. if message.cost:
  336. parent.cumulative_stats.total_cost += message.cost
  337. current_goal = parent
  338. # 保存更新后的 tree
  339. if message.branch_id:
  340. await self.update_branch_goal_tree(trace_id, message.branch_id, tree)
  341. else:
  342. await self.update_goal_tree(trace_id, tree)
  343. async def _get_affected_goals(self, trace_id: str, message: Message) -> List[Dict[str, Any]]:
  344. """获取受影响的 Goals(自身 + 所有祖先)"""
  345. if message.branch_id:
  346. tree = await self.get_branch_goal_tree(trace_id, message.branch_id)
  347. else:
  348. tree = await self.get_goal_tree(trace_id)
  349. if not tree:
  350. return []
  351. goal = tree.find(message.goal_id)
  352. if not goal:
  353. return []
  354. affected = []
  355. # 添加自身(包含 self_stats 和 cumulative_stats)
  356. affected.append({
  357. "goal_id": goal.id,
  358. "self_stats": goal.self_stats.to_dict(),
  359. "cumulative_stats": goal.cumulative_stats.to_dict()
  360. })
  361. # 添加所有祖先(仅 cumulative_stats)
  362. current_goal = goal
  363. while current_goal.parent_id:
  364. parent = tree.find(current_goal.parent_id)
  365. if not parent:
  366. break
  367. affected.append({
  368. "goal_id": parent.id,
  369. "cumulative_stats": parent.cumulative_stats.to_dict()
  370. })
  371. current_goal = parent
  372. return affected
  373. async def get_message(self, message_id: str) -> Optional[Message]:
  374. """获取 Message(扫描所有 trace)"""
  375. for trace_dir in self.base_path.iterdir():
  376. if not trace_dir.is_dir():
  377. continue
  378. # 检查主线 messages
  379. message_file = trace_dir / "messages" / f"{message_id}.json"
  380. if message_file.exists():
  381. try:
  382. data = json.loads(message_file.read_text())
  383. if data.get("created_at"):
  384. data["created_at"] = datetime.fromisoformat(data["created_at"])
  385. return Message(**data)
  386. except Exception:
  387. pass
  388. # 检查分支 messages
  389. branches_dir = trace_dir / "branches"
  390. if branches_dir.exists():
  391. for branch_dir in branches_dir.iterdir():
  392. if not branch_dir.is_dir():
  393. continue
  394. message_file = branch_dir / "messages" / f"{message_id}.json"
  395. if message_file.exists():
  396. try:
  397. data = json.loads(message_file.read_text())
  398. if data.get("created_at"):
  399. data["created_at"] = datetime.fromisoformat(data["created_at"])
  400. return Message(**data)
  401. except Exception:
  402. pass
  403. return None
  404. async def get_trace_messages(
  405. self,
  406. trace_id: str,
  407. branch_id: Optional[str] = None
  408. ) -> List[Message]:
  409. """获取 Trace 的所有 Messages"""
  410. if branch_id:
  411. messages_dir = self._get_branch_messages_dir(trace_id, branch_id)
  412. else:
  413. messages_dir = self._get_messages_dir(trace_id)
  414. if not messages_dir.exists():
  415. return []
  416. messages = []
  417. for message_file in messages_dir.glob("*.json"):
  418. try:
  419. data = json.loads(message_file.read_text())
  420. if data.get("created_at"):
  421. data["created_at"] = datetime.fromisoformat(data["created_at"])
  422. messages.append(Message(**data))
  423. except Exception:
  424. continue
  425. # 按 sequence 排序
  426. messages.sort(key=lambda m: m.sequence)
  427. return messages
  428. async def get_messages_by_goal(
  429. self,
  430. trace_id: str,
  431. goal_id: str,
  432. branch_id: Optional[str] = None
  433. ) -> List[Message]:
  434. """获取指定 Goal 关联的所有 Messages"""
  435. all_messages = await self.get_trace_messages(trace_id, branch_id)
  436. return [m for m in all_messages if m.goal_id == goal_id]
  437. async def update_message(self, message_id: str, **updates) -> None:
  438. """更新 Message 字段"""
  439. message = await self.get_message(message_id)
  440. if not message:
  441. return
  442. # 更新字段
  443. for key, value in updates.items():
  444. if hasattr(message, key):
  445. setattr(message, key, value)
  446. # 确定文件路径
  447. if message.branch_id:
  448. messages_dir = self._get_branch_messages_dir(message.trace_id, message.branch_id)
  449. else:
  450. messages_dir = self._get_messages_dir(message.trace_id)
  451. message_file = messages_dir / f"{message_id}.json"
  452. message_file.write_text(json.dumps(message.to_dict(), indent=2, ensure_ascii=False))
  453. # ===== 事件流操作(用于 WebSocket 断线续传)=====
  454. async def get_events(
  455. self,
  456. trace_id: str,
  457. since_event_id: int = 0
  458. ) -> List[Dict[str, Any]]:
  459. """获取事件流"""
  460. events_file = self._get_events_file(trace_id)
  461. if not events_file.exists():
  462. return []
  463. events = []
  464. with events_file.open('r') as f:
  465. for line in f:
  466. try:
  467. event = json.loads(line.strip())
  468. if event.get("event_id", 0) > since_event_id:
  469. events.append(event)
  470. except Exception:
  471. continue
  472. return events
  473. async def append_event(
  474. self,
  475. trace_id: str,
  476. event_type: str,
  477. payload: Dict[str, Any]
  478. ) -> int:
  479. """追加事件,返回 event_id"""
  480. # 获取 trace 并递增 event_id
  481. trace = await self.get_trace(trace_id)
  482. if not trace:
  483. return 0
  484. trace.last_event_id += 1
  485. event_id = trace.last_event_id
  486. # 更新 trace 的 last_event_id
  487. await self.update_trace(trace_id, last_event_id=event_id)
  488. # 创建事件
  489. event = {
  490. "event_id": event_id,
  491. "event": event_type,
  492. "ts": datetime.now().isoformat(),
  493. **payload
  494. }
  495. # 追加到 events.jsonl
  496. events_file = self._get_events_file(trace_id)
  497. with events_file.open('a') as f:
  498. f.write(json.dumps(event, ensure_ascii=False) + '\n')
  499. return event_id