| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431 |
- """
- Goal 数据模型
- Goal: 执行计划中的目标节点
- GoalTree: 目标树,管理整个执行计划
- GoalStats: 目标统计信息
- BranchContext: 分支执行上下文
- """
- from dataclasses import dataclass, field
- from datetime import datetime
- from typing import Dict, Any, List, Optional, Literal
- import json
- # Goal 状态
- GoalStatus = Literal["pending", "in_progress", "completed", "abandoned"]
- # Goal 类型
- GoalType = Literal["normal", "explore_start", "explore_merge"]
- # Branch 状态
- BranchStatus = Literal["exploring", "completed", "abandoned"]
- @dataclass
- class GoalStats:
- """目标统计信息"""
- message_count: int = 0 # 消息数量
- total_tokens: int = 0 # Token 总数
- total_cost: float = 0.0 # 总成本
- preview: Optional[str] = None # 工具调用摘要,如 "read_file → edit_file → bash"
- def to_dict(self) -> Dict[str, Any]:
- return {
- "message_count": self.message_count,
- "total_tokens": self.total_tokens,
- "total_cost": self.total_cost,
- "preview": self.preview,
- }
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "GoalStats":
- return cls(
- message_count=data.get("message_count", 0),
- total_tokens=data.get("total_tokens", 0),
- total_cost=data.get("total_cost", 0.0),
- preview=data.get("preview"),
- )
- @dataclass
- class Goal:
- """
- 执行目标
- 使用扁平列表 + parent_id 构建层级结构。
- """
- id: str # 内部唯一 ID,纯自增("1", "2", "3"...)
- description: str # 目标描述
- reason: str = "" # 创建理由(为什么做)
- parent_id: Optional[str] = None # 父 Goal ID(层级关系)
- branch_id: Optional[str] = None # 所属分支 ID(分支关系,null=主线)
- type: GoalType = "normal" # Goal 类型
- status: GoalStatus = "pending" # 状态
- summary: Optional[str] = None # 完成/放弃时的总结
- # explore_start 特有
- branch_ids: Optional[List[str]] = None # 关联的分支 ID 列表
- # explore_merge 特有
- explore_start_id: Optional[str] = None # 关联的 explore_start Goal
- merge_summary: Optional[str] = None # 各分支汇总结果
- selected_branch: Optional[str] = None # 选中的分支(可选)
- # 统计(后端维护,用于可视化边的数据)
- self_stats: GoalStats = field(default_factory=GoalStats) # 自身统计(仅直接关联的 messages)
- cumulative_stats: GoalStats = field(default_factory=GoalStats) # 累计统计(自身 + 所有后代)
- created_at: datetime = field(default_factory=datetime.now)
- def to_dict(self) -> Dict[str, Any]:
- """转换为字典"""
- return {
- "id": self.id,
- "description": self.description,
- "reason": self.reason,
- "parent_id": self.parent_id,
- "branch_id": self.branch_id,
- "type": self.type,
- "status": self.status,
- "summary": self.summary,
- "branch_ids": self.branch_ids,
- "explore_start_id": self.explore_start_id,
- "merge_summary": self.merge_summary,
- "selected_branch": self.selected_branch,
- "self_stats": self.self_stats.to_dict(),
- "cumulative_stats": self.cumulative_stats.to_dict(),
- "created_at": self.created_at.isoformat() if self.created_at else None,
- }
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "Goal":
- """从字典创建"""
- created_at = data.get("created_at")
- if isinstance(created_at, str):
- created_at = datetime.fromisoformat(created_at)
- self_stats = data.get("self_stats", {})
- if isinstance(self_stats, dict):
- self_stats = GoalStats.from_dict(self_stats)
- cumulative_stats = data.get("cumulative_stats", {})
- if isinstance(cumulative_stats, dict):
- cumulative_stats = GoalStats.from_dict(cumulative_stats)
- return cls(
- id=data["id"],
- description=data["description"],
- reason=data.get("reason", ""),
- parent_id=data.get("parent_id"),
- branch_id=data.get("branch_id"),
- type=data.get("type", "normal"),
- status=data.get("status", "pending"),
- summary=data.get("summary"),
- branch_ids=data.get("branch_ids"),
- explore_start_id=data.get("explore_start_id"),
- merge_summary=data.get("merge_summary"),
- selected_branch=data.get("selected_branch"),
- self_stats=self_stats,
- cumulative_stats=cumulative_stats,
- created_at=created_at or datetime.now(),
- )
- @dataclass
- class BranchContext:
- """分支执行上下文(独立的 sub-agent 环境)"""
- id: str # 分支 ID,如 "A", "B"
- explore_start_id: str # 关联的 explore_start Goal ID
- description: str # 探索方向描述
- status: BranchStatus # exploring | completed | abandoned
- summary: Optional[str] = None # 完成时的总结
- cumulative_stats: GoalStats = field(default_factory=GoalStats) # 累计统计
- created_at: datetime = field(default_factory=datetime.now)
- def to_dict(self) -> Dict[str, Any]:
- return {
- "id": self.id,
- "explore_start_id": self.explore_start_id,
- "description": self.description,
- "status": self.status,
- "summary": self.summary,
- "cumulative_stats": self.cumulative_stats.to_dict(),
- "created_at": self.created_at.isoformat() if self.created_at else None,
- }
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "BranchContext":
- created_at = data.get("created_at")
- if isinstance(created_at, str):
- created_at = datetime.fromisoformat(created_at)
- cumulative_stats = data.get("cumulative_stats", {})
- if isinstance(cumulative_stats, dict):
- cumulative_stats = GoalStats.from_dict(cumulative_stats)
- return cls(
- id=data["id"],
- explore_start_id=data["explore_start_id"],
- description=data["description"],
- status=data.get("status", "exploring"),
- summary=data.get("summary"),
- cumulative_stats=cumulative_stats,
- created_at=created_at or datetime.now(),
- )
- @dataclass
- class GoalTree:
- """
- 目标树 - 管理整个执行计划
- 使用扁平列表 + parent_id 构建层级结构
- """
- mission: str # 总任务描述
- goals: List[Goal] = field(default_factory=list) # 扁平列表(通过 parent_id 构建层级)
- current_id: Optional[str] = None # 当前焦点 goal ID
- _next_id: int = 1 # 内部 ID 计数器(私有字段)
- created_at: datetime = field(default_factory=datetime.now)
- def find(self, goal_id: str) -> Optional[Goal]:
- """按 ID 查找 Goal"""
- for goal in self.goals:
- if goal.id == goal_id:
- return goal
- return None
- def find_parent(self, goal_id: str) -> Optional[Goal]:
- """查找指定 Goal 的父节点"""
- goal = self.find(goal_id)
- if not goal or not goal.parent_id:
- return None
- return self.find(goal.parent_id)
- def get_children(self, parent_id: Optional[str]) -> List[Goal]:
- """获取指定父节点的所有子节点"""
- return [g for g in self.goals if g.parent_id == parent_id]
- def get_current(self) -> Optional[Goal]:
- """获取当前焦点 Goal"""
- if self.current_id:
- return self.find(self.current_id)
- return None
- def _generate_id(self) -> str:
- """生成新的 Goal ID(纯自增)"""
- new_id = str(self._next_id)
- self._next_id += 1
- return new_id
- def _generate_display_id(self, goal: Goal) -> str:
- """生成显示序号(1, 2, 2.1, 2.2...)"""
- if not goal.parent_id:
- # 顶层目标:找到在同级中的序号
- siblings = [g for g in self.goals if g.parent_id is None and g.status != "abandoned"]
- try:
- index = [g.id for g in siblings].index(goal.id) + 1
- return str(index)
- except ValueError:
- return "?"
- else:
- # 子目标:父序号 + "." + 在同级中的序号
- parent = self.find(goal.parent_id)
- if not parent:
- return "?"
- parent_display = self._generate_display_id(parent)
- siblings = [g for g in self.goals if g.parent_id == goal.parent_id and g.status != "abandoned"]
- try:
- index = [g.id for g in siblings].index(goal.id) + 1
- return f"{parent_display}.{index}"
- except ValueError:
- return f"{parent_display}.?"
- def add_goals(
- self,
- descriptions: List[str],
- reasons: Optional[List[str]] = None,
- parent_id: Optional[str] = None
- ) -> List[Goal]:
- """
- 添加目标
- 如果 parent_id 为 None,添加到顶层
- 如果 parent_id 有值,添加为该 goal 的子目标
- """
- if parent_id:
- parent = self.find(parent_id)
- if not parent:
- raise ValueError(f"Parent goal not found: {parent_id}")
- # 创建新目标
- new_goals = []
- for i, desc in enumerate(descriptions):
- goal_id = self._generate_id()
- reason = reasons[i] if reasons and i < len(reasons) else ""
- goal = Goal(
- id=goal_id,
- description=desc.strip(),
- reason=reason,
- parent_id=parent_id
- )
- self.goals.append(goal)
- new_goals.append(goal)
- return new_goals
- def focus(self, goal_id: str) -> Goal:
- """切换焦点到指定 Goal,并将其状态设为 in_progress"""
- goal = self.find(goal_id)
- if not goal:
- raise ValueError(f"Goal not found: {goal_id}")
- # 更新状态
- if goal.status == "pending":
- goal.status = "in_progress"
- self.current_id = goal_id
- return goal
- def complete(self, goal_id: str, summary: str) -> Goal:
- """完成指定 Goal"""
- goal = self.find(goal_id)
- if not goal:
- raise ValueError(f"Goal not found: {goal_id}")
- goal.status = "completed"
- goal.summary = summary
- # 如果完成的是当前焦点,清除焦点
- if self.current_id == goal_id:
- self.current_id = None
- # 检查是否所有兄弟都完成了,如果是则自动完成父节点
- if goal.parent_id:
- siblings = self.get_children(goal.parent_id)
- all_completed = all(g.status == "completed" for g in siblings)
- if all_completed:
- parent = self.find(goal.parent_id)
- if parent and parent.status != "completed":
- # 自动级联完成父节点
- parent.status = "completed"
- if not parent.summary:
- parent.summary = "所有子目标已完成"
- return goal
- def abandon(self, goal_id: str, reason: str) -> Goal:
- """放弃指定 Goal"""
- goal = self.find(goal_id)
- if not goal:
- raise ValueError(f"Goal not found: {goal_id}")
- goal.status = "abandoned"
- goal.summary = reason
- # 如果放弃的是当前焦点,清除焦点
- if self.current_id == goal_id:
- self.current_id = None
- return goal
- def to_prompt(self, include_abandoned: bool = False) -> str:
- """
- 格式化为 Prompt 注入文本
- 过滤掉 abandoned 目标,重新生成连续显示序号
- """
- lines = []
- lines.append(f"**Mission**: {self.mission}")
- if self.current_id:
- current = self.find(self.current_id)
- if current:
- display_id = self._generate_display_id(current)
- lines.append(f"**Current**: {display_id} {current.description}")
- lines.append("")
- lines.append("**Progress**:")
- def format_goal(goal: Goal, indent: int = 0) -> List[str]:
- # 跳过废弃的目标(除非明确要求包含)
- if goal.status == "abandoned" and not include_abandoned:
- return []
- prefix = " " * indent
- # 状态图标
- if goal.status == "completed":
- icon = "[✓]"
- elif goal.status == "in_progress":
- icon = "[→]"
- elif goal.status == "abandoned":
- icon = "[✗]"
- else:
- icon = "[ ]"
- # 生成显示序号
- display_id = self._generate_display_id(goal)
- # 当前焦点标记
- current_mark = " ← current" if goal.id == self.current_id else ""
- result = [f"{prefix}{icon} {display_id}. {goal.description}{current_mark}"]
- # 显示 summary(如果有)
- if goal.summary:
- result.append(f"{prefix} → {goal.summary}")
- # 递归处理子目标
- children = self.get_children(goal.id)
- for child in children:
- result.extend(format_goal(child, indent + 1))
- return result
- # 处理所有顶层目标
- top_goals = self.get_children(None)
- for goal in top_goals:
- lines.extend(format_goal(goal))
- return "\n".join(lines)
- def to_dict(self) -> Dict[str, Any]:
- """转换为字典"""
- return {
- "mission": self.mission,
- "goals": [g.to_dict() for g in self.goals],
- "current_id": self.current_id,
- "_next_id": self._next_id,
- "created_at": self.created_at.isoformat() if self.created_at else None,
- }
- @classmethod
- def from_dict(cls, data: Dict[str, Any]) -> "GoalTree":
- """从字典创建"""
- goals = [Goal.from_dict(g) for g in data.get("goals", [])]
- created_at = data.get("created_at")
- if isinstance(created_at, str):
- created_at = datetime.fromisoformat(created_at)
- return cls(
- mission=data["mission"],
- goals=goals,
- current_id=data.get("current_id"),
- _next_id=data.get("_next_id", 1),
- created_at=created_at or datetime.now(),
- )
- def save(self, path: str) -> None:
- """保存到 JSON 文件"""
- with open(path, "w", encoding="utf-8") as f:
- json.dump(self.to_dict(), f, ensure_ascii=False, indent=2)
- @classmethod
- def load(cls, path: str) -> "GoalTree":
- """从 JSON 文件加载"""
- with open(path, "r", encoding="utf-8") as f:
- data = json.load(f)
- return cls.from_dict(data)
|