models.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. """
  2. Trace 和 Message 数据模型
  3. Trace: 一次完整的 LLM 交互(单次调用或 Agent 任务)
  4. Message: Trace 中的 LLM 消息,对应 LLM API 格式
  5. """
  6. from dataclasses import dataclass, field
  7. from datetime import datetime
  8. from typing import Dict, Any, List, Optional, Literal
  9. import uuid
  10. @dataclass
  11. class Trace:
  12. """
  13. 执行轨迹 - 一次完整的 LLM 交互
  14. 单次调用: mode="call"
  15. Agent 模式: mode="agent"
  16. """
  17. trace_id: str
  18. mode: Literal["call", "agent"]
  19. # Prompt 标识(可选)
  20. prompt_name: Optional[str] = None
  21. # Agent 模式特有
  22. task: Optional[str] = None
  23. agent_type: Optional[str] = None
  24. # 状态
  25. status: Literal["running", "completed", "failed"] = "running"
  26. # 统计
  27. total_messages: int = 0 # 消息总数(改名自 total_steps)
  28. total_tokens: int = 0
  29. total_cost: float = 0.0
  30. total_duration_ms: int = 0 # 总耗时(毫秒)
  31. # 进度追踪(head)
  32. last_sequence: int = 0 # 最新 message 的 sequence
  33. last_event_id: int = 0 # 最新事件 ID(用于 WS 续传)
  34. # 上下文
  35. uid: Optional[str] = None
  36. context: Dict[str, Any] = field(default_factory=dict)
  37. # 当前焦点 goal
  38. current_goal_id: Optional[str] = None
  39. # 时间
  40. created_at: datetime = field(default_factory=datetime.now)
  41. completed_at: Optional[datetime] = None
  42. @classmethod
  43. def create(
  44. cls,
  45. mode: Literal["call", "agent"],
  46. **kwargs
  47. ) -> "Trace":
  48. """创建新的 Trace"""
  49. return cls(
  50. trace_id=str(uuid.uuid4()),
  51. mode=mode,
  52. **kwargs
  53. )
  54. def to_dict(self) -> Dict[str, Any]:
  55. """转换为字典"""
  56. return {
  57. "trace_id": self.trace_id,
  58. "mode": self.mode,
  59. "prompt_name": self.prompt_name,
  60. "task": self.task,
  61. "agent_type": self.agent_type,
  62. "status": self.status,
  63. "total_messages": self.total_messages,
  64. "total_tokens": self.total_tokens,
  65. "total_cost": self.total_cost,
  66. "total_duration_ms": self.total_duration_ms,
  67. "last_sequence": self.last_sequence,
  68. "last_event_id": self.last_event_id,
  69. "uid": self.uid,
  70. "context": self.context,
  71. "current_goal_id": self.current_goal_id,
  72. "created_at": self.created_at.isoformat() if self.created_at else None,
  73. "completed_at": self.completed_at.isoformat() if self.completed_at else None,
  74. }
  75. @dataclass
  76. class Message:
  77. """
  78. 执行消息 - Trace 中的 LLM 消息
  79. 对应 LLM API 消息格式(assistant/tool),通过 goal_id 和 branch_id 关联 Goal。
  80. description 字段自动生成规则:
  81. - assistant: 优先取 content,若无 content 则生成 "tool call: XX, XX"
  82. - tool: 使用 tool name
  83. """
  84. message_id: str
  85. trace_id: str
  86. role: Literal["assistant", "tool"] # 和 LLM API 一致
  87. sequence: int # 全局顺序
  88. goal_id: str # 关联的 Goal 内部 ID
  89. description: str = "" # 消息描述(系统自动生成)
  90. branch_id: Optional[str] = None # 所属分支(null=主线, "A"/"B"=分支)
  91. tool_call_id: Optional[str] = None # tool 消息关联对应的 tool_call
  92. content: Any = None # 消息内容(和 LLM API 格式一致)
  93. # 元数据
  94. tokens: Optional[int] = None
  95. cost: Optional[float] = None
  96. duration_ms: Optional[int] = None
  97. created_at: datetime = field(default_factory=datetime.now)
  98. @classmethod
  99. def create(
  100. cls,
  101. trace_id: str,
  102. role: Literal["assistant", "tool"],
  103. sequence: int,
  104. goal_id: str,
  105. content: Any = None,
  106. branch_id: Optional[str] = None,
  107. tool_call_id: Optional[str] = None,
  108. tokens: Optional[int] = None,
  109. cost: Optional[float] = None,
  110. duration_ms: Optional[int] = None,
  111. ) -> "Message":
  112. """创建新的 Message,自动生成 description"""
  113. description = cls._generate_description(role, content)
  114. return cls(
  115. message_id=str(uuid.uuid4()),
  116. trace_id=trace_id,
  117. role=role,
  118. sequence=sequence,
  119. goal_id=goal_id,
  120. content=content,
  121. description=description,
  122. branch_id=branch_id,
  123. tool_call_id=tool_call_id,
  124. tokens=tokens,
  125. cost=cost,
  126. duration_ms=duration_ms,
  127. )
  128. @staticmethod
  129. def _generate_description(role: str, content: Any) -> str:
  130. """
  131. 自动生成 description
  132. - assistant: 优先取 content,若无 content 则生成 "tool call: XX, XX"
  133. - tool: 使用 tool name
  134. """
  135. if role == "assistant":
  136. # assistant 消息:content 是字典,可能包含 text 和 tool_calls
  137. if isinstance(content, dict):
  138. # 优先返回文本内容
  139. if content.get("text"):
  140. text = content["text"]
  141. # 截断过长的文本
  142. return text[:200] + "..." if len(text) > 200 else text
  143. # 如果没有文本,检查 tool_calls
  144. if content.get("tool_calls"):
  145. tool_calls = content["tool_calls"]
  146. if isinstance(tool_calls, list):
  147. tool_names = []
  148. for tc in tool_calls:
  149. if isinstance(tc, dict) and tc.get("function", {}).get("name"):
  150. tool_names.append(tc["function"]["name"])
  151. if tool_names:
  152. return f"tool call: {', '.join(tool_names)}"
  153. # 如果 content 是字符串
  154. if isinstance(content, str):
  155. return content[:200] + "..." if len(content) > 200 else content
  156. return "assistant message"
  157. elif role == "tool":
  158. # tool 消息:从 content 中提取 tool name
  159. if isinstance(content, dict):
  160. if content.get("tool_name"):
  161. return content["tool_name"]
  162. # 如果是字符串,尝试解析
  163. if isinstance(content, str):
  164. return content[:100] + "..." if len(content) > 100 else content
  165. return "tool result"
  166. return ""
  167. def to_dict(self) -> Dict[str, Any]:
  168. """转换为字典"""
  169. return {
  170. "message_id": self.message_id,
  171. "trace_id": self.trace_id,
  172. "branch_id": self.branch_id,
  173. "role": self.role,
  174. "sequence": self.sequence,
  175. "goal_id": self.goal_id,
  176. "tool_call_id": self.tool_call_id,
  177. "content": self.content,
  178. "description": self.description,
  179. "tokens": self.tokens,
  180. "cost": self.cost,
  181. "duration_ms": self.duration_ms,
  182. "created_at": self.created_at.isoformat() if self.created_at else None,
  183. }
  184. # ===== 已弃用:Step 模型(保留用于向后兼容)=====
  185. # Step 类型
  186. StepType = Literal[
  187. "goal", "thought", "evaluation", "response",
  188. "action", "result", "memory_read", "memory_write",
  189. ]
  190. # Step 状态
  191. StepStatus = Literal[
  192. "planned", "in_progress", "awaiting_approval",
  193. "completed", "failed", "skipped",
  194. ]
  195. @dataclass
  196. class Step:
  197. """
  198. [已弃用] 执行步骤 - 使用 Message 模型替代
  199. 保留用于向后兼容
  200. """
  201. step_id: str
  202. trace_id: str
  203. step_type: StepType
  204. status: StepStatus
  205. sequence: int
  206. parent_id: Optional[str] = None
  207. description: str = ""
  208. data: Dict[str, Any] = field(default_factory=dict)
  209. summary: Optional[str] = None
  210. has_children: bool = False
  211. children_count: int = 0
  212. duration_ms: Optional[int] = None
  213. tokens: Optional[int] = None
  214. cost: Optional[float] = None
  215. created_at: datetime = field(default_factory=datetime.now)
  216. @classmethod
  217. def create(
  218. cls,
  219. trace_id: str,
  220. step_type: StepType,
  221. sequence: int,
  222. status: StepStatus = "completed",
  223. description: str = "",
  224. data: Dict[str, Any] = None,
  225. parent_id: Optional[str] = None,
  226. summary: Optional[str] = None,
  227. duration_ms: Optional[int] = None,
  228. tokens: Optional[int] = None,
  229. cost: Optional[float] = None,
  230. ) -> "Step":
  231. """创建新的 Step"""
  232. return cls(
  233. step_id=str(uuid.uuid4()),
  234. trace_id=trace_id,
  235. step_type=step_type,
  236. status=status,
  237. sequence=sequence,
  238. parent_id=parent_id,
  239. description=description,
  240. data=data or {},
  241. summary=summary,
  242. duration_ms=duration_ms,
  243. tokens=tokens,
  244. cost=cost,
  245. )
  246. def to_dict(self, view: str = "full") -> Dict[str, Any]:
  247. """
  248. 转换为字典
  249. Args:
  250. view: "compact" - 不返回大字段
  251. "full" - 返回完整数据
  252. """
  253. result = {
  254. "step_id": self.step_id,
  255. "trace_id": self.trace_id,
  256. "step_type": self.step_type,
  257. "status": self.status,
  258. "sequence": self.sequence,
  259. "parent_id": self.parent_id,
  260. "description": self.description,
  261. "summary": self.summary,
  262. "has_children": self.has_children,
  263. "children_count": self.children_count,
  264. "duration_ms": self.duration_ms,
  265. "tokens": self.tokens,
  266. "cost": self.cost,
  267. "created_at": self.created_at.isoformat() if self.created_at else None,
  268. }
  269. # 处理 data 字段
  270. if view == "compact":
  271. data_copy = self.data.copy()
  272. for key in ["output", "content", "full_output", "full_content"]:
  273. data_copy.pop(key, None)
  274. result["data"] = data_copy
  275. else:
  276. result["data"] = self.data
  277. return result