store.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929
  1. """
  2. FileSystem Trace Store - 文件系统存储实现
  3. 用于跨进程数据共享,数据持久化到 .trace/ 目录
  4. 目录结构:
  5. .trace/{trace_id}/
  6. ├── meta.json # Trace 元数据
  7. ├── goal.json # GoalTree(扁平 JSON,通过 parent_id 构建层级)
  8. ├── messages/ # Messages(每条独立文件)
  9. │ ├── {message_id}.json
  10. │ └── ...
  11. └── events.jsonl # 事件流(WebSocket 续传)
  12. Sub-Trace 是完全独立的 Trace,有自己的目录:
  13. .trace/{parent_id}@{mode}-{timestamp}-{seq}/
  14. ├── meta.json # parent_trace_id 指向父 Trace
  15. ├── goal.json
  16. ├── messages/
  17. └── events.jsonl
  18. """
  19. import json
  20. import os
  21. import logging
  22. from pathlib import Path
  23. from typing import Dict, List, Optional, Any
  24. from datetime import datetime
  25. from .models import Trace, Message
  26. from .goal_models import GoalTree, Goal, GoalStats
  27. logger = logging.getLogger(__name__)
  28. class FileSystemTraceStore:
  29. """文件系统 Trace 存储"""
  30. def __init__(self, base_path: str = ".trace"):
  31. self.base_path = Path(base_path)
  32. self.base_path.mkdir(exist_ok=True)
  33. def _get_trace_dir(self, trace_id: str, parent_trace_id: Optional[str] = None, stage_name: Optional[str] = None) -> Path:
  34. """
  35. 获取 trace 目录。
  36. 如果提供 parent_trace_id 和 stage_name,则创建在父目录的 agents/ 子目录下:
  37. - {base_path}/{parent_trace_id}/agents/{stage_name}-{trace_id_suffix}/
  38. 如果只提供 trace_id,则智能查找:
  39. 1. 先查根目录 {base_path}/{trace_id}/
  40. 2. 再查所有 agents/ 子目录(用于读取已存在的子 trace)
  41. Args:
  42. trace_id: Trace ID
  43. parent_trace_id: 父 Trace ID(创建子 trace 时提供)
  44. stage_name: 阶段名称(创建子 trace 时提供)
  45. """
  46. if parent_trace_id and stage_name:
  47. # 创建子 trace:放在父目录的 agents/ 下
  48. parent_dir = self.base_path / parent_trace_id / "agents"
  49. # 使用 stage_name + trace_id 后8位作为目录名
  50. trace_id_suffix = trace_id.split('-')[-1][:8]
  51. dir_name = f"{stage_name}-{trace_id_suffix}"
  52. return parent_dir / dir_name
  53. # 智能查找:先查根目录
  54. root_dir = self.base_path / trace_id
  55. if root_dir.exists():
  56. return root_dir
  57. # 再查所有 agents/ 子目录(用于读取已存在的子 trace)
  58. for parent_dir in self.base_path.iterdir():
  59. if not parent_dir.is_dir():
  60. continue
  61. agents_dir = parent_dir / "agents"
  62. if not agents_dir.exists():
  63. continue
  64. for sub_dir in agents_dir.iterdir():
  65. if not sub_dir.is_dir():
  66. continue
  67. # 检查 meta.json 中的 trace_id
  68. meta_file = sub_dir / "meta.json"
  69. if meta_file.exists():
  70. try:
  71. data = json.loads(meta_file.read_text(encoding="utf-8"))
  72. if data.get("trace_id") == trace_id:
  73. return sub_dir
  74. except Exception:
  75. continue
  76. # 找不到则返回根目录(向后兼容)
  77. return root_dir
  78. def _get_meta_file(self, trace_id: str) -> Path:
  79. """获取 meta.json 文件路径"""
  80. return self._get_trace_dir(trace_id) / "meta.json"
  81. def _get_goal_file(self, trace_id: str) -> Path:
  82. """获取 goal.json 文件路径"""
  83. return self._get_trace_dir(trace_id) / "goal.json"
  84. def _get_messages_dir(self, trace_id: str) -> Path:
  85. """获取 messages 目录"""
  86. return self._get_trace_dir(trace_id) / "messages"
  87. def _get_message_file(self, trace_id: str, message_id: str) -> Path:
  88. """获取 message 文件路径"""
  89. return self._get_messages_dir(trace_id) / f"{message_id}.json"
  90. def _get_events_file(self, trace_id: str) -> Path:
  91. """获取 events.jsonl 文件路径"""
  92. return self._get_trace_dir(trace_id) / "events.jsonl"
  93. def _get_model_usage_file(self, trace_id: str) -> Path:
  94. """获取 model_usage.json 文件路径"""
  95. return self._get_trace_dir(trace_id) / "model_usage.json"
  96. # ===== Trace 操作 =====
  97. async def create_trace(self, trace: Trace, stage_name: Optional[str] = None) -> str:
  98. """
  99. 创建新的 Trace。
  100. Args:
  101. trace: Trace 对象
  102. stage_name: 阶段名称(创建子 trace 时提供,用于目录命名)
  103. """
  104. # 如果有 parent_trace_id,使用层级化目录结构
  105. if trace.parent_trace_id:
  106. # 使用 stage_name 或 trace.task 作为目录名前缀
  107. dir_stage_name = stage_name or trace.task or "subtrace"
  108. trace_dir = self._get_trace_dir(trace.trace_id, trace.parent_trace_id, dir_stage_name)
  109. else:
  110. trace_dir = self._get_trace_dir(trace.trace_id)
  111. trace_dir.mkdir(parents=True, exist_ok=True)
  112. # 创建 messages 目录
  113. messages_dir = trace_dir / "messages"
  114. messages_dir.mkdir(exist_ok=True)
  115. # 写入 meta.json
  116. meta_file = trace_dir / "meta.json"
  117. meta_file.write_text(json.dumps(trace.to_dict(), indent=2, ensure_ascii=False), encoding="utf-8")
  118. # 创建空的 events.jsonl
  119. events_file = trace_dir / "events.jsonl"
  120. events_file.touch()
  121. return trace.trace_id
  122. async def get_trace(self, trace_id: str) -> Optional[Trace]:
  123. """获取 Trace"""
  124. meta_file = self._get_meta_file(trace_id)
  125. if not meta_file.exists():
  126. return None
  127. data = json.loads(meta_file.read_text(encoding="utf-8"))
  128. # 解析 datetime 字段
  129. if data.get("created_at"):
  130. data["created_at"] = datetime.fromisoformat(data["created_at"])
  131. if data.get("completed_at"):
  132. data["completed_at"] = datetime.fromisoformat(data["completed_at"])
  133. return Trace.from_dict(data)
  134. async def update_trace(self, trace_id: str, **updates) -> None:
  135. """更新 Trace"""
  136. trace = await self.get_trace(trace_id)
  137. if not trace:
  138. return
  139. # 更新字段
  140. for key, value in updates.items():
  141. if hasattr(trace, key):
  142. setattr(trace, key, value)
  143. # 写回文件
  144. meta_file = self._get_meta_file(trace_id)
  145. meta_file.write_text(json.dumps(trace.to_dict(), indent=2, ensure_ascii=False), encoding="utf-8")
  146. async def list_traces(
  147. self,
  148. mode: Optional[str] = None,
  149. agent_type: Optional[str] = None,
  150. uid: Optional[str] = None,
  151. status: Optional[str] = None,
  152. limit: int = 50
  153. ) -> List[Trace]:
  154. """列出 Traces"""
  155. traces = []
  156. if not self.base_path.exists():
  157. return []
  158. for trace_dir in self.base_path.iterdir():
  159. if not trace_dir.is_dir():
  160. continue
  161. meta_file = trace_dir / "meta.json"
  162. if not meta_file.exists():
  163. continue
  164. try:
  165. data = json.loads(meta_file.read_text(encoding="utf-8"))
  166. # 过滤
  167. if mode and data.get("mode") != mode:
  168. continue
  169. if agent_type and data.get("agent_type") != agent_type:
  170. continue
  171. if uid and data.get("uid") != uid:
  172. continue
  173. if status and data.get("status") != status:
  174. continue
  175. # 解析 datetime
  176. if data.get("created_at"):
  177. data["created_at"] = datetime.fromisoformat(data["created_at"])
  178. if data.get("completed_at"):
  179. data["completed_at"] = datetime.fromisoformat(data["completed_at"])
  180. traces.append(Trace.from_dict(data))
  181. except Exception:
  182. continue
  183. # 排序(最新的在前)
  184. traces.sort(key=lambda t: t.created_at, reverse=True)
  185. return traces[:limit]
  186. # ===== GoalTree 操作 =====
  187. async def get_goal_tree(self, trace_id: str) -> Optional[GoalTree]:
  188. """获取 GoalTree"""
  189. goal_file = self._get_goal_file(trace_id)
  190. if not goal_file.exists():
  191. return None
  192. try:
  193. data = json.loads(goal_file.read_text(encoding="utf-8"))
  194. return GoalTree.from_dict(data)
  195. except Exception:
  196. return None
  197. async def update_goal_tree(self, trace_id: str, tree: GoalTree) -> None:
  198. """更新完整 GoalTree"""
  199. goal_file = self._get_goal_file(trace_id)
  200. goal_file.write_text(json.dumps(tree.to_dict(), indent=2, ensure_ascii=False), encoding="utf-8")
  201. async def add_goal(self, trace_id: str, goal: Goal) -> None:
  202. """添加 Goal 到 GoalTree"""
  203. tree = await self.get_goal_tree(trace_id)
  204. if not tree:
  205. return
  206. tree.goals.append(goal)
  207. await self.update_goal_tree(trace_id, tree)
  208. # 推送 goal_added 事件
  209. event_data = {
  210. "goal": goal.to_dict(),
  211. "parent_id": goal.parent_id
  212. }
  213. await self.append_event(trace_id, "goal_added", event_data)
  214. # 打印详细的 goal 信息
  215. desc_preview = goal.description[:80] + "..." if len(goal.description) > 80 else goal.description
  216. print(f"[Goal Added] ID={goal.id}, Parent={goal.parent_id or 'root'}")
  217. print(f" 📝 {desc_preview}")
  218. if goal.reason:
  219. reason_preview = goal.reason[:60] + "..." if len(goal.reason) > 60 else goal.reason
  220. print(f" 💡 {reason_preview}")
  221. async def update_goal(self, trace_id: str, goal_id: str, **updates) -> None:
  222. """更新 Goal 字段"""
  223. tree = await self.get_goal_tree(trace_id)
  224. if not tree:
  225. return
  226. goal = tree.find(goal_id)
  227. if not goal:
  228. return
  229. # 更新字段
  230. for key, value in updates.items():
  231. if hasattr(goal, key):
  232. # 特殊处理 stats 字段(可能是 dict)
  233. if key in ["self_stats", "cumulative_stats"] and isinstance(value, dict):
  234. value = GoalStats.from_dict(value)
  235. setattr(goal, key, value)
  236. await self.update_goal_tree(trace_id, tree)
  237. # 推送 goal_updated 事件
  238. # 如果状态变为 completed,检查是否需要级联完成父 Goal
  239. affected_goals = [{"goal_id": goal_id, "updates": updates}]
  240. if updates.get("status") == "completed":
  241. # 检查级联完成:如果所有兄弟 Goal 都完成,父 Goal 也完成
  242. cascade_completed = await self._check_cascade_completion(trace_id, goal)
  243. affected_goals.extend(cascade_completed)
  244. await self.append_event(trace_id, "goal_updated", {
  245. "goal_id": goal_id,
  246. "updates": updates,
  247. "affected_goals": affected_goals
  248. })
  249. print(f"[DEBUG] Pushed goal_updated event: goal_id={goal_id}, updates={updates}, affected={len(affected_goals)}")
  250. # Goal 完成时触发知识评估
  251. if updates.get("status") in ["completed", "abandoned"]:
  252. pending = await self.get_pending_knowledge_entries(trace_id)
  253. if pending:
  254. # 在trace.context中设置标志,由runner主循环检查
  255. trace = await self.get_trace(trace_id)
  256. if trace:
  257. if not trace.context:
  258. trace.context = {}
  259. trace.context["pending_knowledge_eval"] = True
  260. trace.context["knowledge_eval_trigger"] = "goal_completion"
  261. await self.update_trace(trace_id, context=trace.context)
  262. logger.info(f"[Knowledge Eval] Goal {goal_id} 完成,设置评估标志,待评估知识: {len(pending)} 条")
  263. async def _check_cascade_completion(
  264. self,
  265. trace_id: str,
  266. completed_goal: Goal
  267. ) -> List[Dict[str, Any]]:
  268. """
  269. 检查级联完成:如果一个 Goal 的所有子 Goal 都完成,则自动完成父 Goal
  270. Args:
  271. trace_id: Trace ID
  272. completed_goal: 刚完成的 Goal
  273. Returns:
  274. 受影响的父 Goals 列表(自动完成的)
  275. """
  276. if not completed_goal.parent_id:
  277. return []
  278. tree = await self.get_goal_tree(trace_id)
  279. if not tree:
  280. return []
  281. affected = []
  282. parent = tree.find(completed_goal.parent_id)
  283. if not parent:
  284. return []
  285. # 获取父 Goal 的所有子 Goal
  286. children = tree.get_children(parent.id)
  287. # 检查是否所有子 Goal 都已完成(排除 abandoned)
  288. all_completed = all(
  289. child.status in ["completed", "abandoned"]
  290. for child in children
  291. )
  292. if all_completed and parent.status != "completed":
  293. # 自动完成父 Goal
  294. parent.status = "completed"
  295. if not parent.summary:
  296. # 生成自动摘要
  297. completed_count = sum(1 for c in children if c.status == "completed")
  298. parent.summary = f"所有子目标已完成 ({completed_count}/{len(children)})"
  299. await self.update_goal_tree(trace_id, tree)
  300. affected.append({
  301. "goal_id": parent.id,
  302. "status": "completed",
  303. "summary": parent.summary,
  304. "cumulative_stats": parent.cumulative_stats.to_dict()
  305. })
  306. # 递归检查祖父 Goal
  307. grandparent_affected = await self._check_cascade_completion(trace_id, parent)
  308. affected.extend(grandparent_affected)
  309. return affected
  310. # ===== Message 操作 =====
  311. async def add_message(self, message: Message) -> str:
  312. """
  313. 添加 Message
  314. 自动更新关联 Goal 的 stats(self_stats 和祖先的 cumulative_stats)
  315. """
  316. trace_id = message.trace_id
  317. # 1. 写入 message 文件
  318. messages_dir = self._get_messages_dir(trace_id)
  319. message_file = messages_dir / f"{message.message_id}.json"
  320. message_file.write_text(json.dumps(message.to_dict(), indent=2, ensure_ascii=False), encoding="utf-8")
  321. # 2. 更新 trace 统计
  322. trace = await self.get_trace(trace_id)
  323. if trace:
  324. trace.total_messages += 1
  325. trace.last_sequence = max(trace.last_sequence, message.sequence)
  326. # 累计 tokens(完整版)
  327. if message.prompt_tokens:
  328. trace.total_prompt_tokens += message.prompt_tokens
  329. if message.completion_tokens:
  330. trace.total_completion_tokens += message.completion_tokens
  331. if message.reasoning_tokens:
  332. trace.total_reasoning_tokens += message.reasoning_tokens
  333. if message.cache_creation_tokens:
  334. trace.total_cache_creation_tokens += message.cache_creation_tokens
  335. if message.cache_read_tokens:
  336. trace.total_cache_read_tokens += message.cache_read_tokens
  337. # 向后兼容:也更新 total_tokens
  338. if message.tokens:
  339. trace.total_tokens += message.tokens
  340. elif message.prompt_tokens or message.completion_tokens:
  341. trace.total_tokens += (message.prompt_tokens or 0) + (message.completion_tokens or 0)
  342. if message.cost:
  343. trace.total_cost += message.cost
  344. if message.duration_ms:
  345. trace.total_duration_ms += message.duration_ms
  346. # 更新 Trace
  347. await self.update_trace(
  348. trace_id,
  349. total_messages=trace.total_messages,
  350. last_sequence=trace.last_sequence,
  351. total_tokens=trace.total_tokens,
  352. total_prompt_tokens=trace.total_prompt_tokens,
  353. total_completion_tokens=trace.total_completion_tokens,
  354. total_reasoning_tokens=trace.total_reasoning_tokens,
  355. total_cache_creation_tokens=trace.total_cache_creation_tokens,
  356. total_cache_read_tokens=trace.total_cache_read_tokens,
  357. total_cost=trace.total_cost,
  358. total_duration_ms=trace.total_duration_ms
  359. )
  360. # 3. 更新 Goal stats
  361. await self._update_goal_stats(trace_id, message)
  362. # 4. 追加 message_added 事件
  363. affected_goals = await self._get_affected_goals(trace_id, message)
  364. event_id = await self.append_event(trace_id, "message_added", {
  365. "message": message.to_dict(),
  366. "affected_goals": affected_goals
  367. })
  368. if event_id:
  369. try:
  370. from . import websocket as trace_ws
  371. await trace_ws.broadcast_message_added(
  372. trace_id=trace_id,
  373. event_id=event_id,
  374. message_dict=message.to_dict(),
  375. affected_goals=affected_goals,
  376. )
  377. except Exception:
  378. logger.exception("Failed to broadcast message_added (trace_id=%s, event_id=%s)", trace_id, event_id)
  379. return message.message_id
  380. async def _update_goal_stats(self, trace_id: str, message: Message) -> None:
  381. """更新 Goal 的 self_stats 和祖先的 cumulative_stats"""
  382. tree = await self.get_goal_tree(trace_id)
  383. if not tree:
  384. return
  385. # 找到关联的 Goal
  386. goal = tree.find(message.goal_id)
  387. if not goal:
  388. return
  389. # 更新自身 self_stats
  390. goal.self_stats.message_count += 1
  391. if message.tokens:
  392. goal.self_stats.total_tokens += message.tokens
  393. if message.cost:
  394. goal.self_stats.total_cost += message.cost
  395. # TODO: 更新 preview(工具调用摘要)
  396. # 更新自身 cumulative_stats
  397. goal.cumulative_stats.message_count += 1
  398. if message.tokens:
  399. goal.cumulative_stats.total_tokens += message.tokens
  400. if message.cost:
  401. goal.cumulative_stats.total_cost += message.cost
  402. # 沿祖先链向上更新 cumulative_stats
  403. current_goal = goal
  404. while current_goal.parent_id:
  405. parent = tree.find(current_goal.parent_id)
  406. if not parent:
  407. break
  408. parent.cumulative_stats.message_count += 1
  409. if message.tokens:
  410. parent.cumulative_stats.total_tokens += message.tokens
  411. if message.cost:
  412. parent.cumulative_stats.total_cost += message.cost
  413. current_goal = parent
  414. # 保存更新后的 tree
  415. await self.update_goal_tree(trace_id, tree)
  416. async def _get_affected_goals(self, trace_id: str, message: Message) -> List[Dict[str, Any]]:
  417. """获取受影响的 Goals(自身 + 所有祖先)"""
  418. tree = await self.get_goal_tree(trace_id)
  419. if not tree:
  420. return []
  421. goal = tree.find(message.goal_id)
  422. if not goal:
  423. return []
  424. affected = []
  425. # 添加自身(包含 self_stats 和 cumulative_stats)
  426. affected.append({
  427. "goal_id": goal.id,
  428. "self_stats": goal.self_stats.to_dict(),
  429. "cumulative_stats": goal.cumulative_stats.to_dict()
  430. })
  431. # 添加所有祖先(仅 cumulative_stats)
  432. current_goal = goal
  433. while current_goal.parent_id:
  434. parent = tree.find(current_goal.parent_id)
  435. if not parent:
  436. break
  437. affected.append({
  438. "goal_id": parent.id,
  439. "cumulative_stats": parent.cumulative_stats.to_dict()
  440. })
  441. current_goal = parent
  442. return affected
  443. return affected
  444. async def get_message(self, message_id: str) -> Optional[Message]:
  445. """获取 Message(扫描所有 trace)"""
  446. for trace_dir in self.base_path.iterdir():
  447. if not trace_dir.is_dir():
  448. continue
  449. # 检查 messages 目录
  450. message_file = trace_dir / "messages" / f"{message_id}.json"
  451. if message_file.exists():
  452. try:
  453. data = json.loads(message_file.read_text(encoding="utf-8"))
  454. return Message.from_dict(data)
  455. except Exception:
  456. pass
  457. return None
  458. async def get_trace_messages(
  459. self,
  460. trace_id: str,
  461. ) -> List[Message]:
  462. """获取 Trace 的所有 Messages(包含所有分支,按 sequence 排序)"""
  463. messages_dir = self._get_messages_dir(trace_id)
  464. if not messages_dir.exists():
  465. return []
  466. messages = []
  467. for message_file in messages_dir.glob("*.json"):
  468. try:
  469. data = json.loads(message_file.read_text(encoding="utf-8"))
  470. msg = Message.from_dict(data)
  471. messages.append(msg)
  472. except Exception:
  473. continue
  474. # 按 sequence 排序
  475. messages.sort(key=lambda m: m.sequence)
  476. return messages
  477. async def get_main_path_messages(
  478. self,
  479. trace_id: str,
  480. head_sequence: int
  481. ) -> List[Message]:
  482. """
  483. 获取从 head_sequence 沿 parent_sequence 链回溯到 root 的完整路径
  484. 此函数是通用的路径追溯函数,返回从指定 head 到 root 的完整消息链。
  485. 只要 trace.head_sequence 管理正确(指向主路径),此函数自然返回主路径消息。
  486. 侧分支消息通过 parent_sequence 链自然被跳过(因为主路径的 parent 不指向侧分支)。
  487. Returns:
  488. 按 sequence 正序排列的路径 Message 列表
  489. """
  490. # 加载所有消息,建立 sequence -> Message 索引
  491. all_messages = await self.get_trace_messages(trace_id)
  492. messages_by_seq = {m.sequence: m for m in all_messages}
  493. # 从 head 沿 parent chain 回溯
  494. path = []
  495. seq = head_sequence
  496. while seq is not None:
  497. msg = messages_by_seq.get(seq)
  498. if not msg:
  499. break
  500. path.append(msg)
  501. seq = msg.parent_sequence
  502. # 反转为正序(root → head)
  503. path.reverse()
  504. return path
  505. async def get_messages_by_goal(
  506. self,
  507. trace_id: str,
  508. goal_id: str
  509. ) -> List[Message]:
  510. """获取指定 Goal 关联的所有 Messages"""
  511. all_messages = await self.get_trace_messages(trace_id)
  512. return [m for m in all_messages if m.goal_id == goal_id]
  513. async def update_message(self, message_id: str, **updates) -> None:
  514. """更新 Message 字段"""
  515. message = await self.get_message(message_id)
  516. if not message:
  517. return
  518. # 更新字段
  519. for key, value in updates.items():
  520. if hasattr(message, key):
  521. setattr(message, key, value)
  522. # 确定文件路径
  523. messages_dir = self._get_messages_dir(message.trace_id)
  524. message_file = messages_dir / f"{message_id}.json"
  525. message_file.write_text(json.dumps(message.to_dict(), indent=2, ensure_ascii=False), encoding="utf-8")
  526. async def abandon_messages_after(self, trace_id: str, cutoff_sequence: int) -> List[str]:
  527. """
  528. 将 sequence > cutoff_sequence 的 active messages 标记为 abandoned。
  529. 返回被 abandon 的 message_id 列表。
  530. """
  531. all_messages = await self.get_trace_messages(trace_id)
  532. abandoned_ids = []
  533. now = datetime.now()
  534. for msg in all_messages:
  535. if msg.sequence > cutoff_sequence and msg.status == "active":
  536. msg.status = "abandoned"
  537. msg.abandoned_at = now
  538. # 直接写回文件
  539. message_file = self._get_messages_dir(trace_id) / f"{msg.message_id}.json"
  540. message_file.write_text(
  541. json.dumps(msg.to_dict(), indent=2, ensure_ascii=False),
  542. encoding="utf-8"
  543. )
  544. abandoned_ids.append(msg.message_id)
  545. return abandoned_ids
  546. # ===== 模型使用追踪 =====
  547. async def record_model_usage(
  548. self,
  549. trace_id: str,
  550. sequence: int,
  551. role: str,
  552. model: str,
  553. prompt_tokens: int,
  554. completion_tokens: int,
  555. cache_read_tokens: int = 0,
  556. tool_name: Optional[str] = None,
  557. ) -> None:
  558. """
  559. 记录模型使用情况到 model_usage.json
  560. Args:
  561. trace_id: Trace ID
  562. sequence: 消息序号
  563. role: 角色(assistant/tool)
  564. model: 模型名称
  565. prompt_tokens: 输入tokens
  566. completion_tokens: 输出tokens
  567. cache_read_tokens: 缓存读取tokens
  568. tool_name: 工具名称(role=tool时)
  569. """
  570. usage_file = self._get_model_usage_file(trace_id)
  571. # 读取现有数据
  572. if usage_file.exists():
  573. data = json.loads(usage_file.read_text(encoding="utf-8"))
  574. else:
  575. data = {
  576. "summary": {
  577. "total_models": 0,
  578. "total_tokens": 0,
  579. "total_cache_read_tokens": 0,
  580. "agent_tokens": 0,
  581. "tool_tokens": 0,
  582. },
  583. "models": [],
  584. "timeline": [],
  585. }
  586. # 更新summary
  587. total_tokens = prompt_tokens + completion_tokens
  588. data["summary"]["total_tokens"] += total_tokens
  589. data["summary"]["total_cache_read_tokens"] += cache_read_tokens
  590. if role == "assistant":
  591. data["summary"]["agent_tokens"] += total_tokens
  592. source = "agent"
  593. else:
  594. data["summary"]["tool_tokens"] += total_tokens
  595. source = f"tool:{tool_name}" if tool_name else "tool"
  596. # 更新models列表
  597. model_entry = None
  598. for m in data["models"]:
  599. if m["model"] == model and m["source"] == source:
  600. model_entry = m
  601. break
  602. if model_entry:
  603. model_entry["prompt_tokens"] += prompt_tokens
  604. model_entry["completion_tokens"] += completion_tokens
  605. model_entry["total_tokens"] += total_tokens
  606. model_entry["cache_read_tokens"] += cache_read_tokens
  607. model_entry["call_count"] += 1
  608. else:
  609. data["models"].append({
  610. "model": model,
  611. "source": source,
  612. "prompt_tokens": prompt_tokens,
  613. "completion_tokens": completion_tokens,
  614. "total_tokens": total_tokens,
  615. "cache_read_tokens": cache_read_tokens,
  616. "call_count": 1,
  617. })
  618. data["summary"]["total_models"] = len(data["models"])
  619. # 添加到timeline
  620. timeline_entry = {
  621. "sequence": sequence,
  622. "role": role,
  623. "model": model,
  624. "prompt_tokens": prompt_tokens,
  625. "completion_tokens": completion_tokens,
  626. }
  627. if cache_read_tokens > 0:
  628. timeline_entry["cache_read_tokens"] = cache_read_tokens
  629. if tool_name:
  630. timeline_entry["tool_name"] = tool_name
  631. data["timeline"].append(timeline_entry)
  632. # 写回文件
  633. usage_file.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
  634. # ===== 事件流操作(用于 WebSocket 断线续传)=====
  635. async def get_events(
  636. self,
  637. trace_id: str,
  638. since_event_id: int = 0
  639. ) -> List[Dict[str, Any]]:
  640. """获取事件流"""
  641. events_file = self._get_events_file(trace_id)
  642. if not events_file.exists():
  643. return []
  644. events = []
  645. with events_file.open('r', encoding='utf-8') as f:
  646. for line in f:
  647. try:
  648. event = json.loads(line.strip())
  649. if event.get("event_id", 0) > since_event_id:
  650. events.append(event)
  651. except Exception:
  652. continue
  653. return events
  654. async def append_event(
  655. self,
  656. trace_id: str,
  657. event_type: str,
  658. payload: Dict[str, Any]
  659. ) -> int:
  660. """追加事件,返回 event_id"""
  661. # 获取 trace 并递增 event_id
  662. trace = await self.get_trace(trace_id)
  663. if not trace:
  664. return 0
  665. trace.last_event_id += 1
  666. event_id = trace.last_event_id
  667. # 更新 trace 的 last_event_id
  668. await self.update_trace(trace_id, last_event_id=event_id)
  669. # 创建事件
  670. event = {
  671. "event_id": event_id,
  672. "event": event_type,
  673. "ts": datetime.now().isoformat(),
  674. **payload
  675. }
  676. # 追加到 events.jsonl
  677. events_file = self._get_events_file(trace_id)
  678. with events_file.open('a', encoding='utf-8') as f:
  679. f.write(json.dumps(event, ensure_ascii=False) + '\n')
  680. return event_id
  681. # ===== Knowledge Log 管理 =====
  682. def _get_knowledge_log_file(self, trace_id: str) -> Path:
  683. """获取 knowledge_log.json 文件路径"""
  684. return self._get_trace_dir(trace_id) / "knowledge_log.json"
  685. async def get_knowledge_log(self, trace_id: str) -> Dict[str, Any]:
  686. """读取知识日志"""
  687. log_file = self._get_knowledge_log_file(trace_id)
  688. if not log_file.exists():
  689. return {"trace_id": trace_id, "entries": []}
  690. return json.loads(log_file.read_text(encoding="utf-8"))
  691. async def append_knowledge_entry(
  692. self,
  693. trace_id: str,
  694. knowledge_id: str,
  695. goal_id: str,
  696. injected_at_sequence: int,
  697. task: str,
  698. content: str
  699. ) -> None:
  700. """追加知识注入记录"""
  701. log = await self.get_knowledge_log(trace_id)
  702. log["entries"].append({
  703. "knowledge_id": knowledge_id,
  704. "goal_id": goal_id,
  705. "injected_at_sequence": injected_at_sequence,
  706. "injected_at": datetime.now().isoformat(),
  707. "task": task,
  708. "content": content[:500], # 限制长度
  709. "eval_result": None,
  710. "evaluated_at": None,
  711. "evaluated_at_trigger": None
  712. })
  713. log_file = self._get_knowledge_log_file(trace_id)
  714. log_file.write_text(json.dumps(log, indent=2, ensure_ascii=False), encoding="utf-8")
  715. async def update_knowledge_evaluation(
  716. self,
  717. trace_id: str,
  718. knowledge_id: str,
  719. eval_result: Dict[str, Any],
  720. trigger_event: str
  721. ) -> None:
  722. """更新知识评估结果
  723. 当同一个knowledge_id在不同goal中被多次注入时,
  724. 优先更新最近一个未评估的条目(按injected_at_sequence倒序)
  725. """
  726. log = await self.get_knowledge_log(trace_id)
  727. # 找到所有匹配且未评估的条目
  728. matching_entries = [
  729. (i, entry) for i, entry in enumerate(log["entries"])
  730. if entry["knowledge_id"] == knowledge_id and entry["eval_result"] is None
  731. ]
  732. if matching_entries:
  733. # 按injected_at_sequence倒序排序,取最近的一个
  734. matching_entries.sort(key=lambda x: x[1]["injected_at_sequence"], reverse=True)
  735. idx, entry = matching_entries[0]
  736. entry["eval_result"] = eval_result
  737. entry["evaluated_at"] = datetime.now().isoformat()
  738. entry["evaluated_at_trigger"] = trigger_event
  739. log_file = self._get_knowledge_log_file(trace_id)
  740. log_file.write_text(json.dumps(log, indent=2, ensure_ascii=False), encoding="utf-8")
  741. async def get_pending_knowledge_entries(self, trace_id: str) -> List[Dict[str, Any]]:
  742. """获取所有待评估的知识条目"""
  743. log = await self.get_knowledge_log(trace_id)
  744. return [e for e in log["entries"] if e["eval_result"] is None]
  745. async def update_user_feedback(
  746. self,
  747. trace_id: str,
  748. knowledge_id: str,
  749. user_feedback: Dict[str, Any]
  750. ) -> None:
  751. """记录用户对知识的反馈(confirm/override),不覆盖 agent 的 eval_result
  752. 当同一个 knowledge_id 被多次注入时,更新最近一次注入的条目。
  753. """
  754. log = await self.get_knowledge_log(trace_id)
  755. # 找到所有匹配的条目(不限 eval_result 是否为 None)
  756. matching_entries = [
  757. (i, entry) for i, entry in enumerate(log["entries"])
  758. if entry["knowledge_id"] == knowledge_id
  759. ]
  760. if matching_entries:
  761. # 按 injected_at_sequence 倒序,取最近一次注入的条目
  762. matching_entries.sort(key=lambda x: x[1]["injected_at_sequence"], reverse=True)
  763. idx, entry = matching_entries[0]
  764. entry["user_feedback"] = user_feedback
  765. log_file = self._get_knowledge_log_file(trace_id)
  766. log_file.write_text(json.dumps(log, indent=2, ensure_ascii=False), encoding="utf-8")