models.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. """
  2. Goal 数据模型
  3. Goal: 执行计划中的目标节点
  4. GoalTree: 目标树,管理整个执行计划
  5. GoalStats: 目标统计信息
  6. BranchContext: 分支执行上下文
  7. """
  8. from dataclasses import dataclass, field
  9. from datetime import datetime
  10. from typing import Dict, Any, List, Optional, Literal
  11. import json
  12. # Goal 状态
  13. GoalStatus = Literal["pending", "in_progress", "completed", "abandoned"]
  14. # Goal 类型
  15. GoalType = Literal["normal", "explore_start", "explore_merge"]
  16. # Branch 状态
  17. BranchStatus = Literal["exploring", "completed", "abandoned"]
  18. @dataclass
  19. class GoalStats:
  20. """目标统计信息"""
  21. message_count: int = 0 # 消息数量
  22. total_tokens: int = 0 # Token 总数
  23. total_cost: float = 0.0 # 总成本
  24. preview: Optional[str] = None # 工具调用摘要,如 "read_file → edit_file → bash"
  25. def to_dict(self) -> Dict[str, Any]:
  26. return {
  27. "message_count": self.message_count,
  28. "total_tokens": self.total_tokens,
  29. "total_cost": self.total_cost,
  30. "preview": self.preview,
  31. }
  32. @classmethod
  33. def from_dict(cls, data: Dict[str, Any]) -> "GoalStats":
  34. return cls(
  35. message_count=data.get("message_count", 0),
  36. total_tokens=data.get("total_tokens", 0),
  37. total_cost=data.get("total_cost", 0.0),
  38. preview=data.get("preview"),
  39. )
  40. @dataclass
  41. class Goal:
  42. """
  43. 执行目标
  44. 使用扁平列表 + parent_id 构建层级结构。
  45. """
  46. id: str # 内部唯一 ID,纯自增("1", "2", "3"...)
  47. description: str # 目标描述
  48. reason: str = "" # 创建理由(为什么做)
  49. parent_id: Optional[str] = None # 父 Goal ID(层级关系)
  50. branch_id: Optional[str] = None # 所属分支 ID(分支关系,null=主线)
  51. type: GoalType = "normal" # Goal 类型
  52. status: GoalStatus = "pending" # 状态
  53. summary: Optional[str] = None # 完成/放弃时的总结
  54. # explore_start 特有
  55. branch_ids: Optional[List[str]] = None # 关联的分支 ID 列表
  56. # explore_merge 特有
  57. explore_start_id: Optional[str] = None # 关联的 explore_start Goal
  58. merge_summary: Optional[str] = None # 各分支汇总结果
  59. selected_branch: Optional[str] = None # 选中的分支(可选)
  60. # 统计(后端维护,用于可视化边的数据)
  61. self_stats: GoalStats = field(default_factory=GoalStats) # 自身统计(仅直接关联的 messages)
  62. cumulative_stats: GoalStats = field(default_factory=GoalStats) # 累计统计(自身 + 所有后代)
  63. created_at: datetime = field(default_factory=datetime.now)
  64. def to_dict(self) -> Dict[str, Any]:
  65. """转换为字典"""
  66. return {
  67. "id": self.id,
  68. "description": self.description,
  69. "reason": self.reason,
  70. "parent_id": self.parent_id,
  71. "branch_id": self.branch_id,
  72. "type": self.type,
  73. "status": self.status,
  74. "summary": self.summary,
  75. "branch_ids": self.branch_ids,
  76. "explore_start_id": self.explore_start_id,
  77. "merge_summary": self.merge_summary,
  78. "selected_branch": self.selected_branch,
  79. "self_stats": self.self_stats.to_dict(),
  80. "cumulative_stats": self.cumulative_stats.to_dict(),
  81. "created_at": self.created_at.isoformat() if self.created_at else None,
  82. }
  83. @classmethod
  84. def from_dict(cls, data: Dict[str, Any]) -> "Goal":
  85. """从字典创建"""
  86. created_at = data.get("created_at")
  87. if isinstance(created_at, str):
  88. created_at = datetime.fromisoformat(created_at)
  89. self_stats = data.get("self_stats", {})
  90. if isinstance(self_stats, dict):
  91. self_stats = GoalStats.from_dict(self_stats)
  92. cumulative_stats = data.get("cumulative_stats", {})
  93. if isinstance(cumulative_stats, dict):
  94. cumulative_stats = GoalStats.from_dict(cumulative_stats)
  95. return cls(
  96. id=data["id"],
  97. description=data["description"],
  98. reason=data.get("reason", ""),
  99. parent_id=data.get("parent_id"),
  100. branch_id=data.get("branch_id"),
  101. type=data.get("type", "normal"),
  102. status=data.get("status", "pending"),
  103. summary=data.get("summary"),
  104. branch_ids=data.get("branch_ids"),
  105. explore_start_id=data.get("explore_start_id"),
  106. merge_summary=data.get("merge_summary"),
  107. selected_branch=data.get("selected_branch"),
  108. self_stats=self_stats,
  109. cumulative_stats=cumulative_stats,
  110. created_at=created_at or datetime.now(),
  111. )
  112. @dataclass
  113. class BranchContext:
  114. """分支执行上下文(独立的 sub-agent 环境)"""
  115. id: str # 分支 ID,如 "A", "B"
  116. explore_start_id: str # 关联的 explore_start Goal ID
  117. description: str # 探索方向描述
  118. status: BranchStatus # exploring | completed | abandoned
  119. summary: Optional[str] = None # 完成时的总结
  120. cumulative_stats: GoalStats = field(default_factory=GoalStats) # 累计统计
  121. created_at: datetime = field(default_factory=datetime.now)
  122. def to_dict(self) -> Dict[str, Any]:
  123. return {
  124. "id": self.id,
  125. "explore_start_id": self.explore_start_id,
  126. "description": self.description,
  127. "status": self.status,
  128. "summary": self.summary,
  129. "cumulative_stats": self.cumulative_stats.to_dict(),
  130. "created_at": self.created_at.isoformat() if self.created_at else None,
  131. }
  132. @classmethod
  133. def from_dict(cls, data: Dict[str, Any]) -> "BranchContext":
  134. created_at = data.get("created_at")
  135. if isinstance(created_at, str):
  136. created_at = datetime.fromisoformat(created_at)
  137. cumulative_stats = data.get("cumulative_stats", {})
  138. if isinstance(cumulative_stats, dict):
  139. cumulative_stats = GoalStats.from_dict(cumulative_stats)
  140. return cls(
  141. id=data["id"],
  142. explore_start_id=data["explore_start_id"],
  143. description=data["description"],
  144. status=data.get("status", "exploring"),
  145. summary=data.get("summary"),
  146. cumulative_stats=cumulative_stats,
  147. created_at=created_at or datetime.now(),
  148. )
  149. @dataclass
  150. class GoalTree:
  151. """
  152. 目标树 - 管理整个执行计划
  153. 使用扁平列表 + parent_id 构建层级结构
  154. """
  155. mission: str # 总任务描述
  156. goals: List[Goal] = field(default_factory=list) # 扁平列表(通过 parent_id 构建层级)
  157. current_id: Optional[str] = None # 当前焦点 goal ID
  158. _next_id: int = 1 # 内部 ID 计数器(私有字段)
  159. created_at: datetime = field(default_factory=datetime.now)
  160. def find(self, goal_id: str) -> Optional[Goal]:
  161. """按 ID 查找 Goal"""
  162. for goal in self.goals:
  163. if goal.id == goal_id:
  164. return goal
  165. return None
  166. def find_parent(self, goal_id: str) -> Optional[Goal]:
  167. """查找指定 Goal 的父节点"""
  168. goal = self.find(goal_id)
  169. if not goal or not goal.parent_id:
  170. return None
  171. return self.find(goal.parent_id)
  172. def get_children(self, parent_id: Optional[str]) -> List[Goal]:
  173. """获取指定父节点的所有子节点"""
  174. return [g for g in self.goals if g.parent_id == parent_id]
  175. def get_current(self) -> Optional[Goal]:
  176. """获取当前焦点 Goal"""
  177. if self.current_id:
  178. return self.find(self.current_id)
  179. return None
  180. def _generate_id(self) -> str:
  181. """生成新的 Goal ID(纯自增)"""
  182. new_id = str(self._next_id)
  183. self._next_id += 1
  184. return new_id
  185. def _generate_display_id(self, goal: Goal) -> str:
  186. """生成显示序号(1, 2, 2.1, 2.2...)"""
  187. if not goal.parent_id:
  188. # 顶层目标:找到在同级中的序号
  189. siblings = [g for g in self.goals if g.parent_id is None and g.status != "abandoned"]
  190. try:
  191. index = [g.id for g in siblings].index(goal.id) + 1
  192. return str(index)
  193. except ValueError:
  194. return "?"
  195. else:
  196. # 子目标:父序号 + "." + 在同级中的序号
  197. parent = self.find(goal.parent_id)
  198. if not parent:
  199. return "?"
  200. parent_display = self._generate_display_id(parent)
  201. siblings = [g for g in self.goals if g.parent_id == goal.parent_id and g.status != "abandoned"]
  202. try:
  203. index = [g.id for g in siblings].index(goal.id) + 1
  204. return f"{parent_display}.{index}"
  205. except ValueError:
  206. return f"{parent_display}.?"
  207. def add_goals(
  208. self,
  209. descriptions: List[str],
  210. reasons: Optional[List[str]] = None,
  211. parent_id: Optional[str] = None
  212. ) -> List[Goal]:
  213. """
  214. 添加目标
  215. 如果 parent_id 为 None,添加到顶层
  216. 如果 parent_id 有值,添加为该 goal 的子目标
  217. """
  218. if parent_id:
  219. parent = self.find(parent_id)
  220. if not parent:
  221. raise ValueError(f"Parent goal not found: {parent_id}")
  222. # 创建新目标
  223. new_goals = []
  224. for i, desc in enumerate(descriptions):
  225. goal_id = self._generate_id()
  226. reason = reasons[i] if reasons and i < len(reasons) else ""
  227. goal = Goal(
  228. id=goal_id,
  229. description=desc.strip(),
  230. reason=reason,
  231. parent_id=parent_id
  232. )
  233. self.goals.append(goal)
  234. new_goals.append(goal)
  235. return new_goals
  236. def focus(self, goal_id: str) -> Goal:
  237. """切换焦点到指定 Goal,并将其状态设为 in_progress"""
  238. goal = self.find(goal_id)
  239. if not goal:
  240. raise ValueError(f"Goal not found: {goal_id}")
  241. # 更新状态
  242. if goal.status == "pending":
  243. goal.status = "in_progress"
  244. self.current_id = goal_id
  245. return goal
  246. def complete(self, goal_id: str, summary: str) -> Goal:
  247. """完成指定 Goal"""
  248. goal = self.find(goal_id)
  249. if not goal:
  250. raise ValueError(f"Goal not found: {goal_id}")
  251. goal.status = "completed"
  252. goal.summary = summary
  253. # 如果完成的是当前焦点,清除焦点
  254. if self.current_id == goal_id:
  255. self.current_id = None
  256. # 检查是否所有兄弟都完成了,如果是则自动完成父节点
  257. if goal.parent_id:
  258. siblings = self.get_children(goal.parent_id)
  259. all_completed = all(g.status == "completed" for g in siblings)
  260. if all_completed:
  261. parent = self.find(goal.parent_id)
  262. if parent and parent.status != "completed":
  263. # 自动级联完成父节点
  264. parent.status = "completed"
  265. if not parent.summary:
  266. parent.summary = "所有子目标已完成"
  267. return goal
  268. def abandon(self, goal_id: str, reason: str) -> Goal:
  269. """放弃指定 Goal"""
  270. goal = self.find(goal_id)
  271. if not goal:
  272. raise ValueError(f"Goal not found: {goal_id}")
  273. goal.status = "abandoned"
  274. goal.summary = reason
  275. # 如果放弃的是当前焦点,清除焦点
  276. if self.current_id == goal_id:
  277. self.current_id = None
  278. return goal
  279. def to_prompt(self, include_abandoned: bool = False) -> str:
  280. """
  281. 格式化为 Prompt 注入文本
  282. 过滤掉 abandoned 目标,重新生成连续显示序号
  283. """
  284. lines = []
  285. lines.append(f"**Mission**: {self.mission}")
  286. if self.current_id:
  287. current = self.find(self.current_id)
  288. if current:
  289. display_id = self._generate_display_id(current)
  290. lines.append(f"**Current**: {display_id} {current.description}")
  291. lines.append("")
  292. lines.append("**Progress**:")
  293. def format_goal(goal: Goal, indent: int = 0) -> List[str]:
  294. # 跳过废弃的目标(除非明确要求包含)
  295. if goal.status == "abandoned" and not include_abandoned:
  296. return []
  297. prefix = " " * indent
  298. # 状态图标
  299. if goal.status == "completed":
  300. icon = "[✓]"
  301. elif goal.status == "in_progress":
  302. icon = "[→]"
  303. elif goal.status == "abandoned":
  304. icon = "[✗]"
  305. else:
  306. icon = "[ ]"
  307. # 生成显示序号
  308. display_id = self._generate_display_id(goal)
  309. # 当前焦点标记
  310. current_mark = " ← current" if goal.id == self.current_id else ""
  311. result = [f"{prefix}{icon} {display_id}. {goal.description}{current_mark}"]
  312. # 显示 summary(如果有)
  313. if goal.summary:
  314. result.append(f"{prefix} → {goal.summary}")
  315. # 递归处理子目标
  316. children = self.get_children(goal.id)
  317. for child in children:
  318. result.extend(format_goal(child, indent + 1))
  319. return result
  320. # 处理所有顶层目标
  321. top_goals = self.get_children(None)
  322. for goal in top_goals:
  323. lines.extend(format_goal(goal))
  324. return "\n".join(lines)
  325. def to_dict(self) -> Dict[str, Any]:
  326. """转换为字典"""
  327. return {
  328. "mission": self.mission,
  329. "goals": [g.to_dict() for g in self.goals],
  330. "current_id": self.current_id,
  331. "_next_id": self._next_id,
  332. "created_at": self.created_at.isoformat() if self.created_at else None,
  333. }
  334. @classmethod
  335. def from_dict(cls, data: Dict[str, Any]) -> "GoalTree":
  336. """从字典创建"""
  337. goals = [Goal.from_dict(g) for g in data.get("goals", [])]
  338. created_at = data.get("created_at")
  339. if isinstance(created_at, str):
  340. created_at = datetime.fromisoformat(created_at)
  341. return cls(
  342. mission=data["mission"],
  343. goals=goals,
  344. current_id=data.get("current_id"),
  345. _next_id=data.get("_next_id", 1),
  346. created_at=created_at or datetime.now(),
  347. )
  348. def save(self, path: str) -> None:
  349. """保存到 JSON 文件"""
  350. with open(path, "w", encoding="utf-8") as f:
  351. json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
  352. @classmethod
  353. def load(cls, path: str) -> "GoalTree":
  354. """从 JSON 文件加载"""
  355. with open(path, "r", encoding="utf-8") as f:
  356. data = json.load(f)
  357. return cls.from_dict(data)