models.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. """
  2. Trace 和 Message 数据模型
  3. Trace: 一次完整的 LLM 交互(单次调用或 Agent 任务)
  4. Message: Trace 中的 LLM 消息,对应 LLM API 格式
  5. """
  6. from dataclasses import dataclass, field
  7. from datetime import datetime
  8. from typing import Dict, Any, List, Optional, Literal, Union
  9. import uuid
  10. # ===== 消息线格式类型别名 =====
  11. # 轻量 wire-format 类型,用于工具参数和 runner/LLM API 接口。
  12. # 内部存储使用下方的 Message dataclass。
  13. ChatMessage = Dict[str, Any] # 单条 OpenAI 格式消息
  14. Messages = List[ChatMessage] # 消息列表
  15. MessageContent = Union[str, List[Dict[str, str]]] # content 字段(文本或多模态)
  16. # 导入 TokenUsage(延迟导入避免循环依赖)
  17. def _get_token_usage_class():
  18. from ..llm.usage import TokenUsage
  19. return TokenUsage
  20. @dataclass
  21. class Trace:
  22. """
  23. 执行轨迹 - 一次完整的 LLM 交互
  24. 单次调用: mode="call"
  25. Agent 模式: mode="agent"
  26. 主 Trace 和 Sub-Trace 使用相同的数据结构。
  27. Sub-Trace 通过 parent_trace_id 和 parent_goal_id 关联父 Trace。
  28. """
  29. trace_id: str
  30. mode: Literal["call", "agent"]
  31. # Prompt 标识(可选)
  32. prompt_name: Optional[str] = None
  33. # Agent 模式特有
  34. task: Optional[str] = None
  35. agent_type: Optional[str] = None
  36. # 父子关系(Sub-Trace 特有)
  37. parent_trace_id: Optional[str] = None # 父 Trace ID
  38. parent_goal_id: Optional[str] = None # 哪个 Goal 启动的
  39. # 状态
  40. status: Literal["running", "completed", "failed", "stopped"] = "running"
  41. # 统计
  42. total_messages: int = 0 # 消息总数(改名自 total_steps)
  43. total_tokens: int = 0 # 总 tokens(向后兼容,= prompt + completion)
  44. total_prompt_tokens: int = 0 # 总输入 tokens
  45. total_completion_tokens: int = 0 # 总输出 tokens
  46. total_reasoning_tokens: int = 0 # 总推理 tokens(o1/o3, DeepSeek R1, Gemini thinking)
  47. total_cache_creation_tokens: int = 0 # 总缓存创建 tokens(Claude)
  48. total_cache_read_tokens: int = 0 # 总缓存读取 tokens(Claude)
  49. total_cost: float = 0.0
  50. total_duration_ms: int = 0 # 总耗时(毫秒)
  51. # 进度追踪(head)
  52. last_sequence: int = 0 # 最新 message 的 sequence
  53. head_sequence: int = 0 # 当前主路径的头节点 sequence(用于 build_llm_messages)
  54. last_event_id: int = 0 # 最新事件 ID(用于 WS 续传)
  55. # 配置
  56. uid: Optional[str] = None
  57. model: Optional[str] = None # 默认模型
  58. tools: Optional[List[Dict]] = None # 工具定义(整个 trace 共享)
  59. llm_params: Dict[str, Any] = field(default_factory=dict) # LLM 参数(temperature 等)
  60. context: Dict[str, Any] = field(default_factory=dict) # 其他元数据
  61. # 当前焦点 goal
  62. current_goal_id: Optional[str] = None
  63. # Memory 系统 - 记忆反思的进度追踪(见 agent/docs/memory-plan.md 第四节)
  64. # dream 操作扫描 reflected_at_sequence < latest_sequence 的 trace 做反思;
  65. # None 表示该 trace 从未被记忆反思处理过。
  66. reflected_at_sequence: Optional[int] = None
  67. # 结果
  68. result_summary: Optional[str] = None # 执行结果摘要
  69. error_message: Optional[str] = None # 错误信息
  70. # 时间
  71. created_at: datetime = field(default_factory=datetime.now)
  72. completed_at: Optional[datetime] = None
  73. last_activity_at: datetime = field(default_factory=datetime.now) # 最后活动时间(用于判断是否真正运行中)
  74. @classmethod
  75. def create(
  76. cls,
  77. mode: Literal["call", "agent"],
  78. **kwargs
  79. ) -> "Trace":
  80. """创建新的 Trace"""
  81. return cls(
  82. trace_id=str(uuid.uuid4()),
  83. mode=mode,
  84. **kwargs
  85. )
  86. @classmethod
  87. def from_dict(cls, data: Dict[str, Any]) -> "Trace":
  88. """从字典创建 Trace(处理日期字段反序列化)"""
  89. from dateutil import parser
  90. # 处理日期字段
  91. if "created_at" in data and isinstance(data["created_at"], str):
  92. data["created_at"] = parser.isoparse(data["created_at"])
  93. if "completed_at" in data and isinstance(data["completed_at"], str):
  94. data["completed_at"] = parser.isoparse(data["completed_at"])
  95. if "last_activity_at" in data and isinstance(data["last_activity_at"], str):
  96. data["last_activity_at"] = parser.isoparse(data["last_activity_at"])
  97. return cls(**data)
  98. def to_dict(self) -> Dict[str, Any]:
  99. """转换为字典"""
  100. return {
  101. "trace_id": self.trace_id,
  102. "mode": self.mode,
  103. "prompt_name": self.prompt_name,
  104. "task": self.task,
  105. "agent_type": self.agent_type,
  106. "parent_trace_id": self.parent_trace_id,
  107. "parent_goal_id": self.parent_goal_id,
  108. "status": self.status,
  109. "total_messages": self.total_messages,
  110. "total_tokens": self.total_tokens,
  111. "total_prompt_tokens": self.total_prompt_tokens,
  112. "total_completion_tokens": self.total_completion_tokens,
  113. "total_reasoning_tokens": self.total_reasoning_tokens,
  114. "total_cache_creation_tokens": self.total_cache_creation_tokens,
  115. "total_cache_read_tokens": self.total_cache_read_tokens,
  116. "total_cost": self.total_cost,
  117. "total_duration_ms": self.total_duration_ms,
  118. "last_sequence": self.last_sequence,
  119. "head_sequence": self.head_sequence,
  120. "last_event_id": self.last_event_id,
  121. "uid": self.uid,
  122. "model": self.model,
  123. "tools": self.tools,
  124. "llm_params": self.llm_params,
  125. "context": self.context,
  126. "current_goal_id": self.current_goal_id,
  127. "reflected_at_sequence": self.reflected_at_sequence,
  128. "result_summary": self.result_summary,
  129. "error_message": self.error_message,
  130. "created_at": self.created_at.isoformat() if self.created_at else None,
  131. "completed_at": self.completed_at.isoformat() if self.completed_at else None,
  132. "last_activity_at": self.last_activity_at.isoformat() if self.last_activity_at else None,
  133. }
  134. @dataclass
  135. class Message:
  136. """
  137. 执行消息 - Trace 中的 LLM 消息
  138. 对应 LLM API 消息格式(system/user/assistant/tool),通过 goal_id 关联 Goal。
  139. description 字段自动生成规则:
  140. - system: 取 content 前 200 字符
  141. - user: 取 content 前 200 字符
  142. - assistant: 优先取 content,若无 content 则生成 "tool call: XX, XX"
  143. - tool: 使用 tool name
  144. """
  145. message_id: str
  146. trace_id: str
  147. role: Literal["system", "user", "assistant", "tool"] # 和 LLM API 一致
  148. sequence: int # 全局顺序
  149. parent_sequence: Optional[int] = None # 父消息的 sequence(构成消息树)
  150. status: Literal["active", "abandoned"] = "active" # [已弃用] 由 parent_sequence 树结构替代
  151. goal_id: Optional[str] = None # 关联的 Goal 内部 ID(None = 还没有创建 Goal)
  152. description: str = "" # 消息描述(系统自动生成)
  153. tool_call_id: Optional[str] = None # tool 消息关联对应的 tool_call
  154. content: Any = None # 消息内容(和 LLM API 格式一致)
  155. # 侧分支标记
  156. branch_type: Optional[Literal["compression", "reflection", "knowledge_eval"]] = None # 侧分支类型(None = 主路径)
  157. branch_id: Optional[str] = None # 侧分支 ID(同一侧分支的消息共享)
  158. # 元数据
  159. prompt_tokens: Optional[int] = None # 输入 tokens
  160. completion_tokens: Optional[int] = None # 输出 tokens
  161. reasoning_tokens: Optional[int] = None # 推理 tokens(o1/o3, DeepSeek R1, Gemini thinking)
  162. cache_creation_tokens: Optional[int] = None # 缓存创建 tokens(Claude)
  163. cache_read_tokens: Optional[int] = None # 缓存读取 tokens(Claude)
  164. cost: Optional[float] = None
  165. duration_ms: Optional[int] = None
  166. created_at: datetime = field(default_factory=datetime.now)
  167. abandoned_at: Optional[datetime] = None # [已弃用] 由 parent_sequence 树结构替代
  168. # LLM 响应信息(仅 role="assistant" 时使用)
  169. finish_reason: Optional[str] = None # stop, length, tool_calls, content_filter 等
  170. @property
  171. def tokens(self) -> int:
  172. """动态计算总 tokens(向后兼容,input + output)"""
  173. return (self.prompt_tokens or 0) + (self.completion_tokens or 0)
  174. @property
  175. def all_tokens(self) -> int:
  176. """所有 tokens(包括 reasoning)"""
  177. return self.tokens + (self.reasoning_tokens or 0)
  178. def get_usage(self):
  179. """获取 TokenUsage 对象"""
  180. TokenUsage = _get_token_usage_class()
  181. return TokenUsage(
  182. input_tokens=self.prompt_tokens or 0,
  183. output_tokens=self.completion_tokens or 0,
  184. reasoning_tokens=self.reasoning_tokens or 0,
  185. cache_creation_tokens=self.cache_creation_tokens or 0,
  186. cache_read_tokens=self.cache_read_tokens or 0,
  187. )
  188. def to_llm_dict(self) -> Dict[str, Any]:
  189. """转换为 OpenAI SDK 格式的消息字典(用于 LLM 调用)"""
  190. msg: Dict[str, Any] = {"role": self.role, "_message_id": self.message_id}
  191. if self.role == "tool":
  192. # tool message: tool_call_id + name + content
  193. if self.tool_call_id:
  194. msg["tool_call_id"] = self.tool_call_id
  195. msg["name"] = self.description or "unknown"
  196. if isinstance(self.content, dict):
  197. result = self.content.get("result", self.content)
  198. # result 可能是 list(含图片的多模态内容)或字符串
  199. msg["content"] = result if isinstance(result, list) else str(result)
  200. else:
  201. msg["content"] = str(self.content) if self.content is not None else ""
  202. elif self.role == "assistant":
  203. # assistant message: content(text) + tool_calls
  204. if isinstance(self.content, dict):
  205. msg["content"] = self.content.get("text", "") or ""
  206. if self.content.get("tool_calls"):
  207. msg["tool_calls"] = self.content["tool_calls"]
  208. elif isinstance(self.content, str):
  209. msg["content"] = self.content
  210. else:
  211. msg["content"] = ""
  212. else:
  213. # system / user message: content 直接传
  214. msg["content"] = self.content
  215. return msg
  216. @classmethod
  217. def from_llm_dict(
  218. cls,
  219. d: Dict[str, Any],
  220. trace_id: str,
  221. sequence: int,
  222. goal_id: Optional[str] = None,
  223. parent_sequence: Optional[int] = None,
  224. ) -> "Message":
  225. """从 OpenAI SDK 格式创建 Message"""
  226. role = d["role"]
  227. if role == "assistant":
  228. content = {"text": d.get("content", ""), "tool_calls": d.get("tool_calls")}
  229. elif role == "tool":
  230. content = {"tool_name": d.get("name", "unknown"), "result": d.get("content", "")}
  231. else:
  232. content = d.get("content", "")
  233. return cls.create(
  234. trace_id=trace_id,
  235. role=role,
  236. sequence=sequence,
  237. goal_id=goal_id,
  238. parent_sequence=parent_sequence,
  239. content=content,
  240. tool_call_id=d.get("tool_call_id"),
  241. )
  242. @classmethod
  243. def from_dict(cls, data: Dict[str, Any]) -> "Message":
  244. """从字典创建 Message(处理向后兼容)"""
  245. # 过滤掉已删除的字段
  246. filtered_data = {k: v for k, v in data.items() if k not in ["tokens", "available_tools"]}
  247. # 解析 datetime
  248. if filtered_data.get("created_at") and isinstance(filtered_data["created_at"], str):
  249. filtered_data["created_at"] = datetime.fromisoformat(filtered_data["created_at"])
  250. if filtered_data.get("abandoned_at") and isinstance(filtered_data["abandoned_at"], str):
  251. filtered_data["abandoned_at"] = datetime.fromisoformat(filtered_data["abandoned_at"])
  252. # 向后兼容:旧消息没有 status 字段,默认 active
  253. if "status" not in filtered_data:
  254. filtered_data["status"] = "active"
  255. # 向后兼容:旧消息没有 parent_sequence 字段
  256. if "parent_sequence" not in filtered_data:
  257. filtered_data["parent_sequence"] = None
  258. # 向后兼容:旧消息没有侧分支字段
  259. if "branch_type" not in filtered_data:
  260. filtered_data["branch_type"] = None
  261. if "branch_id" not in filtered_data:
  262. filtered_data["branch_id"] = None
  263. return cls(**filtered_data)
  264. @classmethod
  265. def create(
  266. cls,
  267. trace_id: str,
  268. role: Literal["system", "user", "assistant", "tool"],
  269. sequence: int,
  270. goal_id: Optional[str] = None,
  271. content: Any = None,
  272. tool_call_id: Optional[str] = None,
  273. parent_sequence: Optional[int] = None,
  274. branch_type: Optional[Literal["compression", "reflection", "knowledge_eval"]] = None,
  275. branch_id: Optional[str] = None,
  276. prompt_tokens: Optional[int] = None,
  277. completion_tokens: Optional[int] = None,
  278. reasoning_tokens: Optional[int] = None,
  279. cache_creation_tokens: Optional[int] = None,
  280. cache_read_tokens: Optional[int] = None,
  281. cost: Optional[float] = None,
  282. duration_ms: Optional[int] = None,
  283. finish_reason: Optional[str] = None,
  284. ) -> "Message":
  285. """创建新的 Message,自动生成 description"""
  286. description = cls._generate_description(role, content)
  287. return cls(
  288. message_id=f"{trace_id}-{sequence:04d}",
  289. trace_id=trace_id,
  290. role=role,
  291. sequence=sequence,
  292. parent_sequence=parent_sequence,
  293. goal_id=goal_id,
  294. content=content,
  295. description=description,
  296. tool_call_id=tool_call_id,
  297. branch_type=branch_type,
  298. branch_id=branch_id,
  299. prompt_tokens=prompt_tokens,
  300. completion_tokens=completion_tokens,
  301. reasoning_tokens=reasoning_tokens,
  302. cache_creation_tokens=cache_creation_tokens,
  303. cache_read_tokens=cache_read_tokens,
  304. cost=cost,
  305. duration_ms=duration_ms,
  306. finish_reason=finish_reason,
  307. )
  308. @staticmethod
  309. def _generate_description(role: str, content: Any) -> str:
  310. """
  311. 自动生成 description
  312. - system: 取 content 前 200 字符
  313. - user: 取 content 前 200 字符
  314. - assistant: 优先取 content,若无 content 则生成 "tool call: XX, XX"
  315. - tool: 使用 tool name
  316. """
  317. if role == "system":
  318. # system 消息:直接返回文本
  319. if isinstance(content, str):
  320. return content
  321. return "system prompt"
  322. elif role == "user":
  323. # user 消息:直接返回文本
  324. if isinstance(content, str):
  325. return content
  326. return "user message"
  327. elif role == "assistant":
  328. # assistant 消息:content 是字典,可能包含 text 和 tool_calls
  329. if isinstance(content, dict):
  330. # 优先返回文本内容
  331. if content.get("text"):
  332. text = content["text"]
  333. # 返回完整文本
  334. return text
  335. # 如果没有文本,检查 tool_calls
  336. if content.get("tool_calls"):
  337. tool_calls = content["tool_calls"]
  338. if isinstance(tool_calls, list):
  339. tool_descriptions = []
  340. for tc in tool_calls:
  341. if isinstance(tc, dict) and tc.get("function", {}).get("name"):
  342. tool_name = tc["function"]["name"]
  343. # 提取参数并截断到 100 字符
  344. tool_args = tc["function"].get("arguments", "{}")
  345. if isinstance(tool_args, str):
  346. args_str = tool_args
  347. else:
  348. import json
  349. args_str = json.dumps(tool_args, ensure_ascii=False)
  350. args_display = args_str[:100] + "..." if len(args_str) > 100 else args_str
  351. tool_descriptions.append(f"{tool_name}({args_display})")
  352. if tool_descriptions:
  353. return "tool call: " + ", ".join(tool_descriptions)
  354. # 如果 content 是字符串
  355. if isinstance(content, str):
  356. return content
  357. return "assistant message"
  358. elif role == "tool":
  359. # tool 消息:从 content 中提取 tool name
  360. if isinstance(content, dict):
  361. if content.get("tool_name"):
  362. return content["tool_name"]
  363. # 如果是字符串,尝试解析
  364. if isinstance(content, str):
  365. return content[:100] + "..." if len(content) > 100 else content
  366. return "tool result"
  367. return ""
  368. def to_dict(self) -> Dict[str, Any]:
  369. """转换为字典"""
  370. result = {
  371. "message_id": self.message_id,
  372. "trace_id": self.trace_id,
  373. "role": self.role,
  374. "sequence": self.sequence,
  375. "parent_sequence": self.parent_sequence,
  376. "status": self.status,
  377. "goal_id": self.goal_id,
  378. "tool_call_id": self.tool_call_id,
  379. "content": self.content,
  380. "description": self.description,
  381. "tokens": self.tokens, # 使用 @property 动态计算
  382. "prompt_tokens": self.prompt_tokens,
  383. "completion_tokens": self.completion_tokens,
  384. "cost": self.cost,
  385. "duration_ms": self.duration_ms,
  386. "finish_reason": self.finish_reason,
  387. "created_at": self.created_at.isoformat() if self.created_at else None,
  388. }
  389. # 只添加非空的可选字段
  390. if self.abandoned_at:
  391. result["abandoned_at"] = self.abandoned_at.isoformat()
  392. if self.reasoning_tokens is not None:
  393. result["reasoning_tokens"] = self.reasoning_tokens
  394. if self.cache_creation_tokens is not None:
  395. result["cache_creation_tokens"] = self.cache_creation_tokens
  396. if self.cache_read_tokens is not None:
  397. result["cache_read_tokens"] = self.cache_read_tokens
  398. return result
  399. # ===== 已弃用:Step 模型(保留用于向后兼容)=====
  400. # Step 类型
  401. StepType = Literal[
  402. "goal", "thought", "evaluation", "response",
  403. "action", "result", "memory_read", "memory_write",
  404. ]
  405. # Step 状态
  406. StepStatus = Literal[
  407. "planned", "in_progress", "awaiting_approval",
  408. "completed", "failed", "skipped",
  409. ]
  410. @dataclass
  411. class Step:
  412. """
  413. [已弃用] 执行步骤 - 使用 Message 模型替代
  414. 保留用于向后兼容
  415. """
  416. step_id: str
  417. trace_id: str
  418. step_type: StepType
  419. status: StepStatus
  420. sequence: int
  421. parent_id: Optional[str] = None
  422. description: str = ""
  423. data: Dict[str, Any] = field(default_factory=dict)
  424. summary: Optional[str] = None
  425. has_children: bool = False
  426. children_count: int = 0
  427. duration_ms: Optional[int] = None
  428. tokens: Optional[int] = None
  429. cost: Optional[float] = None
  430. created_at: datetime = field(default_factory=datetime.now)
  431. @classmethod
  432. def create(
  433. cls,
  434. trace_id: str,
  435. step_type: StepType,
  436. sequence: int,
  437. status: StepStatus = "completed",
  438. description: str = "",
  439. data: Dict[str, Any] = None,
  440. parent_id: Optional[str] = None,
  441. summary: Optional[str] = None,
  442. duration_ms: Optional[int] = None,
  443. tokens: Optional[int] = None,
  444. cost: Optional[float] = None,
  445. ) -> "Step":
  446. """创建新的 Step"""
  447. return cls(
  448. step_id=str(uuid.uuid4()),
  449. trace_id=trace_id,
  450. step_type=step_type,
  451. status=status,
  452. sequence=sequence,
  453. parent_id=parent_id,
  454. description=description,
  455. data=data or {},
  456. summary=summary,
  457. duration_ms=duration_ms,
  458. tokens=tokens,
  459. cost=cost,
  460. )
  461. def to_dict(self, view: str = "full") -> Dict[str, Any]:
  462. """
  463. 转换为字典
  464. Args:
  465. view: "compact" - 不返回大字段
  466. "full" - 返回完整数据
  467. """
  468. result = {
  469. "step_id": self.step_id,
  470. "trace_id": self.trace_id,
  471. "step_type": self.step_type,
  472. "status": self.status,
  473. "sequence": self.sequence,
  474. "parent_id": self.parent_id,
  475. "description": self.description,
  476. "summary": self.summary,
  477. "has_children": self.has_children,
  478. "children_count": self.children_count,
  479. "duration_ms": self.duration_ms,
  480. "tokens": self.tokens,
  481. "cost": self.cost,
  482. "created_at": self.created_at.isoformat() if self.created_at else None,
  483. }
  484. # 处理 data 字段
  485. if view == "compact":
  486. data_copy = self.data.copy()
  487. for key in ["output", "content", "full_output", "full_content"]:
  488. data_copy.pop(key, None)
  489. result["data"] = data_copy
  490. else:
  491. result["data"] = self.data
  492. return result