models.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. """
  2. Trace 和 Message 数据模型
  3. Trace: 一次完整的 LLM 交互(单次调用或 Agent 任务)
  4. Message: Trace 中的 LLM 消息,对应 LLM API 格式
  5. """
  6. from dataclasses import dataclass, field
  7. from datetime import datetime
  8. from typing import Dict, Any, List, Optional, Literal
  9. import uuid
  10. @dataclass
  11. class Trace:
  12. """
  13. 执行轨迹 - 一次完整的 LLM 交互
  14. 单次调用: mode="call"
  15. Agent 模式: mode="agent"
  16. 主 Trace 和 Sub-Trace 使用相同的数据结构。
  17. Sub-Trace 通过 parent_trace_id 和 parent_goal_id 关联父 Trace。
  18. """
  19. trace_id: str
  20. mode: Literal["call", "agent"]
  21. # Prompt 标识(可选)
  22. prompt_name: Optional[str] = None
  23. # Agent 模式特有
  24. task: Optional[str] = None
  25. agent_type: Optional[str] = None
  26. # 父子关系(Sub-Trace 特有)
  27. parent_trace_id: Optional[str] = None # 父 Trace ID
  28. parent_goal_id: Optional[str] = None # 哪个 Goal 启动的
  29. # 状态
  30. status: Literal["running", "completed", "failed"] = "running"
  31. # 统计
  32. total_messages: int = 0 # 消息总数(改名自 total_steps)
  33. total_tokens: int = 0 # 总 tokens(向后兼容,= prompt + completion)
  34. total_prompt_tokens: int = 0 # 总输入 tokens
  35. total_completion_tokens: int = 0 # 总输出 tokens
  36. total_cost: float = 0.0
  37. total_duration_ms: int = 0 # 总耗时(毫秒)
  38. # 进度追踪(head)
  39. last_sequence: int = 0 # 最新 message 的 sequence
  40. last_event_id: int = 0 # 最新事件 ID(用于 WS 续传)
  41. # 配置
  42. uid: Optional[str] = None
  43. model: Optional[str] = None # 默认模型
  44. tools: Optional[List[Dict]] = None # 工具定义(整个 trace 共享)
  45. llm_params: Dict[str, Any] = field(default_factory=dict) # LLM 参数(temperature 等)
  46. context: Dict[str, Any] = field(default_factory=dict) # 其他元数据
  47. # 当前焦点 goal
  48. current_goal_id: Optional[str] = None
  49. # 结果
  50. result_summary: Optional[str] = None # 执行结果摘要
  51. error_message: Optional[str] = None # 错误信息
  52. # 时间
  53. created_at: datetime = field(default_factory=datetime.now)
  54. completed_at: Optional[datetime] = None
  55. @classmethod
  56. def create(
  57. cls,
  58. mode: Literal["call", "agent"],
  59. **kwargs
  60. ) -> "Trace":
  61. """创建新的 Trace"""
  62. return cls(
  63. trace_id=str(uuid.uuid4()),
  64. mode=mode,
  65. **kwargs
  66. )
  67. def to_dict(self) -> Dict[str, Any]:
  68. """转换为字典"""
  69. return {
  70. "trace_id": self.trace_id,
  71. "mode": self.mode,
  72. "prompt_name": self.prompt_name,
  73. "task": self.task,
  74. "agent_type": self.agent_type,
  75. "parent_trace_id": self.parent_trace_id,
  76. "parent_goal_id": self.parent_goal_id,
  77. "status": self.status,
  78. "total_messages": self.total_messages,
  79. "total_tokens": self.total_tokens,
  80. "total_prompt_tokens": self.total_prompt_tokens,
  81. "total_completion_tokens": self.total_completion_tokens,
  82. "total_cost": self.total_cost,
  83. "total_duration_ms": self.total_duration_ms,
  84. "last_sequence": self.last_sequence,
  85. "last_event_id": self.last_event_id,
  86. "uid": self.uid,
  87. "model": self.model,
  88. "tools": self.tools,
  89. "llm_params": self.llm_params,
  90. "context": self.context,
  91. "current_goal_id": self.current_goal_id,
  92. "result_summary": self.result_summary,
  93. "error_message": self.error_message,
  94. "created_at": self.created_at.isoformat() if self.created_at else None,
  95. "completed_at": self.completed_at.isoformat() if self.completed_at else None,
  96. }
  97. @dataclass
  98. class Message:
  99. """
  100. 执行消息 - Trace 中的 LLM 消息
  101. 对应 LLM API 消息格式(system/user/assistant/tool),通过 goal_id 关联 Goal。
  102. description 字段自动生成规则:
  103. - system: 取 content 前 200 字符
  104. - user: 取 content 前 200 字符
  105. - assistant: 优先取 content,若无 content 则生成 "tool call: XX, XX"
  106. - tool: 使用 tool name
  107. """
  108. message_id: str
  109. trace_id: str
  110. role: Literal["system", "user", "assistant", "tool"] # 和 LLM API 一致
  111. sequence: int # 全局顺序
  112. goal_id: Optional[str] = None # 关联的 Goal 内部 ID(None = 还没有创建 Goal)
  113. description: str = "" # 消息描述(系统自动生成)
  114. tool_call_id: Optional[str] = None # tool 消息关联对应的 tool_call
  115. content: Any = None # 消息内容(和 LLM API 格式一致)
  116. # 元数据
  117. prompt_tokens: Optional[int] = None # 输入 tokens
  118. completion_tokens: Optional[int] = None # 输出 tokens
  119. cost: Optional[float] = None
  120. duration_ms: Optional[int] = None
  121. created_at: datetime = field(default_factory=datetime.now)
  122. # LLM 响应信息(仅 role="assistant" 时使用)
  123. finish_reason: Optional[str] = None # stop, length, tool_calls, content_filter 等
  124. @property
  125. def tokens(self) -> int:
  126. """动态计算总 tokens(向后兼容)"""
  127. return (self.prompt_tokens or 0) + (self.completion_tokens or 0)
  128. @classmethod
  129. def from_dict(cls, data: Dict[str, Any]) -> "Message":
  130. """从字典创建 Message(处理向后兼容)"""
  131. # 过滤掉已删除的字段
  132. filtered_data = {k: v for k, v in data.items() if k not in ["tokens", "available_tools"]}
  133. # 解析 datetime
  134. if filtered_data.get("created_at") and isinstance(filtered_data["created_at"], str):
  135. filtered_data["created_at"] = datetime.fromisoformat(filtered_data["created_at"])
  136. return cls(**filtered_data)
  137. @classmethod
  138. def create(
  139. cls,
  140. trace_id: str,
  141. role: Literal["system", "user", "assistant", "tool"],
  142. sequence: int,
  143. goal_id: Optional[str] = None,
  144. content: Any = None,
  145. tool_call_id: Optional[str] = None,
  146. prompt_tokens: Optional[int] = None,
  147. completion_tokens: Optional[int] = None,
  148. cost: Optional[float] = None,
  149. duration_ms: Optional[int] = None,
  150. finish_reason: Optional[str] = None,
  151. ) -> "Message":
  152. """创建新的 Message,自动生成 description"""
  153. description = cls._generate_description(role, content)
  154. return cls(
  155. message_id=f"{trace_id}-{sequence:04d}",
  156. trace_id=trace_id,
  157. role=role,
  158. sequence=sequence,
  159. goal_id=goal_id,
  160. content=content,
  161. description=description,
  162. tool_call_id=tool_call_id,
  163. prompt_tokens=prompt_tokens,
  164. completion_tokens=completion_tokens,
  165. cost=cost,
  166. duration_ms=duration_ms,
  167. finish_reason=finish_reason,
  168. )
  169. @staticmethod
  170. def _generate_description(role: str, content: Any) -> str:
  171. """
  172. 自动生成 description
  173. - system: 取 content 前 200 字符
  174. - user: 取 content 前 200 字符
  175. - assistant: 优先取 content,若无 content 则生成 "tool call: XX, XX"
  176. - tool: 使用 tool name
  177. """
  178. if role == "system":
  179. # system 消息:直接截取文本
  180. if isinstance(content, str):
  181. return content[:200] + "..." if len(content) > 200 else content
  182. return "system prompt"
  183. elif role == "user":
  184. # user 消息:直接截取文本
  185. if isinstance(content, str):
  186. return content[:200] + "..." if len(content) > 200 else content
  187. return "user message"
  188. elif role == "assistant":
  189. # assistant 消息:content 是字典,可能包含 text 和 tool_calls
  190. if isinstance(content, dict):
  191. # 优先返回文本内容
  192. if content.get("text"):
  193. text = content["text"]
  194. # 截断过长的文本
  195. return text[:200] + "..." if len(text) > 200 else text
  196. # 如果没有文本,检查 tool_calls
  197. if content.get("tool_calls"):
  198. tool_calls = content["tool_calls"]
  199. if isinstance(tool_calls, list):
  200. tool_names = []
  201. for tc in tool_calls:
  202. if isinstance(tc, dict) and tc.get("function", {}).get("name"):
  203. tool_names.append(tc["function"]["name"])
  204. if tool_names:
  205. return f"tool call: {', '.join(tool_names)}"
  206. # 如果 content 是字符串
  207. if isinstance(content, str):
  208. return content[:200] + "..." if len(content) > 200 else content
  209. return "assistant message"
  210. elif role == "tool":
  211. # tool 消息:从 content 中提取 tool name
  212. if isinstance(content, dict):
  213. if content.get("tool_name"):
  214. return content["tool_name"]
  215. # 如果是字符串,尝试解析
  216. if isinstance(content, str):
  217. return content[:100] + "..." if len(content) > 100 else content
  218. return "tool result"
  219. return ""
  220. def to_dict(self) -> Dict[str, Any]:
  221. """转换为字典"""
  222. return {
  223. "message_id": self.message_id,
  224. "trace_id": self.trace_id,
  225. "role": self.role,
  226. "sequence": self.sequence,
  227. "goal_id": self.goal_id,
  228. "tool_call_id": self.tool_call_id,
  229. "content": self.content,
  230. "description": self.description,
  231. "tokens": self.tokens, # 使用 @property 动态计算
  232. "prompt_tokens": self.prompt_tokens,
  233. "completion_tokens": self.completion_tokens,
  234. "cost": self.cost,
  235. "duration_ms": self.duration_ms,
  236. "finish_reason": self.finish_reason,
  237. "created_at": self.created_at.isoformat() if self.created_at else None,
  238. }
  239. # ===== 已弃用:Step 模型(保留用于向后兼容)=====
  240. # Step 类型
  241. StepType = Literal[
  242. "goal", "thought", "evaluation", "response",
  243. "action", "result", "memory_read", "memory_write",
  244. ]
  245. # Step 状态
  246. StepStatus = Literal[
  247. "planned", "in_progress", "awaiting_approval",
  248. "completed", "failed", "skipped",
  249. ]
  250. @dataclass
  251. class Step:
  252. """
  253. [已弃用] 执行步骤 - 使用 Message 模型替代
  254. 保留用于向后兼容
  255. """
  256. step_id: str
  257. trace_id: str
  258. step_type: StepType
  259. status: StepStatus
  260. sequence: int
  261. parent_id: Optional[str] = None
  262. description: str = ""
  263. data: Dict[str, Any] = field(default_factory=dict)
  264. summary: Optional[str] = None
  265. has_children: bool = False
  266. children_count: int = 0
  267. duration_ms: Optional[int] = None
  268. tokens: Optional[int] = None
  269. cost: Optional[float] = None
  270. created_at: datetime = field(default_factory=datetime.now)
  271. @classmethod
  272. def create(
  273. cls,
  274. trace_id: str,
  275. step_type: StepType,
  276. sequence: int,
  277. status: StepStatus = "completed",
  278. description: str = "",
  279. data: Dict[str, Any] = None,
  280. parent_id: Optional[str] = None,
  281. summary: Optional[str] = None,
  282. duration_ms: Optional[int] = None,
  283. tokens: Optional[int] = None,
  284. cost: Optional[float] = None,
  285. ) -> "Step":
  286. """创建新的 Step"""
  287. return cls(
  288. step_id=str(uuid.uuid4()),
  289. trace_id=trace_id,
  290. step_type=step_type,
  291. status=status,
  292. sequence=sequence,
  293. parent_id=parent_id,
  294. description=description,
  295. data=data or {},
  296. summary=summary,
  297. duration_ms=duration_ms,
  298. tokens=tokens,
  299. cost=cost,
  300. )
  301. def to_dict(self, view: str = "full") -> Dict[str, Any]:
  302. """
  303. 转换为字典
  304. Args:
  305. view: "compact" - 不返回大字段
  306. "full" - 返回完整数据
  307. """
  308. result = {
  309. "step_id": self.step_id,
  310. "trace_id": self.trace_id,
  311. "step_type": self.step_type,
  312. "status": self.status,
  313. "sequence": self.sequence,
  314. "parent_id": self.parent_id,
  315. "description": self.description,
  316. "summary": self.summary,
  317. "has_children": self.has_children,
  318. "children_count": self.children_count,
  319. "duration_ms": self.duration_ms,
  320. "tokens": self.tokens,
  321. "cost": self.cost,
  322. "created_at": self.created_at.isoformat() if self.created_at else None,
  323. }
  324. # 处理 data 字段
  325. if view == "compact":
  326. data_copy = self.data.copy()
  327. for key in ["output", "content", "full_output", "full_content"]:
  328. data_copy.pop(key, None)
  329. result["data"] = data_copy
  330. else:
  331. result["data"] = self.data
  332. return result