models.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. """
  2. Goal 数据模型
  3. Goal: 执行计划中的目标节点
  4. GoalTree: 目标树,管理整个执行计划
  5. GoalStats: 目标统计信息
  6. """
  7. from dataclasses import dataclass, field
  8. from datetime import datetime
  9. from typing import Dict, Any, List, Optional, Literal
  10. import json
  11. # Goal 状态
  12. GoalStatus = Literal["pending", "in_progress", "completed", "abandoned"]
  13. # Goal 类型
  14. GoalType = Literal["normal", "agent_call"]
  15. @dataclass
  16. class GoalStats:
  17. """目标统计信息"""
  18. message_count: int = 0 # 消息数量
  19. total_tokens: int = 0 # Token 总数
  20. total_cost: float = 0.0 # 总成本
  21. preview: Optional[str] = None # 工具调用摘要,如 "read_file → edit_file → bash"
  22. def to_dict(self) -> Dict[str, Any]:
  23. return {
  24. "message_count": self.message_count,
  25. "total_tokens": self.total_tokens,
  26. "total_cost": self.total_cost,
  27. "preview": self.preview,
  28. }
  29. @classmethod
  30. def from_dict(cls, data: Dict[str, Any]) -> "GoalStats":
  31. return cls(
  32. message_count=data.get("message_count", 0),
  33. total_tokens=data.get("total_tokens", 0),
  34. total_cost=data.get("total_cost", 0.0),
  35. preview=data.get("preview"),
  36. )
  37. @dataclass
  38. class Goal:
  39. """
  40. 执行目标
  41. 使用扁平列表 + parent_id 构建层级结构。
  42. agent_call 类型用于标记启动了 Sub-Trace 的 Goal。
  43. """
  44. id: str # 内部唯一 ID,纯自增("1", "2", "3"...)
  45. description: str # 目标描述
  46. reason: str = "" # 创建理由(为什么做)
  47. parent_id: Optional[str] = None # 父 Goal ID(层级关系)
  48. type: GoalType = "normal" # Goal 类型
  49. status: GoalStatus = "pending" # 状态
  50. summary: Optional[str] = None # 完成/放弃时的总结
  51. # agent_call 特有
  52. sub_trace_ids: Optional[List[str]] = None # 启动的 Sub-Trace IDs
  53. agent_call_mode: Optional[str] = None # "explore" | "delegate" | "sequential"
  54. # 统计(后端维护,用于可视化边的数据)
  55. self_stats: GoalStats = field(default_factory=GoalStats) # 自身统计(仅直接关联的 messages)
  56. cumulative_stats: GoalStats = field(default_factory=GoalStats) # 累计统计(自身 + 所有后代)
  57. created_at: datetime = field(default_factory=datetime.now)
  58. def to_dict(self) -> Dict[str, Any]:
  59. """转换为字典"""
  60. return {
  61. "id": self.id,
  62. "description": self.description,
  63. "reason": self.reason,
  64. "parent_id": self.parent_id,
  65. "type": self.type,
  66. "status": self.status,
  67. "summary": self.summary,
  68. "sub_trace_ids": self.sub_trace_ids,
  69. "agent_call_mode": self.agent_call_mode,
  70. "self_stats": self.self_stats.to_dict(),
  71. "cumulative_stats": self.cumulative_stats.to_dict(),
  72. "created_at": self.created_at.isoformat() if self.created_at else None,
  73. }
  74. @classmethod
  75. def from_dict(cls, data: Dict[str, Any]) -> "Goal":
  76. """从字典创建"""
  77. created_at = data.get("created_at")
  78. if isinstance(created_at, str):
  79. created_at = datetime.fromisoformat(created_at)
  80. self_stats = data.get("self_stats", {})
  81. if isinstance(self_stats, dict):
  82. self_stats = GoalStats.from_dict(self_stats)
  83. cumulative_stats = data.get("cumulative_stats", {})
  84. if isinstance(cumulative_stats, dict):
  85. cumulative_stats = GoalStats.from_dict(cumulative_stats)
  86. return cls(
  87. id=data["id"],
  88. description=data["description"],
  89. reason=data.get("reason", ""),
  90. parent_id=data.get("parent_id"),
  91. type=data.get("type", "normal"),
  92. status=data.get("status", "pending"),
  93. summary=data.get("summary"),
  94. sub_trace_ids=data.get("sub_trace_ids"),
  95. agent_call_mode=data.get("agent_call_mode"),
  96. self_stats=self_stats,
  97. cumulative_stats=cumulative_stats,
  98. created_at=created_at or datetime.now(),
  99. )
  100. @dataclass
  101. class GoalTree:
  102. """
  103. 目标树 - 管理整个执行计划
  104. 使用扁平列表 + parent_id 构建层级结构
  105. """
  106. mission: str # 总任务描述
  107. goals: List[Goal] = field(default_factory=list) # 扁平列表(通过 parent_id 构建层级)
  108. current_id: Optional[str] = None # 当前焦点 goal ID
  109. _next_id: int = 1 # 内部 ID 计数器(私有字段)
  110. created_at: datetime = field(default_factory=datetime.now)
  111. def find(self, goal_id: str) -> Optional[Goal]:
  112. """按 ID 查找 Goal"""
  113. for goal in self.goals:
  114. if goal.id == goal_id:
  115. return goal
  116. return None
  117. def find_parent(self, goal_id: str) -> Optional[Goal]:
  118. """查找指定 Goal 的父节点"""
  119. goal = self.find(goal_id)
  120. if not goal or not goal.parent_id:
  121. return None
  122. return self.find(goal.parent_id)
  123. def get_children(self, parent_id: Optional[str]) -> List[Goal]:
  124. """获取指定父节点的所有子节点"""
  125. return [g for g in self.goals if g.parent_id == parent_id]
  126. def get_current(self) -> Optional[Goal]:
  127. """获取当前焦点 Goal"""
  128. if self.current_id:
  129. return self.find(self.current_id)
  130. return None
  131. def _generate_id(self) -> str:
  132. """生成新的 Goal ID(纯自增)"""
  133. new_id = str(self._next_id)
  134. self._next_id += 1
  135. return new_id
  136. def _generate_display_id(self, goal: Goal) -> str:
  137. """生成显示序号(1, 2, 2.1, 2.2...)"""
  138. if not goal.parent_id:
  139. # 顶层目标:找到在同级中的序号
  140. siblings = [g for g in self.goals if g.parent_id is None and g.status != "abandoned"]
  141. try:
  142. index = [g.id for g in siblings].index(goal.id) + 1
  143. return str(index)
  144. except ValueError:
  145. return "?"
  146. else:
  147. # 子目标:父序号 + "." + 在同级中的序号
  148. parent = self.find(goal.parent_id)
  149. if not parent:
  150. return "?"
  151. parent_display = self._generate_display_id(parent)
  152. siblings = [g for g in self.goals if g.parent_id == goal.parent_id and g.status != "abandoned"]
  153. try:
  154. index = [g.id for g in siblings].index(goal.id) + 1
  155. return f"{parent_display}.{index}"
  156. except ValueError:
  157. return f"{parent_display}.?"
  158. def add_goals(
  159. self,
  160. descriptions: List[str],
  161. reasons: Optional[List[str]] = None,
  162. parent_id: Optional[str] = None
  163. ) -> List[Goal]:
  164. """
  165. 添加目标
  166. 如果 parent_id 为 None,添加到顶层
  167. 如果 parent_id 有值,添加为该 goal 的子目标
  168. """
  169. if parent_id:
  170. parent = self.find(parent_id)
  171. if not parent:
  172. raise ValueError(f"Parent goal not found: {parent_id}")
  173. # 创建新目标
  174. new_goals = []
  175. for i, desc in enumerate(descriptions):
  176. goal_id = self._generate_id()
  177. reason = reasons[i] if reasons and i < len(reasons) else ""
  178. goal = Goal(
  179. id=goal_id,
  180. description=desc.strip(),
  181. reason=reason,
  182. parent_id=parent_id
  183. )
  184. self.goals.append(goal)
  185. new_goals.append(goal)
  186. return new_goals
  187. def focus(self, goal_id: str) -> Goal:
  188. """切换焦点到指定 Goal,并将其状态设为 in_progress"""
  189. goal = self.find(goal_id)
  190. if not goal:
  191. raise ValueError(f"Goal not found: {goal_id}")
  192. # 更新状态
  193. if goal.status == "pending":
  194. goal.status = "in_progress"
  195. self.current_id = goal_id
  196. return goal
  197. def complete(self, goal_id: str, summary: str) -> Goal:
  198. """完成指定 Goal"""
  199. goal = self.find(goal_id)
  200. if not goal:
  201. raise ValueError(f"Goal not found: {goal_id}")
  202. goal.status = "completed"
  203. goal.summary = summary
  204. # 如果完成的是当前焦点,清除焦点
  205. if self.current_id == goal_id:
  206. self.current_id = None
  207. # 检查是否所有兄弟都完成了,如果是则自动完成父节点
  208. if goal.parent_id:
  209. siblings = self.get_children(goal.parent_id)
  210. all_completed = all(g.status == "completed" for g in siblings)
  211. if all_completed:
  212. parent = self.find(goal.parent_id)
  213. if parent and parent.status != "completed":
  214. # 自动级联完成父节点
  215. parent.status = "completed"
  216. if not parent.summary:
  217. parent.summary = "所有子目标已完成"
  218. return goal
  219. def abandon(self, goal_id: str, reason: str) -> Goal:
  220. """放弃指定 Goal"""
  221. goal = self.find(goal_id)
  222. if not goal:
  223. raise ValueError(f"Goal not found: {goal_id}")
  224. goal.status = "abandoned"
  225. goal.summary = reason
  226. # 如果放弃的是当前焦点,清除焦点
  227. if self.current_id == goal_id:
  228. self.current_id = None
  229. return goal
  230. def to_prompt(self, include_abandoned: bool = False) -> str:
  231. """
  232. 格式化为 Prompt 注入文本
  233. 过滤掉 abandoned 目标,重新生成连续显示序号
  234. """
  235. lines = []
  236. lines.append(f"**Mission**: {self.mission}")
  237. if self.current_id:
  238. current = self.find(self.current_id)
  239. if current:
  240. display_id = self._generate_display_id(current)
  241. lines.append(f"**Current**: {display_id} {current.description}")
  242. lines.append("")
  243. lines.append("**Progress**:")
  244. def format_goal(goal: Goal, indent: int = 0) -> List[str]:
  245. # 跳过废弃的目标(除非明确要求包含)
  246. if goal.status == "abandoned" and not include_abandoned:
  247. return []
  248. prefix = " " * indent
  249. # 状态图标
  250. if goal.status == "completed":
  251. icon = "[✓]"
  252. elif goal.status == "in_progress":
  253. icon = "[→]"
  254. elif goal.status == "abandoned":
  255. icon = "[✗]"
  256. else:
  257. icon = "[ ]"
  258. # 生成显示序号
  259. display_id = self._generate_display_id(goal)
  260. # 当前焦点标记
  261. current_mark = " ← current" if goal.id == self.current_id else ""
  262. result = [f"{prefix}{icon} {display_id}. {goal.description}{current_mark}"]
  263. # 显示 summary(如果有)
  264. if goal.summary:
  265. result.append(f"{prefix} → {goal.summary}")
  266. # 递归处理子目标
  267. children = self.get_children(goal.id)
  268. for child in children:
  269. result.extend(format_goal(child, indent + 1))
  270. return result
  271. # 处理所有顶层目标
  272. top_goals = self.get_children(None)
  273. for goal in top_goals:
  274. lines.extend(format_goal(goal))
  275. return "\n".join(lines)
  276. def to_dict(self) -> Dict[str, Any]:
  277. """转换为字典"""
  278. return {
  279. "mission": self.mission,
  280. "goals": [g.to_dict() for g in self.goals],
  281. "current_id": self.current_id,
  282. "_next_id": self._next_id,
  283. "created_at": self.created_at.isoformat() if self.created_at else None,
  284. }
  285. @classmethod
  286. def from_dict(cls, data: Dict[str, Any]) -> "GoalTree":
  287. """从字典创建"""
  288. goals = [Goal.from_dict(g) for g in data.get("goals", [])]
  289. created_at = data.get("created_at")
  290. if isinstance(created_at, str):
  291. created_at = datetime.fromisoformat(created_at)
  292. return cls(
  293. mission=data["mission"],
  294. goals=goals,
  295. current_id=data.get("current_id"),
  296. _next_id=data.get("_next_id", 1),
  297. created_at=created_at or datetime.now(),
  298. )
  299. def save(self, path: str) -> None:
  300. """保存到 JSON 文件"""
  301. with open(path, "w", encoding="utf-8") as f:
  302. json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
  303. @classmethod
  304. def load(cls, path: str) -> "GoalTree":
  305. """从 JSON 文件加载"""
  306. with open(path, "r", encoding="utf-8") as f:
  307. data = json.load(f)
  308. return cls.from_dict(data)